diff --git a/build/moz.configure/old.configure b/build/moz.configure/old.configure index 2124785d9..6ce8be9e8 100644 --- a/build/moz.configure/old.configure +++ b/build/moz.configure/old.configure @@ -159,6 +159,7 @@ def old_configure_options(*options): '--enable-alsa', '--enable-android-omx', '--enable-av1', + '--enable-jxl', '--enable-b2g-bt', '--enable-b2g-camera', '--enable-b2g-ril', diff --git a/config/external/moz.build b/config/external/moz.build index fbf6da089..0a9443b89 100644 --- a/config/external/moz.build +++ b/config/external/moz.build @@ -53,6 +53,9 @@ if CONFIG['MOZ_WEBSPEECH_POCKETSPHINX']: if CONFIG['MOZ_FFVPX']: external_dirs += ['media/ffvpx'] +if CONFIG["MOZ_JXL"]: + external_dirs += ["media/libjxl", "media/highway"] + external_dirs += [ 'media/kiss_fft', 'media/libcubeb', diff --git a/layout/media/symbols.def.in b/layout/media/symbols.def.in index 4a66c20f0..4a8bc2e02 100644 --- a/layout/media/symbols.def.in +++ b/layout/media/symbols.def.in @@ -80,6 +80,24 @@ aom_codec_peek_stream_info aom_img_alloc aom_img_free #endif +#ifdef MOZ_JXL +JxlDecoderCreate +JxlDecoderDestroy +JxlDecoderSetParallelRunner +JxlDecoderSubscribeEvents +JxlDecoderProcessInput +JxlDecoderSetInput +JxlDecoderReleaseInput +JxlDecoderGetBasicInfo +JxlDecoderImageOutBufferSize +JxlDecoderSetImageOutBuffer +JxlDecoderGetFrameHeader +JxlDecoderFlushImage +JxlThreadParallelRunner +JxlThreadParallelRunnerCreate +JxlThreadParallelRunnerDestroy +JxlThreadParallelRunnerDefaultNumWorkerThreads +#endif #ifdef MOZ_VORBIS ogg_page_bos ogg_page_granulepos diff --git a/media/highway/README_MCP b/media/highway/README_MCP new file mode 100644 index 000000000..c656a0ec1 --- /dev/null +++ b/media/highway/README_MCP @@ -0,0 +1,12 @@ +This directory contains build files for the Highway C++ +SIMD library. + +Any patches or additional configuration to be applied to the +upstream source should be kept here in the media/highway +directory. + +The upstream highway git repository is: + + https://github.com/google/highway + +The version used was tagged 1.0.2. \ No newline at end of file diff --git a/media/highway/moz.build b/media/highway/moz.build new file mode 100644 index 000000000..da8876c5b --- /dev/null +++ b/media/highway/moz.build @@ -0,0 +1,48 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +LOCAL_INCLUDES += [ + "/media/highway/src/", +] + +SOURCES += [ + "/media/highway/src/hwy/aligned_allocator.cc", + "/media/highway/src/hwy/contrib/image/image.cc", + "/media/highway/src/hwy/per_target.cc", + "/media/highway/src/hwy/targets.cc", +] + +EXPORTS.hwy += [ + "/media/highway/src/hwy/aligned_allocator.h", + "/media/highway/src/hwy/base.h", + "/media/highway/src/hwy/cache_control.h", + "/media/highway/src/hwy/detect_compiler_arch.h", + "/media/highway/src/hwy/detect_targets.h", + "/media/highway/src/hwy/foreach_target.h", + "/media/highway/src/hwy/highway.h", + "/media/highway/src/hwy/highway_export.h", + "/media/highway/src/hwy/targets.h", +] + +EXPORTS.hwy.ops += [ + "/media/highway/src/hwy/ops/arm_neon-inl.h", + "/media/highway/src/hwy/ops/arm_sve-inl.h", + "/media/highway/src/hwy/ops/emu128-inl.h", + "/media/highway/src/hwy/ops/generic_ops-inl.h", + "/media/highway/src/hwy/ops/rvv-inl.h", + "/media/highway/src/hwy/ops/scalar-inl.h", + "/media/highway/src/hwy/ops/set_macros-inl.h", + "/media/highway/src/hwy/ops/shared-inl.h", + "/media/highway/src/hwy/ops/wasm_128-inl.h", + "/media/highway/src/hwy/ops/x86_128-inl.h", + "/media/highway/src/hwy/ops/x86_256-inl.h", + "/media/highway/src/hwy/ops/x86_512-inl.h", +] + +FINAL_LIBRARY = "gkmedias" + +# We allow warnings for third-party code that can be updated from upstream. +ALLOW_COMPILER_WARNINGS = True diff --git a/media/highway/src/BUILD b/media/highway/src/BUILD new file mode 100644 index 000000000..1928c3275 --- /dev/null +++ b/media/highway/src/BUILD @@ -0,0 +1,413 @@ +load("@bazel_skylib//lib:selects.bzl", "selects") + +load("@rules_cc//cc:defs.bzl", "cc_test") +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +# Detect compiler: +config_setting( + name = "compiler_clang", + flag_values = {"@bazel_tools//tools/cpp:compiler": "clang"}, +) + +config_setting( + name = "compiler_clangcl", + flag_values = {"@bazel_tools//tools/cpp:compiler": "lexan"}, +) + +config_setting( + name = "compiler_msvc_actual", + flag_values = {"@bazel_tools//tools/cpp:compiler": "msvc"}, +) + +# The above is insufficient for Bazel on Windows, which does not seem to +# detect/set a compiler flag. This workaround prevents compile errors due to +# passing clang-only warning flags to MSVC. +config_setting( + name = "compiler_msvc_cpu", + values = { + "cpu": "x64_windows", + }, +) + +selects.config_setting_group( + name = "compiler_msvc", + match_any = [ + ":compiler_msvc_actual", + ":compiler_msvc_cpu", + ], +) + +config_setting( + name = "compiler_emscripten", + values = {"cpu": "wasm32"}, +) + +# See https://github.com/bazelbuild/bazel/issues/12707 +config_setting( + name = "compiler_gcc_bug", + flag_values = { + "@bazel_tools//tools/cpp:compiler": "compiler", + }, +) + +config_setting( + name = "compiler_gcc_actual", + flag_values = { + "@bazel_tools//tools/cpp:compiler": "gcc", + }, +) + +selects.config_setting_group( + name = "compiler_gcc", + match_any = [ + ":compiler_gcc_bug", + ":compiler_gcc_actual", + ], +) + +# Additional warnings for Clang OR GCC (skip for MSVC) +CLANG_GCC_COPTS = [ + "-Wunused-parameter", + "-Wunused-variable", + "-Wextra-semi", + "-Wunreachable-code", +] + +# Warnings supported by Clang and Clang-cl +CLANG_OR_CLANGCL_OPTS = CLANG_GCC_COPTS + [ + "-Wfloat-overflow-conversion", + "-Wfloat-zero-conversion", + "-Wfor-loop-analysis", + "-Wgnu-redeclared-enum", + "-Winfinite-recursion", + "-Wliteral-conversion", + "-Wno-c++98-compat", + "-Wno-unused-command-line-argument", + "-Wprivate-header", + "-Wself-assign", + "-Wstring-conversion", + "-Wtautological-overlap-compare", + "-Wthread-safety-analysis", + "-Wundefined-func-template", + "-Wunused-comparison", +] + +# Warnings only supported by Clang, but not Clang-cl +CLANG_ONLY_COPTS = CLANG_OR_CLANGCL_OPTS + [ + # Do not treat the third_party headers as system headers when building + # highway - the errors are pertinent. + "--no-system-header-prefix=third_party/highway", +] + +COPTS = select({ + ":compiler_msvc": [], + ":compiler_gcc": CLANG_GCC_COPTS, + ":compiler_clangcl": CLANG_OR_CLANGCL_OPTS, + # Default to clang because compiler detection only works in Bazel + "//conditions:default": CLANG_ONLY_COPTS, +}) + select({ + "@platforms//cpu:riscv64": [ + "-march=rv64gcv1p0", + "-menable-experimental-extensions", + ], + "//conditions:default": [ + ], +}) + +DEFINES = select({ + ":compiler_msvc": ["HWY_SHARED_DEFINE"], + ":compiler_clangcl": ["HWY_SHARED_DEFINE"], + "//conditions:default": [], +}) + +# Unused on Bazel builds, where this is not defined/known; Copybara replaces +# usages with an empty list. +COMPAT = [ + "//buildenv/target:non_prod", # includes mobile/vendor. +] + +# WARNING: changing flags such as HWY_DISABLED_TARGETS may break users without +# failing integration tests, if the machine running tests does not support the +# newly enabled instruction set, or the failure is only caught by sanitizers +# which do not run in CI. + +cc_library( + name = "hwy", + srcs = [ + "hwy/aligned_allocator.cc", + "hwy/per_target.cc", + "hwy/print.cc", + "hwy/targets.cc", + ], + # Normal headers with include guards + hdrs = [ + "hwy/aligned_allocator.h", + "hwy/base.h", + "hwy/cache_control.h", + "hwy/detect_compiler_arch.h", # private + "hwy/print.h", + ], + compatible_with = [], + copts = COPTS, + defines = DEFINES, + local_defines = ["hwy_EXPORTS"], + textual_hdrs = [ + # These are textual because config macros influence them: + "hwy/detect_targets.h", # private + "hwy/targets.h", + # This .cc file #includes itself through foreach_target.h + "hwy/per_target.cc", + # End of list + "hwy/highway.h", # public + "hwy/foreach_target.h", # public + "hwy/per_target.h", # public + "hwy/print-inl.h", # public + "hwy/highway_export.h", # public + "hwy/ops/arm_neon-inl.h", + "hwy/ops/arm_sve-inl.h", + "hwy/ops/emu128-inl.h", + "hwy/ops/generic_ops-inl.h", + "hwy/ops/scalar-inl.h", + "hwy/ops/set_macros-inl.h", + "hwy/ops/shared-inl.h", + "hwy/ops/x86_128-inl.h", + "hwy/ops/x86_256-inl.h", + "hwy/ops/x86_512-inl.h", + # Select avoids recompiling native arch if only non-native changed + ] + select({ + ":compiler_emscripten": ["hwy/ops/wasm_128-inl.h"], + "//conditions:default": [], + }) + select({ + "@platforms//cpu:riscv64": ["hwy/ops/rvv-inl.h"], + "//conditions:default": [], + }), +) + +cc_library( + name = "algo", + compatible_with = [], + copts = COPTS, + textual_hdrs = [ + "hwy/contrib/algo/copy-inl.h", + "hwy/contrib/algo/find-inl.h", + "hwy/contrib/algo/transform-inl.h", + ], + deps = [ + ":hwy", + ], +) + +cc_library( + name = "dot", + compatible_with = [], + copts = COPTS, + textual_hdrs = [ + "hwy/contrib/dot/dot-inl.h", + ], + deps = [ + ":hwy", + ], +) + +cc_library( + name = "image", + srcs = [ + "hwy/contrib/image/image.cc", + ], + hdrs = [ + "hwy/contrib/image/image.h", + ], + compatible_with = [], + copts = COPTS, + local_defines = ["hwy_contrib_EXPORTS"], + deps = [ + ":hwy", + ], +) + +cc_library( + name = "math", + compatible_with = [], + copts = COPTS, + textual_hdrs = [ + "hwy/contrib/math/math-inl.h", + ], + deps = [ + ":hwy", + ], +) + +# Everything required for tests that use Highway. +cc_library( + name = "hwy_test_util", + srcs = ["hwy/tests/test_util.cc"], + hdrs = ["hwy/tests/test_util.h"], + compatible_with = [], + copts = COPTS, + local_defines = ["hwy_test_EXPORTS"], + textual_hdrs = [ + "hwy/tests/test_util-inl.h", + "hwy/tests/hwy_gtest.h", + ], + # Must not depend on a gtest variant, which can conflict with the + # GUNIT_INTERNAL_BUILD_MODE defined by the test. + deps = [ + ":hwy", + ], +) + +cc_library( + name = "nanobenchmark", + srcs = ["hwy/nanobenchmark.cc"], + hdrs = ["hwy/nanobenchmark.h"], + compatible_with = [], + copts = COPTS, + local_defines = ["hwy_EXPORTS"], + deps = [":hwy"], +) + +cc_binary( + name = "benchmark", + srcs = ["hwy/examples/benchmark.cc"], + copts = COPTS, + deps = [ + ":hwy", + ":nanobenchmark", + ], +) + +cc_library( + name = "skeleton", + srcs = ["hwy/examples/skeleton.cc"], + hdrs = ["hwy/examples/skeleton.h"], + copts = COPTS, + local_defines = ["hwy_EXPORTS"], + textual_hdrs = ["hwy/examples/skeleton-inl.h"], + deps = [ + ":hwy", + ], +) + +cc_binary( + name = "list_targets", + srcs = ["hwy/tests/list_targets.cc"], + deps = [":hwy"], +) + +# path, name +HWY_TESTS = [ + ("hwy/contrib/algo/", "copy_test"), + ("hwy/contrib/algo/", "find_test"), + ("hwy/contrib/algo/", "transform_test"), + ("hwy/contrib/dot/", "dot_test"), + ("hwy/contrib/image/", "image_test"), + ("hwy/contrib/math/", "math_test"), + # contrib/sort has its own BUILD, we add it to GUITAR_TESTS. + ("hwy/examples/", "skeleton_test"), + ("hwy/", "nanobenchmark_test"), + ("hwy/", "aligned_allocator_test"), + ("hwy/", "base_test"), + ("hwy/", "highway_test"), + ("hwy/", "targets_test"), + ("hwy/tests/", "arithmetic_test"), + ("hwy/tests/", "blockwise_test"), + ("hwy/tests/", "blockwise_shift_test"), + ("hwy/tests/", "combine_test"), + ("hwy/tests/", "compare_test"), + ("hwy/tests/", "compress_test"), + ("hwy/tests/", "convert_test"), + ("hwy/tests/", "crypto_test"), + ("hwy/tests/", "demote_test"), + ("hwy/tests/", "float_test"), + ("hwy/tests/", "if_test"), + ("hwy/tests/", "interleaved_test"), + ("hwy/tests/", "logical_test"), + ("hwy/tests/", "mask_test"), + ("hwy/tests/", "mask_mem_test"), + ("hwy/tests/", "memory_test"), + ("hwy/tests/", "mul_test"), + ("hwy/tests/", "reduction_test"), + ("hwy/tests/", "reverse_test"), + ("hwy/tests/", "shift_test"), + ("hwy/tests/", "swizzle_test"), + ("hwy/tests/", "test_util_test"), +] + +HWY_TEST_COPTS = select({ + ":compiler_msvc": [], + "//conditions:default": [ + # gTest triggers this warning (which is enabled by the + # extra-semi in COPTS), so we need to disable it here, + # but it's still enabled for :hwy. + "-Wno-c++98-compat-extra-semi", + ], +}) + +HWY_TEST_DEPS = [ + ":algo", + ":dot", + ":hwy", + ":hwy_test_util", + ":image", + ":math", + ":nanobenchmark", + ":skeleton", + "//hwy/contrib/sort:vqsort", + "@com_google_googletest//:gtest_main", +] + +[ + [ + cc_test( + name = test, + size = "medium", + timeout = "long", # default moderate is not enough for math_test + srcs = [ + subdir + test + ".cc", + ], + copts = COPTS + HWY_TEST_COPTS, + features = select({ + "@platforms//cpu:riscv64": ["fully_static_link"], + "//conditions:default": [], + }), + linkopts = select({ + ":compiler_emscripten": [ + "-s ASSERTIONS=2", + "-s ENVIRONMENT=node,shell,web", + "-s ERROR_ON_UNDEFINED_SYMBOLS=1", + "-s DEMANGLE_SUPPORT=1", + "-s EXIT_RUNTIME=1", + "-s ALLOW_MEMORY_GROWTH=1", + "--pre-js $(location :preamble.js.lds)", + ], + "//conditions:default": [], + }), + linkstatic = select({ + "@platforms//cpu:riscv64": True, + "//conditions:default": False, + }), + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = HWY_TEST_DEPS + select({ + ":compiler_emscripten": [":preamble.js.lds"], + "//conditions:default": [], + }), + ), + ] + for subdir, test in HWY_TESTS +] + +# For manually building the tests we define here (:all does not work in --config=msvc) +test_suite( + name = "hwy_ops_tests", + tags = ["hwy_ops_test"], +) + +# Placeholder for integration test, do not remove diff --git a/media/highway/src/CMakeLists.txt b/media/highway/src/CMakeLists.txt new file mode 100644 index 000000000..b6b14ab83 --- /dev/null +++ b/media/highway/src/CMakeLists.txt @@ -0,0 +1,580 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.10) + +# Set PIE flags for POSITION_INDEPENDENT_CODE targets, added in 3.14. +if(POLICY CMP0083) + cmake_policy(SET CMP0083 NEW) +endif() + +# Workaround for 3.19 raising error 'IMPORTED_LOCATION not set for imported +# target "GTest::gtest_main"'. +if(POLICY CMP0111) + cmake_policy(SET CMP0111 OLD) +endif() + +project(hwy VERSION 1.0.2) # Keep in sync with highway.h version + +# Directly define the ABI version from the cmake project() version values: +set(LIBRARY_VERSION "${hwy_VERSION}") +set(LIBRARY_SOVERSION ${hwy_VERSION_MAJOR}) + +set(CMAKE_CXX_EXTENSIONS OFF) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +# Search for Atomics implementation: +find_package(Atomics REQUIRED) + +# Enabled PIE binaries by default if supported. +include(CheckPIESupported OPTIONAL RESULT_VARIABLE CHECK_PIE_SUPPORTED) +if(CHECK_PIE_SUPPORTED) + check_pie_supported(LANGUAGES CXX) + if(CMAKE_CXX_LINK_PIE_SUPPORTED) + set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) + endif() +endif() + +include(GNUInstallDirs) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE RelWithDebInfo) +endif() + +set(HWY_CMAKE_ARM7 OFF CACHE BOOL "Set copts for ARMv7 with NEON (requires vfpv4)?") + +# Unconditionally adding -Werror risks breaking the build when new warnings +# arise due to compiler/platform changes. Enable this in CI/tests. +set(HWY_WARNINGS_ARE_ERRORS OFF CACHE BOOL "Add -Werror flag?") + +set(HWY_ENABLE_CONTRIB ON CACHE BOOL "Include contrib/") +set(HWY_ENABLE_EXAMPLES ON CACHE BOOL "Build examples") +set(HWY_ENABLE_INSTALL ON CACHE BOOL "Install library") +set(HWY_ENABLE_TESTS ON CACHE BOOL "Enable HWY tests") + +include(CheckCXXSourceCompiles) +check_cxx_source_compiles( + "int main() { + #if !defined(__EMSCRIPTEN__) + static_assert(false, \"__EMSCRIPTEN__ is not defined\"); + #endif + return 0; + }" + HWY_EMSCRIPTEN +) + +check_cxx_source_compiles( + "int main() { + #if !defined(__riscv) + static_assert(false, \"__riscv is not defined\"); + #endif + return 0; + }" + HWY_RISCV +) + +if (HWY_ENABLE_CONTRIB) +# Glob all the traits so we don't need to modify this file when adding +# additional special cases. +file(GLOB HWY_CONTRIB_SOURCES "hwy/contrib/sort/vqsort_*.cc") +list(APPEND HWY_CONTRIB_SOURCES + hwy/contrib/dot/dot-inl.h + hwy/contrib/image/image.cc + hwy/contrib/image/image.h + hwy/contrib/math/math-inl.h + hwy/contrib/sort/shared-inl.h + hwy/contrib/sort/sorting_networks-inl.h + hwy/contrib/sort/traits-inl.h + hwy/contrib/sort/traits128-inl.h + hwy/contrib/sort/vqsort-inl.h + hwy/contrib/sort/vqsort.cc + hwy/contrib/sort/vqsort.h + hwy/contrib/algo/copy-inl.h + hwy/contrib/algo/find-inl.h + hwy/contrib/algo/transform-inl.h +) +endif() # HWY_ENABLE_CONTRIB + +set(HWY_SOURCES + hwy/aligned_allocator.cc + hwy/aligned_allocator.h + hwy/base.h + hwy/cache_control.h + hwy/detect_compiler_arch.h # private + hwy/detect_targets.h # private + hwy/foreach_target.h + hwy/highway.h + hwy/highway_export.h + hwy/nanobenchmark.cc + hwy/nanobenchmark.h + hwy/ops/arm_neon-inl.h + hwy/ops/arm_sve-inl.h + hwy/ops/emu128-inl.h + hwy/ops/generic_ops-inl.h + hwy/ops/rvv-inl.h + hwy/ops/scalar-inl.h + hwy/ops/set_macros-inl.h + hwy/ops/shared-inl.h + hwy/ops/wasm_128-inl.h + hwy/ops/x86_128-inl.h + hwy/ops/x86_256-inl.h + hwy/ops/x86_512-inl.h + hwy/per_target.cc + hwy/per_target.h + hwy/print-inl.h + hwy/print.cc + hwy/print.h + hwy/targets.cc + hwy/targets.h +) + +set(HWY_TEST_SOURCES + hwy/tests/hwy_gtest.h + hwy/tests/test_util-inl.h + hwy/tests/test_util.cc + hwy/tests/test_util.h +) + +if (MSVC) + set(HWY_FLAGS + # fix build error C1128 in blockwise*_test & arithmetic_test + /bigobj + ) +else() + set(HWY_FLAGS + # Avoid changing binaries based on the current time and date. + -Wno-builtin-macro-redefined + -D__DATE__="redacted" + -D__TIMESTAMP__="redacted" + -D__TIME__="redacted" + + # Optimizations + -fmerge-all-constants + + # Warnings + -Wall + -Wextra + # These are not included in Wall nor Wextra: + -Wconversion + -Wsign-conversion + -Wvla + -Wnon-virtual-dtor + ) + + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS + -Wfloat-overflow-conversion + -Wfloat-zero-conversion + -Wfor-loop-analysis + -Wgnu-redeclared-enum + -Winfinite-recursion + -Wself-assign + -Wstring-conversion + -Wtautological-overlap-compare + -Wthread-safety-analysis + -Wundefined-func-template + + -fno-cxx-exceptions + -fno-slp-vectorize + -fno-vectorize + + # Use color in messages + -fdiagnostics-show-option -fcolor-diagnostics + ) + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 6.0) + list(APPEND HWY_FLAGS -Wc++2a-extensions) + endif() + endif() + + if (WIN32) + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS + -Wno-global-constructors + -Wno-language-extension-token + -Wno-used-but-marked-unused + -Wno-shadow-field-in-constructor + -Wno-unused-member-function + -Wno-unused-template + -Wno-c++98-compat-pedantic + -Wno-used-but-marked-unused + -Wno-zero-as-null-pointer-constant + ) + endif() + + list(APPEND HWY_FLAGS + -Wno-cast-align + -Wno-double-promotion + -Wno-float-equal + -Wno-format-nonliteral + -Wno-shadow + -Wno-sign-conversion + ) + else() + list(APPEND HWY_FLAGS + -fmath-errno + -fno-exceptions + ) + endif() # WIN32 + + if (HWY_CMAKE_ARM7) + list(APPEND HWY_FLAGS + -march=armv7-a + -mfpu=neon-vfpv4 + -mfloat-abi=hard # must match the toolchain specified as CXX= + -mfp16-format=ieee # required for vcvt_f32_f16 + ) + endif() # HWY_CMAKE_ARM7 + + if(HWY_RISCV) + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + # Not yet supported by GCC. When runtime dispatch is supported and + # implemented, we will remove v from the required flags. Until then, using + # clang for RISC-V will require the CPU to support the V extension (1.0). + list(APPEND HWY_FLAGS -march=rv64gcv1p0) + list(APPEND HWY_FLAGS -menable-experimental-extensions) + endif() + endif() + + if (HWY_WARNINGS_ARE_ERRORS) + list(APPEND HWY_FLAGS -Werror) + endif() + + # Prevent "wasm-ld: error: --shared-memory is disallowed by targets.cc.o + # because it was not compiled with 'atomics' or 'bulk-memory' features." + if (HWY_EMSCRIPTEN) + list(APPEND HWY_FLAGS -matomics) + endif() + +endif() # !MSVC + +# By default prefer STATIC build (legacy behavior) +option(BUILD_SHARED_LIBS "Build shared libraries" OFF) +option(HWY_FORCE_STATIC_LIBS "Ignore BUILD_SHARED_LIBS" OFF) +# only expose shared/static options to advanced users: +mark_as_advanced(BUILD_SHARED_LIBS) +mark_as_advanced(HWY_FORCE_STATIC_LIBS) +# Define visibility settings globally: +set(CMAKE_CXX_VISIBILITY_PRESET hidden) +set(CMAKE_VISIBILITY_INLINES_HIDDEN 1) + +# Copy-cat "add_library" logic + add override. +set(HWY_LIBRARY_TYPE "SHARED") +if (NOT BUILD_SHARED_LIBS OR HWY_FORCE_STATIC_LIBS) + set(HWY_LIBRARY_TYPE "STATIC") +endif() + +# This preprocessor define will drive the build, also used in the *.pc files: +if("${HWY_LIBRARY_TYPE}" STREQUAL "SHARED") + set(DLLEXPORT_TO_DEFINE "HWY_SHARED_DEFINE") +else() + set(DLLEXPORT_TO_DEFINE "HWY_STATIC_DEFINE") +endif() + +add_library(hwy ${HWY_LIBRARY_TYPE} ${HWY_SOURCES}) +target_compile_definitions(hwy PUBLIC "${DLLEXPORT_TO_DEFINE}") +target_compile_options(hwy PRIVATE ${HWY_FLAGS}) +set_property(TARGET hwy PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) +target_include_directories(hwy PUBLIC + $ + $) +target_compile_features(hwy PUBLIC cxx_std_11) +set_target_properties(hwy PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# For GCC __atomic_store_8, see #887 +target_link_libraries(hwy PRIVATE ${ATOMICS_LIBRARIES}) +if(UNIX AND NOT APPLE) + # not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) + set_property(TARGET hwy APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "unknown") + # uname -p is broken on this system. Try uname -m + EXECUTE_PROCESS( COMMAND uname -m + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + OUTPUT_VARIABLE HWY_ARCH) +else (CMAKE_SYSTEM_PROCESSOR MATCHES "unknown") + set(HWY_ARCH ${CMAKE_SYSTEM_PROCESSOR}) +endif (CMAKE_SYSTEM_PROCESSOR MATCHES "unknown") +message(STATUS "Architecture: " ${HWY_ARCH}) +if (HWY_ARCH MATCHES "mips") + target_link_options(hwy PUBLIC "LINKER:-z,noexecstack") +endif (HWY_ARCH MATCHES "mips") + + +if (HWY_ENABLE_CONTRIB) +add_library(hwy_contrib ${HWY_LIBRARY_TYPE} ${HWY_CONTRIB_SOURCES}) +target_link_libraries(hwy_contrib hwy) +target_compile_options(hwy_contrib PRIVATE ${HWY_FLAGS}) +set_property(TARGET hwy_contrib PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy_contrib PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) +target_include_directories(hwy_contrib PUBLIC + $ + $) +target_compile_features(hwy_contrib PUBLIC cxx_std_11) +set_target_properties(hwy_contrib PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) +if(UNIX AND NOT APPLE) + set_property(TARGET hwy_contrib APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() +endif() # HWY_ENABLE_CONTRIB + +add_library(hwy_test ${HWY_LIBRARY_TYPE} ${HWY_TEST_SOURCES}) +target_link_libraries(hwy_test hwy) +target_compile_options(hwy_test PRIVATE ${HWY_FLAGS}) +set_property(TARGET hwy_test PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy_test PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) +target_include_directories(hwy_test PUBLIC + $ + $) +target_compile_features(hwy_test PUBLIC cxx_std_11) +set_target_properties(hwy_test PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) +if(UNIX AND NOT APPLE) + set_property(TARGET hwy_test APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() + +# -------------------------------------------------------- hwy_list_targets +# Generate a tool to print the compiled-in targets as defined by the current +# flags. This tool will print to stderr at build time, after building hwy. +add_executable(hwy_list_targets hwy/tests/list_targets.cc) +target_compile_options(hwy_list_targets PRIVATE ${HWY_FLAGS}) +target_link_libraries(hwy_list_targets hwy) +target_include_directories(hwy_list_targets PRIVATE + $) +# TARGET_FILE always returns the path to executable +# Naked target also not always could be run (due to the lack of '.\' prefix) +# Thus effective command to run should contain the full path +# and emulator prefix (if any). +if (NOT CMAKE_CROSSCOMPILING OR CMAKE_CROSSCOMPILING_EMULATOR) +add_custom_command(TARGET hwy_list_targets POST_BUILD + COMMAND ${CMAKE_CROSSCOMPILING_EMULATOR} $ || (exit 0)) +endif() + +# -------------------------------------------------------- +# Allow skipping the following sections for projects that do not need them: +# tests, examples, benchmarks and installation. + +# -------------------------------------------------------- install library +if (HWY_ENABLE_INSTALL) + +install(TARGETS hwy + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() + +if (HWY_ENABLE_CONTRIB) +install(TARGETS hwy_contrib + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_CONTRIB_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() +endif() # HWY_ENABLE_CONTRIB + +install(TARGETS hwy_test + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_TEST_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() + +# Add a pkg-config file for libhwy and the contrib/test libraries. +set(HWY_LIBRARY_VERSION "${CMAKE_PROJECT_VERSION}") +set(HWY_PC_FILES libhwy.pc libhwy-test.pc) +if (HWY_ENABLE_CONTRIB) +list(APPEND HWY_PC_FILES libhwy-contrib.pc) +endif() # HWY_ENABLE_CONTRIB +foreach (pc ${HWY_PC_FILES}) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/${pc}.in" "${pc}" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/${pc}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") +endforeach() + +endif() # HWY_ENABLE_INSTALL +# -------------------------------------------------------- Examples +if (HWY_ENABLE_EXAMPLES) + +# Avoids mismatch between GTest's static CRT and our dynamic. +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Programming exercise with integrated benchmark +add_executable(hwy_benchmark hwy/examples/benchmark.cc) +target_sources(hwy_benchmark PRIVATE + hwy/nanobenchmark.h) +# Try adding one of -DHWY_COMPILE_ONLY_SCALAR, -DHWY_COMPILE_ONLY_EMU128 or +# -DHWY_COMPILE_ONLY_STATIC to observe the difference in targets printed. +target_compile_options(hwy_benchmark PRIVATE ${HWY_FLAGS}) +target_link_libraries(hwy_benchmark hwy) +set_target_properties(hwy_benchmark + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") + +endif() # HWY_ENABLE_EXAMPLES +# -------------------------------------------------------- Tests + +include(CTest) + +if(BUILD_TESTING AND HWY_ENABLE_TESTS) +enable_testing() +include(GoogleTest) + +set(HWY_SYSTEM_GTEST OFF CACHE BOOL "Use pre-installed googletest?") +if(HWY_SYSTEM_GTEST) +find_package(GTest REQUIRED) +else() +# Download and unpack googletest at configure time +configure_file(CMakeLists.txt.in googletest-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "CMake step for googletest failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "Build step for googletest failed: ${result}") +endif() + +# Prevent overriding the parent project's compiler/linker +# settings on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Add googletest directly to our build. This defines +# the gtest and gtest_main targets. +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src + ${CMAKE_CURRENT_BINARY_DIR}/googletest-build + EXCLUDE_FROM_ALL) +endif() # HWY_SYSTEM_GTEST + +set(HWY_TEST_FILES + hwy/contrib/algo/copy_test.cc + hwy/contrib/algo/find_test.cc + hwy/contrib/algo/transform_test.cc + hwy/aligned_allocator_test.cc + hwy/base_test.cc + hwy/highway_test.cc + hwy/nanobenchmark_test.cc + hwy/targets_test.cc + hwy/examples/skeleton_test.cc + hwy/tests/arithmetic_test.cc + hwy/tests/blockwise_test.cc + hwy/tests/blockwise_shift_test.cc + hwy/tests/combine_test.cc + hwy/tests/compare_test.cc + hwy/tests/compress_test.cc + hwy/tests/convert_test.cc + hwy/tests/crypto_test.cc + hwy/tests/demote_test.cc + hwy/tests/float_test.cc + hwy/tests/if_test.cc + hwy/tests/interleaved_test.cc + hwy/tests/logical_test.cc + hwy/tests/mask_test.cc + hwy/tests/mask_mem_test.cc + hwy/tests/memory_test.cc + hwy/tests/mul_test.cc + hwy/tests/reduction_test.cc + hwy/tests/reverse_test.cc + hwy/tests/shift_test.cc + hwy/tests/swizzle_test.cc + hwy/tests/test_util_test.cc +) + +set(HWY_TEST_LIBS hwy hwy_test) + +if (HWY_ENABLE_CONTRIB) +list(APPEND HWY_TEST_LIBS hwy_contrib) + +list(APPEND HWY_TEST_FILES + hwy/contrib/dot/dot_test.cc + hwy/contrib/image/image_test.cc + # Disabled due to SIGILL in clang7 debug build during gtest discovery phase, + # not reproducible locally. Still tested via bazel build. + # hwy/contrib/math/math_test.cc + hwy/contrib/sort/sort_test.cc +) +endif() # HWY_ENABLE_CONTRIB + +if(HWY_SYSTEM_GTEST) + if (CMAKE_VERSION VERSION_LESS 3.20) + set(HWY_GTEST_LIBS GTest::GTest GTest::Main) + else() + set(HWY_GTEST_LIBS GTest::gtest GTest::gtest_main) + endif() +else() + set(HWY_GTEST_LIBS gtest gtest_main) +endif() + +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests) +foreach (TESTFILE IN LISTS HWY_TEST_FILES) + # The TESTNAME is the name without the extension or directory. + get_filename_component(TESTNAME ${TESTFILE} NAME_WE) + add_executable(${TESTNAME} ${TESTFILE}) + target_compile_options(${TESTNAME} PRIVATE ${HWY_FLAGS}) + # Test all targets, not just the best/baseline. This changes the default + # policy to all-attainable; note that setting -DHWY_COMPILE_* directly can + # cause compile errors because only one may be set, and other CMakeLists.txt + # that include us may set them. + target_compile_options(${TESTNAME} PRIVATE -DHWY_IS_TEST=1) + + target_link_libraries(${TESTNAME} ${HWY_TEST_LIBS} ${HWY_GTEST_LIBS}) + # Output test targets in the test directory. + set_target_properties(${TESTNAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "tests") + + if (HWY_EMSCRIPTEN) + set_target_properties(${TESTNAME} PROPERTIES LINK_FLAGS "-s SINGLE_FILE=1") + endif() + + if(${CMAKE_VERSION} VERSION_LESS "3.10.3") + gtest_discover_tests(${TESTNAME} TIMEOUT 60) + else () + gtest_discover_tests(${TESTNAME} DISCOVERY_TIMEOUT 60) + endif () +endforeach () + +# The skeleton test uses the skeleton library code. +target_sources(skeleton_test PRIVATE hwy/examples/skeleton.cc) + +endif() # BUILD_TESTING diff --git a/media/highway/src/CMakeLists.txt.in b/media/highway/src/CMakeLists.txt.in new file mode 100644 index 000000000..a0260b82f --- /dev/null +++ b/media/highway/src/CMakeLists.txt.in @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 2.8.12) + +project(googletest-download NONE) + +include(ExternalProject) +ExternalProject_Add(googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG 43efa0a4efd40c78b9210d15373112081899a97c + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/media/highway/src/CONTRIBUTING b/media/highway/src/CONTRIBUTING new file mode 100644 index 000000000..8b7d4d253 --- /dev/null +++ b/media/highway/src/CONTRIBUTING @@ -0,0 +1,33 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Testing + +This repository is used by JPEG XL, so major API changes will require +coordination. Please get in touch with us beforehand, e.g. by raising an issue. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/media/highway/src/LICENSE b/media/highway/src/LICENSE new file mode 100644 index 000000000..f49a4e16e --- /dev/null +++ b/media/highway/src/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/media/highway/src/README.md b/media/highway/src/README.md new file mode 100644 index 000000000..969f32950 --- /dev/null +++ b/media/highway/src/README.md @@ -0,0 +1,322 @@ +# Efficient and performance-portable vector software + +[//]: # (placeholder, do not remove) + +Highway is a C++ library that provides portable SIMD/vector intrinsics. + +## Why + +We are passionate about high-performance software. We see major untapped +potential in CPUs (servers, mobile, desktops). Highway is for engineers who want +to reliably and economically push the boundaries of what is possible in +software. + +## How + +CPUs provide SIMD/vector instructions that apply the same operation to multiple +data items. This can reduce energy usage e.g. *fivefold* because fewer +instructions are executed. We also often see *5-10x* speedups. + +Highway makes SIMD/vector programming practical and workable according to these +guiding principles: + +**Does what you expect**: Highway is a C++ library with carefully-chosen +functions that map well to CPU instructions without extensive compiler +transformations. The resulting code is more predictable and robust to code +changes/compiler updates than autovectorization. + +**Works on widely-used platforms**: Highway supports four architectures; the +same application code can target eight instruction sets, including those with +'scalable' vectors (size unknown at compile time). Highway only requires C++11 +and supports four families of compilers. If you would like to use Highway on +other platforms, please raise an issue. + +**Flexible to deploy**: Applications using Highway can run on heterogeneous +clouds or client devices, choosing the best available instruction set at +runtime. Alternatively, developers may choose to target a single instruction set +without any runtime overhead. In both cases, the application code is the same +except for swapping `HWY_STATIC_DISPATCH` with `HWY_DYNAMIC_DISPATCH` plus one +line of code. + +**Suitable for a variety of domains**: Highway provides an extensive set of +operations, used for image processing (floating-point), compression, video +analysis, linear algebra, cryptography, sorting and random generation. We +recognise that new use-cases may require additional ops and are happy to add +them where it makes sense (e.g. no performance cliffs on some architectures). If +you would like to discuss, please file an issue. + +**Rewards data-parallel design**: Highway provides tools such as Gather, +MaskedLoad, and FixedTag to enable speedups for legacy data structures. However, +the biggest gains are unlocked by designing algorithms and data structures for +scalable vectors. Helpful techniques include batching, structure-of-array +layouts, and aligned/padded allocations. + +## Examples + +Online demos using Compiler Explorer: + +- [multiple targets with dynamic dispatch](https://gcc.godbolt.org/z/zP7MYe9Yf) + (recommended) +- [single target using -m flags](https://gcc.godbolt.org/z/rGnjMevKG) + +Projects using Highway: (to add yours, feel free to raise an issue or contact us +via the below email) + +* [iresearch database index](https://github.com/iresearch-toolkit/iresearch/blob/e7638e7a4b99136ca41f82be6edccf01351a7223/core/utils/simd_utils.hpp) +* [JPEG XL image codec](https://github.com/libjxl/libjxl) +* [Grok JPEG 2000 image codec](https://github.com/GrokImageCompression/grok) +* [vectorized Quicksort](https://github.com/google/highway/tree/master/hwy/contrib/sort) ([paper](https://arxiv.org/abs/2205.05982)) + +## Current status + +### Targets + +Supported targets: scalar, S-SSE3, SSE4, AVX2, AVX-512, AVX3_DL (~Icelake, +requires opt-in by defining `HWY_WANT_AVX3_DL`), NEON (ARMv7 and v8), SVE, SVE2, +WASM SIMD, RISC-V V. + +SVE was initially tested using farm_sve (see acknowledgments). + +### Versioning + +Highway releases aim to follow the semver.org system (MAJOR.MINOR.PATCH), +incrementing MINOR after backward-compatible additions and PATCH after +backward-compatible fixes. We recommend using releases (rather than the Git tip) +because they are tested more extensively, see below. + +The current version 1.0 signals an increased focus on backwards compatibility. +Applications using documented functionality will remain compatible with future +updates that have the same major version number. + +### Testing + +Continuous integration tests build with a recent version of Clang (running on +native x86, or QEMU for RVV and ARM) and MSVC 2019 (v19.28, running on native +x86). + +Before releases, we also test on x86 with Clang and GCC, and ARMv7/8 via GCC +cross-compile. See the [testing process](g3doc/release_testing_process.md) for +details. + +### Related modules + +The `contrib` directory contains SIMD-related utilities: an image class with +aligned rows, a math library (16 functions already implemented, mostly +trigonometry), and functions for computing dot products and sorting. + +## Installation + +This project uses CMake to generate and build. In a Debian-based system you can +install it via: + +```bash +sudo apt install cmake +``` + +Highway's unit tests use [googletest](https://github.com/google/googletest). +By default, Highway's CMake downloads this dependency at configuration time. +You can disable this by setting the `HWY_SYSTEM_GTEST` CMake variable to ON and +installing gtest separately: + +```bash +sudo apt install libgtest-dev +``` + +To build Highway as a shared or static library (depending on BUILD_SHARED_LIBS), +the standard CMake workflow can be used: + +```bash +mkdir -p build && cd build +cmake .. +make -j && make test +``` + +Or you can run `run_tests.sh` (`run_tests.bat` on Windows). + +Bazel is also supported for building, but it is not as widely used/tested. + +## Quick start + +You can use the `benchmark` inside examples/ as a starting point. + +A [quick-reference page](g3doc/quick_reference.md) briefly lists all operations +and their parameters, and the [instruction_matrix](g3doc/instruction_matrix.pdf) +indicates the number of instructions per operation. + +The [FAQ](g3doc/faq.md) answers questions about portability, API design and +where to find more information. + +We recommend using full SIMD vectors whenever possible for maximum performance +portability. To obtain them, pass a `ScalableTag` (or equivalently +`HWY_FULL(float)`) tag to functions such as `Zero/Set/Load`. There are two +alternatives for use-cases requiring an upper bound on the lanes: + +- For up to `N` lanes, specify `CappedTag` or the equivalent + `HWY_CAPPED(T, N)`. The actual number of lanes will be `N` rounded down to + the nearest power of two, such as 4 if `N` is 5, or 8 if `N` is 8. This is + useful for data structures such as a narrow matrix. A loop is still required + because vectors may actually have fewer than `N` lanes. + +- For exactly a power of two `N` lanes, specify `FixedTag`. The largest + supported `N` depends on the target, but is guaranteed to be at least + `16/sizeof(T)`. + +Due to ADL restrictions, user code calling Highway ops must either: +* Reside inside `namespace hwy { namespace HWY_NAMESPACE {`; or +* prefix each op with an alias such as `namespace hn = hwy::HWY_NAMESPACE; + hn::Add()`; or +* add using-declarations for each op used: `using hwy::HWY_NAMESPACE::Add;`. + +Additionally, each function that calls Highway ops (such as `Load`) must either +be prefixed with `HWY_ATTR`, OR reside between `HWY_BEFORE_NAMESPACE()` and +`HWY_AFTER_NAMESPACE()`. Lambda functions currently require `HWY_ATTR` before +their opening brace. + +The entry points into code using Highway differ slightly depending on whether +they use static or dynamic dispatch. + +* For static dispatch, `HWY_TARGET` will be the best available target among + `HWY_BASELINE_TARGETS`, i.e. those allowed for use by the compiler (see + [quick-reference](g3doc/quick_reference.md)). Functions inside + `HWY_NAMESPACE` can be called using `HWY_STATIC_DISPATCH(func)(args)` within + the same module they are defined in. You can call the function from other + modules by wrapping it in a regular function and declaring the regular + function in a header. + +* For dynamic dispatch, a table of function pointers is generated via the + `HWY_EXPORT` macro that is used by `HWY_DYNAMIC_DISPATCH(func)(args)` to + call the best function pointer for the current CPU's supported targets. A + module is automatically compiled for each target in `HWY_TARGETS` (see + [quick-reference](g3doc/quick_reference.md)) if `HWY_TARGET_INCLUDE` is + defined and `foreach_target.h` is included. + +When using dynamic dispatch, `foreach_target.h` is included from translation +units (.cc files), not headers. Headers containing vector code shared between +several translation units require a special include guard, for example the +following taken from `examples/skeleton-inl.h`: + +``` +#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#else +#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#endif + +#include "hwy/highway.h" +// Your vector code +#endif +``` + +By convention, we name such headers `-inl.h` because their contents (often +function templates) are usually inlined. + +## Compiler flags + +Applications should be compiled with optimizations enabled - without inlining, +SIMD code may slow down by factors of 10 to 100. For clang and GCC, `-O2` is +generally sufficient. + +For MSVC, we recommend compiling with `/Gv` to allow non-inlined functions to +pass vector arguments in registers. If intending to use the AVX2 target together +with half-width vectors (e.g. for `PromoteTo`), it is also important to compile +with `/arch:AVX2`. This seems to be the only way to generate VEX-encoded SSE4 +instructions on MSVC. Otherwise, mixing VEX-encoded AVX2 instructions and +non-VEX SSE4 may cause severe performance degradation. Unfortunately, the +resulting binary will then require AVX2. Note that no such flag is needed for +clang and GCC because they support target-specific attributes, which we use to +ensure proper VEX code generation for AVX2 targets. + +## Strip-mining loops + +To vectorize a loop, "strip-mining" transforms it into an outer loop and inner +loop with number of iterations matching the preferred vector width. + +In this section, let `T` denote the element type, `d = ScalableTag`, `count` +the number of elements to process, and `N = Lanes(d)` the number of lanes in a +full vector. Assume the loop body is given as a function `template void LoopBody(D d, size_t index, size_t max_n)`. + +Highway offers several ways to express loops where `N` need not divide `count`: + +* Ensure all inputs/outputs are padded. Then the loop is simply + + ``` + for (size_t i = 0; i < count; i += N) LoopBody(d, i, 0); + ``` + Here, the template parameter and second function argument are not needed. + + This is the preferred option, unless `N` is in the thousands and vector + operations are pipelined with long latencies. This was the case for + supercomputers in the 90s, but nowadays ALUs are cheap and we see most + implementations split vectors into 1, 2 or 4 parts, so there is little cost + to processing entire vectors even if we do not need all their lanes. Indeed + this avoids the (potentially large) cost of predication or partial + loads/stores on older targets, and does not duplicate code. + +* Use the `Transform*` functions in hwy/contrib/algo/transform-inl.h. This + takes care of the loop and remainder handling and you simply define a + generic lambda function (C++14) or functor which receives the current vector + from the input/output array, plus optionally vectors from up to two extra + input arrays, and returns the value to write to the input/output array. + + Here is an example implementing the BLAS function SAXPY (`alpha * x + y`): + + ``` + Transform1(d, x, n, y, [](auto d, const auto v, const auto v1) HWY_ATTR { + return MulAdd(Set(d, alpha), v, v1); + }); + ``` + +* Process whole vectors as above, followed by a scalar loop: + + ``` + size_t i = 0; + for (; i + N <= count; i += N) LoopBody(d, i, 0); + for (; i < count; ++i) LoopBody(CappedTag(), i, 0); + ``` + The template parameter and second function arguments are again not needed. + + This avoids duplicating code, and is reasonable if `count` is large. + If `count` is small, the second loop may be slower than the next option. + +* Process whole vectors as above, followed by a single call to a modified + `LoopBody` with masking: + + ``` + size_t i = 0; + for (; i + N <= count; i += N) { + LoopBody(d, i, 0); + } + if (i < count) { + LoopBody(d, i, count - i); + } + ``` + Now the template parameter and third function argument can be used inside + `LoopBody` to non-atomically 'blend' the first `num_remaining` lanes of `v` + with the previous contents of memory at subsequent locations: + `BlendedStore(v, FirstN(d, num_remaining), d, pointer);`. Similarly, + `MaskedLoad(FirstN(d, num_remaining), d, pointer)` loads the first + `num_remaining` elements and returns zero in other lanes. + + This is a good default when it is infeasible to ensure vectors are padded, + but is only safe `#if !HWY_MEM_OPS_MIGHT_FAULT`! + In contrast to the scalar loop, only a single final iteration is needed. + The increased code size from two loop bodies is expected to be worthwhile + because it avoids the cost of masking in all but the final iteration. + +## Additional resources + +* [Highway introduction (slides)](g3doc/highway_intro.pdf) +* [Overview of instructions per operation on different architectures](g3doc/instruction_matrix.pdf) +* [Design philosophy and comparison](g3doc/design_philosophy.md) +* [Implementation details](g3doc/impl_details.md) + +## Acknowledgments + +We have used [farm-sve](https://gitlab.inria.fr/bramas/farm-sve) by Berenger +Bramas; it has proved useful for checking the SVE port on an x86 development +machine. + +This is not an officially supported Google product. +Contact: janwas@google.com diff --git a/media/highway/src/WORKSPACE b/media/highway/src/WORKSPACE new file mode 100644 index 000000000..6df1f62e9 --- /dev/null +++ b/media/highway/src/WORKSPACE @@ -0,0 +1,24 @@ +workspace(name = "highway") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/609281088cfefc76f9d0ce82e1ff6c30cc3591e5.zip"], + sha256 = "5cf189eb6847b4f8fc603a3ffff3b0771c08eec7dd4bd961bfd45477dd13eb73", + strip_prefix = "googletest-609281088cfefc76f9d0ce82e1ff6c30cc3591e5", +) + +# See https://google.github.io/googletest/quickstart-bazel.html +http_archive( + name = "rules_cc", + urls = ["https://github.com/bazelbuild/rules_cc/archive/40548a2974f1aea06215272d9c2b47a14a24e556.zip"], + sha256 = "56ac9633c13d74cb71e0546f103ce1c58810e4a76aa8325da593ca4277908d72", + strip_prefix = "rules_cc-40548a2974f1aea06215272d9c2b47a14a24e556", +) + +# Need recent version for config_setting_group +http_archive( + name = "bazel_skylib", + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz"], +) diff --git a/media/highway/src/debian/changelog b/media/highway/src/debian/changelog new file mode 100644 index 000000000..36d0c1de0 --- /dev/null +++ b/media/highway/src/debian/changelog @@ -0,0 +1,157 @@ +highway (1.0.2-1) UNRELEASED; urgency=medium + +* Add ExclusiveNeither, FindKnownFirstTrue, Ne128 +* Add 16-bit SumOfLanes/ReorderWidenMulAccumulate/ReorderDemote2To +* Faster sort for low-entropy input, improved pivot selection +* Add GN build system, Highway FAQ, k32v32 type to vqsort +* CMake: Support find_package(GTest), add rvv-inl.h, add HWY_ENABLE_TESTS +* Fix MIPS and C++20 build, Apple LLVM 10.3 detection, EMU128 AllTrue on RVV +* Fix missing exec_prefix, RVV build, warnings, libatomic linking +* Work around GCC 10.4 issue, disabled RDCYCLE, arm7 with vfpv3 +* Documentation/example improvements +* Support static dispatch to SVE2_128 and SVE_256 + + -- Jan Wassenberg Thu, 27 Oct 2022 17:00:00 +0200 + +highway (1.0.1-1) UNRELEASED; urgency=medium + +* Add Eq128, i64 Mul, unsigned->float ConvertTo +* Faster sort for few unique keys, more robust pivot selection +* Fix: floating-point generator for sort tests, Min/MaxOfLanes for i16 +* Fix: avoid always_inline in debug, link atomic +* GCC warnings: string.h, maybe-uninitialized, ignored-attributes +* GCC warnings: preprocessor int overflow, spurious use-after-free/overflow +* Doc: <=HWY_AVX3, Full32/64/128, how to use generic-inl + + -- Jan Wassenberg Tue, 23 Aug 2022 10:00:00 +0200 + +highway (1.0.0-1) UNRELEASED; urgency=medium + +* ABI change: 64-bit target values, more room for expansion +* Add CompressBlocksNot, CompressNot, Lt128Upper, Min/Max128Upper, TruncateTo +* Add HWY_SVE2_128 target +* Sort speedups especially for 128-bit +* Documentation clarifications +* Faster NEON CountTrue/FindFirstTrue/AllFalse/AllTrue +* Improved SVE codegen +* Fix u16x8 ConcatEven/Odd, SSSE3 i64 Lt +* MSVC 2017 workarounds +* Support for runtime dispatch on Arm/GCC/Linux + + -- Jan Wassenberg Wed, 27 Jul 2022 10:00:00 +0200 + +highway (0.17.0-1) UNRELEASED; urgency=medium + +* Add ExtractLane, InsertLane, IsInf, IsFinite, IsNaN +* Add StoreInterleaved2, LoadInterleaved2/3/4, BlendedStore, SafeFillN +* Add MulFixedPoint15, Or3 +* Add Copy[If], Find[If], Generate, Replace[If] algos +* Add HWY_EMU128 target (replaces HWY_SCALAR) +* HWY_RVV is feature-complete +* Add HWY_ENABLE_CONTRIB build flag, HWY_NATIVE_FMA, HWY_WANT_SSSE3/SSE4 macros +* Extend ConcatOdd/Even and StoreInterleaved* to all types +* Allow CappedTag +* Sort speedups: 2x for AVX2, 1.09x for AVX3; avoid x86 malloc +* Expand documentation +* Fix RDTSCP crash in nanobenchmark +* Fix XCR0 check (was ignoring AVX3 on ICL) +* Support Arm/RISC-V timers + + -- Jan Wassenberg Fri, 20 May 2022 10:00:00 +0200 + +highway (0.16.0-1) UNRELEASED; urgency=medium + + * Add contrib/sort (vectorized quicksort) + * Add IfNegativeThenElse, IfVecThenElse + * Add Reverse2,4,8, ReverseBlocks, DupEven/Odd, AESLastRound + * Add OrAnd, Min128, Max128, Lt128, SumsOf8 + * Support capped/partial vectors on RVV/SVE, int64 in WASM + * Support SVE2, shared library build + * Remove deprecated overloads without the required d arg (UpperHalf etc.) + + -- Jan Wassenberg Thu, 03 Feb 2022 11:00:00 +0100 + +highway (0.15.0-1) UNRELEASED; urgency=medium + + * New ops: CompressBlendedStore, ConcatOdd/Even, IndicesFromVec + * New ops: OddEvenBlocks, SwapAdjacentBlocks, Reverse, RotateRight + * Add bf16, unsigned comparisons, more lane types for Reverse/TableLookupLanes + * Contrib: add sort(ing network) and dot(product) + * Targets: update RVV for LLVM, add experimental WASM2 + * Separate library hwy_test for test utils + * Add non-macro Simd<> aliases + * Fixes: const V& for GCC, AVX3 BZHI, POPCNT with AVX on MSVC, avoid %zu + + -- Jan Wassenberg Wed, 10 Nov 2021 10:00:00 +0100 + +highway (0.14.2-1) UNRELEASED; urgency=medium + + * Add MaskedLoad + * Fix non-glibc PPC, Windows GCC, MSVC 19.14 + * Opt-in for -Werror; separate design_philosophy.md + + -- Jan Wassenberg Tue, 24 Aug 2021 15:00:00 +0200 + +highway (0.14.1-1) UNRELEASED; urgency=medium + + * Add LoadMaskBits, CompressBits[Store] + * Fix CPU feature check (AES/F16C) and warnings + * Improved DASSERT - disabled in optimized builds + + -- Jan Wassenberg Tue, 17 Aug 2021 14:00:00 +0200 + +highway (0.14.0-1) UNRELEASED; urgency=medium + + * Add SVE, S-SSE3, AVX3_DL targets + * Support partial vectors in all ops + * Add PopulationCount, FindFirstTrue, Ne, TableLookupBytesOr0 + * Add AESRound, CLMul, MulOdd, HWY_CAP_FLOAT16 + + -- Jan Wassenberg Thu, 29 Jul 2021 15:00:00 +0200 + +highway (0.12.2-1) UNRELEASED; urgency=medium + + * fix scalar-only test and Windows macro conflict with Load/StoreFence + * replace deprecated wasm intrinsics + + -- Jan Wassenberg Mon, 31 May 2021 16:00:00 +0200 + +highway (0.12.1-1) UNRELEASED; urgency=medium + + * doc updates, ARM GCC support, fix s390/ppc, complete partial vectors + * fix warnings, faster ARM div/sqrt, separate hwy_contrib library + * add Abs(i64)/FirstN/Pause, enable AVX2 on MSVC + + -- Jan Wassenberg Wed, 19 May 2021 15:00:00 +0200 + +highway (0.12.0-1) UNRELEASED; urgency=medium + + * Add Shift*8, Compress16, emulated Scatter/Gather, StoreInterleaved3/4 + * Remove deprecated HWY_*_LANES, deprecate HWY_GATHER_LANES + * Proper IEEE rounding, reduce libstdc++ usage, inlined math + + -- Jan Wassenberg Thu, 15 Apr 2021 20:00:00 +0200 + +highway (0.11.1-1) UNRELEASED; urgency=medium + + * Fix clang7 asan error, finish f16 conversions and add test + + -- Jan Wassenberg Thu, 25 Feb 2021 16:00:00 +0200 + +highway (0.11.0-1) UNRELEASED; urgency=medium + + * Add RVV+mask logical ops, allow Shl/ShiftLeftSame on all targets, more math + + -- Jan Wassenberg Thu, 18 Feb 2021 20:00:00 +0200 + +highway (0.7.0-1) UNRELEASED; urgency=medium + + * Added API stability notice, Compress[Store], contrib/, SignBit, CopySign + + -- Jan Wassenberg Tue, 5 Jan 2021 17:00:00 +0200 + +highway (0.1-1) UNRELEASED; urgency=medium + + * Initial debian package. + + -- Alex Deymo Mon, 19 Oct 2020 16:48:07 +0200 diff --git a/media/highway/src/debian/compat b/media/highway/src/debian/compat new file mode 100644 index 000000000..f599e28b8 --- /dev/null +++ b/media/highway/src/debian/compat @@ -0,0 +1 @@ +10 diff --git a/media/highway/src/debian/control b/media/highway/src/debian/control new file mode 100644 index 000000000..7c60ebc7f --- /dev/null +++ b/media/highway/src/debian/control @@ -0,0 +1,23 @@ +Source: highway +Maintainer: JPEG XL Maintainers +Section: misc +Priority: optional +Standards-Version: 3.9.8 +Build-Depends: cmake, + debhelper (>= 9), + libgtest-dev +Homepage: https://github.com/google/highway + +Package: libhwy-dev +Architecture: any +Section: libdevel +Depends: ${misc:Depends} +Description: Efficient and performance-portable SIMD wrapper (developer files) + This library provides type-safe and source-code portable wrappers over + existing platform-specific intrinsics. Its design aims for simplicity, + reliable efficiency across platforms, and immediate usability with current + compilers. + . + This package installs the development files. There's no runtime library + since most of Highway is implemented in headers and only a very small + static library is needed. diff --git a/media/highway/src/debian/copyright b/media/highway/src/debian/copyright new file mode 100644 index 000000000..53ea57aa9 --- /dev/null +++ b/media/highway/src/debian/copyright @@ -0,0 +1,20 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ +Upstream-Name: highway + +Files: * +Copyright: 2020 Google LLC +License: Apache-2.0 + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + http://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + . + On Debian systems, the complete text of the Apache License, Version 2 + can be found in "/usr/share/common-licenses/Apache-2.0". diff --git a/media/highway/src/debian/rules b/media/highway/src/debian/rules new file mode 100644 index 000000000..969fc120e --- /dev/null +++ b/media/highway/src/debian/rules @@ -0,0 +1,6 @@ +#!/usr/bin/make -f +%: + dh $@ --buildsystem=cmake + +override_dh_auto_configure: + dh_auto_configure -- -DHWY_SYSTEM_GTEST=ON diff --git a/media/highway/src/debian/source/format b/media/highway/src/debian/source/format new file mode 100644 index 000000000..163aaf8d8 --- /dev/null +++ b/media/highway/src/debian/source/format @@ -0,0 +1 @@ +3.0 (quilt) diff --git a/media/highway/src/hwy/aligned_allocator.cc b/media/highway/src/hwy/aligned_allocator.cc new file mode 100644 index 000000000..7b9947970 --- /dev/null +++ b/media/highway/src/hwy/aligned_allocator.cc @@ -0,0 +1,152 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include +#include +#include +#include // malloc + +#include +#include + +#include "hwy/base.h" + +namespace hwy { +namespace { + +#if HWY_ARCH_RVV && defined(__riscv_vector) +// Not actually an upper bound on the size, but this value prevents crossing a +// 4K boundary (relevant on Andes). +constexpr size_t kAlignment = HWY_MAX(HWY_ALIGNMENT, 4096); +#else +constexpr size_t kAlignment = HWY_ALIGNMENT; +#endif + +#if HWY_ARCH_X86 +// On x86, aliasing can only occur at multiples of 2K, but that's too wasteful +// if this is used for single-vector allocations. 256 is more reasonable. +constexpr size_t kAlias = kAlignment * 4; +#else +constexpr size_t kAlias = kAlignment; +#endif + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t payload_size; +}; +#pragma pack(pop) + +// Returns a 'random' (cyclical) offset for AllocateAlignedBytes. +size_t NextAlignedOffset() { + static std::atomic next{0}; + constexpr uint32_t kGroups = kAlias / kAlignment; + const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + const size_t offset = kAlignment * group; + HWY_DASSERT((offset % kAlignment == 0) && offset <= kAlias); + return offset; +} + +} // namespace + +HWY_DLLEXPORT void* AllocateAlignedBytes(const size_t payload_size, + AllocPtr alloc_ptr, void* opaque_ptr) { + HWY_ASSERT(payload_size != 0); // likely a bug in caller + if (payload_size >= std::numeric_limits::max() / 2) { + HWY_DASSERT(false && "payload_size too large"); + return nullptr; + } + + size_t offset = NextAlignedOffset(); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + offset = kAlignment; // = RoundUpTo(sizeof(AllocationHeader), kAlignment) + static_assert(sizeof(AllocationHeader) <= kAlignment, "Else: round up"); + } + + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated; + if (alloc_ptr == nullptr) { + allocated = malloc(allocated_size); + } else { + allocated = (*alloc_ptr)(opaque_ptr, allocated_size); + } + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); + + const uintptr_t payload = aligned + offset; // still aligned + + // Stash `allocated` and payload_size inside header for FreeAlignedBytes(). + // The allocated_size can be reconstructed from the payload_size. + AllocationHeader* header = reinterpret_cast(payload) - 1; + header->allocated = allocated; + header->payload_size = payload_size; + + return HWY_ASSUME_ALIGNED(reinterpret_cast(payload), kAlignment); +} + +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +// static +HWY_DLLEXPORT void AlignedDeleter::DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + if (deleter) { + (*deleter)(aligned_pointer, header->payload_size); + } + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +} // namespace hwy diff --git a/media/highway/src/hwy/aligned_allocator.h b/media/highway/src/hwy/aligned_allocator.h new file mode 100644 index 000000000..f6bfca11e --- /dev/null +++ b/media/highway/src/hwy/aligned_allocator.h @@ -0,0 +1,212 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ +#define HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ + +// Memory allocator with support for alignment and offsets. + +#include + +#include + +#include "hwy/highway_export.h" + +namespace hwy { + +// Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which +// requires a literal. This matches typical L1 cache line sizes, which prevents +// false sharing. +#define HWY_ALIGNMENT 64 + +// Pointers to functions equivalent to malloc/free with an opaque void* passed +// to them. +using AllocPtr = void* (*)(void* opaque, size_t bytes); +using FreePtr = void (*)(void* opaque, void* memory); + +// Returns null or a pointer to at least `payload_size` (which can be zero) +// bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and +// the vector size. Calls `alloc` with the passed `opaque` pointer to obtain +// memory or malloc() if it is null. +HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size, + AllocPtr alloc_ptr, void* opaque_ptr); + +// Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it +// must have been returned from a previous call to `AllocateAlignedBytes`. +// Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if +// `free_ptr` function is null, uses the default free(). +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr); + +// Class that deletes the aligned pointer passed to operator() calling the +// destructor before freeing the pointer. This is equivalent to the +// std::default_delete but for aligned objects. For a similar deleter equivalent +// to free() for aligned memory see AlignedFreer(). +class AlignedDeleter { + public: + AlignedDeleter() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedDeleter(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + return DeleteAlignedArray(aligned_pointer, free_, opaque_ptr_, + TypedArrayDeleter); + } + + private: + template + static void TypedArrayDeleter(void* ptr, size_t size_in_bytes) { + size_t elems = size_in_bytes / sizeof(T); + for (size_t i = 0; i < elems; i++) { + // Explicitly call the destructor on each element. + (static_cast(ptr) + i)->~T(); + } + } + + // Function prototype that calls the destructor for each element in a typed + // array. TypeArrayDeleter would match this prototype. + using ArrayDeleter = void (*)(void* t_ptr, size_t t_size); + + HWY_DLLEXPORT static void DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter); + + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to T with custom aligned deleter. This can be a single +// element U or an array of element if T is a U[]. The custom aligned deleter +// will call the destructor on U or each element of a U[] in the array case. +template +using AlignedUniquePtr = std::unique_ptr; + +// Aligned memory equivalent of make_unique using the custom allocators +// alloc/free with the passed `opaque` pointer. This function calls the +// constructor with the passed Args... and calls the destructor of the object +// when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free, + void* opaque, Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T), alloc, opaque)); + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter(free, opaque)); +} + +// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free +// functions. +template +AlignedUniquePtr MakeUniqueAligned(Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes( + sizeof(T), /*alloc_ptr=*/nullptr, /*opaque_ptr=*/nullptr)); + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter()); +} + +// Helpers for array allocators (avoids overflow) +namespace detail { + +// Returns x such that 1u << x == n (if n is a power of two). +static inline constexpr size_t ShiftCount(size_t n) { + return (n <= 1) ? 0 : 1 + ShiftCount(n / 2); +} + +template +T* AllocateAlignedItems(size_t items, AllocPtr alloc_ptr, void* opaque_ptr) { + constexpr size_t size = sizeof(T); + + constexpr bool is_pow2 = (size & (size - 1)) == 0; + constexpr size_t bits = ShiftCount(size); + static_assert(!is_pow2 || (1ull << bits) == size, "ShiftCount is incorrect"); + + const size_t bytes = is_pow2 ? items << bits : items * size; + const size_t check = is_pow2 ? bytes >> bits : bytes / size; + if (check != items) { + return nullptr; // overflowed + } + return static_cast(AllocateAlignedBytes(bytes, alloc_ptr, opaque_ptr)); +} + +} // namespace detail + +// Aligned memory equivalent of make_unique for array types using the +// custom allocators alloc/free. This function calls the constructor with the +// passed Args... on every created item. The destructor of each element will be +// called when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedArrayWithAlloc( + size_t items, AllocPtr alloc, FreePtr free, void* opaque, Args&&... args) { + T* ptr = detail::AllocateAlignedItems(items, alloc, opaque); + if (ptr != nullptr) { + for (size_t i = 0; i < items; i++) { + new (ptr + i) T(std::forward(args)...); + } + } + return AlignedUniquePtr(ptr, AlignedDeleter(free, opaque)); +} + +template +AlignedUniquePtr MakeUniqueAlignedArray(size_t items, Args&&... args) { + return MakeUniqueAlignedArrayWithAlloc( + items, nullptr, nullptr, nullptr, std::forward(args)...); +} + +// Custom deleter for std::unique_ptr equivalent to using free() as a deleter +// but for aligned memory. +class AlignedFreer { + public: + // Pass address of this to ctor to skip deleting externally-owned memory. + static void DoNothing(void* /*opaque*/, void* /*aligned_pointer*/) {} + + AlignedFreer() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedFreer(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + // TODO(deymo): assert that we are using a POD type T. + FreeAlignedBytes(aligned_pointer, free_, opaque_ptr_); + } + + private: + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to single POD, or (if T is U[]) an array of POD. For non POD +// data use AlignedUniquePtr. +template +using AlignedFreeUniquePtr = std::unique_ptr; + +// Allocate an aligned and uninitialized array of POD values as a unique_ptr. +// Upon destruction of the unique_ptr the aligned array will be freed. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items, AllocPtr alloc, + FreePtr free, void* opaque) { + return AlignedFreeUniquePtr( + detail::AllocateAlignedItems(items, alloc, opaque), + AlignedFreer(free, opaque)); +} + +// Same as previous AllocateAligned(), using default allocate/free functions. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items) { + return AllocateAligned(items, nullptr, nullptr, nullptr); +} + +} // namespace hwy +#endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ diff --git a/media/highway/src/hwy/aligned_allocator_test.cc b/media/highway/src/hwy/aligned_allocator_test.cc new file mode 100644 index 000000000..ced08e7bd --- /dev/null +++ b/media/highway/src/hwy/aligned_allocator_test.cc @@ -0,0 +1,278 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +namespace { + +// Sample object that keeps track on an external counter of how many times was +// the explicit constructor and destructor called. +template +class SampleObject { + public: + SampleObject() { data_[0] = 'a'; } + explicit SampleObject(int* counter) : counter_(counter) { + if (counter) (*counter)++; + data_[0] = 'b'; + } + + ~SampleObject() { + if (counter_) (*counter_)--; + } + + static_assert(N > sizeof(int*), "SampleObject size too small."); + int* counter_ = nullptr; + char data_[N - sizeof(int*)]; +}; + +class FakeAllocator { + public: + // static AllocPtr and FreePtr member to be used with the alligned + // allocator. These functions calls the private non-static members. + static void* StaticAlloc(void* opaque, size_t bytes) { + return reinterpret_cast(opaque)->Alloc(bytes); + } + static void StaticFree(void* opaque, void* memory) { + return reinterpret_cast(opaque)->Free(memory); + } + + // Returns the number of pending allocations to be freed. + size_t PendingAllocs() { return allocs_.size(); } + + private: + void* Alloc(size_t bytes) { + void* ret = malloc(bytes); + allocs_.insert(ret); + return ret; + } + void Free(void* memory) { + if (!memory) return; + EXPECT_NE(allocs_.end(), allocs_.find(memory)); + allocs_.erase(memory); + free(memory); + } + + std::set allocs_; +}; + +} // namespace + +namespace hwy { + +class AlignedAllocatorTest : public testing::Test {}; + +TEST(AlignedAllocatorTest, FreeNullptr) { + // Calling free with a nullptr is always ok. + FreeAlignedBytes(/*aligned_pointer=*/nullptr, /*free_ptr=*/nullptr, + /*opaque_ptr=*/nullptr); +} + +TEST(AlignedAllocatorTest, Log2) { + EXPECT_EQ(0u, detail::ShiftCount(1)); + EXPECT_EQ(1u, detail::ShiftCount(2)); + EXPECT_EQ(3u, detail::ShiftCount(8)); +} + +// Allocator returns null when it detects overflow of items * sizeof(T). +TEST(AlignedAllocatorTest, Overflow) { + constexpr size_t max = ~size_t(0); + constexpr size_t msb = (max >> 1) + 1; + using Size5 = std::array; + using Size10 = std::array; + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems(max / 2, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems(max / 3, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems(max / 4, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems(msb, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems(msb + 1, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems(msb / 4, nullptr, nullptr)); +} + +TEST(AlignedAllocatorTest, AllocDefaultPointers) { + const size_t kSize = 7777; + void* ptr = AllocateAlignedBytes(kSize, /*alloc_ptr=*/nullptr, + /*opaque_ptr=*/nullptr); + ASSERT_NE(nullptr, ptr); + // Make sure the pointer is actually aligned. + EXPECT_EQ(0U, reinterpret_cast(ptr) % HWY_ALIGNMENT); + char* p = static_cast(ptr); + size_t ret = 0; + for (size_t i = 0; i < kSize; i++) { + // Performs a computation using p[] to prevent it being optimized away. + p[i] = static_cast(i & 0x7F); + if (i) ret += static_cast(p[i] * p[i - 1]); + } + EXPECT_NE(0U, ret); + FreeAlignedBytes(ptr, /*free_ptr=*/nullptr, /*opaque_ptr=*/nullptr); +} + +TEST(AlignedAllocatorTest, EmptyAlignedUniquePtr) { + AlignedUniquePtr> ptr(nullptr, AlignedDeleter()); + AlignedUniquePtr[]> arr(nullptr, AlignedDeleter()); +} + +TEST(AlignedAllocatorTest, EmptyAlignedFreeUniquePtr) { + AlignedFreeUniquePtr> ptr(nullptr, AlignedFreer()); + AlignedFreeUniquePtr[]> arr(nullptr, AlignedFreer()); +} + +TEST(AlignedAllocatorTest, CustomAlloc) { + FakeAllocator fake_alloc; + + const size_t kSize = 7777; + void* ptr = + AllocateAlignedBytes(kSize, &FakeAllocator::StaticAlloc, &fake_alloc); + ASSERT_NE(nullptr, ptr); + // We should have only requested one alloc from the allocator. + EXPECT_EQ(1U, fake_alloc.PendingAllocs()); + // Make sure the pointer is actually aligned. + EXPECT_EQ(0U, reinterpret_cast(ptr) % HWY_ALIGNMENT); + FreeAlignedBytes(ptr, &FakeAllocator::StaticFree, &fake_alloc); + EXPECT_EQ(0U, fake_alloc.PendingAllocs()); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedDefaultConstructor) { + { + auto ptr = MakeUniqueAligned>(); + // Default constructor sets the data_[0] to 'a'. + EXPECT_EQ('a', ptr->data_[0]); + EXPECT_EQ(nullptr, ptr->counter_); + } +} + +TEST(AlignedAllocatorTest, MakeUniqueAligned) { + int counter = 0; + { + // Creates the object, initializes it with the explicit constructor and + // returns an unique_ptr to it. + auto ptr = MakeUniqueAligned>(&counter); + EXPECT_EQ(1, counter); + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', ptr->data_[0]); + } + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedArray) { + int counter = 0; + { + // Creates the array of objects and initializes them with the explicit + // constructor. + auto arr = MakeUniqueAlignedArray>(7, &counter); + EXPECT_EQ(7, counter); + for (size_t i = 0; i < 7; i++) { + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', arr[i].data_[0]) << "Where i = " << i; + } + } + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, AllocSingleInt) { + auto ptr = AllocateAligned(1); + ASSERT_NE(nullptr, ptr.get()); + EXPECT_EQ(0U, reinterpret_cast(ptr.get()) % HWY_ALIGNMENT); + // Force delete of the unique_ptr now to check that it doesn't crash. + ptr.reset(nullptr); + EXPECT_EQ(nullptr, ptr.get()); +} + +TEST(AlignedAllocatorTest, AllocMultipleInt) { + const size_t kSize = 7777; + auto ptr = AllocateAligned(kSize); + ASSERT_NE(nullptr, ptr.get()); + EXPECT_EQ(0U, reinterpret_cast(ptr.get()) % HWY_ALIGNMENT); + // ptr[i] is actually (*ptr.get())[i] which will use the operator[] of the + // underlying type chosen by AllocateAligned() for the std::unique_ptr. + EXPECT_EQ(&(ptr[0]) + 1, &(ptr[1])); + + size_t ret = 0; + for (size_t i = 0; i < kSize; i++) { + // Performs a computation using ptr[] to prevent it being optimized away. + ptr[i] = static_cast(i); + if (i) ret += ptr[i] * ptr[i - 1]; + } + EXPECT_NE(0U, ret); +} + +TEST(AlignedAllocatorTest, AllocateAlignedObjectWithoutDestructor) { + int counter = 0; + { + // This doesn't call the constructor. + auto obj = AllocateAligned>(1); + obj[0].counter_ = &counter; + } + // Destroying the unique_ptr shouldn't have called the destructor of the + // SampleObject<24>. + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedArrayWithCustomAlloc) { + FakeAllocator fake_alloc; + int counter = 0; + { + // Creates the array of objects and initializes them with the explicit + // constructor. + auto arr = MakeUniqueAlignedArrayWithAlloc>( + 7, FakeAllocator::StaticAlloc, FakeAllocator::StaticFree, &fake_alloc, + &counter); + ASSERT_NE(nullptr, arr.get()); + // An array should still only call a single allocation. + EXPECT_EQ(1u, fake_alloc.PendingAllocs()); + EXPECT_EQ(7, counter); + for (size_t i = 0; i < 7; i++) { + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', arr[i].data_[0]) << "Where i = " << i; + } + } + EXPECT_EQ(0, counter); + EXPECT_EQ(0u, fake_alloc.PendingAllocs()); +} + +TEST(AlignedAllocatorTest, DefaultInit) { + // The test is whether this compiles. Default-init is useful for output params + // and per-thread storage. + std::vector> ptrs; + std::vector> free_ptrs; + ptrs.resize(128); + free_ptrs.resize(128); + // The following is to prevent elision of the pointers. + std::mt19937 rng(129); // Emscripten lacks random_device. + std::uniform_int_distribution dist(0, 127); + ptrs[dist(rng)] = MakeUniqueAlignedArray(123); + free_ptrs[dist(rng)] = AllocateAligned(456); + // "Use" pointer without resorting to printf. 0 == 0. Can't shift by 64. + const auto addr1 = reinterpret_cast(ptrs[dist(rng)].get()); + const auto addr2 = reinterpret_cast(free_ptrs[dist(rng)].get()); + constexpr size_t kBits = sizeof(uintptr_t) * 8; + EXPECT_EQ((addr1 >> (kBits - 1)) >> (kBits - 1), + (addr2 >> (kBits - 1)) >> (kBits - 1)); +} + +} // namespace hwy diff --git a/media/highway/src/hwy/base.h b/media/highway/src/hwy/base.h new file mode 100644 index 000000000..0a4491eb7 --- /dev/null +++ b/media/highway/src/hwy/base.h @@ -0,0 +1,946 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BASE_H_ +#define HIGHWAY_HWY_BASE_H_ + +// For SIMD module implementations and their callers, target-independent. + +#include +#include + +#include "hwy/detect_compiler_arch.h" +#include "hwy/highway_export.h" + +#if HWY_COMPILER_MSVC +#include // memcpy +#endif +#if HWY_ARCH_X86 +#include +#endif + +//------------------------------------------------------------------------------ +// Compiler-specific definitions + +#define HWY_STR_IMPL(macro) #macro +#define HWY_STR(macro) HWY_STR_IMPL(macro) + +#if HWY_COMPILER_MSVC + +#include + +#define HWY_RESTRICT __restrict +#define HWY_INLINE __forceinline +#define HWY_NOINLINE __declspec(noinline) +#define HWY_FLATTEN +#define HWY_NORETURN __declspec(noreturn) +#define HWY_LIKELY(expr) (expr) +#define HWY_UNLIKELY(expr) (expr) +#define HWY_PRAGMA(tokens) __pragma(tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(warning(tokens)) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(msc) +#define HWY_MAYBE_UNUSED +#define HWY_HAS_ASSUME_ALIGNED 0 +#if (_MSC_VER >= 1700) +#define HWY_MUST_USE_RESULT _Check_return_ +#else +#define HWY_MUST_USE_RESULT +#endif + +#else + +#define HWY_RESTRICT __restrict__ +// force inlining without optimization enabled creates very inefficient code +// that can cause compiler timeout +#ifdef __OPTIMIZE__ +#define HWY_INLINE inline __attribute__((always_inline)) +#else +#define HWY_INLINE inline +#endif +#define HWY_NOINLINE __attribute__((noinline)) +#define HWY_FLATTEN __attribute__((flatten)) +#define HWY_NORETURN __attribute__((noreturn)) +#define HWY_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define HWY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#define HWY_PRAGMA(tokens) _Pragma(#tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(GCC diagnostic tokens) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(gcc) +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define HWY_MAYBE_UNUSED __attribute__((unused)) +#define HWY_MUST_USE_RESULT __attribute__((warn_unused_result)) + +#endif // !HWY_COMPILER_MSVC + +//------------------------------------------------------------------------------ +// Builtin/attributes + +// Enables error-checking of format strings. +#if HWY_HAS_ATTRIBUTE(__format__) +#define HWY_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define HWY_FORMAT(idx_fmt, idx_arg) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* HWY_RESTRICT aligned = (float*)HWY_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if HWY_HAS_BUILTIN(__builtin_assume_aligned) +#define HWY_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +// Clang and GCC require attributes on each function into which SIMD intrinsics +// are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and +// automatic annotation via pragmas. +#if HWY_COMPILER_CLANG +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(clang attribute push(__attribute__((target(targets_str))), \ + apply_to = function)) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(clang attribute pop) +#elif HWY_COMPILER_GCC +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(GCC push_options) HWY_PRAGMA(GCC target targets_str) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(GCC pop_options) +#else +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#endif + +//------------------------------------------------------------------------------ +// Macros + +#define HWY_API static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED + +#define HWY_CONCAT_IMPL(a, b) a##b +#define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b) + +#define HWY_MIN(a, b) ((a) < (b) ? (a) : (b)) +#define HWY_MAX(a, b) ((a) > (b) ? (a) : (b)) + +#if HWY_COMPILER_GCC_ACTUAL +// nielskm: GCC does not support '#pragma GCC unroll' without the factor. +#define HWY_UNROLL(factor) HWY_PRAGMA(GCC unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL(4) +#elif HWY_COMPILER_CLANG || HWY_COMPILER_ICC || HWY_COMPILER_ICX +#define HWY_UNROLL(factor) HWY_PRAGMA(unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL() +#else +#define HWY_UNROLL(factor) +#define HWY_DEFAULT_UNROLL +#endif + + +// Compile-time fence to prevent undesirable code reordering. On Clang x86, the +// typical asm volatile("" : : : "memory") has no effect, whereas atomic fence +// does, without generating code. +#if HWY_ARCH_X86 +#define HWY_FENCE std::atomic_thread_fence(std::memory_order_acq_rel) +#else +// TODO(janwas): investigate alternatives. On ARM, the above generates barriers. +#define HWY_FENCE +#endif + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define HWY_REP4(literal) literal, literal, literal, literal + +#define HWY_ABORT(format, ...) \ + ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) + +// Always enabled. +#define HWY_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + HWY_ABORT("Assert %s", #condition); \ + } \ + } while (0) + +#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) +#define HWY_IS_MSAN 1 +#else +#define HWY_IS_MSAN 0 +#endif + +#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) +#define HWY_IS_ASAN 1 +#else +#define HWY_IS_ASAN 0 +#endif + +#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) +#define HWY_IS_TSAN 1 +#else +#define HWY_IS_TSAN 0 +#endif + +// MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo. +// You can disable MSAN by adding this attribute to the function that fails. +#if HWY_IS_MSAN +#define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory)) +#else +#define HWY_ATTR_NO_MSAN +#endif + +// For enabling HWY_DASSERT and shortening tests in slower debug builds +#if !defined(HWY_IS_DEBUG_BUILD) +// Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent +// MSVC defines NDEBUG (if not, could instead check _DEBUG). +#if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || HWY_IS_ASAN || \ + HWY_IS_MSAN || HWY_IS_TSAN || defined(__clang_analyzer__) +#define HWY_IS_DEBUG_BUILD 1 +#else +#define HWY_IS_DEBUG_BUILD 0 +#endif +#endif // HWY_IS_DEBUG_BUILD + +#if HWY_IS_DEBUG_BUILD +#define HWY_DASSERT(condition) HWY_ASSERT(condition) +#else +#define HWY_DASSERT(condition) \ + do { \ + } while (0) +#endif + +namespace hwy { + +//------------------------------------------------------------------------------ +// kMaxVectorSize (undocumented, pending removal) + +#if HWY_ARCH_X86 +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 +#elif HWY_ARCH_RVV && defined(__riscv_vector) +// Not actually an upper bound on the size. +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; +#else +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; +#endif + +//------------------------------------------------------------------------------ +// Alignment + +// Potentially useful for LoadDup128 and capped vectors. In other cases, arrays +// should be allocated dynamically via aligned_allocator.h because Lanes() may +// exceed the stack size. +#if HWY_ARCH_X86 +#define HWY_ALIGN_MAX alignas(64) +#elif HWY_ARCH_RVV && defined(__riscv_vector) +#define HWY_ALIGN_MAX alignas(8) // only elements need be aligned +#else +#define HWY_ALIGN_MAX alignas(16) +#endif + +//------------------------------------------------------------------------------ +// Lane types + +// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name +// by concatenating base type and bits. + +#pragma pack(push, 1) + +// ACLE (https://gcc.gnu.org/onlinedocs/gcc/Half-Precision.html): +// always supported on aarch64, for v7 only if -mfp16-format is given. +#if ((HWY_ARCH_ARM_A64 || (__ARM_FP & 2)) && HWY_COMPILER_GCC) +using float16_t = __fp16; +// C11 extension ISO/IEC TS 18661-3:2015 but not supported on all targets. +// Required for Clang RVV if the float16 extension is used. +#elif HWY_ARCH_RVV && HWY_COMPILER_CLANG && defined(__riscv_zvfh) +using float16_t = _Float16; +// Otherwise emulate +#else +struct float16_t { + uint16_t bits; +}; +#endif + +struct bfloat16_t { + uint16_t bits; +}; + +#pragma pack(pop) + +using float32_t = float; +using float64_t = double; + +#pragma pack(push, 1) + +// Aligned 128-bit type. Cannot use __int128 because clang doesn't yet align it: +// https://reviews.llvm.org/D86310 +struct alignas(16) uint128_t { + uint64_t lo; // little-endian layout + uint64_t hi; +}; + +// 64 bit key plus 64 bit value. Faster than using uint128_t when only the key +// field is to be compared (Lt128Upper instead of Lt128). +struct alignas(16) K64V64 { + uint64_t value; // little-endian layout + uint64_t key; +}; + +// 32 bit key plus 32 bit value. Allows vqsort recursions to terminate earlier +// than when considering both to be a 64-bit key. +struct alignas(8) K32V32 { + uint32_t value; // little-endian layout + uint32_t key; +}; + +#pragma pack(pop) + +static inline HWY_MAYBE_UNUSED bool operator<(const uint128_t& a, + const uint128_t& b) { + return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const uint128_t& a, + const uint128_t& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const uint128_t& a, + const uint128_t& b) { + return a.lo == b.lo && a.hi == b.hi; +} + +static inline HWY_MAYBE_UNUSED bool operator<(const K64V64& a, + const K64V64& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K64V64& a, + const K64V64& b) { + return b < a; +} + +static inline HWY_MAYBE_UNUSED bool operator<(const K32V32& a, + const K32V32& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K32V32& a, + const K32V32& b) { + return b < a; +} + +//------------------------------------------------------------------------------ +// Controlling overload resolution (SFINAE) + +template +struct EnableIfT {}; +template <> +struct EnableIfT { + using type = void; +}; + +template +using EnableIf = typename EnableIfT::type; + +template +struct IsSameT { + enum { value = 0 }; +}; + +template +struct IsSameT { + enum { value = 1 }; +}; + +template +HWY_API constexpr bool IsSame() { + return IsSameT::value; +} + +// Insert into template/function arguments to enable this overload only for +// vectors of AT MOST this many bits. +// +// Note that enabling for exactly 128 bits is unnecessary because a function can +// simply be overloaded with Vec128 and/or Full128 tag. Enabling for other +// sizes (e.g. 64 bit) can be achieved via Simd. +#define HWY_IF_LE128(T, N) hwy::EnableIf* = nullptr +#define HWY_IF_LE64(T, N) hwy::EnableIf* = nullptr +#define HWY_IF_LE32(T, N) hwy::EnableIf* = nullptr +#define HWY_IF_GE32(T, N) hwy::EnableIf= 4>* = nullptr +#define HWY_IF_GE64(T, N) hwy::EnableIf= 8>* = nullptr +#define HWY_IF_GE128(T, N) hwy::EnableIf= 16>* = nullptr +#define HWY_IF_GT128(T, N) hwy::EnableIf<(N * sizeof(T) > 16)>* = nullptr + +#define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SIGNED(T) \ + hwy::EnableIf() && !IsFloat()>* = nullptr +#define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr + +#define HWY_IF_LANE_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_NOT_LANE_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_LANE_SIZE_LT(T, bytes) \ + hwy::EnableIf* = nullptr + +#define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ + hwy::EnableIf* = nullptr + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +template +struct RemoveConstT { + using type = T; +}; +template +struct RemoveConstT { + using type = T; +}; + +template +using RemoveConst = typename RemoveConstT::type; + +//------------------------------------------------------------------------------ +// Type relations + +namespace detail { + +template +struct Relations; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = uint16_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = int16_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = uint32_t; + using Narrow = uint8_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = int32_t; + using Narrow = int8_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = uint64_t; + using Narrow = uint16_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = int64_t; + using Narrow = int16_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Wide = uint128_t; + using Narrow = uint32_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = int32_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint128_t; + using Narrow = uint64_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = double; + using Narrow = float16_t; + enum { is_signed = 1, is_float = 1 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = float; + enum { is_signed = 1, is_float = 1 }; +}; + +template +struct TypeFromSize; +template <> +struct TypeFromSize<1> { + using Unsigned = uint8_t; + using Signed = int8_t; +}; +template <> +struct TypeFromSize<2> { + using Unsigned = uint16_t; + using Signed = int16_t; +}; +template <> +struct TypeFromSize<4> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; +}; +template <> +struct TypeFromSize<8> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; +}; +template <> +struct TypeFromSize<16> { + using Unsigned = uint128_t; +}; + +} // namespace detail + +// Aliases for types of a different category, but the same size. +template +using MakeUnsigned = typename detail::Relations::Unsigned; +template +using MakeSigned = typename detail::Relations::Signed; +template +using MakeFloat = typename detail::Relations::Float; + +// Aliases for types of the same category, but different size. +template +using MakeWide = typename detail::Relations::Wide; +template +using MakeNarrow = typename detail::Relations::Narrow; + +// Obtain type from its size [bytes]. +template +using UnsignedFromSize = typename detail::TypeFromSize::Unsigned; +template +using SignedFromSize = typename detail::TypeFromSize::Signed; +template +using FloatFromSize = typename detail::TypeFromSize::Float; + +// Avoid confusion with SizeTag where the parameter is a lane size. +using UnsignedTag = SizeTag<0>; +using SignedTag = SizeTag<0x100>; // integer +using FloatTag = SizeTag<0x200>; + +template > +constexpr auto TypeTag() -> hwy::SizeTag<((R::is_signed + R::is_float) << 8)> { + return hwy::SizeTag<((R::is_signed + R::is_float) << 8)>(); +} + +// For when we only want to distinguish FloatTag from everything else. +using NonFloatTag = SizeTag<0x400>; + +template > +constexpr auto IsFloatTag() -> hwy::SizeTag<(R::is_float ? 0x200 : 0x400)> { + return hwy::SizeTag<(R::is_float ? 0x200 : 0x400)>(); +} + +//------------------------------------------------------------------------------ +// Type traits + +template +HWY_API constexpr bool IsFloat() { + // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or + // from a float, not compared. + return IsSame() || IsSame(); +} + +template +HWY_API constexpr bool IsSigned() { + return T(0) > T(-1); +} +template <> +constexpr bool IsSigned() { + return true; +} +template <> +constexpr bool IsSigned() { + return true; +} + +// Largest/smallest representable integer values. +template +HWY_API constexpr T LimitsMax() { + static_assert(!IsFloat(), "Only for integer types"); + using TU = MakeUnsigned; + return static_cast(IsSigned() ? (static_cast(~0ull) >> 1) + : static_cast(~0ull)); +} +template +HWY_API constexpr T LimitsMin() { + static_assert(!IsFloat(), "Only for integer types"); + return IsSigned() ? T(-1) - LimitsMax() : T(0); +} + +// Largest/smallest representable value (integer or float). This naming avoids +// confusion with numeric_limits::min() (the smallest positive value). +template +HWY_API constexpr T LowestValue() { + return LimitsMin(); +} +template <> +constexpr float LowestValue() { + return -3.402823466e+38F; +} +template <> +constexpr double LowestValue() { + return -1.7976931348623158e+308; +} + +template +HWY_API constexpr T HighestValue() { + return LimitsMax(); +} +template <> +constexpr float HighestValue() { + return 3.402823466e+38F; +} +template <> +constexpr double HighestValue() { + return 1.7976931348623158e+308; +} + +// Difference between 1.0 and the next representable value. +template +HWY_API constexpr T Epsilon() { + return 1; +} +template <> +constexpr float Epsilon() { + return 1.192092896e-7f; +} +template <> +constexpr double Epsilon() { + return 2.2204460492503131e-16; +} + +// Returns width in bits of the mantissa field in IEEE binary32/64. +template +constexpr int MantissaBits() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr int MantissaBits() { + return 23; +} +template <> +constexpr int MantissaBits() { + return 52; +} + +// Returns the (left-shifted by one bit) IEEE binary32/64 representation with +// the largest possible (biased) exponent field. Used by IsInf. +template +constexpr MakeSigned MaxExponentTimes2() { + return -(MakeSigned{1} << (MantissaBits() + 1)); +} + +// Returns bitmask of the sign bit in IEEE binary32/64. +template +constexpr MakeUnsigned SignMask() { + return MakeUnsigned{1} << (sizeof(T) * 8 - 1); +} + +// Returns bitmask of the exponent field in IEEE binary32/64. +template +constexpr MakeUnsigned ExponentMask() { + return (~(MakeUnsigned{1} << MantissaBits()) + 1) & ~SignMask(); +} + +// Returns bitmask of the mantissa field in IEEE binary32/64. +template +constexpr MakeUnsigned MantissaMask() { + return (MakeUnsigned{1} << MantissaBits()) - 1; +} + +// Returns 1 << mantissa_bits as a floating-point number. All integers whose +// absolute value are less than this can be represented exactly. +template +constexpr T MantissaEnd() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr float MantissaEnd() { + return 8388608.0f; // 1 << 23 +} +template <> +constexpr double MantissaEnd() { + // floating point literal with p52 requires C++17. + return 4503599627370496.0; // 1 << 52 +} + +// Returns width in bits of the exponent field in IEEE binary32/64. +template +constexpr int ExponentBits() { + // Exponent := remaining bits after deducting sign and mantissa. + return 8 * sizeof(T) - 1 - MantissaBits(); +} + +// Returns largest value of the biased exponent field in IEEE binary32/64, +// right-shifted so that the LSB is bit zero. Example: 0xFF for float. +// This is expressed as a signed integer for more efficient comparison. +template +constexpr MakeSigned MaxExponentField() { + return (MakeSigned{1} << ExponentBits()) - 1; +} + +//------------------------------------------------------------------------------ +// Helper functions + +template +constexpr inline T1 DivCeil(T1 a, T2 b) { + return (a + b - 1) / b; +} + +// Works for any `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanForward(&index, x); + return index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanForward64(&index, x); + return index; +#else // HWY_ARCH_X86_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + unsigned long index; // NOLINT + if (lsb == 0) { + uint32_t msb = static_cast(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctzll(x)); +#endif // HWY_COMPILER_MSVC +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanReverse(&index, x); + return 31 - index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_clz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanReverse64(&index, x); + return 63 - index; +#else // HWY_ARCH_X86_64 + // _BitScanReverse64 not available + const uint32_t msb = static_cast(x >> 32u); + unsigned long index; // NOLINT + if (msb == 0) { + const uint32_t lsb = static_cast(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_clzll(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t PopCount(uint64_t x) { +#if HWY_COMPILER_GCC // includes clang + return static_cast(__builtin_popcountll(x)); + // This instruction has a separate feature flag, but is often called from + // non-SIMD code, so we don't want to require dynamic dispatch. It was first + // supported by Intel in Nehalem (SSE4.2), but MSVC only predefines a macro + // for AVX, so check for that. +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 && defined(__AVX__) + return _mm_popcnt_u64(x); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return _mm_popcnt_u32(static_cast(x & 0xFFFFFFFFu)) + + _mm_popcnt_u32(static_cast(x >> 32)); +#else + x -= ((x >> 1) & 0x5555555555555555ULL); + x = (((x >> 2) & 0x3333333333333333ULL) + (x & 0x3333333333333333ULL)); + x = (((x >> 4) + x) & 0x0F0F0F0F0F0F0F0FULL); + x += (x >> 8); + x += (x >> 16); + x += (x >> 32); + return static_cast(x & 0x7Fu); +#endif +} + +// Skip HWY_API due to GCC "function not considered for inlining". Previously +// such errors were caused by underlying type mismatches, but it's not clear +// what is still mismatched despite all the casts. +template +/*HWY_API*/ constexpr size_t FloorLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x >> 1)) + 1); +} + +template +/*HWY_API*/ constexpr size_t CeilLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x - 1)) + 1); +} + +#if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 +#pragma intrinsic(_umul128) +#endif + +// 64 x 64 = 128 bit multiplication +HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __uint128_t product = (__uint128_t)a * (__uint128_t)b; + *upper = (uint64_t)(product >> 64); + return (uint64_t)(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _umul128(a, b, upper); +#else + constexpr uint64_t kLo32 = 0xFFFFFFFFU; + const uint64_t lo_lo = (a & kLo32) * (b & kLo32); + const uint64_t hi_lo = (a >> 32) * (b & kLo32); + const uint64_t lo_hi = (a & kLo32) * (b >> 32); + const uint64_t hi_hi = (a >> 32) * (b >> 32); + const uint64_t t = (lo_lo >> 32) + (hi_lo & kLo32) + lo_hi; + *upper = (hi_lo >> 32) + (t >> 32) + hi_hi; + return (t << 32) | (lo_lo & kLo32); +#endif +} + +#if HWY_COMPILER_MSVC +#pragma intrinsic(memcpy) +#pragma intrinsic(memset) +#endif + +// The source/destination must not overlap/alias. +template +HWY_API void CopyBytes(const From* from, To* to) { +#if HWY_COMPILER_MSVC + memcpy(to, from, kBytes); +#else + __builtin_memcpy( + static_cast(to), static_cast(from), kBytes); +#endif +} + +// Same as CopyBytes, but for same-sized objects; avoids a size argument. +template +HWY_API void CopySameSize(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { + static_assert(sizeof(From) == sizeof(To), ""); + CopyBytes(from, to); +} + +template +HWY_API void ZeroBytes(To* to) { +#if HWY_COMPILER_MSVC + memset(to, 0, kBytes); +#else + __builtin_memset(to, 0, kBytes); +#endif +} + +HWY_API float F32FromBF16(bfloat16_t bf) { + uint32_t bits = bf.bits; + bits <<= 16; + float f; + CopySameSize(&bits, &f); + return f; +} + +HWY_API bfloat16_t BF16FromF32(float f) { + uint32_t bits; + CopySameSize(&f, &bits); + bfloat16_t bf; + bf.bits = static_cast(bits >> 16); + return bf; +} + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...); + +} // namespace hwy + +#endif // HIGHWAY_HWY_BASE_H_ diff --git a/media/highway/src/hwy/base_test.cc b/media/highway/src/hwy/base_test.cc new file mode 100644 index 000000000..baca70b6f --- /dev/null +++ b/media/highway/src/hwy/base_test.cc @@ -0,0 +1,178 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "base_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +HWY_NOINLINE void TestAllLimits() { + HWY_ASSERT_EQ(uint8_t{0}, LimitsMin()); + HWY_ASSERT_EQ(uint16_t{0}, LimitsMin()); + HWY_ASSERT_EQ(uint32_t{0}, LimitsMin()); + HWY_ASSERT_EQ(uint64_t{0}, LimitsMin()); + + HWY_ASSERT_EQ(int8_t{-128}, LimitsMin()); + HWY_ASSERT_EQ(int16_t{-32768}, LimitsMin()); + HWY_ASSERT_EQ(static_cast(0x80000000u), LimitsMin()); + HWY_ASSERT_EQ(static_cast(0x8000000000000000ull), + LimitsMin()); + + HWY_ASSERT_EQ(uint8_t{0xFF}, LimitsMax()); + HWY_ASSERT_EQ(uint16_t{0xFFFF}, LimitsMax()); + HWY_ASSERT_EQ(uint32_t{0xFFFFFFFFu}, LimitsMax()); + HWY_ASSERT_EQ(uint64_t{0xFFFFFFFFFFFFFFFFull}, LimitsMax()); + + HWY_ASSERT_EQ(int8_t{0x7F}, LimitsMax()); + HWY_ASSERT_EQ(int16_t{0x7FFF}, LimitsMax()); + HWY_ASSERT_EQ(int32_t{0x7FFFFFFFu}, LimitsMax()); + HWY_ASSERT_EQ(int64_t{0x7FFFFFFFFFFFFFFFull}, LimitsMax()); +} + +struct TestLowestHighest { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + HWY_ASSERT_EQ(std::numeric_limits::lowest(), LowestValue()); + HWY_ASSERT_EQ(std::numeric_limits::max(), HighestValue()); + } +}; + +HWY_NOINLINE void TestAllLowestHighest() { ForAllTypes(TestLowestHighest()); } +struct TestIsUnsigned { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(!IsFloat(), "Expected !IsFloat"); + static_assert(!IsSigned(), "Expected !IsSigned"); + } +}; + +struct TestIsSigned { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(!IsFloat(), "Expected !IsFloat"); + static_assert(IsSigned(), "Expected IsSigned"); + } +}; + +struct TestIsFloat { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(IsFloat(), "Expected IsFloat"); + static_assert(IsSigned(), "Floats are also considered signed"); + } +}; + +HWY_NOINLINE void TestAllType() { + ForUnsignedTypes(TestIsUnsigned()); + ForSignedTypes(TestIsSigned()); + ForFloatTypes(TestIsFloat()); + + static_assert(sizeof(MakeUnsigned) == 16, ""); + static_assert(sizeof(MakeWide) == 16, "Expected uint128_t"); + static_assert(sizeof(MakeNarrow) == 8, "Expected uint64_t"); +} + +struct TestIsSame { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(IsSame(), "T == T"); + static_assert(!IsSame, MakeUnsigned>(), "S != U"); + static_assert(!IsSame, MakeSigned>(), "U != S"); + } +}; + +HWY_NOINLINE void TestAllIsSame() { ForAllTypes(TestIsSame()); } + +HWY_NOINLINE void TestAllBitScan() { + HWY_ASSERT_EQ(size_t{0}, Num0BitsAboveMS1Bit_Nonzero32(0x80000000u)); + HWY_ASSERT_EQ(size_t{0}, Num0BitsAboveMS1Bit_Nonzero32(0xFFFFFFFFu)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsAboveMS1Bit_Nonzero32(0x40000000u)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsAboveMS1Bit_Nonzero32(0x40108210u)); + HWY_ASSERT_EQ(size_t{30}, Num0BitsAboveMS1Bit_Nonzero32(2u)); + HWY_ASSERT_EQ(size_t{30}, Num0BitsAboveMS1Bit_Nonzero32(3u)); + HWY_ASSERT_EQ(size_t{31}, Num0BitsAboveMS1Bit_Nonzero32(1u)); + + HWY_ASSERT_EQ(size_t{0}, + Num0BitsAboveMS1Bit_Nonzero64(0x8000000000000000ull)); + HWY_ASSERT_EQ(size_t{0}, + Num0BitsAboveMS1Bit_Nonzero64(0xFFFFFFFFFFFFFFFFull)); + HWY_ASSERT_EQ(size_t{1}, + Num0BitsAboveMS1Bit_Nonzero64(0x4000000000000000ull)); + HWY_ASSERT_EQ(size_t{1}, + Num0BitsAboveMS1Bit_Nonzero64(0x4010821004200011ull)); + HWY_ASSERT_EQ(size_t{62}, Num0BitsAboveMS1Bit_Nonzero64(2ull)); + HWY_ASSERT_EQ(size_t{62}, Num0BitsAboveMS1Bit_Nonzero64(3ull)); + HWY_ASSERT_EQ(size_t{63}, Num0BitsAboveMS1Bit_Nonzero64(1ull)); + + HWY_ASSERT_EQ(size_t{0}, Num0BitsBelowLS1Bit_Nonzero32(1u)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsBelowLS1Bit_Nonzero32(2u)); + HWY_ASSERT_EQ(size_t{30}, Num0BitsBelowLS1Bit_Nonzero32(0xC0000000u)); + HWY_ASSERT_EQ(size_t{31}, Num0BitsBelowLS1Bit_Nonzero32(0x80000000u)); + + HWY_ASSERT_EQ(size_t{0}, Num0BitsBelowLS1Bit_Nonzero64(1ull)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsBelowLS1Bit_Nonzero64(2ull)); + HWY_ASSERT_EQ(size_t{62}, + Num0BitsBelowLS1Bit_Nonzero64(0xC000000000000000ull)); + HWY_ASSERT_EQ(size_t{63}, + Num0BitsBelowLS1Bit_Nonzero64(0x8000000000000000ull)); +} + +HWY_NOINLINE void TestAllPopCount() { + HWY_ASSERT_EQ(size_t{0}, PopCount(0u)); + HWY_ASSERT_EQ(size_t{1}, PopCount(1u)); + HWY_ASSERT_EQ(size_t{1}, PopCount(2u)); + HWY_ASSERT_EQ(size_t{2}, PopCount(3u)); + HWY_ASSERT_EQ(size_t{1}, PopCount(0x80000000u)); + HWY_ASSERT_EQ(size_t{31}, PopCount(0x7FFFFFFFu)); + HWY_ASSERT_EQ(size_t{32}, PopCount(0xFFFFFFFFu)); + + HWY_ASSERT_EQ(size_t{1}, PopCount(0x80000000ull)); + HWY_ASSERT_EQ(size_t{31}, PopCount(0x7FFFFFFFull)); + HWY_ASSERT_EQ(size_t{32}, PopCount(0xFFFFFFFFull)); + HWY_ASSERT_EQ(size_t{33}, PopCount(0x10FFFFFFFFull)); + HWY_ASSERT_EQ(size_t{63}, PopCount(0xFFFEFFFFFFFFFFFFull)); + HWY_ASSERT_EQ(size_t{64}, PopCount(0xFFFFFFFFFFFFFFFFull)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(BaseTest); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllLimits); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllLowestHighest); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllType); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllIsSame); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllBitScan); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllPopCount); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/cache_control.h b/media/highway/src/hwy/cache_control.h new file mode 100644 index 000000000..b124e5707 --- /dev/null +++ b/media/highway/src/hwy/cache_control.h @@ -0,0 +1,110 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_ +#define HIGHWAY_HWY_CACHE_CONTROL_H_ + +#include +#include + +#include "hwy/base.h" + +// Requires SSE2; fails to compile on 32-bit Clang 7 (see +// https://github.com/gperftools/gperftools/issues/946). +#if !defined(__SSE2__) || (HWY_COMPILER_CLANG && HWY_ARCH_X86_32) +#undef HWY_DISABLE_CACHE_CONTROL +#define HWY_DISABLE_CACHE_CONTROL +#endif + +// intrin.h is sufficient on MSVC and already included by base.h. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#include // SSE2 +#endif + +// Windows.h #defines these, which causes infinite recursion. Temporarily +// undefine them in this header; these functions are anyway deprecated. +// TODO(janwas): remove when these functions are removed. +#pragma push_macro("LoadFence") +#undef LoadFence + +namespace hwy { + +// Even if N*sizeof(T) is smaller, Stream may write a multiple of this size. +#define HWY_STREAM_MULTIPLE 16 + +// The following functions may also require an attribute. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#define HWY_ATTR_CACHE __attribute__((target("sse2"))) +#else +#define HWY_ATTR_CACHE +#endif + +// Delays subsequent loads until prior loads are visible. Beware of potentially +// differing behavior across architectures and vendors: on Intel but not +// AMD CPUs, also serves as a full fence (waits for all prior instructions to +// complete). +HWY_INLINE HWY_ATTR_CACHE void LoadFence() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_lfence(); +#endif +} + +// Ensures values written by previous `Stream` calls are visible on the current +// core. This is NOT sufficient for synchronizing across cores; when `Stream` +// outputs are to be consumed by other core(s), the producer must publish +// availability (e.g. via mutex or atomic_flag) after `FlushStream`. +HWY_INLINE HWY_ATTR_CACHE void FlushStream() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_sfence(); +#endif +} + +// Optionally begins loading the cache line containing "p" to reduce latency of +// subsequent actual loads. +template +HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_prefetch(reinterpret_cast(p), _MM_HINT_T0); +#elif HWY_COMPILER_GCC // includes clang + // Hint=0 (NTA) behavior differs, but skipping outer caches is probably not + // desirable, so use the default 3 (keep in caches). + __builtin_prefetch(p, /*write=*/0, /*hint=*/3); +#else + (void)p; +#endif +} + +// Invalidates and flushes the cache line containing "p", if possible. +HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_clflush(p); +#else + (void)p; +#endif +} + +// When called inside a spin-loop, may reduce power consumption. +HWY_INLINE HWY_ATTR_CACHE void Pause() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_pause(); +#endif +} + +} // namespace hwy + +// TODO(janwas): remove when these functions are removed. (See above.) +#pragma pop_macro("LoadFence") + +#endif // HIGHWAY_HWY_CACHE_CONTROL_H_ diff --git a/media/highway/src/hwy/contrib/algo/copy-inl.h b/media/highway/src/hwy/contrib/algo/copy-inl.h new file mode 100644 index 000000000..033cf8a62 --- /dev/null +++ b/media/highway/src/hwy/contrib/algo/copy-inl.h @@ -0,0 +1,136 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a CopyAlignedPadded because it +// would be more verbose than such a loop. + +// Fills `to`[0, `count`) with `value`. +template > +void Fill(D d, T value, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + const Vec v = Set(d, value); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + StoreU(v, d, to + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeFillN(remaining, value, d, to + idx); +} + +// Copies `from`[0, `count`) to `to`, which must not overlap `from`. +template > +void Copy(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec v = LoadU(d, from + idx); + StoreU(v, d, to + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeCopyN(remaining, d, from + idx, to + idx); +} + +// For idx in [0, count) in ascending order, appends `from[idx]` to `to` if the +// corresponding mask element of `func(d, v)` is true. Returns the STL-style end +// of the newly written elements in `to`. +// +// `func` is either a functor with a templated operator()(d, v) returning a +// mask, or a generic lambda if using C++14. Due to apparent limitations of +// Clang on Windows, it is currently necessary to add HWY_ATTR before the +// opening { of the lambda to avoid errors about "function .. requires target". +// +// NOTE: this is only supported for 16-, 32- or 64-bit types. +// NOTE: Func may be called a second time for elements it has already seen, but +// these elements will not be written to `to` again. +template > +T* CopyIf(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec v = LoadU(d, from + idx); + to += CompressBlendedStore(v, func(d, v), d, to); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return to; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + for (; idx < count; ++idx) { + using V1 = Vec; + // Workaround for -Waggressive-loop-optimizations on GCC 8 + // (iteration 2305843009213693951 invokes undefined behavior for T=i64) + const uintptr_t addr = reinterpret_cast(from); + const T* HWY_RESTRICT from_idx = + reinterpret_cast(addr + (idx * sizeof(T))); + const V1 v = LoadU(d1, from_idx); + // Avoid storing to `to` unless we know it should be kept - otherwise, we + // might overrun the end if it was allocated for the exact count. + if (CountTrue(d1, func(d1, v)) == 0) continue; + StoreU(v, d1, to); + to += 1; + } +#else + // Start index of the last unaligned whole vector, ending at the array end. + const size_t last = count - N; + // Number of elements before `from` or already written. + const size_t invalid = idx - last; + HWY_DASSERT(0 != invalid && invalid < N); + const Mask mask = Not(FirstN(d, invalid)); + const Vec v = MaskedLoad(mask, d, from + last); + to += CompressBlendedStore(v, And(mask, func(d, v)), d, to); +#endif + return to; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ diff --git a/media/highway/src/hwy/contrib/algo/copy_test.cc b/media/highway/src/hwy/contrib/algo/copy_test.cc new file mode 100644 index 000000000..e2675a39d --- /dev/null +++ b/media/highway/src/hwy/contrib/algo/copy_test.cc @@ -0,0 +1,199 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/algo/copy_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/algo/copy-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// If your project requires C++14 or later, you can ignore this and pass lambdas +// directly to Transform, without requiring an lvalue as we do here for C++11. +#if __cplusplus < 201402L +#define HWY_GENERIC_LAMBDA 0 +#else +#define HWY_GENERIC_LAMBDA 1 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns random integer in [0, 128), which fits in any lane type. +template +T Random7Bit(RandomState& rng) { + return static_cast(Random32(&rng) & 127); +} + +// In C++14, we can instead define these as generic lambdas next to where they +// are invoked. +#if !HWY_GENERIC_LAMBDA + +struct IsOdd { + template + Mask operator()(D d, V v) const { + return TestBit(v, Set(d, TFromD{1})); + } +}; + +#endif // !HWY_GENERIC_LAMBDA + +// Invokes Test (e.g. TestCopyIf) with all arg combinations. T comes from +// ForFloatTypes. +template +struct ForeachCountAndMisalign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + for (size_t count = 0; count < 2 * N; ++count) { + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test()(d, count, ma, mb, rng); + } + } + } + } +}; + +struct TestFill { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD; + // HWY_MAX prevents error when misalign == count == 0. + AlignedFreeUniquePtr pa = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + T* expected = pa.get() + misalign_a; + const T value = Random7Bit(rng); + for (size_t i = 0; i < count; ++i) { + expected[i] = value; + } + AlignedFreeUniquePtr pb = AllocateAligned(misalign_b + count + 1); + T* actual = pb.get() + misalign_b; + + actual[count] = T{0}; // sentinel + Fill(d, value, count, actual); + HWY_ASSERT_EQ(T{0}, actual[count]); // did not write past end + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected, actual, count, target_name, + __FILE__, __LINE__); + } +}; + +void TestAllFill() { + ForAllTypes(ForPartialVectors>()); +} + +struct TestCopy { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr pa = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random7Bit(rng); + } + AlignedFreeUniquePtr pb = + AllocateAligned(HWY_MAX(1, misalign_b + count)); + T* b = pb.get() + misalign_b; + + Copy(d, a, count, b); + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, a, b, count, target_name, __FILE__, + __LINE__); + } +}; + +void TestAllCopy() { + ForAllTypes(ForPartialVectors>()); +} + +struct TestCopyIf { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr pa = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random7Bit(rng); + } + const size_t padding = Lanes(ScalableTag()); + AlignedFreeUniquePtr pb = + AllocateAligned(HWY_MAX(1, misalign_b + count + padding)); + T* b = pb.get() + misalign_b; + + AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); + size_t num_odd = 0; + for (size_t i = 0; i < count; ++i) { + if (a[i] & 1) { + expected[num_odd++] = a[i]; + } + } + +#if HWY_GENERIC_LAMBDA + const auto is_odd = [](const auto d, const auto v) HWY_ATTR { + return TestBit(v, Set(d, TFromD{1})); + }; +#else + const IsOdd is_odd; +#endif + T* end = CopyIf(d, a, count, b, is_odd); + const size_t num_written = static_cast(end - b); + HWY_ASSERT_EQ(num_odd, num_written); + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), b, num_odd, target_name, + __FILE__, __LINE__); + } +}; + +void TestAllCopyIf() { + ForUI163264(ForPartialVectors>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(CopyTest); +HWY_EXPORT_AND_TEST_P(CopyTest, TestAllFill); +HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopy); +HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopyIf); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/contrib/algo/find-inl.h b/media/highway/src/hwy/contrib/algo/find-inl.h new file mode 100644 index 000000000..388842e98 --- /dev/null +++ b/media/highway/src/hwy/contrib/algo/find-inl.h @@ -0,0 +1,109 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns index of the first element equal to `value` in `in[0, count)`, or +// `count` if not found. +template > +size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + const Vec broadcasted = Set(d, value); + + size_t i = 0; + for (; i + N <= count; i += N) { + const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag d1; + using V1 = Vec; + const V1 broadcasted1 = Set(d1, GetLane(broadcasted)); + for (; i < count; ++i) { + if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask)); + if (pos >= 0) return i + static_cast(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// Returns index of the first element in `in[0, count)` for which `func(d, vec)` +// returns true, otherwise `count`. +template > +size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t i = 0; + for (; i + N <= count; i += N) { + const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag d1; + for (; i < count; ++i) { + if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask)); + if (pos >= 0) return i + static_cast(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ diff --git a/media/highway/src/hwy/contrib/algo/find_test.cc b/media/highway/src/hwy/contrib/algo/find_test.cc new file mode 100644 index 000000000..da13c475d --- /dev/null +++ b/media/highway/src/hwy/contrib/algo/find_test.cc @@ -0,0 +1,219 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/print.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/algo/find_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/algo/find-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// If your project requires C++14 or later, you can ignore this and pass lambdas +// directly to FindIf, without requiring an lvalue as we do here for C++11. +#if __cplusplus < 201402L +#define HWY_GENERIC_LAMBDA 0 +#else +#define HWY_GENERIC_LAMBDA 1 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns random number in [-8, 8) - we use knowledge of the range to Find() +// values we know are not present. +template +T Random(RandomState& rng) { + const int32_t bits = static_cast(Random32(&rng)) & 1023; + const double val = (bits - 512) / 64.0; + // Clamp negative to zero for unsigned types. + return static_cast(HWY_MAX(hwy::LowestValue(), val)); +} + +// In C++14, we can instead define these as generic lambdas next to where they +// are invoked. +#if !HWY_GENERIC_LAMBDA + +class GreaterThan { + public: + GreaterThan(int val) : val_(val) {} + template + Mask operator()(D d, V v) const { + return Gt(v, Set(d, static_cast>(val_))); + } + + private: + int val_; +}; + +#endif // !HWY_GENERIC_LAMBDA + +// Invokes Test (e.g. TestFind) with all arg combinations. +template +struct ForeachCountAndMisalign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + // Find() checks 8 vectors at a time, so we want to cover a fairly large + // range without oversampling (checking every possible count). + std::vector counts(AdjustedReps(512)); + for (size_t& count : counts) { + count = static_cast(rng()) % (16 * N + 1); + } + counts[0] = 0; // ensure we test count=0. + + for (size_t count : counts) { + for (size_t m : misalignments) { + Test()(d, count, m, rng); + } + } + } +}; + +struct TestFind { + template + void operator()(D d, size_t count, size_t misalign, RandomState& rng) { + using T = TFromD; + // Must allocate at least one even if count is zero. + AlignedFreeUniquePtr storage = + AllocateAligned(HWY_MAX(1, misalign + count)); + T* in = storage.get() + misalign; + for (size_t i = 0; i < count; ++i) { + in[i] = Random(rng); + } + + // For each position, search for that element (which we know is there) + for (size_t pos = 0; pos < count; ++pos) { + const size_t actual = Find(d, in[pos], in, count); + + // We may have found an earlier occurrence of the same value; ensure the + // value is the same, and that it is the first. + if (!IsEqual(in[pos], in[actual])) { + fprintf(stderr, "%s count %d, found %.15f at %d but wanted %.15f\n", + hwy::TypeName(T(), Lanes(d)).c_str(), static_cast(count), + static_cast(in[actual]), static_cast(actual), + static_cast(in[pos])); + HWY_ASSERT(false); + } + for (size_t i = 0; i < actual; ++i) { + if (IsEqual(in[i], in[pos])) { + fprintf(stderr, "%s count %d, found %f at %d but Find returned %d\n", + hwy::TypeName(T(), Lanes(d)).c_str(), static_cast(count), + static_cast(in[i]), static_cast(i), + static_cast(actual)); + HWY_ASSERT(false); + } + } + } + + // Also search for values we know not to be present (out of range) + HWY_ASSERT_EQ(count, Find(d, T{9}, in, count)); + HWY_ASSERT_EQ(count, Find(d, static_cast(-9), in, count)); + } +}; + +void TestAllFind() { + ForAllTypes(ForPartialVectors>()); +} + +struct TestFindIf { + template + void operator()(D d, size_t count, size_t misalign, RandomState& rng) { + using T = TFromD; + using TI = MakeSigned; + // Must allocate at least one even if count is zero. + AlignedFreeUniquePtr storage = + AllocateAligned(HWY_MAX(1, misalign + count)); + T* in = storage.get() + misalign; + for (size_t i = 0; i < count; ++i) { + in[i] = Random(rng); + HWY_ASSERT(in[i] < 8); + HWY_ASSERT(!hwy::IsSigned() || static_cast(in[i]) >= -8); + } + + bool found_any = false; + bool not_found_any = false; + + // unsigned T would be promoted to signed and compare greater than any + // negative val, whereas Set() would just cast to an unsigned value and the + // comparison remains unsigned, so avoid negative numbers there. + const int min_val = IsSigned() ? -9 : 0; + // Includes out-of-range value 9 to test the not-found path. + for (int val = min_val; val <= 9; ++val) { +#if HWY_GENERIC_LAMBDA + const auto greater = [val](const auto d, const auto v) HWY_ATTR { + return Gt(v, Set(d, static_cast(val))); + }; +#else + const GreaterThan greater(val); +#endif + const size_t actual = FindIf(d, in, count, greater); + found_any |= actual < count; + not_found_any |= actual == count; + + const auto pos = std::find_if( + in, in + count, [val](T x) { return x > static_cast(val); }); + // Convert returned iterator to index. + const size_t expected = static_cast(pos - in); + if (expected != actual) { + fprintf(stderr, "%s count %d val %d, expected %d actual %d\n", + hwy::TypeName(T(), Lanes(d)).c_str(), static_cast(count), + val, static_cast(expected), static_cast(actual)); + hwy::detail::PrintArray(hwy::detail::MakeTypeInfo(), "in", in, count, + 0, count); + HWY_ASSERT(false); + } + } + + // We will always not-find something due to val=9. + HWY_ASSERT(not_found_any); + // We'll find something unless the input is empty or {0} - because 0 > i + // is false for all i=[0,9]. + if (count != 0 && in[0] != 0) { + HWY_ASSERT(found_any); + } + } +}; + +void TestAllFindIf() { + ForAllTypes(ForPartialVectors>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(FindTest); +HWY_EXPORT_AND_TEST_P(FindTest, TestAllFind); +HWY_EXPORT_AND_TEST_P(FindTest, TestAllFindIf); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/contrib/algo/transform-inl.h b/media/highway/src/hwy/contrib/algo/transform-inl.h new file mode 100644 index 000000000..3e830acb4 --- /dev/null +++ b/media/highway/src/hwy/contrib/algo/transform-inl.h @@ -0,0 +1,262 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a TransformAlignedPadded because it +// would be more verbose than such a loop. +// +// Func is either a functor with a templated operator()(d, v[, v1[, v2]]), or a +// generic lambda if using C++14. Due to apparent limitations of Clang on +// Windows, it is currently necessary to add HWY_ATTR before the opening { of +// the lambda to avoid errors about "always_inline function .. requires target". +// +// If HWY_MEM_OPS_MIGHT_FAULT, we use scalar code instead of masking. Otherwise, +// we used `MaskedLoad` and `BlendedStore` to read/write the final partial +// vector. + +// Fills `out[0, count)` with the vectors returned by `func(d, index_vec)`, +// where `index_vec` is `Vec>`. On the first call to `func`, +// the value of its lane i is i, and increases by `Lanes(d)` after every call. +// Note that some of these indices may be `>= count`, but the elements that +// `func` returns in those lanes will not be written to `out`. +template > +void Generate(D d, T* HWY_RESTRICT out, size_t count, const Func& func) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(d); + + size_t idx = 0; + Vec vidx = Iota(du, 0); + for (; idx + N <= count; idx += N) { + StoreU(func(d, vidx), d, out + idx); + vidx = Add(vidx, Set(du, static_cast(N))); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + const RebindToUnsigned du1; + for (; idx < count; ++idx) { + StoreU(func(d1, Set(du1, static_cast(idx))), d1, out + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + BlendedStore(func(d, vidx), mask, d, out + idx); +#endif +} + +// Replaces `inout[idx]` with `func(d, inout[idx])`. Example usage: multiplying +// array elements by a constant. +template > +void Transform(D d, T* HWY_RESTRICT inout, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec v = LoadU(d, inout + idx); + StoreU(func(d, v), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + for (; idx < count; ++idx) { + using V1 = Vec; + const V1 v = LoadU(d1, inout + idx); + StoreU(func(d1, v), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, inout + idx); + BlendedStore(func(d, v), mask, d, inout + idx); +#endif +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx])`. Example usage: +// multiplying array elements by those of another array. +template > +void Transform1(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + StoreU(func(d, v, v1), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + for (; idx < count; ++idx) { + using V1 = Vec; + const V1 v = LoadU(d1, inout + idx); + const V1 v1 = LoadU(d1, in1 + idx); + StoreU(func(d1, v, v1), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, inout + idx); + const Vec v1 = MaskedLoad(mask, d, in1 + idx); + BlendedStore(func(d, v, v1), mask, d, inout + idx); +#endif +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx], in2[idx])`. Example +// usage: FMA of elements from three arrays, stored into the first array. +template > +void Transform2(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const T* HWY_RESTRICT in2, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + const Vec v2 = LoadU(d, in2 + idx); + StoreU(func(d, v, v1, v2), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + for (; idx < count; ++idx) { + using V1 = Vec; + const V1 v = LoadU(d1, inout + idx); + const V1 v1 = LoadU(d1, in1 + idx); + const V1 v2 = LoadU(d1, in2 + idx); + StoreU(func(d1, v, v1, v2), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, inout + idx); + const Vec v1 = MaskedLoad(mask, d, in1 + idx); + const Vec v2 = MaskedLoad(mask, d, in2 + idx); + BlendedStore(func(d, v, v1, v2), mask, d, inout + idx); +#endif +} + +template > +void Replace(D d, T* HWY_RESTRICT inout, size_t count, T new_t, T old_t) { + const size_t N = Lanes(d); + const Vec old_v = Set(d, old_t); + const Vec new_v = Set(d, new_t); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + const Vec old_v1 = Set(d1, old_t); + const Vec new_v1 = Set(d1, new_t); + for (; idx < count; ++idx) { + using V1 = Vec; + const V1 v1 = LoadU(d1, inout + idx); + StoreU(IfThenElse(Eq(v1, old_v1), new_v1, v1), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, inout + idx); + BlendedStore(IfThenElse(Eq(v, old_v), new_v, v), mask, d, inout + idx); +#endif +} + +template > +void ReplaceIf(D d, T* HWY_RESTRICT inout, size_t count, T new_t, + const Func& func) { + const size_t N = Lanes(d); + const Vec new_v = Set(d, new_t); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + const Vec new_v1 = Set(d1, new_t); + for (; idx < count; ++idx) { + using V1 = Vec; + const V1 v = LoadU(d1, inout + idx); + StoreU(IfThenElse(func(d1, v), new_v1, v), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, inout + idx); + BlendedStore(IfThenElse(func(d, v), new_v, v), mask, d, inout + idx); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ diff --git a/media/highway/src/hwy/contrib/algo/transform_test.cc b/media/highway/src/hwy/contrib/algo/transform_test.cc new file mode 100644 index 000000000..335607ccf --- /dev/null +++ b/media/highway/src/hwy/contrib/algo/transform_test.cc @@ -0,0 +1,372 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // memcpy + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/algo/transform_test.cc" //NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/algo/transform-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// If your project requires C++14 or later, you can ignore this and pass lambdas +// directly to Transform, without requiring an lvalue as we do here for C++11. +#if __cplusplus < 201402L +#define HWY_GENERIC_LAMBDA 0 +#else +#define HWY_GENERIC_LAMBDA 1 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +T Alpha() { + return static_cast(1.5); // arbitrary scalar +} + +// Returns random floating-point number in [-8, 8) to ensure computations do +// not exceed float32 precision. +template +T Random(RandomState& rng) { + const int32_t bits = static_cast(Random32(&rng)) & 1023; + const double val = (bits - 512) / 64.0; + // Clamp negative to zero for unsigned types. + return static_cast(HWY_MAX(hwy::LowestValue(), val)); +} + +// SCAL, AXPY names are from BLAS. +template +HWY_NOINLINE void SimpleSCAL(const T* x, T* out, size_t count) { + for (size_t i = 0; i < count; ++i) { + out[i] = Alpha() * x[i]; + } +} + +template +HWY_NOINLINE void SimpleAXPY(const T* x, const T* y, T* out, size_t count) { + for (size_t i = 0; i < count; ++i) { + out[i] = Alpha() * x[i] + y[i]; + } +} + +template +HWY_NOINLINE void SimpleFMA4(const T* x, const T* y, const T* z, T* out, + size_t count) { + for (size_t i = 0; i < count; ++i) { + out[i] = x[i] * y[i] + z[i]; + } +} + +// In C++14, we can instead define these as generic lambdas next to where they +// are invoked. +#if !HWY_GENERIC_LAMBDA + +// Generator that returns even numbers by doubling the output indices. +struct Gen2 { + template + Vec operator()(D d, VU vidx) const { + return BitCast(d, Add(vidx, vidx)); + } +}; + +struct SCAL { + template + Vec operator()(D d, V v) const { + using T = TFromD; + return Mul(Set(d, Alpha()), v); + } +}; + +struct AXPY { + template + Vec operator()(D d, V v, V v1) const { + using T = TFromD; + return MulAdd(Set(d, Alpha()), v, v1); + } +}; + +struct FMA4 { + template + Vec operator()(D /*d*/, V v, V v1, V v2) const { + return MulAdd(v, v1, v2); + } +}; + +#endif // !HWY_GENERIC_LAMBDA + +// Invokes Test (e.g. TestTransform1) with all arg combinations. T comes from +// ForFloatTypes. +template +struct ForeachCountAndMisalign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + for (size_t count = 0; count < 2 * N; ++count) { + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test()(d, count, ma, mb, rng); + } + } + } + } +}; + +// Output-only, no loads +struct TestGenerate { + template + void operator()(D d, size_t count, size_t misalign_a, size_t /*misalign_b*/, + RandomState& /*rng*/) { + using T = TFromD; + AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + count + 1); + T* actual = pa.get() + misalign_a; + + AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); + for (size_t i = 0; i < count; ++i) { + expected[i] = static_cast(2 * i); + } + + // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that + // the attribute also applies to lambdas? If so, remove HWY_ATTR. +#if HWY_GENERIC_LAMBDA + const auto gen2 = [](const auto d, const auto vidx) + HWY_ATTR { return BitCast(d, Add(vidx, vidx)); }; +#else + const Gen2 gen2; +#endif + actual[count] = T{0}; // sentinel + Generate(d, actual, count, gen2); + HWY_ASSERT_EQ(T{0}, actual[count]); // did not write past end + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), actual, count, + target_name, __FILE__, __LINE__); + } +}; + +// Zero extra input arrays +struct TestTransform { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + if (misalign_b != 0) return; + using T = TFromD; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr pa = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random(rng); + } + + AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); + SimpleSCAL(a, expected.get(), count); + + // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that + // the attribute also applies to lambdas? If so, remove HWY_ATTR. +#if HWY_GENERIC_LAMBDA + const auto scal = [](const auto d, const auto v) + HWY_ATTR { return Mul(Set(d, Alpha()), v); }; +#else + const SCAL scal; +#endif + Transform(d, a, count, scal); + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, + __FILE__, __LINE__); + } +}; + +// One extra input array +struct TestTransform1 { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr pa = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + AlignedFreeUniquePtr pb = + AllocateAligned(HWY_MAX(1, misalign_b + count)); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + for (size_t i = 0; i < count; ++i) { + a[i] = Random(rng); + b[i] = Random(rng); + } + + AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); + SimpleAXPY(a, b, expected.get(), count); + +#if HWY_GENERIC_LAMBDA + const auto axpy = [](const auto d, const auto v, const auto v1) HWY_ATTR { + return MulAdd(Set(d, Alpha()), v, v1); + }; +#else + const AXPY axpy; +#endif + Transform1(d, a, count, b, axpy); + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, + __FILE__, __LINE__); + } +}; + +// Two extra input arrays +struct TestTransform2 { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr pa = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + AlignedFreeUniquePtr pb = + AllocateAligned(HWY_MAX(1, misalign_b + count)); + AlignedFreeUniquePtr pc = + AllocateAligned(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + T* c = pc.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random(rng); + b[i] = Random(rng); + c[i] = Random(rng); + } + + AlignedFreeUniquePtr expected = AllocateAligned(HWY_MAX(1, count)); + SimpleFMA4(a, b, c, expected.get(), count); + +#if HWY_GENERIC_LAMBDA + const auto fma4 = [](auto /*d*/, auto v, auto v1, auto v2) + HWY_ATTR { return MulAdd(v, v1, v2); }; +#else + const FMA4 fma4; +#endif + Transform2(d, a, count, b, c, fma4); + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, + __FILE__, __LINE__); + } +}; + +template +class IfEq { + public: + IfEq(T val) : val_(val) {} + + template + Mask operator()(D d, V v) const { + return Eq(v, Set(d, val_)); + } + + private: + T val_; +}; + +struct TestReplace { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + if (misalign_b != 0) return; + if (count == 0) return; + using T = TFromD; + AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + count); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random(rng); + } + AlignedFreeUniquePtr pb = AllocateAligned(count); + + AlignedFreeUniquePtr expected = AllocateAligned(count); + + std::vector positions(AdjustedReps(count)); + for (size_t& pos : positions) { + pos = static_cast(rng()) % count; + } + + for (size_t pos = 0; pos < count; ++pos) { + const T old_t = a[pos]; + const T new_t = Random(rng); + for (size_t i = 0; i < count; ++i) { + expected[i] = IsEqual(a[i], old_t) ? new_t : a[i]; + } + + // Copy so ReplaceIf gets the same input (and thus also outputs expected) + memcpy(pb.get(), a, count * sizeof(T)); + + Replace(d, a, count, new_t, old_t); + HWY_ASSERT_ARRAY_EQ(expected.get(), a, count); + + ReplaceIf(d, pb.get(), count, new_t, IfEq(old_t)); + HWY_ASSERT_ARRAY_EQ(expected.get(), pb.get(), count); + } + } +}; + +void TestAllGenerate() { + // The test BitCast-s the indices, which does not work for floats. + ForIntegerTypes(ForPartialVectors>()); +} + +void TestAllTransform() { + ForFloatTypes(ForPartialVectors>()); +} + +void TestAllTransform1() { + ForFloatTypes(ForPartialVectors>()); +} + +void TestAllTransform2() { + ForFloatTypes(ForPartialVectors>()); +} + +void TestAllReplace() { + ForFloatTypes(ForPartialVectors>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(TransformTest); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllGenerate); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform1); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform2); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllReplace); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/contrib/dot/dot-inl.h b/media/highway/src/hwy/contrib/dot/dot-inl.h new file mode 100644 index 000000000..e04636f1b --- /dev/null +++ b/media/highway/src/hwy/contrib/dot/dot-inl.h @@ -0,0 +1,252 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#include + +#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct Dot { + // Specify zero or more of these, ORed together, as the kAssumptions template + // argument to Compute. Each one may improve performance or reduce code size, + // at the cost of additional requirements on the arguments. + enum Assumptions { + // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T). + kAtLeastOneVector = 1, + // num_elements is divisible by N (a power of two, so this can be used if + // the problem size is known to be a power of two >= HWY_MAX_BYTES / + // sizeof(T)). + kMultipleOfVector = 2, + // RoundUpTo(num_elements, N) elements are accessible; their value does not + // matter (will be treated as if they were zero). + kPaddedToVector = 4, + }; + + // Returns sum{pa[i] * pb[i]} for float or double inputs. Aligning the + // pointers to a multiple of N elements is helpful but not required. + template , + HWY_IF_NOT_LANE_SIZE_D(D, 2)> + static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa, + const T* const HWY_RESTRICT pb, + const size_t num_elements) { + static_assert(IsFloat(), "MulAdd requires float type"); + using V = decltype(Zero(d)); + + const size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + // Only 2x unroll to avoid excessive code size. + T sum0 = T(0); + T sum1 = T(0); + for (; i + 2 <= num_elements; i += 2) { + sum0 += pa[i + 0] * pb[i + 0]; + sum1 += pa[i + 1] * pb[i + 1]; + } + if (i < num_elements) { + sum1 += pa[i] * pb[i]; + } + return sum0 + sum1; + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + + // Main loop: unrolled + for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = LoadU(d, pa + i); + const auto b2 = LoadU(d, pb + i); + i += N; + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = LoadU(d, pa + i); + const auto b3 = LoadU(d, pb + i); + i += N; + sum3 = MulAdd(a3, b3, sum3); + } + + // Up to 3 iterations of whole vectors + for (; i + N <= num_elements; i += N) { + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(d, remaining); + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(d, N - remaining); + const auto a = LoadU(d, pa + i); // always unaligned + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return GetLane(SumOfLanes(d, sum0)); + } + + // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a + // multiple of N elements is helpful but not required. + template + static HWY_INLINE float Compute(const D d, + const bfloat16_t* const HWY_RESTRICT pa, + const bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned du16; + const Repartition df32; + + using V = decltype(Zero(df32)); + const size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. + float sum1 = 0.0f; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]); + sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]); + } + if (i < num_elements) { + sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return GetLane(SumOfLanes(df32, sum0)); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ diff --git a/media/highway/src/hwy/contrib/dot/dot_test.cc b/media/highway/src/hwy/contrib/dot/dot_test.cc new file mode 100644 index 000000000..12d7ab270 --- /dev/null +++ b/media/highway/src/hwy/contrib/dot/dot_test.cc @@ -0,0 +1,167 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/dot/dot-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +HWY_NOINLINE T SimpleDot(const T* pa, const T* pb, size_t num) { + double sum = 0.0; + for (size_t i = 0; i < num; ++i) { + sum += pa[i] * pb[i]; + } + return static_cast(sum); +} + +HWY_NOINLINE float SimpleDot(const bfloat16_t* pa, const bfloat16_t* pb, + size_t num) { + float sum = 0.0f; + for (size_t i = 0; i < num; ++i) { + sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum; +} + +template +void SetValue(const float value, T* HWY_RESTRICT ptr) { + *ptr = static_cast(value); +} +void SetValue(const float value, bfloat16_t* HWY_RESTRICT ptr) { + *ptr = BF16FromF32(value); +} + +class TestDot { + // Computes/verifies one dot product. + template + void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD; + const size_t N = Lanes(d); + const auto random_t = [&rng]() { + const int32_t bits = static_cast(Random32(&rng)) & 1023; + return static_cast(bits - 512) * (1.0f / 64); + }; + + const size_t padded = + (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; + AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + padded); + AlignedFreeUniquePtr pb = AllocateAligned(misalign_b + padded); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + size_t i = 0; + for (; i < num; ++i) { + SetValue(random_t(), a + i); + SetValue(random_t(), b + i); + } + // Fill padding with NaN - the values are not used, but avoids MSAN errors. + for (; i < padded; ++i) { + ScalableTag df1; + SetValue(GetLane(NaN(df1)), a + i); + SetValue(GetLane(NaN(df1)), b + i); + } + + const auto expected = SimpleDot(a, b, num); + const auto actual = Dot::Compute(d, a, b, num); + const auto max = static_cast(8 * 8 * num); + HWY_ASSERT(-max <= actual && actual <= max); + HWY_ASSERT(expected - 1E-4 <= actual && actual <= expected + 1E-4); + } + + // Runs tests with various alignments. + template + void ForeachMisalign(D d, size_t num, RandomState& rng) { + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test(d, num, ma, mb, rng); + } + } + } + + // Runs tests with various lengths compatible with the given assumptions. + template + void ForeachCount(D d, RandomState& rng) { + const size_t N = Lanes(d); + const size_t counts[] = {1, + 3, + 7, + 16, + HWY_MAX(N / 2, 1), + HWY_MAX(2 * N / 3, 1), + N, + N + 1, + 4 * N / 3, + 3 * N, + 8 * N, + 8 * N + 2}; + for (size_t num : counts) { + if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; + if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; + ForeachMisalign(d, num, rng); + } + } + + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + // All 8 combinations of the three length-related flags: + ForeachCount<0>(d, rng); + ForeachCount(d, rng); + ForeachCount(d, rng); + ForeachCount(d, rng); + ForeachCount(d, rng); + ForeachCount(d, rng); + ForeachCount(d, rng); + ForeachCount(d, rng); + } +}; + +void TestAllDot() { ForFloatTypes(ForPartialVectors()); } +void TestAllDotBF16() { ForShrinkableVectors()(bfloat16_t()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(DotTest); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/contrib/image/image.cc b/media/highway/src/hwy/contrib/image/image.cc new file mode 100644 index 000000000..2bcdcd6c9 --- /dev/null +++ b/media/highway/src/hwy/contrib/image/image.cc @@ -0,0 +1,145 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/image/image.h" + +#include // swap +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/image/image.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +size_t GetVectorSize() { return Lanes(ScalableTag()); } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(GetVectorSize); // Local function. +} // namespace + +size_t ImageBase::VectorSize() { + // Do not cache result - must return the current value, which may be greater + // than the first call if it was subject to DisableTargets! + return HWY_DYNAMIC_DISPATCH(GetVectorSize)(); +} + +size_t ImageBase::BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 1) { + HWY_DASSERT(vec_size >= sizeof_t); + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = HWY_MAX(vec_size, HWY_ALIGNMENT); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % HWY_ALIGNMENT == 0) { + bytes_per_row += align; + } + + HWY_DASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + HWY_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + bytes_ = AllocateAligned(bytes_per_row_ * ysize); + HWY_ASSERT(bytes_.get() != nullptr); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t bytes_per_row, void* const aligned) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + bytes_per_row_(bytes_per_row), + bytes_(static_cast(aligned), + AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + const size_t vec_size = VectorSize(); + HWY_ASSERT(bytes_per_row % vec_size == 0); + HWY_ASSERT(reinterpret_cast(aligned) % vec_size == 0); +} + +void ImageBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if HWY_IS_MSAN || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); // Bytes, independent of sizeof_t! + if (vec_size == 1) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* HWY_RESTRICT row = static_cast(VoidRow(y)); +#if defined(__clang__) && (__clang_major__ <= 6) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + memset(row, 0, initialize_size); +#else + memset(row + valid_size, 0, initialize_size - valid_size); +#endif // clang6 + } +#else + (void)sizeof_t; + (void)padding; +#endif // HWY_IS_MSAN +} + +void ImageBase::Swap(ImageBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/image/image.h b/media/highway/src/hwy/contrib/image/image.h new file mode 100644 index 000000000..231f3c51a --- /dev/null +++ b/media/highway/src/hwy/contrib/image/image.h @@ -0,0 +1,471 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ +#define HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include +#include +#include + +#include +#include // std::move + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/highway_export.h" + +namespace hwy { + +// Type-independent parts of Image<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct HWY_CONTRIB_DLLEXPORT ImageBase { + // Returns required alignment in bytes for externally allocated memory. + static size_t VectorSize(); + + // Returns distance [bytes] between the start of two consecutive rows, a + // multiple of VectorSize but NOT kAlias (see implementation). + static size_t BytesPerRow(const size_t xsize, const size_t sizeof_t); + + // No allocation (for output params or unused images) + ImageBase() + : xsize_(0), + ysize_(0), + bytes_per_row_(0), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {} + + // Allocates memory (this is the common case) + ImageBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // References but does not take ownership of external memory. Useful for + // interoperability with other libraries. `aligned` must be aligned to a + // multiple of VectorSize() and `bytes_per_row` must also be a multiple of + // VectorSize() or preferably equal to BytesPerRow(). + ImageBase(size_t xsize, size_t ysize, size_t bytes_per_row, void* aligned); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + ImageBase(const ImageBase& other) = delete; + ImageBase& operator=(const ImageBase& other) = delete; + + // Move constructor (required for returning Image from function) + ImageBase(ImageBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + ImageBase& operator=(ImageBase&& other) noexcept = default; + + void Swap(ImageBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. Caller is responsible + // for ensuring xsize/ysize are <= the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + xsize_ = static_cast(xsize); + ysize_ = static_cast(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + HWY_INLINE size_t xsize() const { return xsize_; } + HWY_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + HWY_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + HWY_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidRow(const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (y >= ysize_) { + HWY_ABORT("Row(%d) >= %u\n", static_cast(y), ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return HWY_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x <= xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + size_t bytes_per_row_; // Includes padding. + AlignedFreeUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template +class Image : public ImageBase { + public: + using T = ComponentType; + + Image() = default; + Image(const size_t xsize, const size_t ysize) + : ImageBase(xsize, ysize, sizeof(T)) {} + Image(const size_t xsize, const size_t ysize, size_t bytes_per_row, + void* aligned) + : ImageBase(xsize, ysize, bytes_per_row, aligned) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + HWY_INLINE const T* ConstRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE const T* ConstRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns pointer to non-const. This allows passing const Image* parameters + // when the callee is only supposed to fill the pixels, as opposed to + // allocating or resizing the image. + HWY_INLINE T* MutableRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE T* MutableRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { + return static_cast(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageF = Image; + +// A bundle of 3 same-sized images. To fill an existing Image3 using +// single-channel producers, we also need access to each const Image*. Const +// prevents breaking the same-size invariant, while still allowing pixels to be +// changed via MutableRow. +template +class Image3 { + public: + using T = ComponentType; + using ImageT = Image; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{ImageT(), ImageT(), ImageT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{ImageT(xsize, ysize), ImageT(xsize, ysize), + ImageT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(ImageT&& plane0, ImageT&& plane1, ImageT&& plane2) { + if (!SameSize(plane0, plane1) || !SameSize(plane0, plane2)) { + HWY_ABORT( + "Not same size: %d x %d, %d x %d, %d x %d\n", + static_cast(plane0.xsize()), static_cast(plane0.ysize()), + static_cast(plane1.xsize()), static_cast(plane1.ysize()), + static_cast(plane2.xsize()), static_cast(plane2.ysize())); + } + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE const ImageT& Plane(size_t idx) const { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (ImageT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + HWY_INLINE size_t xsize() const { return planes_[0].xsize(); } + HWY_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (c >= kNumPlanes || y >= ysize()) { + HWY_ABORT("PlaneRow(%d, %d) >= %d\n", static_cast(c), + static_cast(y), static_cast(ysize())); + } +#endif + // Use the first plane's stride because the compiler might not realize they + // are all equal. Thus we only need a single multiplication for all planes. + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast( + HWY_ASSUME_ALIGNED(row, HWY_ALIGNMENT)); + } + + private: + ImageT planes_[kNumPlanes]; +}; + +using Image3F = Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions. Can compare size via SameSize(rect1, rect2). +class Rect { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max, size_t xend, size_t yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image. + template + explicit Rect(const Image& image) + : Rect(0, 0, image.xsize(), image.ysize()) {} + + Rect() : Rect(0, 0, 0, 0) {} + + Rect(const Rect&) = default; + Rect& operator=(const Rect&) = default; + + Rect Subrect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max) { + return Rect(x0_ + xbegin, y0_ + ybegin, xsize_max, ysize_max, x0_ + xsize_, + y0_ + ysize_); + } + + template + const T* ConstRow(const Image* image, size_t y) const { + return image->ConstRow(y + y0_) + x0_; + } + + template + T* MutableRow(const Image* image, size_t y) const { + return image->MutableRow(y + y0_) + x0_; + } + + template + const T* ConstPlaneRow(const Image3& image, size_t c, size_t y) const { + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + template + T* MutablePlaneRow(Image3* image, const size_t c, size_t y) const { + return image->MutablePlaneRow(c, y + y0_) + x0_; + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Image or Image3; however if ImageT is Rect, results are nonsensical. + template + bool IsInside(const ImageT& image) const { + return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize()); + } + + size_t x0() const { return x0_; } + size_t y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(size_t begin, size_t size_max, + size_t end) { + return (begin + size_max <= end) ? size_max + : (end > begin ? end - begin : 0); + } + + size_t x0_; + size_t y0_; + + size_t xsize_; + size_t ysize_; +}; + +// Works for any image-like input type(s). +template +HWY_MAYBE_UNUSED bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static HWY_INLINE HWY_MAYBE_UNUSED size_t Mirror(int64_t x, + const int64_t xsize) { + HWY_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return static_cast(x); +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + HWY_INLINE size_t operator()(const int64_t coord, const size_t size) const { + return Mirror(coord, static_cast(size)); + } +}; + +// Returns the same coordinate, for when we know "coord" is already valid (e.g. +// interior of an image). +struct WrapUnchanged { + HWY_INLINE size_t operator()(const int64_t coord, size_t /*size*/) const { + return static_cast(coord); + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template + WrapRowMirror(const View& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const HWY_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const HWY_RESTRICT first_row_; + const float* const HWY_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + HWY_INLINE const float* operator()(const float* const HWY_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ diff --git a/media/highway/src/hwy/contrib/image/image_test.cc b/media/highway/src/hwy/contrib/image/image_test.cc new file mode 100644 index 000000000..6886577a4 --- /dev/null +++ b/media/highway/src/hwy/contrib/image/image_test.cc @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/image/image.h" + +#include +#include +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/image/image_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target: +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Ensure we can always write full aligned vectors. +struct TestAlignedT { + template + void operator()(T /*unused*/) const { + std::mt19937 rng(129); + std::uniform_int_distribution dist(0, 16); + const ScalableTag d; + + for (size_t ysize = 1; ysize < 4; ++ysize) { + for (size_t xsize = 1; xsize < 64; ++xsize) { + Image img(xsize, ysize); + + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto values = Iota(d, static_cast(dist(rng))); + Store(values, d, row + x); + } + } + + // Sanity check to prevent optimizing out the writes + const auto x = std::uniform_int_distribution(0, xsize - 1)(rng); + const auto y = std::uniform_int_distribution(0, ysize - 1)(rng); + HWY_ASSERT(img.ConstRow(y)[x] < 16 + Lanes(d)); + } + } + } +}; + +void TestAligned() { ForUnsignedTypes(TestAlignedT()); } + +// Ensure we can write an unaligned vector starting at the last valid value. +struct TestUnalignedT { + template + void operator()(T /*unused*/) const { + std::mt19937 rng(129); + std::uniform_int_distribution dist(0, 3); + const ScalableTag d; + + for (size_t ysize = 1; ysize < 4; ++ysize) { + for (size_t xsize = 1; xsize < 128; ++xsize) { + Image img(xsize, ysize); + img.InitializePaddingForUnalignedAccesses(); + +// This test reads padding, which only works if it was initialized, +// which only happens in MSAN builds. +#if HWY_IS_MSAN || HWY_IDE + // Initialize only the valid samples + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + row[x] = static_cast(1u << dist(rng)); + } + } + + // Read padding bits + auto accum = Zero(d); + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + accum = Or(accum, LoadU(d, row + x)); + } + } + + // Ensure padding was zero + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(accum, d, lanes.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(lanes[i] < 16); + } +#else // Check that writing padding does not overwrite valid samples + // Initialize only the valid samples + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + row[x] = static_cast(x); + } + } + + // Zero padding and rightmost sample + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + StoreU(Zero(d), d, row + xsize - 1); + } + + // Ensure no samples except the rightmost were overwritten + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize - 1; ++x) { + HWY_ASSERT_EQ(static_cast(x), row[x]); + } + } +#endif + } + } + } +}; + +void TestUnaligned() { ForUnsignedTypes(TestUnalignedT()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(ImageTest); +HWY_EXPORT_AND_TEST_P(ImageTest, TestAligned); +HWY_EXPORT_AND_TEST_P(ImageTest, TestUnaligned); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/contrib/math/math-inl.h b/media/highway/src/hwy/contrib/math/math-inl.h new file mode 100644 index 000000000..b4cbb5d11 --- /dev/null +++ b/media/highway/src/hwy/contrib/math/math-inl.h @@ -0,0 +1,1242 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +/** + * Highway SIMD version of std::acos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc cosine of 'x' + */ +template +HWY_INLINE V Acos(const D d, V x); +template +HWY_NOINLINE V CallAcos(const D d, VecArg x) { + return Acos(d, x); +} + +/** + * Highway SIMD version of std::acosh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[1, +FLT_MAX], float64[1, +DBL_MAX] + * @return hyperbolic arc cosine of 'x' + */ +template +HWY_INLINE V Acosh(const D d, V x); +template +HWY_NOINLINE V CallAcosh(const D d, VecArg x) { + return Acosh(d, x); +} + +/** + * Highway SIMD version of std::asin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc sine of 'x' + */ +template +HWY_INLINE V Asin(const D d, V x); +template +HWY_NOINLINE V CallAsin(const D d, VecArg x) { + return Asin(d, x); +} + +/** + * Highway SIMD version of std::asinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic arc sine of 'x' + */ +template +HWY_INLINE V Asinh(const D d, V x); +template +HWY_NOINLINE V CallAsinh(const D d, VecArg x) { + return Asinh(d, x); +} + +/** + * Highway SIMD version of std::atan(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return arc tangent of 'x' + */ +template +HWY_INLINE V Atan(const D d, V x); +template +HWY_NOINLINE V CallAtan(const D d, VecArg x) { + return Atan(d, x); +} + +/** + * Highway SIMD version of std::atanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: (-1, +1) + * @return hyperbolic arc tangent of 'x' + */ +template +HWY_INLINE V Atanh(const D d, V x); +template +HWY_NOINLINE V CallAtanh(const D d, VecArg x) { + return Atanh(d, x); +} + +/** + * Highway SIMD version of std::cos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return cosine of 'x' + */ +template +HWY_INLINE V Cos(const D d, V x); +template +HWY_NOINLINE V CallCos(const D d, VecArg x) { + return Cos(d, x); +} + +/** + * Highway SIMD version of std::exp(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x + */ +template +HWY_INLINE V Exp(const D d, V x); +template +HWY_NOINLINE V CallExp(const D d, VecArg x) { + return Exp(d, x); +} + +/** + * Highway SIMD version of std::expm1(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x - 1 + */ +template +HWY_INLINE V Expm1(const D d, V x); +template +HWY_NOINLINE V CallExpm1(const D d, VecArg x) { + return Expm1(d, x); +} + +/** + * Highway SIMD version of std::log(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return natural logarithm of 'x' + */ +template +HWY_INLINE V Log(const D d, V x); +template +HWY_NOINLINE V CallLog(const D d, VecArg x) { + return Log(d, x); +} + +/** + * Highway SIMD version of std::log10(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 10 logarithm of 'x' + */ +template +HWY_INLINE V Log10(const D d, V x); +template +HWY_NOINLINE V CallLog10(const D d, VecArg x) { + return Log10(d, x); +} + +/** + * Highway SIMD version of std::log1p(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[0, +FLT_MAX], float64[0, +DBL_MAX] + * @return log(1 + x) + */ +template +HWY_INLINE V Log1p(const D d, V x); +template +HWY_NOINLINE V CallLog1p(const D d, VecArg x) { + return Log1p(d, x); +} + +/** + * Highway SIMD version of std::log2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 2 logarithm of 'x' + */ +template +HWY_INLINE V Log2(const D d, V x); +template +HWY_NOINLINE V CallLog2(const D d, VecArg x) { + return Log2(d, x); +} + +/** + * Highway SIMD version of std::sin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return sine of 'x' + */ +template +HWY_INLINE V Sin(const D d, V x); +template +HWY_NOINLINE V CallSin(const D d, VecArg x) { + return Sin(d, x); +} + +/** + * Highway SIMD version of std::sinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-88.7228, +88.7228], float64[-709, +709] + * @return hyperbolic sine of 'x' + */ +template +HWY_INLINE V Sinh(const D d, V x); +template +HWY_NOINLINE V CallSinh(const D d, VecArg x) { + return Sinh(d, x); +} + +/** + * Highway SIMD version of std::tanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic tangent of 'x' + */ +template +HWY_INLINE V Tanh(const D d, V x); +template +HWY_NOINLINE V CallTanh(const D d, VecArg x) { + return Tanh(d, x); +} + +//////////////////////////////////////////////////////////////////////////////// +// Implementation +//////////////////////////////////////////////////////////////////////////////// +namespace impl { + +// Estrin's Scheme is a faster method for evaluating large polynomials on +// super scalar architectures. It works by factoring the Horner's Method +// polynomial into power of two sub-trees that can be evaluated in parallel. +// Wikipedia Link: https://en.wikipedia.org/wiki/Estrin%27s_scheme +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1) { + return MulAdd(c1, x, c0); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2) { + T x2 = Mul(x, x); + return MulAdd(x2, c2, MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3) { + T x2 = Mul(x, x); + return MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, c4, MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(c5, x, c4), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, c6, MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, c8, + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(c9, x, c8), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, c10, MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd( + x8, MulAdd(x4, c12, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(c13, x, c12), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, c14, MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, c16, + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(c17, x, c16), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17, + T c18) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(x2, c18, MulAdd(c17, x, c16)), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} + +template +struct AsinImpl {}; +template +struct AtanImpl {}; +template +struct CosSinImpl {}; +template +struct ExpImpl {}; +template +struct LogImpl {}; + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666677296f); + const auto k1 = Set(d, +0.07495029271f); + const auto k2 = Set(d, +0.04547423869f); + const auto k3 = Set(d, +0.02424046025f); + const auto k4 = Set(d, +0.04197454825f); + + return Estrin(x2, k0, k1, k2, k3, k4); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666666666666497543); + const auto k1 = Set(d, +0.07500000000378581611); + const auto k2 = Set(d, +0.04464285681377102438); + const auto k3 = Set(d, +0.03038195928038132237); + const auto k4 = Set(d, +0.02237176181932048341); + const auto k5 = Set(d, +0.01735956991223614604); + const auto k6 = Set(d, +0.01388715184501609218); + const auto k7 = Set(d, +0.01215360525577377331); + const auto k8 = Set(d, +0.006606077476277170610); + const auto k9 = Set(d, +0.01929045477267910674); + const auto k10 = Set(d, -0.01581918243329996643); + const auto k11 = Set(d, +0.03161587650653934628); + + return Estrin(x2, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11); + } +}; + +#endif + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333331018686294555664062f); + const auto k1 = Set(d, +0.199926957488059997558594f); + const auto k2 = Set(d, -0.142027363181114196777344f); + const auto k3 = Set(d, +0.106347933411598205566406f); + const auto k4 = Set(d, -0.0748900920152664184570312f); + const auto k5 = Set(d, +0.0425049886107444763183594f); + const auto k6 = Set(d, -0.0159569028764963150024414f); + const auto k7 = Set(d, +0.00282363896258175373077393f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7), Mul(y, x), x); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333333333333311110369124); + const auto k1 = Set(d, +0.199999999996591265594148); + const auto k2 = Set(d, -0.14285714266771329383765); + const auto k3 = Set(d, +0.111111105648261418443745); + const auto k4 = Set(d, -0.090908995008245008229153); + const auto k5 = Set(d, +0.0769219538311769618355029); + const auto k6 = Set(d, -0.0666573579361080525984562); + const auto k7 = Set(d, +0.0587666392926673580854313); + const auto k8 = Set(d, -0.0523674852303482457616113); + const auto k9 = Set(d, +0.0466667150077840625632675); + const auto k10 = Set(d, -0.0407629191276836500001934); + const auto k11 = Set(d, +0.0337852580001353069993897); + const auto k12 = Set(d, -0.0254517624932312641616861); + const auto k13 = Set(d, +0.016599329773529201970117); + const auto k14 = Set(d, -0.00889896195887655491740809); + const auto k15 = Set(d, +0.00370026744188713119232403); + const auto k16 = Set(d, -0.00110611831486672482563471); + const auto k17 = Set(d, +0.000209850076645816976906797); + const auto k18 = Set(d, -1.88796008463073496563746e-5); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, + k12, k13, k14, k15, k16, k17, k18), + Mul(y, x), x); + } +}; + +#endif + +template <> +struct CosSinImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + template + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -1.66666597127914428710938e-1f); + const auto k1 = Set(d, +8.33307858556509017944336e-3f); + const auto k2 = Set(d, -1.981069071916863322258e-4f); + const auto k3 = Set(d, +2.6083159809786593541503e-6f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3), Mul(y, x), x); + } + + template + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0f + kHalfPiPart1f + kHalfPiPart2f + kHalfPiPart3f ~= -pi/2 + const V kHalfPiPart0f = Set(d, -0.5f * 3.140625f); + const V kHalfPiPart1f = Set(d, -0.5f * 0.0009670257568359375f); + const V kHalfPiPart2f = Set(d, -0.5f * 6.2771141529083251953e-7f); + const V kHalfPiPart3f = Set(d, -0.5f * 1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kHalfPiPart0f, x); + x = MulAdd(qf, kHalfPiPart1f, x); + x = MulAdd(qf, kHalfPiPart2f, x); + x = MulAdd(qf, kHalfPiPart3f, x); + return x; + } + + template + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0f + kPiPart1f + kPiPart2f + kPiPart3f ~= -pi + const V kPiPart0f = Set(d, -3.140625f); + const V kPiPart1f = Set(d, -0.0009670257568359375f); + const V kPiPart2f = Set(d, -6.2771141529083251953e-7f); + const V kPiPart3f = Set(d, -1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kPiPart0f, x); + x = MulAdd(qf, kPiPart1f, x); + x = MulAdd(qf, kPiPart2f, x); + x = MulAdd(qf, kPiPart3f, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast(d, ShiftLeft<30>(AndNot(q, kTwo))); + } + + // ((q & 1) ? -0.0 : +0.0) + template + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast(d, ShiftLeft<31>(And(q, kOne))); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct CosSinImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + template + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -0.166666666666666657414808); + const auto k1 = Set(d, +0.00833333333333332974823815); + const auto k2 = Set(d, -0.000198412698412696162806809); + const auto k3 = Set(d, +2.75573192239198747630416e-6); + const auto k4 = Set(d, -2.50521083763502045810755e-8); + const auto k5 = Set(d, +1.60590430605664501629054e-10); + const auto k6 = Set(d, -7.64712219118158833288484e-13); + const auto k7 = Set(d, +2.81009972710863200091251e-15); + const auto k8 = Set(d, -7.97255955009037868891952e-18); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8), Mul(y, x), x); + } + + template + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0d + kHalfPiPart1d + kHalfPiPart2d + kHalfPiPart3d ~= -pi/2 + const V kHalfPiPart0d = Set(d, -0.5 * 3.1415926218032836914); + const V kHalfPiPart1d = Set(d, -0.5 * 3.1786509424591713469e-8); + const V kHalfPiPart2d = Set(d, -0.5 * 1.2246467864107188502e-16); + const V kHalfPiPart3d = Set(d, -0.5 * 1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kHalfPiPart0d, x); + x = MulAdd(qf, kHalfPiPart1d, x); + x = MulAdd(qf, kHalfPiPart2d, x); + x = MulAdd(qf, kHalfPiPart3d, x); + return x; + } + + template + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0d + kPiPart1d + kPiPart2d + kPiPart3d ~= -pi + const V kPiPart0d = Set(d, -3.1415926218032836914); + const V kPiPart1d = Set(d, -3.1786509424591713469e-8); + const V kPiPart2d = Set(d, -1.2246467864107188502e-16); + const V kPiPart3d = Set(d, -1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kPiPart0d, x); + x = MulAdd(qf, kPiPart1d, x); + x = MulAdd(qf, kPiPart2d, x); + x = MulAdd(qf, kPiPart3d, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast( + d, ShiftLeft<62>(PromoteTo(Rebind(), AndNot(q, kTwo)))); + } + + // ((q & 1) ? -0.0 : +0.0) + template + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast( + d, ShiftLeft<63>(PromoteTo(Rebind(), And(q, kOne)))); + } +}; + +#endif + +template <> +struct ExpImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + template + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5f); + const auto k1 = Set(d, +0.166666671633720397949219f); + const auto k2 = Set(d, +0.0416664853692054748535156f); + const auto k3 = Set(d, +0.00833336077630519866943359f); + const auto k4 = Set(d, +0.00139304355252534151077271f); + const auto k5 = Set(d, +0.000198527617612853646278381f); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5), Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(Add(x, kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0f + kLn2Part1f ~= -ln(2) + const V kLn2Part0f = Set(d, -0.693145751953125f); + const V kLn2Part1f = Set(d, -1.428606765330187045e-6f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kLn2Part0f, x); + x = MulAdd(qf, kLn2Part1f, x); + return x; + } +}; + +template <> +struct LogImpl { + template + HWY_INLINE Vec> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind di32; + const Rebind du32; + const auto kBias = Set(di32, 0x7F); + return Sub(BitCast(di32, ShiftRight<23>(BitCast(du32, x))), kBias); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.66666662693f); + const V k1 = Set(d, 0.40000972152f); + const V k2 = Set(d, 0.28498786688f); + const V k3 = Set(d, 0.24279078841f); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(k2, x4, k0), x2, Mul(MulAdd(k3, x4, k1), x4)); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct ExpImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + template + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5); + const auto k1 = Set(d, +0.166666666666666851703837); + const auto k2 = Set(d, +0.0416666666666665047591422); + const auto k3 = Set(d, +0.00833333333331652721664984); + const auto k4 = Set(d, +0.00138888888889774492207962); + const auto k5 = Set(d, +0.000198412698960509205564975); + const auto k6 = Set(d, +2.4801587159235472998791e-5); + const auto k7 = Set(d, +2.75572362911928827629423e-6); + const auto k8 = Set(d, +2.75573911234900471893338e-7); + const auto k9 = Set(d, +2.51112930892876518610661e-8); + const auto k10 = Set(d, +2.08860621107283687536341e-9); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10), + Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const Rebind di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset)))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0d + kLn2Part1d ~= -ln(2) + const V kLn2Part0d = Set(d, -0.6931471805596629565116018); + const V kLn2Part1d = Set(d, -0.28235290563031577122588448175e-12); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kLn2Part0d, x); + x = MulAdd(qf, kLn2Part1d, x); + return x; + } +}; + +template <> +struct LogImpl { + template + HWY_INLINE Vec> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind di64; + const Rebind du64; + return Sub(BitCast(di64, ShiftRight<52>(BitCast(du64, x))), + Set(di64, 0x3FF)); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.6666666666666735130); + const V k1 = Set(d, 0.3999999999940941908); + const V k2 = Set(d, 0.2857142874366239149); + const V k3 = Set(d, 0.2222219843214978396); + const V k4 = Set(d, 0.1818357216161805012); + const V k5 = Set(d, 0.1531383769920937332); + const V k6 = Set(d, 0.1479819860511658591); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(MulAdd(MulAdd(k6, x4, k4), x4, k2), x4, k0), x2, + (Mul(MulAdd(MulAdd(k5, x4, k3), x4, k1), x4))); + } +}; + +#endif + +template +HWY_INLINE V Log(const D d, V x) { + // http://git.musl-libc.org/cgit/musl/tree/src/math/log.c for more info. + using T = TFromD; + impl::LogImpl impl; + + constexpr bool kIsF32 = (sizeof(T) == 4); + + // Float Constants + const V kLn2Hi = Set(d, kIsF32 ? static_cast(0.69313812256f) + : static_cast(0.693147180369123816490)); + const V kLn2Lo = Set(d, kIsF32 ? static_cast(9.0580006145e-6f) + : static_cast(1.90821492927058770002e-10)); + const V kOne = Set(d, static_cast(+1.0)); + const V kMinNormal = Set(d, kIsF32 ? static_cast(1.175494351e-38f) + : static_cast(2.2250738585072014e-308)); + const V kScale = Set(d, kIsF32 ? static_cast(3.355443200e+7f) + : static_cast(1.8014398509481984e+16)); + + // Integer Constants + using TI = MakeSigned; + const Rebind di; + using VI = decltype(Zero(di)); + const VI kLowerBits = Set(di, kIsF32 ? static_cast(0x00000000L) + : static_cast(0xFFFFFFFFLL)); + const VI kMagic = Set(di, kIsF32 ? static_cast(0x3F3504F3L) + : static_cast(0x3FE6A09E00000000LL)); + const VI kExpMask = Set(di, kIsF32 ? static_cast(0x3F800000L) + : static_cast(0x3FF0000000000000LL)); + const VI kExpScale = + Set(di, kIsF32 ? static_cast(-25) : static_cast(-54)); + const VI kManMask = Set(di, kIsF32 ? static_cast(0x7FFFFFL) + : static_cast(0xFFFFF00000000LL)); + + // Scale up 'x' so that it is no longer denormalized. + VI exp_bits; + V exp; + if (kAllowSubnormals == true) { + const auto is_denormal = Lt(x, kMinNormal); + x = IfThenElse(is_denormal, Mul(x, kScale), x); + + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic)); + const VI exp_scale = + BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale))); + exp = ConvertTo( + d, Add(exp_scale, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits)))); + } else { + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic)); + exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits))); + } + + // Renormalize. + const V y = Or(And(x, BitCast(d, kLowerBits)), + BitCast(d, Add(And(exp_bits, kManMask), kMagic))); + + // Approximate and reconstruct. + const V ym1 = Sub(y, kOne); + const V z = Div(ym1, Add(y, kOne)); + + return MulSub( + exp, kLn2Hi, + Sub(MulSub(z, Sub(ym1, impl.LogPoly(d, z)), Mul(exp, kLn2Lo)), ym1)); +} + +} // namespace impl + +template +HWY_INLINE V Acos(const D d, V x) { + using T = TFromD; + + const V kZero = Zero(d); + const V kHalf = Set(d, static_cast(+0.5)); + const V kPi = Set(d, static_cast(+3.14159265358979323846264)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V t = Mul(impl.AsinPoly(d, yy, y), Mul(y, yy)); + + const V t_plus_y = Add(t, y); + const V z = + IfThenElse(mask, Sub(kPiOverTwo, Add(Xor(y, sign_x), Xor(t, sign_x))), + Add(t_plus_y, t_plus_y)); + return IfThenElse(Or(mask, Ge(x, kZero)), z, Sub(kPi, z)); +} + +template +HWY_INLINE V Acosh(const D d, V x) { + using T = TFromD; + + const V kLarge = Set(d, static_cast(268435456.0)); + const V kLog2 = Set(d, static_cast(0.693147180559945286227)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const auto is_x_large = Gt(x, kLarge); + const auto is_x_gt_2 = Gt(x, kTwo); + + const V x_minus_1 = Sub(x, kOne); + const V y0 = MulSub(kTwo, x, Div(kOne, Add(Sqrt(MulSub(x, x, kOne)), x))); + const V y1 = + Add(Sqrt(MulAdd(x_minus_1, kTwo, Mul(x_minus_1, x_minus_1))), x_minus_1); + const V y2 = + IfThenElse(is_x_gt_2, IfThenElse(is_x_large, x, y0), Add(y1, kOne)); + const V z = impl::Log(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + return Add(IfThenElse(is_x_gt_2, z, + IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor))), + IfThenElseZero(is_x_large, kLog2)); +} + +template +HWY_INLINE V Asin(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kTwo = Set(d, static_cast(+2.0)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V z0 = MulAdd(impl.AsinPoly(d, yy, y), Mul(yy, y), y); + const V z1 = NegMulAdd(z0, kTwo, kPiOverTwo); + return Or(IfThenElse(mask, z0, z1), sign_x); +} + +template +HWY_INLINE V Asinh(const D d, V x) { + using T = TFromD; + + const V kSmall = Set(d, static_cast(1.0 / 268435456.0)); + const V kLarge = Set(d, static_cast(268435456.0)); + const V kLog2 = Set(d, static_cast(0.693147180559945286227)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign_x = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign_x); + + const auto is_x_large = Gt(abs_x, kLarge); + const auto is_x_lt_2 = Lt(abs_x, kTwo); + + const V x2 = Mul(x, x); + const V sqrt_x2_plus_1 = Sqrt(Add(x2, kOne)); + + const V y0 = MulAdd(abs_x, kTwo, Div(kOne, Add(sqrt_x2_plus_1, abs_x))); + const V y1 = Add(Div(x2, Add(sqrt_x2_plus_1, kOne)), abs_x); + const V y2 = + IfThenElse(is_x_lt_2, Add(y1, kOne), IfThenElse(is_x_large, abs_x, y0)); + const V z = impl::Log(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + const auto large = IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor)); + const V y = IfThenElse(Lt(abs_x, kSmall), x, large); + return Or(Add(IfThenElse(is_x_lt_2, y, z), IfThenElseZero(is_x_large, kLog2)), + sign_x); +} + +template +HWY_INLINE V Atan(const D d, V x) { + using T = TFromD; + + const V kOne = Set(d, static_cast(+1.0)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign = And(SignBit(d), x); + const V abs_x = Xor(x, sign); + const auto mask = Gt(abs_x, kOne); + + impl::AtanImpl impl; + const auto divisor = IfThenElse(mask, abs_x, kOne); + const V y = impl.AtanPoly(d, IfThenElse(mask, Div(kOne, divisor), abs_x)); + return Or(IfThenElse(mask, Sub(kPiOverTwo, y), y), sign); +} + +template +HWY_INLINE V Atanh(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kOne = Set(d, static_cast(+1.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + return Mul(Log1p(d, Div(Add(abs_x, abs_x), Sub(kOne, abs_x))), + Xor(kHalf, sign)); +} + +template +HWY_INLINE V Cos(const D d, V x) { + using T = TFromD; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast(0.31830988618379067153)); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + const VI32 kOne = Set(di32, 1); + + const V y = Abs(x); // cos(x) == cos(|x|) + + // Compute the quadrant, q = int(|x| / pi) * 2 + 1 + const VI32 q = Add(ShiftLeft<1>(impl.ToInt32(d, Mul(y, kOneOverPi))), kOne); + + // Reduce range, apply sign, and approximate. + return impl.Poly( + d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q))); +} + +template +HWY_INLINE V Exp(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kNegZero = Set(d, static_cast(-0.0)); + const V kOne = Set(d, static_cast(+1.0)); + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.ExpReduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + +template +HWY_INLINE V Expm1(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kLn2Over2 = Set(d, static_cast(+0.346573590279972654708616)); + const V kNegOne = Set(d, static_cast(-1.0)); + const V kNegZero = Set(d, static_cast(-0.0)); + const V kOne = Set(d, static_cast(+1.0)); + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.ExpPoly(d, impl.ExpReduce(d, x, q)); + const V z = IfThenElse(Lt(Abs(x), kLn2Over2), y, + Sub(impl.LoadExpShortRange(d, Add(y, kOne), q), kOne)); + return IfThenElse(Lt(x, kLowerBound), kNegOne, z); +} + +template +HWY_INLINE V Log(const D d, V x) { + return impl::Log(d, x); +} + +template +HWY_INLINE V Log10(const D d, V x) { + using T = TFromD; + return Mul(Log(d, x), Set(d, static_cast(0.4342944819032518276511))); +} + +template +HWY_INLINE V Log1p(const D d, V x) { + using T = TFromD; + const V kOne = Set(d, static_cast(+1.0)); + + const V y = Add(x, kOne); + const auto is_pole = Eq(y, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y), kOne); + const auto non_pole = + Mul(impl::Log(d, y), Div(x, divisor)); + return IfThenElse(is_pole, x, non_pole); +} + +template +HWY_INLINE V Log2(const D d, V x) { + using T = TFromD; + return Mul(Log(d, x), Set(d, static_cast(1.44269504088896340735992))); +} + +template +HWY_INLINE V Sin(const D d, V x) { + using T = TFromD; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast(0.31830988618379067153)); + const V kHalf = Set(d, static_cast(0.5)); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + + const V abs_x = Abs(x); + const V sign_x = Xor(abs_x, x); + + // Compute the quadrant, q = int((|x| / pi) + 0.5) + const VI32 q = impl.ToInt32(d, MulAdd(abs_x, kOneOverPi, kHalf)); + + // Reduce range, apply sign, and approximate. + return impl.Poly(d, Xor(impl.SinReduce(d, abs_x, q), + Xor(impl.SinSignFromQuadrant(d, q), sign_x))); +} + +template +HWY_INLINE V Sinh(const D d, V x) { + using T = TFromD; + const V kHalf = Set(d, static_cast(+0.5)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, abs_x); + const V z = Mul(Div(Add(y, kTwo), Add(y, kOne)), Mul(y, kHalf)); + return Xor(z, sign); // Reapply the sign bit +} + +template +HWY_INLINE V Tanh(const D d, V x) { + using T = TFromD; + const V kLimit = Set(d, static_cast(18.714973875)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, Mul(abs_x, kTwo)); + const V z = IfThenElse(Gt(abs_x, kLimit), kOne, Div(y, Add(y, kTwo))); + return Xor(z, sign); // Reapply the sign bit +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ diff --git a/media/highway/src/hwy/contrib/math/math_test.cc b/media/highway/src/hwy/contrib/math/math_test.cc new file mode 100644 index 000000000..246a081d6 --- /dev/null +++ b/media/highway/src/hwy/contrib/math/math_test.cc @@ -0,0 +1,227 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include + +#include // FLT_MAX +#include + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/math/math-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +inline Out BitCast(const In& in) { + static_assert(sizeof(Out) == sizeof(In), ""); + Out out; + CopyBytes(&in, &out); + return out; +} + +template +HWY_NOINLINE void TestMath(const std::string name, T (*fx1)(T), + Vec (*fxN)(D, VecArg>), D d, T min, T max, + uint64_t max_error_ulp) { + using UintT = MakeUnsigned; + + const UintT min_bits = BitCast(min); + const UintT max_bits = BitCast(max); + + // If min is negative and max is positive, the range needs to be broken into + // two pieces, [+0, max] and [-0, min], otherwise [min, max]. + int range_count = 1; + UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}}; + if ((min < 0.0) && (max > 0.0)) { + ranges[0][0] = BitCast(static_cast(+0.0)); + ranges[0][1] = max_bits; + ranges[1][0] = BitCast(static_cast(-0.0)); + ranges[1][1] = min_bits; + range_count = 2; + } + + uint64_t max_ulp = 0; + // Emulation is slower, so cannot afford as many. + constexpr UintT kSamplesPerRange = static_cast(AdjustedReps(4000)); + for (int range_index = 0; range_index < range_count; ++range_index) { + const UintT start = ranges[range_index][0]; + const UintT stop = ranges[range_index][1]; + const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange)); + for (UintT value_bits = start; value_bits <= stop; value_bits += step) { + // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise + // value_bits can be less than start, and thus possibly NaN. + const T value = BitCast(HWY_MIN(HWY_MAX(start, value_bits), stop)); + const T actual = GetLane(fxN(d, Set(d, value))); + const T expected = fx1(value); + + // Skip small inputs and outputs on armv7, it flushes subnormals to zero. +#if HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7 + if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) { + continue; + } +#endif + + const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected); + max_ulp = HWY_MAX(max_ulp, ulp); + if (ulp > max_error_ulp) { + fprintf(stderr, + "%s: %s(%f) expected %f actual %f ulp %" PRIu64 " max ulp %u\n", + hwy::TypeName(T(), Lanes(d)).c_str(), name.c_str(), value, + expected, actual, static_cast(ulp), + static_cast(max_error_ulp)); + } + } + } + fprintf(stderr, "%s: %s max_ulp %" PRIu64 "\n", + hwy::TypeName(T(), Lanes(d)).c_str(), name.c_str(), max_ulp); + HWY_ASSERT(max_ulp <= max_error_ulp); +} + +#define DEFINE_MATH_TEST_FUNC(NAME) \ + HWY_NOINLINE void TestAll##NAME() { \ + ForFloatTypes(ForPartialVectors()); \ + } + +#undef DEFINE_MATH_TEST +#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \ + F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \ + struct Test##NAME { \ + template \ + HWY_NOINLINE void operator()(T, D d) { \ + if (sizeof(T) == 4) { \ + TestMath(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \ + F32_ERROR); \ + } else { \ + TestMath(HWY_STR(NAME), F64x1, F64xN, d, \ + static_cast(F64_MIN), static_cast(F64_MAX), \ + F64_ERROR); \ + } \ + } \ + }; \ + DEFINE_MATH_TEST_FUNC(NAME) + +// Floating point values closest to but less than 1.0 +const float kNearOneF = BitCast(0x3F7FFFFF); +const double kNearOneD = BitCast(0x3FEFFFFFFFFFFFFFULL); + +// The discrepancy is unacceptably large for MSYS2 (less accurate libm?), so +// only increase the error tolerance there. +constexpr uint64_t Cos64ULP() { +#if defined(__MINGW32__) + return 23; +#else + return 3; +#endif +} + +constexpr uint64_t ACosh32ULP() { +#if defined(__MINGW32__) + return 8; +#else + return 3; +#endif +} + +// clang-format off +DEFINE_MATH_TEST(Acos, + std::acos, CallAcos, -1.0f, +1.0f, 3, // NEON is 3 instead of 2 + std::acos, CallAcos, -1.0, +1.0, 2) +DEFINE_MATH_TEST(Acosh, + std::acosh, CallAcosh, +1.0f, +FLT_MAX, ACosh32ULP(), + std::acosh, CallAcosh, +1.0, +DBL_MAX, 3) +DEFINE_MATH_TEST(Asin, + std::asin, CallAsin, -1.0f, +1.0f, 4, // ARMv7 is 4 instead of 2 + std::asin, CallAsin, -1.0, +1.0, 2) +DEFINE_MATH_TEST(Asinh, + std::asinh, CallAsinh, -FLT_MAX, +FLT_MAX, 3, + std::asinh, CallAsinh, -DBL_MAX, +DBL_MAX, 3) +DEFINE_MATH_TEST(Atan, + std::atan, CallAtan, -FLT_MAX, +FLT_MAX, 3, + std::atan, CallAtan, -DBL_MAX, +DBL_MAX, 3) +DEFINE_MATH_TEST(Atanh, + std::atanh, CallAtanh, -kNearOneF, +kNearOneF, 4, // NEON is 4 instead of 3 + std::atanh, CallAtanh, -kNearOneD, +kNearOneD, 3) +DEFINE_MATH_TEST(Cos, + std::cos, CallCos, -39000.0f, +39000.0f, 3, + std::cos, CallCos, -39000.0, +39000.0, Cos64ULP()) +DEFINE_MATH_TEST(Exp, + std::exp, CallExp, -FLT_MAX, +104.0f, 1, + std::exp, CallExp, -DBL_MAX, +104.0, 1) +DEFINE_MATH_TEST(Expm1, + std::expm1, CallExpm1, -FLT_MAX, +104.0f, 4, + std::expm1, CallExpm1, -DBL_MAX, +104.0, 4) +DEFINE_MATH_TEST(Log, + std::log, CallLog, +FLT_MIN, +FLT_MAX, 1, + std::log, CallLog, +DBL_MIN, +DBL_MAX, 1) +DEFINE_MATH_TEST(Log10, + std::log10, CallLog10, +FLT_MIN, +FLT_MAX, 2, + std::log10, CallLog10, +DBL_MIN, +DBL_MAX, 2) +DEFINE_MATH_TEST(Log1p, + std::log1p, CallLog1p, +0.0f, +1e37f, 3, // NEON is 3 instead of 2 + std::log1p, CallLog1p, +0.0, +DBL_MAX, 2) +DEFINE_MATH_TEST(Log2, + std::log2, CallLog2, +FLT_MIN, +FLT_MAX, 2, + std::log2, CallLog2, +DBL_MIN, +DBL_MAX, 2) +DEFINE_MATH_TEST(Sin, + std::sin, CallSin, -39000.0f, +39000.0f, 3, + std::sin, CallSin, -39000.0, +39000.0, 4) // MSYS is 4 instead of 3 +DEFINE_MATH_TEST(Sinh, + std::sinh, CallSinh, -80.0f, +80.0f, 4, + std::sinh, CallSinh, -709.0, +709.0, 4) +DEFINE_MATH_TEST(Tanh, + std::tanh, CallTanh, -FLT_MAX, +FLT_MAX, 4, + std::tanh, CallTanh, -DBL_MAX, +DBL_MAX, 4) +// clang-format on + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMathTest); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcosh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAsin); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAsinh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtan); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtanh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllCos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSin); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSinh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllTanh); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/contrib/sort/BUILD b/media/highway/src/hwy/contrib/sort/BUILD new file mode 100644 index 000000000..3f56d6d74 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/BUILD @@ -0,0 +1,190 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +# Unused on Bazel builds, where this is not defined/known; Copybara replaces +# usages with an empty list. +COMPAT = [ + "//buildenv/target:non_prod", # includes mobile/vendor. +] + +# cc_library( +# name = "vxsort", +# srcs = [ +# "vxsort/isa_detection.cpp", +# "vxsort/isa_detection_msvc.cpp", +# "vxsort/isa_detection_sane.cpp", +# "vxsort/machine_traits.avx2.cpp", +# "vxsort/smallsort/avx2_load_mask_tables.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.double.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.float.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.double.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.float.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.cpp", +# "vxsort/vxsort_stats.cpp", +# ], +# hdrs = [ +# "vxsort/alignment.h", +# "vxsort/defs.h", +# "vxsort/isa_detection.h", +# "vxsort/machine_traits.avx2.h", +# "vxsort/machine_traits.avx512.h", +# "vxsort/machine_traits.h", +# "vxsort/packer.h", +# "vxsort/smallsort/bitonic_sort.AVX2.double.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.float.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.double.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.float.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.h", +# "vxsort/vxsort.h", +# "vxsort/vxsort_stats.h", +# ], +# compatible_with = [], +# textual_hdrs = [ +# "vxsort/vxsort_targets_disable.h", +# "vxsort/vxsort_targets_enable_avx2.h", +# "vxsort/vxsort_targets_enable_avx512.h", +# ], +# ) + +cc_library( + name = "vqsort", + srcs = [ + # Split into separate files to reduce MSVC build time. + "vqsort.cc", + "vqsort_128a.cc", + "vqsort_128d.cc", + "vqsort_f32a.cc", + "vqsort_f32d.cc", + "vqsort_f64a.cc", + "vqsort_f64d.cc", + "vqsort_i16a.cc", + "vqsort_i16d.cc", + "vqsort_i32a.cc", + "vqsort_i32d.cc", + "vqsort_i64a.cc", + "vqsort_i64d.cc", + "vqsort_kv64a.cc", + "vqsort_kv64d.cc", + "vqsort_kv128a.cc", + "vqsort_kv128d.cc", + "vqsort_u16a.cc", + "vqsort_u16d.cc", + "vqsort_u32a.cc", + "vqsort_u32d.cc", + "vqsort_u64a.cc", + "vqsort_u64d.cc", + ], + hdrs = [ + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = ["hwy_contrib_EXPORTS"], + textual_hdrs = [ + "shared-inl.h", + "sorting_networks-inl.h", + "traits-inl.h", + "traits128-inl.h", + "vqsort-inl.h", + # Placeholder for internal instrumentation. Do not remove. + ], + deps = [ + # Only if VQSORT_SECURE_RNG is set. + # "//third_party/absl/random", + "//:hwy", + # ":vxsort", # required if HAVE_VXSORT + ], +) + +# ----------------------------------------------------------------------------- +# Internal-only targets + +cc_library( + name = "helpers", + testonly = 1, + textual_hdrs = [ + "algo-inl.h", + "result-inl.h", + ], + deps = [ + ":vqsort", + "//:nanobenchmark", + # Required for HAVE_PDQSORT, but that is unused and this is + # unavailable to Bazel builds, hence commented out. + # "//third_party/boost/allowed", + # Avoid ips4o and thus TBB to work around hwloc build failure. + ], +) + +cc_binary( + name = "print_network", + testonly = 1, + srcs = ["print_network.cc"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + ], +) + +cc_test( + name = "sort_test", + size = "medium", + srcs = ["sort_test.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) + +cc_binary( + name = "bench_sort", + testonly = 1, + srcs = ["bench_sort.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) + +cc_binary( + name = "bench_parallel", + testonly = 1, + srcs = ["bench_parallel.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) diff --git a/media/highway/src/hwy/contrib/sort/README.md b/media/highway/src/hwy/contrib/sort/README.md new file mode 100644 index 000000000..a0051414d --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/README.md @@ -0,0 +1,87 @@ +# Vectorized and performance-portable Quicksort + +## Introduction + +As of 2022-06-07 this sorts large arrays of built-in types about ten times as +fast as `std::sort`. See also our +[blog post](https://opensource.googleblog.com/2022/06/Vectorized%20and%20performance%20portable%20Quicksort.html) +and [paper](https://arxiv.org/abs/2205.05982). + +## Instructions + +Here are instructions for reproducing our results on Linux and AWS (SVE, NEON). + +### Linux + +Please first ensure golang, and Clang (tested with 13.0.1) are installed via +your system's package manager. + +``` +go install github.com/bazelbuild/bazelisk@latest +git clone https://github.com/google/highway +cd highway +CC=clang CXX=clang++ ~/go/bin/bazelisk build -c opt hwy/contrib/sort:all +bazel-bin/hwy/contrib/sort/sort_test +bazel-bin/hwy/contrib/sort/bench_sort +``` + +### AWS Graviton3 + +Instance config: amazon linux 5.10 arm64, c7g.8xlarge (largest allowed config is +32 vCPU). Initial launch will fail. Wait a few minutes for an email saying the +config is verified, then re-launch. See IPv4 hostname in list of instances. + +`ssh -i /path/key.pem ec2-user@hostname` + +Note that the AWS CMake package is too old for llvm, so we build it first: +``` +wget https://cmake.org/files/v3.23/cmake-3.23.2.tar.gz +tar -xvzf cmake-3.23.2.tar.gz && cd cmake-3.23.2/ +./bootstrap -- -DCMAKE_USE_OPENSSL=OFF +make -j8 && sudo make install +cd .. +``` + +AWS clang is at version 11.1, which generates unnecessary `AND` instructions +which slow down the sort by 1.15x. We tested with clang trunk as of June 13 +(which reports Git hash 8f6512fea000c3a0d394864bb94e524bee375069). To build: + +``` +git clone --depth 1 https://github.com/llvm/llvm-project.git +cd llvm-project +mkdir -p build && cd build +/usr/local/bin/cmake ../llvm -DLLVM_ENABLE_PROJECTS="clang" -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi" -DCMAKE_BUILD_TYPE=Release +make -j32 && sudo make install +``` + +``` +sudo yum install go +go install github.com/bazelbuild/bazelisk@latest +git clone https://github.com/google/highway +cd highway +CC=/usr/local/bin/clang CXX=/usr/local/bin/clang++ ~/go/bin/bazelisk build -c opt --copt=-march=armv8.2-a+sve hwy/contrib/sort:all +bazel-bin/hwy/contrib/sort/sort_test +bazel-bin/hwy/contrib/sort/bench_sort +``` + +The above command line enables SVE, which is currently only available on +Graviton 3. You can also test NEON on the same processor, or other Arm CPUs, by +changing the `-march=` option to `--copt=-march=armv8.2-a+crypto`. Note that +such flags will be unnecessary once Clang supports `#pragma target` for NEON and +SVE intrinsics, as it does for x86. + +## Results + +`bench_sort` outputs the instruction set (AVX3 refers to AVX-512), the sort +algorithm (std for `std::sort`, vq for our vqsort), the type of keys being +sorted (f32 is float), the distribution of keys (uniform32 for uniform random +with range 0-2^32), the number of keys, then the throughput of sorted keys (i.e. +number of key bytes output per second). + +Example excerpt from Xeon 6154 (Skylake-X) CPU clocked at 3 GHz: + +``` +[ RUN ] BenchSortGroup/BenchSort.BenchAllSort/AVX3 + AVX3: std: f32: uniform32: 1.00E+06 54 MB/s ( 1 threads) + AVX3: vq: f32: uniform32: 1.00E+06 1143 MB/s ( 1 threads) +``` diff --git a/media/highway/src/hwy/contrib/sort/algo-inl.h b/media/highway/src/hwy/contrib/sort/algo-inl.h new file mode 100644 index 000000000..4b01e2de3 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/algo-inl.h @@ -0,0 +1,512 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +#include +#include // memcpy + +#include +#include // std::abs +#include + +#include "hwy/base.h" +#include "hwy/contrib/sort/vqsort.h" + +// Third-party algorithms +#define HAVE_AVX2SORT 0 +#define HAVE_IPS4O 0 +// When enabling, consider changing max_threads (required for Table 1a) +#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) +#define HAVE_PDQSORT 0 +#define HAVE_SORT512 0 +#define HAVE_VXSORT 0 + +#if HAVE_AVX2SORT +HWY_PUSH_ATTRIBUTES("avx2,avx") +#include "avx2sort.h" //NOLINT +HWY_POP_ATTRIBUTES +#endif +#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O +#include "third_party/ips4o/include/ips4o.hpp" +#include "third_party/ips4o/include/ips4o/thread_pool.hpp" +#endif +#if HAVE_PDQSORT +#include "third_party/boost/allowed/sort/sort.hpp" +#endif +#if HAVE_SORT512 +#include "sort512.h" //NOLINT +#endif + +// vxsort is difficult to compile for multiple targets because it also uses +// .cpp files, and we'd also have to #undef its include guards. Instead, compile +// only for AVX2 or AVX3 depending on this macro. +#define VXSORT_AVX3 1 +#if HAVE_VXSORT +// inlined from vxsort_targets_enable_avx512 (must close before end of header) +#ifdef __GNUC__ +#ifdef __clang__ +#if VXSORT_AVX3 +#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \ + apply_to = any(function)) +#else +#pragma clang attribute push(__attribute__((target("avx2"))), \ + apply_to = any(function)) +#endif // VXSORT_AVX3 + +#else +#pragma GCC push_options +#if VXSORT_AVX3 +#pragma GCC target("avx512f,avx512dq") +#else +#pragma GCC target("avx2") +#endif // VXSORT_AVX3 +#endif +#endif + +#if VXSORT_AVX3 +#include "vxsort/machine_traits.avx512.h" +#else +#include "vxsort/machine_traits.avx2.h" +#endif // VXSORT_AVX3 +#include "vxsort/vxsort.h" +#ifdef __GNUC__ +#ifdef __clang__ +#pragma clang attribute pop +#else +#pragma GCC pop_options +#endif +#endif +#endif // HAVE_VXSORT + +namespace hwy { + +enum class Dist { kUniform8, kUniform16, kUniform32 }; + +static inline std::vector AllDist() { + return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32}; +} + +static inline const char* DistName(Dist dist) { + switch (dist) { + case Dist::kUniform8: + return "uniform8"; + case Dist::kUniform16: + return "uniform16"; + case Dist::kUniform32: + return "uniform32"; + } + return "unreachable"; +} + +template +class InputStats { + public: + void Notify(T value) { + min_ = std::min(min_, value); + max_ = std::max(max_, value); + // Converting to integer would truncate floats, multiplying to save digits + // risks overflow especially when casting, so instead take the sum of the + // bit representations as the checksum. + uint64_t bits = 0; + static_assert(sizeof(T) <= 8, "Expected a built-in type"); + CopyBytes(&value, &bits); // not same size + sum_ += bits; + count_ += 1; + } + + bool operator==(const InputStats& other) const { + if (count_ != other.count_) { + HWY_ABORT("count %d vs %d\n", static_cast(count_), + static_cast(other.count_)); + } + + if (min_ != other.min_ || max_ != other.max_) { + HWY_ABORT("minmax %f/%f vs %f/%f\n", static_cast(min_), + static_cast(max_), static_cast(other.min_), + static_cast(other.max_)); + } + + // Sum helps detect duplicated/lost values + if (sum_ != other.sum_) { + HWY_ABORT("Sum mismatch %g %g; min %g max %g\n", + static_cast(sum_), static_cast(other.sum_), + static_cast(min_), static_cast(max_)); + } + + return true; + } + + private: + T min_ = hwy::HighestValue(); + T max_ = hwy::LowestValue(); + uint64_t sum_ = 0; + size_t count_ = 0; +}; + +enum class Algo { +#if HAVE_AVX2SORT + kSEA, +#endif +#if HAVE_IPS4O + kIPS4O, +#endif +#if HAVE_PARALLEL_IPS4O + kParallelIPS4O, +#endif +#if HAVE_PDQSORT + kPDQ, +#endif +#if HAVE_SORT512 + kSort512, +#endif +#if HAVE_VXSORT + kVXSort, +#endif + kStd, + kVQSort, + kHeap, +}; + +static inline const char* AlgoName(Algo algo) { + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return "sea"; +#endif +#if HAVE_IPS4O + case Algo::kIPS4O: + return "ips4o"; +#endif +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + return "par_ips4o"; +#endif +#if HAVE_PDQSORT + case Algo::kPDQ: + return "pdq"; +#endif +#if HAVE_SORT512 + case Algo::kSort512: + return "sort512"; +#endif +#if HAVE_VXSORT + case Algo::kVXSort: + return "vxsort"; +#endif + case Algo::kStd: + return "std"; + case Algo::kVQSort: + return "vq"; + case Algo::kHeap: + return "heap"; + } + return "unreachable"; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#endif + +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +class Xorshift128Plus { + static HWY_INLINE uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + public: + // Generates two vectors of 64-bit seeds via SplitMix64 and stores into + // `seeds`. Generating these afresh in each ChoosePivot is too expensive. + template + static void GenerateSeeds(DU64 du64, TFromD* HWY_RESTRICT seeds) { + seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < 2 * Lanes(du64); ++i) { + seeds[i] = SplitMix64(seeds[i - 1]); + } + } + + // Need to pass in the state because vector cannot be class members. + template + static VU64 RandomBits(VU64& state0, VU64& state1) { + VU64 s1 = state0; + VU64 s0 = state1; + const VU64 bits = Add(s1, s0); + state0 = s0; + s1 = Xor(s1, ShiftLeft<23>(s1)); + state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + return bits; + } +}; + +template +Vec RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) { + const VU64 bits = Xorshift128Plus::RandomBits(s0, s1); + return BitCast(d, And(bits, mask)); +} + +// It is important to avoid denormals, which are flushed to zero by SIMD but not +// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. +template +Vec RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) { + using TF = TFromD; + const RebindToUnsigned du; + using VU = Vec; + + const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask); + +#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types + using TU = MakeUnsigned; + const VU bits = Set(du, static_cast(GetLane(bits64) & LimitsMax())); +#else + const VU bits = BitCast(du, bits64); +#endif + // Avoid NaN/denormal by only generating values in [1, 2), i.e. random + // mantissas with the exponent taken from the representation of 1.0. + const VU k1 = BitCast(du, Set(df, TF{1.0})); + const VU mantissa_mask = Set(du, MantissaMask()); + const VU representation = OrAnd(k1, bits, mantissa_mask); + return BitCast(df, representation); +} + +template +Vec MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { + switch (sizeof_t) { + case 2: + return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull + : 0xFFFFFFFFFFFFFFFFull); + case 4: + return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull + : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull + : 0xFFFFFFFFFFFFFFFFull); + case 8: + return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull + : (dist == Dist::kUniform16) ? 0x000000000000FFFFull + : 0x00000000FFFFFFFFull); + default: + HWY_ABORT("Logic error"); + return Zero(du64); + } +} + +template +InputStats GenerateInput(const Dist dist, T* v, size_t num) { + SortTag du64; + using VU64 = Vec; + const size_t N64 = Lanes(du64); + auto seeds = hwy::AllocateAligned(2 * N64); + Xorshift128Plus::GenerateSeeds(du64, seeds.get()); + VU64 s0 = Load(du64, seeds.get()); + VU64 s1 = Load(du64, seeds.get() + N64); + +#if HWY_TARGET == HWY_SCALAR + const Sisd d; +#else + const Repartition d; +#endif + using V = Vec; + const size_t N = Lanes(d); + const VU64 mask = MaskForDist(du64, dist, sizeof(T)); + auto buf = hwy::AllocateAligned(N); + + size_t i = 0; + for (; i + N <= num; i += N) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, v + i); + } + if (i < num) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, buf.get()); + memcpy(v + i, buf.get(), (num - i) * sizeof(T)); + } + + InputStats input_stats; + for (size_t i = 0; i < num; ++i) { + input_stats.Notify(v[i]); + } + return input_stats; +} + +struct ThreadLocal { + Sorter sorter; +}; + +struct SharedState { +#if HAVE_PARALLEL_IPS4O + const unsigned max_threads = hwy::LimitsMax(); // 16 for Table 1a + ips4o::StdThreadPool pool{static_cast( + HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))}; +#endif + std::vector tls{1}; +}; + +// Bridge from keys (passed to Run) to lanes as expected by HeapSort. For +// non-128-bit keys they are the same: +template +void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) { + using detail::TraitsLane; + using detail::SharedTraits; + if (Order().IsAscending()) { + const SharedTraits>> st; + return detail::HeapSort(st, keys, num_keys); + } else { + const SharedTraits>> st; + return detail::HeapSort(st, keys, num_keys); + } +} + +#if VQSORT_ENABLED +template +void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) { + using detail::SharedTraits; + using detail::Traits128; + uint64_t* lanes = reinterpret_cast(keys); + const size_t num_lanes = num_keys * 2; + if (Order().IsAscending()) { + const SharedTraits> st; + return detail::HeapSort(st, lanes, num_lanes); + } else { + const SharedTraits> st; + return detail::HeapSort(st, lanes, num_lanes); + } +} + +template +void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) { + using detail::SharedTraits; + using detail::Traits128; + uint64_t* lanes = reinterpret_cast(keys); + const size_t num_lanes = num_keys * 2; + if (Order().IsAscending()) { + const SharedTraits> st; + return detail::HeapSort(st, lanes, num_lanes); + } else { + const SharedTraits> st; + return detail::HeapSort(st, lanes, num_lanes); + } +} +#endif // VQSORT_ENABLED + +template +void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num, + SharedState& shared, size_t thread) { + const std::less less; + const std::greater greater; + + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return avx2::quicksort(inout, static_cast(num)); +#endif + +#if HAVE_IPS4O + case Algo::kIPS4O: + if (Order().IsAscending()) { + return ips4o::sort(inout, inout + num, less); + } else { + return ips4o::sort(inout, inout + num, greater); + } +#endif + +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + if (Order().IsAscending()) { + return ips4o::parallel::sort(inout, inout + num, less, shared.pool); + } else { + return ips4o::parallel::sort(inout, inout + num, greater, shared.pool); + } +#endif + +#if HAVE_SORT512 + case Algo::kSort512: + HWY_ABORT("not supported"); + // return Sort512::Sort(inout, num); +#endif + +#if HAVE_PDQSORT + case Algo::kPDQ: + if (Order().IsAscending()) { + return boost::sort::pdqsort_branchless(inout, inout + num, less); + } else { + return boost::sort::pdqsort_branchless(inout, inout + num, greater); + } +#endif + +#if HAVE_VXSORT + case Algo::kVXSort: { +#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2) + fprintf(stderr, "Do not call for target %s\n", + hwy::TargetName(HWY_TARGET)); + return; +#else +#if VXSORT_AVX3 + vxsort::vxsort vx; +#else + vxsort::vxsort vx; +#endif + if (Order().IsAscending()) { + return vx.sort(inout, inout + num - 1); + } else { + fprintf(stderr, "Skipping VX - does not support descending order\n"); + return; + } +#endif // enabled for this target + } +#endif // HAVE_VXSORT + + case Algo::kStd: + if (Order().IsAscending()) { + return std::sort(inout, inout + num, less); + } else { + return std::sort(inout, inout + num, greater); + } + + case Algo::kVQSort: + return shared.tls[thread].sorter(inout, num, Order()); + + case Algo::kHeap: + return CallHeapSort(inout, num); + + default: + HWY_ABORT("Not implemented"); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/bench_parallel.cc b/media/highway/src/hwy/contrib/sort/bench_parallel.cc new file mode 100644 index 000000000..1c8c928e2 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/bench_parallel.cc @@ -0,0 +1,238 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Concurrent, independent sorts for generating more memory traffic and testing +// scalability. + +#include +#include + +#include //NOLINT +#include +#include +#include //NOLINT +#include //NOLINT +#include +#include + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_parallel.cc" //NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/aligned_allocator.h" +// Last +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +class ThreadPool { + public: + // Starts the given number of worker threads and blocks until they are ready. + explicit ThreadPool( + const size_t num_threads = std::thread::hardware_concurrency()) + : num_threads_(num_threads) { + HWY_ASSERT(num_threads_ > 0); + threads_.reserve(num_threads_); + for (size_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back(ThreadFunc, this, i); + } + + WorkersReadyBarrier(); + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Waits for all threads to exit. + ~ThreadPool() { + StartWorkers(kWorkerExit); + + for (std::thread& thread : threads_) { + thread.join(); + } + } + + size_t NumThreads() const { return threads_.size(); } + + template + void RunOnThreads(size_t max_threads, const Func& func) { + task_ = &CallClosure; + data_ = &func; + StartWorkers(max_threads); + WorkersReadyBarrier(); + } + + private: + // After construction and between calls to Run, workers are "ready", i.e. + // waiting on worker_start_cv_. They are "started" by sending a "command" + // and notifying all worker_start_cv_ waiters. (That is why all workers + // must be ready/waiting - otherwise, the notification will not reach all of + // them and the main thread waits in vain for them to report readiness.) + using WorkerCommand = uint64_t; + + static constexpr WorkerCommand kWorkerWait = ~1ULL; + static constexpr WorkerCommand kWorkerExit = ~2ULL; + + // Calls a closure (lambda with captures). + template + static void CallClosure(const void* f, size_t thread) { + (*reinterpret_cast(f))(thread); + } + + void WorkersReadyBarrier() { + std::unique_lock lock(mutex_); + // Typically only a single iteration. + while (workers_ready_ != threads_.size()) { + workers_ready_cv_.wait(lock); + } + workers_ready_ = 0; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + } + + // Precondition: all workers are ready. + void StartWorkers(const WorkerCommand worker_command) { + std::unique_lock lock(mutex_); + worker_start_command_ = worker_command; + // Workers will need this lock, so release it before they wake up. + lock.unlock(); + worker_start_cv_.notify_all(); + } + + static void ThreadFunc(ThreadPool* self, size_t thread) { + // Until kWorkerExit command received: + for (;;) { + std::unique_lock lock(self->mutex_); + // Notify main thread that this thread is ready. + if (++self->workers_ready_ == self->num_threads_) { + self->workers_ready_cv_.notify_one(); + } + RESUME_WAIT: + // Wait for a command. + self->worker_start_cv_.wait(lock); + const WorkerCommand command = self->worker_start_command_; + switch (command) { + case kWorkerWait: // spurious wakeup: + goto RESUME_WAIT; // lock still held, avoid incrementing ready. + case kWorkerExit: + return; // exits thread + default: + break; + } + + lock.unlock(); + // Command is the maximum number of threads that should run the task. + HWY_ASSERT(command < self->NumThreads()); + if (thread < command) { + self->task_(self->data_, thread); + } + } + } + + const size_t num_threads_; + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + std::mutex mutex_; // guards both cv and their variables. + std::condition_variable workers_ready_cv_; + size_t workers_ready_ = 0; + std::condition_variable worker_start_cv_; + WorkerCommand worker_start_command_; + + // Written by main thread, read by workers (after mutex lock/unlock). + std::function task_; // points to CallClosure + const void* data_; // points to caller's Func +}; + +template +void RunWithoutVerify(Traits st, const Dist dist, const size_t num_keys, + const Algo algo, SharedState& shared, size_t thread) { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + using Order = typename Traits::Order; + const size_t num_lanes = num_keys * st.LanesPerKey(); + auto aligned = hwy::AllocateAligned(num_lanes); + + (void)GenerateInput(dist, aligned.get(), num_lanes); + + const Timestamp t0; + Run(algo, reinterpret_cast(aligned.get()), num_keys, shared, + thread); + HWY_ASSERT(aligned[0] < aligned[num_lanes - 1]); +} + +void BenchParallel() { + // Not interested in benchmark results for other targets on x86 + if (HWY_ARCH_X86 && (HWY_TARGET != HWY_AVX2 && HWY_TARGET != HWY_AVX3)) { + return; + } + + ThreadPool pool; + const size_t NT = pool.NumThreads(); + + detail::SharedTraits>> st; + using KeyType = typename decltype(st)::KeyType; + const size_t num_keys = size_t{100} * 1000 * 1000; + +#if HAVE_IPS4O + const Algo algo = Algo::kIPS4O; +#else + const Algo algo = Algo::kVQSort; +#endif + const Dist dist = Dist::kUniform32; + + SharedState shared; + shared.tls.resize(NT); + + std::vector results; + for (size_t nt = 1; nt < NT; nt += HWY_MAX(1, NT / 16)) { + Timestamp t0; + // Default capture because MSVC wants algo/dist but clang does not. + pool.RunOnThreads(nt, [=, &shared](size_t thread) { + RunWithoutVerify(st, dist, num_keys, algo, shared, thread); + }); + const double sec = SecondsSince(t0); + results.emplace_back(algo, dist, num_keys, nt, sec, sizeof(KeyType), + st.KeyString()); + results.back().Print(); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(BenchParallel); +HWY_EXPORT_AND_TEST_P(BenchParallel, BenchParallel); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/bench_sort.cc b/media/highway/src/hwy/contrib/sort/bench_sort.cc new file mode 100644 index 000000000..a668fde90 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/bench_sort.cc @@ -0,0 +1,310 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_sort.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" // SharedTraits +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// Mode for larger sorts because M1 is able to access more than the per-core +// share of L2, so 1M elements might still be in cache. +#define SORT_100M 0 + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +// Defined within HWY_ONCE, used by BenchAllSort. +extern int64_t first_sort_target; + +namespace HWY_NAMESPACE { +namespace { +using detail::TraitsLane; +using detail::OrderAscending; +using detail::OrderDescending; +using detail::SharedTraits; + +#if VQSORT_ENABLED || HWY_IDE +using detail::OrderAscending128; +using detail::OrderAscendingKV128; +using detail::Traits128; + +template +HWY_NOINLINE void BenchPartition() { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const SortTag d; + detail::SharedTraits st; + const Dist dist = Dist::kUniform8; + double sum = 0.0; + + detail::Generator rng(&sum, 123); // for ChoosePivot + + const size_t max_log2 = AdjustedLog2Reps(20); + for (size_t log2 = max_log2; log2 < max_log2 + 1; ++log2) { + const size_t num_lanes = 1ull << log2; + const size_t num_keys = num_lanes / st.LanesPerKey(); + auto aligned = hwy::AllocateAligned(num_lanes); + auto buf = hwy::AllocateAligned( + HWY_MAX(hwy::SortConstants::PartitionBufNum(Lanes(d)), + hwy::SortConstants::PivotBufNum(sizeof(LaneType), Lanes(d)))); + + std::vector seconds; + const size_t num_reps = (1ull << (14 - log2 / 2)) * 30; + for (size_t rep = 0; rep < num_reps; ++rep) { + (void)GenerateInput(dist, aligned.get(), num_lanes); + + // The pivot value can influence performance. Do exactly what vqsort will + // do so that the performance (influenced by prefetching and branch + // prediction) is likely to predict the actual performance inside vqsort. + detail::DrawSamples(d, st, aligned.get(), num_lanes, buf.get(), rng); + detail::SortSamples(d, st, buf.get()); + auto pivot = detail::ChoosePivotByRank(d, st, buf.get()); + + const Timestamp t0; + detail::Partition(d, st, aligned.get(), num_lanes - 1, pivot, buf.get()); + seconds.push_back(SecondsSince(t0)); + // 'Use' the result to prevent optimizing out the partition. + sum += static_cast(aligned.get()[num_lanes / 2]); + } + + Result(Algo::kVQSort, dist, num_keys, 1, SummarizeMeasurements(seconds), + sizeof(KeyType), st.KeyString()) + .Print(); + } + HWY_ASSERT(sum != 999999); // Prevent optimizing out +} + +HWY_NOINLINE void BenchAllPartition() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + BenchPartition>>(); + BenchPartition>>(); + BenchPartition>>(); + BenchPartition>(); + // BenchPartition>(); + BenchPartition>(); +} + +template +HWY_NOINLINE void BenchBase(std::vector& results) { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const SortTag d; + detail::SharedTraits st; + const Dist dist = Dist::kUniform32; + + const size_t N = Lanes(d); + const size_t num_lanes = SortConstants::BaseCaseNum(N); + const size_t num_keys = num_lanes / st.LanesPerKey(); + auto keys = hwy::AllocateAligned(num_lanes); + auto buf = hwy::AllocateAligned(num_lanes + N); + + std::vector seconds; + double sum = 0; // prevents elision + constexpr size_t kMul = AdjustedReps(600); // ensures long enough to measure + + for (size_t rep = 0; rep < 30; ++rep) { + InputStats input_stats = + GenerateInput(dist, keys.get(), num_lanes); + + const Timestamp t0; + for (size_t i = 0; i < kMul; ++i) { + detail::BaseCase(d, st, keys.get(), keys.get() + num_lanes, num_lanes, + buf.get()); + sum += static_cast(keys[0]); + } + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + HWY_ASSERT(VerifySort(st, input_stats, keys.get(), num_lanes, "BenchBase")); + } + HWY_ASSERT(sum < 1E99); + results.emplace_back(Algo::kVQSort, dist, num_keys * kMul, 1, + SummarizeMeasurements(seconds), sizeof(KeyType), + st.KeyString()); +} + +HWY_NOINLINE void BenchAllBase() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + std::vector results; + BenchBase>>(results); + BenchBase>>(results); + BenchBase>(results); + for (const Result& r : results) { + r.Print(); + } +} + +#else +void BenchAllPartition() {} +void BenchAllBase() {} +#endif // VQSORT_ENABLED + +std::vector AlgoForBench() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_PARALLEL_IPS4O + Algo::kParallelIPS4O, +#elif HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif +// Only include if we're compiling for the target it supports. +#if HAVE_VXSORT && ((VXSORT_AVX3 && HWY_TARGET == HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET == HWY_AVX2)) + Algo::kVXSort, +#endif + +#if !HAVE_PARALLEL_IPS4O +#if !SORT_100M + // These are 10-20x slower, but that's OK for the default size when we + // are not testing the parallel nor 100M modes. + Algo::kStd, Algo::kHeap, +#endif + + Algo::kVQSort, // only ~4x slower, but not required for Table 1a +#endif + }; +} + +template +HWY_NOINLINE void BenchSort(size_t num_keys) { + if (first_sort_target == 0) first_sort_target = HWY_TARGET; + + SharedState shared; + detail::SharedTraits st; + using Order = typename Traits::Order; + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const size_t num_lanes = num_keys * st.LanesPerKey(); + auto aligned = hwy::AllocateAligned(num_lanes); + + const size_t reps = num_keys > 1000 * 1000 ? 10 : 30; + + for (Algo algo : AlgoForBench()) { + // Other algorithms don't depend on the vector instructions, so only run + // them for the first target. +#if !HAVE_VXSORT + if (algo != Algo::kVQSort && HWY_TARGET != first_sort_target) { + continue; + } +#endif + + for (Dist dist : AllDist()) { + std::vector seconds; + for (size_t rep = 0; rep < reps; ++rep) { + InputStats input_stats = + GenerateInput(dist, aligned.get(), num_lanes); + + const Timestamp t0; + Run(algo, reinterpret_cast(aligned.get()), num_keys, + shared, /*thread=*/0); + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + HWY_ASSERT( + VerifySort(st, input_stats, aligned.get(), num_lanes, "BenchSort")); + } + Result(algo, dist, num_keys, 1, SummarizeMeasurements(seconds), + sizeof(KeyType), st.KeyString()) + .Print(); + } // dist + } // algo +} + +HWY_NOINLINE void BenchAllSort() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + constexpr size_t K = 1000; + constexpr size_t M = K * K; + (void)K; + (void)M; + for (size_t num_keys : { +#if HAVE_PARALLEL_IPS4O || SORT_100M + 100 * M, +#else + 1 * M, +#endif + }) { + BenchSort>>(num_keys); + // BenchSort>>(num_keys); + // BenchSort>>(num_keys); + BenchSort>>(num_keys); + BenchSort>>(num_keys); + // BenchSort>>(num_keys); + // BenchSort>>(num_keys); + // BenchSort>>(num_keys); + +#if !HAVE_VXSORT && VQSORT_ENABLED + BenchSort>(num_keys); + BenchSort>(num_keys); +#endif + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +int64_t first_sort_target = 0; // none run yet +namespace { +HWY_BEFORE_TEST(BenchSort); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllPartition); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllBase); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllSort); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/print_network.cc b/media/highway/src/hwy/contrib/sort/print_network.cc new file mode 100644 index 000000000..59cfebcfb --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/print_network.cc @@ -0,0 +1,191 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "hwy/base.h" + +// Based on A.7 in "Entwurf und Implementierung vektorisierter +// Sortieralgorithmen" and code by Mark Blacher. +void PrintMergeNetwork16x2() { + for (int i = 8; i < 16; ++i) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } + printf("\n"); +} + +void PrintMergeNetwork16x4() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.Reverse4(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.Reverse4(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.Reverse4(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse4(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } +} + +void PrintMergeNetwork16x8() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse8(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance2(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } +} + +void PrintMergeNetwork16x16() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse16(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance4(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance2(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } +} + +int main(int argc, char** argv) { + PrintMergeNetwork16x2(); + PrintMergeNetwork16x4(); + PrintMergeNetwork16x8(); + PrintMergeNetwork16x16(); + return 0; +} diff --git a/media/highway/src/hwy/contrib/sort/result-inl.h b/media/highway/src/hwy/contrib/sort/result-inl.h new file mode 100644 index 000000000..f3d842dfb --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/result-inl.h @@ -0,0 +1,139 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/algo-inl.h" + +// Normal include guard for non-SIMD parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +#include + +#include // std::sort +#include + +#include "hwy/base.h" +#include "hwy/nanobenchmark.h" + +namespace hwy { + +struct Timestamp { + Timestamp() { t = platform::Now(); } + double t; +}; + +static inline double SecondsSince(const Timestamp& t0) { + const Timestamp t1; + return t1.t - t0.t; +} + +// Returns trimmed mean (we don't want to run an out-of-L3-cache sort often +// enough for the mode to be reliable). +static inline double SummarizeMeasurements(std::vector& seconds) { + std::sort(seconds.begin(), seconds.end()); + double sum = 0; + int count = 0; + const size_t num = seconds.size(); + for (size_t i = num / 4; i < num / 2; ++i) { + sum += seconds[i]; + count += 1; + } + return sum / count; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct Result { + Result() {} + Result(const Algo algo, Dist dist, size_t num_keys, size_t num_threads, + double sec, size_t sizeof_key, const std::string& key_name) + : target(HWY_TARGET), + algo(algo), + dist(dist), + num_keys(num_keys), + num_threads(num_threads), + sec(sec), + sizeof_key(sizeof_key), + key_name(key_name) {} + + void Print() const { + const double bytes = static_cast(num_keys) * + static_cast(num_threads) * + static_cast(sizeof_key); + printf("%10s: %12s: %7s: %9s: %.2E %4.0f MB/s (%2zu threads)\n", + hwy::TargetName(target), AlgoName(algo), key_name.c_str(), + DistName(dist), static_cast(num_keys), bytes * 1E-6 / sec, + num_threads); + } + + int64_t target; + Algo algo; + Dist dist; + size_t num_keys = 0; + size_t num_threads = 0; + double sec = 0.0; + size_t sizeof_key = 0; + std::string key_name; +}; + +template +bool VerifySort(Traits st, const InputStats& input_stats, + const LaneType* out, size_t num_lanes, const char* caller) { + constexpr size_t N1 = st.LanesPerKey(); + HWY_ASSERT(num_lanes >= N1); + + InputStats output_stats; + // Ensure it matches the sort order + for (size_t i = 0; i < num_lanes - N1; i += N1) { + output_stats.Notify(out[i]); + if (N1 == 2) output_stats.Notify(out[i + 1]); + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (st.Compare1(out + i + N1, out + i)) { + printf("%s: i=%d of %d lanes: N1=%d %5.0f %5.0f vs. %5.0f %5.0f\n\n", + caller, static_cast(i), static_cast(num_lanes), + static_cast(N1), static_cast(out[i + 1]), + static_cast(out[i + 0]), + static_cast(out[i + N1 + 1]), + static_cast(out[i + N1])); + HWY_ABORT("%d-bit sort is incorrect\n", + static_cast(sizeof(LaneType) * 8 * N1)); + } + } + output_stats.Notify(out[num_lanes - N1]); + if (N1 == 2) output_stats.Notify(out[num_lanes - N1 + 1]); + + return input_stats == output_stats; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/shared-inl.h b/media/highway/src/hwy/contrib/sort/shared-inl.h new file mode 100644 index 000000000..ea604ed91 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/shared-inl.h @@ -0,0 +1,133 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions shared between vqsort-inl and sorting_networks-inl. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Internal constants - these are to avoid magic numbers/literals and cannot be +// changed without also changing the associated code. +struct SortConstants { +// SortingNetwork reshapes its input into a matrix. This is the maximum number +// of *keys* per vector. +#if HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + static constexpr size_t kMaxCols = 8; // avoid build timeout/stack overflow +#else + static constexpr size_t kMaxCols = 16; // enough for u32 in 512-bit vector +#endif + + // 16 rows is a compromise between using the 32 AVX-512/SVE/RVV registers, + // fitting within 16 AVX2 registers with only a few spills, keeping BaseCase + // code size reasonable (7 KiB for AVX-512 and 16 cols), and minimizing the + // extra logN factor for larger networks (for which only loose upper bounds + // on size are known). + static constexpr size_t kMaxRowsLog2 = 4; + static constexpr size_t kMaxRows = size_t{1} << kMaxRowsLog2; + + static constexpr HWY_INLINE size_t BaseCaseNum(size_t N) { + return kMaxRows * HWY_MIN(N, kMaxCols); + } + + // Unrolling is important (pipelining and amortizing branch mispredictions); + // 2x is sufficient to reach full memory bandwidth on SKX in Partition, but + // somewhat slower for sorting than 4x. + // + // To change, must also update left + 3 * N etc. in the loop. + static constexpr size_t kPartitionUnroll = 4; + + static constexpr HWY_INLINE size_t PartitionBufNum(size_t N) { + // The main loop reads kPartitionUnroll vectors, and first loads from + // both left and right beforehand, so it requires min = 2 * + // kPartitionUnroll vectors. To handle smaller amounts (only guaranteed + // >= BaseCaseNum), we partition the right side into a buffer. We need + // another vector at the end so CompressStore does not overwrite anything. + return (2 * kPartitionUnroll + 1) * N; + } + + // Chunk := group of keys loaded for sampling a pivot. Matches the typical + // cache line size of 64 bytes to get maximum benefit per L2 miss. If vectors + // are larger, use entire vectors to ensure we do not overrun the array. + static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N) { + return HWY_MAX(64 / sizeof_t, N); + } + + static constexpr HWY_INLINE size_t PivotBufNum(size_t sizeof_t, size_t N) { + // 3 chunks of medians, 1 chunk of median medians plus two padding vectors. + return (3 + 1) * LanesPerChunk(sizeof_t, N) + 2 * N; + } + + template + static constexpr HWY_INLINE size_t BufNum(size_t N) { + // One extra for padding plus another for full-vector loads. + return HWY_MAX(BaseCaseNum(N) + 2 * N, + HWY_MAX(PartitionBufNum(N), PivotBufNum(sizeof(T), N))); + } + + template + static constexpr HWY_INLINE size_t BufBytes(size_t vector_size) { + return sizeof(T) * BufNum(vector_size / sizeof(T)); + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#endif + +#include "hwy/highway.h" + +// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and +// Arm v7 debug. +#undef VQSORT_ENABLED +#if (HWY_TARGET == HWY_SCALAR) || \ + (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \ + (HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) +#define VQSORT_ENABLED 0 +#else +#define VQSORT_ENABLED 1 +#endif + +namespace hwy { +namespace HWY_NAMESPACE { + +// Default tag / vector width selector. +#if HWY_TARGET == HWY_RVV +// Use LMUL = 1/2; for SEW=64 this ends up emulated via vsetvl. +template +using SortTag = ScalableTag; +#else +template +using SortTag = ScalableTag; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/sort_test.cc b/media/highway/src/hwy/contrib/sort/sort_test.cc new file mode 100644 index 000000000..2d1f1d516 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/sort_test.cc @@ -0,0 +1,626 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include +#include +#include // memcpy + +#include +#include + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/sort/vqsort.h" +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // BaseCase +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +using detail::OrderAscending; +using detail::OrderDescending; +using detail::SharedTraits; +using detail::TraitsLane; +#if VQSORT_ENABLED || HWY_IDE +using detail::OrderAscending128; +using detail::OrderAscendingKV128; +using detail::OrderAscendingKV64; +using detail::OrderDescending128; +using detail::OrderDescendingKV128; +using detail::OrderDescendingKV64; +using detail::Traits128; + +template +static HWY_NOINLINE void TestMedian3() { + using LaneType = typename Traits::LaneType; + using D = CappedTag; + SharedTraits st; + const D d; + using V = Vec; + for (uint32_t bits = 0; bits < 8; ++bits) { + const V v0 = Set(d, LaneType{(bits & (1u << 0)) ? 1u : 0u}); + const V v1 = Set(d, LaneType{(bits & (1u << 1)) ? 1u : 0u}); + const V v2 = Set(d, LaneType{(bits & (1u << 2)) ? 1u : 0u}); + const LaneType m = GetLane(detail::MedianOf3(st, v0, v1, v2)); + // If at least half(rounded up) of bits are 1, so is the median. + const size_t count = PopCount(bits); + HWY_ASSERT_EQ((count >= 2) ? static_cast(1) : 0, m); + } +} + +HWY_NOINLINE void TestAllMedian() { + TestMedian3 > >(); +} + +template +static HWY_NOINLINE void TestBaseCaseAscDesc() { + using LaneType = typename Traits::LaneType; + SharedTraits st; + const SortTag d; + const size_t N = Lanes(d); + const size_t base_case_num = SortConstants::BaseCaseNum(N); + const size_t N1 = st.LanesPerKey(); + + constexpr int kDebug = 0; + auto aligned_lanes = hwy::AllocateAligned(N + base_case_num + N); + auto buf = hwy::AllocateAligned(base_case_num + 2 * N); + + std::vector lengths; + lengths.push_back(HWY_MAX(1, N1)); + lengths.push_back(3 * N1); + lengths.push_back(base_case_num / 2); + lengths.push_back(base_case_num / 2 + N1); + lengths.push_back(base_case_num - N1); + lengths.push_back(base_case_num); + + std::vector misalignments; + misalignments.push_back(0); + misalignments.push_back(1); + if (N >= 6) misalignments.push_back(N / 2 - 1); + misalignments.push_back(N / 2); + misalignments.push_back(N / 2 + 1); + misalignments.push_back(HWY_MIN(2 * N / 3 + 3, size_t{N - 1})); + + for (bool asc : {false, true}) { + for (size_t len : lengths) { + for (size_t misalign : misalignments) { + LaneType* HWY_RESTRICT lanes = aligned_lanes.get() + misalign; + if (kDebug) { + printf("============%s asc %d N1 %d len %d misalign %d\n", + st.KeyString().c_str(), asc, static_cast(N1), + static_cast(len), static_cast(misalign)); + } + + for (size_t i = 0; i < misalign; ++i) { + aligned_lanes[i] = hwy::LowestValue(); + } + InputStats input_stats; + for (size_t i = 0; i < len; ++i) { + lanes[i] = asc ? static_cast(LaneType(i) + 1) + : static_cast(LaneType(len) - LaneType(i)); + input_stats.Notify(lanes[i]); + if (kDebug >= 2) { + printf("%3zu: %f\n", i, static_cast(lanes[i])); + } + } + for (size_t i = len; i < base_case_num + N; ++i) { + lanes[i] = hwy::LowestValue(); + } + + detail::BaseCase(d, st, lanes, lanes + len, len, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = 0; i < len; ++i) { + printf("%3zu: %f\n", i, static_cast(lanes[i])); + } + } + + HWY_ASSERT(VerifySort(st, input_stats, lanes, len, "BaseAscDesc")); + for (size_t i = 0; i < misalign; ++i) { + if (aligned_lanes[i] != hwy::LowestValue()) + HWY_ABORT("Overrun misalign at %d\n", static_cast(i)); + } + for (size_t i = len; i < base_case_num + N; ++i) { + if (lanes[i] != hwy::LowestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // misalign + } // len + } // asc +} + +template +static HWY_NOINLINE void TestBaseCase01() { + using LaneType = typename Traits::LaneType; + SharedTraits st; + const SortTag d; + const size_t N = Lanes(d); + const size_t base_case_num = SortConstants::BaseCaseNum(N); + const size_t N1 = st.LanesPerKey(); + + constexpr int kDebug = 0; + auto lanes = hwy::AllocateAligned(base_case_num + N); + auto buf = hwy::AllocateAligned(base_case_num + 2 * N); + + std::vector lengths; + lengths.push_back(HWY_MAX(1, N1)); + lengths.push_back(3 * N1); + lengths.push_back(base_case_num / 2); + lengths.push_back(base_case_num / 2 + N1); + lengths.push_back(base_case_num - N1); + lengths.push_back(base_case_num); + + for (size_t len : lengths) { + if (kDebug) { + printf("============%s 01 N1 %d len %d\n", st.KeyString().c_str(), + static_cast(N1), static_cast(len)); + } + const uint64_t kMaxBits = AdjustedLog2Reps(HWY_MIN(len, size_t{14})); + for (uint64_t bits = 0; bits < ((1ull << kMaxBits) - 1); ++bits) { + InputStats input_stats; + for (size_t i = 0; i < len; ++i) { + lanes[i] = (i < 64 && (bits & (1ull << i))) ? 1 : 0; + input_stats.Notify(lanes[i]); + if (kDebug >= 2) { + printf("%3zu: %f\n", i, static_cast(lanes[i])); + } + } + for (size_t i = len; i < base_case_num + N; ++i) { + lanes[i] = hwy::LowestValue(); + } + + detail::BaseCase(d, st, lanes.get(), lanes.get() + len, len, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = 0; i < len; ++i) { + printf("%3zu: %f\n", i, static_cast(lanes[i])); + } + } + + HWY_ASSERT(VerifySort(st, input_stats, lanes.get(), len, "Base01")); + for (size_t i = len; i < base_case_num + N; ++i) { + if (lanes[i] != hwy::LowestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // bits + } // len +} + +template +static HWY_NOINLINE void TestBaseCase() { + TestBaseCaseAscDesc(); + TestBaseCase01(); +} + +HWY_NOINLINE void TestAllBaseCase() { + // Workaround for stack overflow on MSVC debug. +#if defined(_MSC_VER) + return; +#endif + TestBaseCase > >(); + TestBaseCase > >(); + TestBaseCase >(); + TestBaseCase >(); +} + +template +static HWY_NOINLINE void VerifyPartition( + Traits st, typename Traits::LaneType* HWY_RESTRICT lanes, size_t left, + size_t border, size_t right, const size_t N1, + const typename Traits::LaneType* pivot) { + /* for (size_t i = left; i < right; ++i) { + if (i == border) printf("--\n"); + printf("%4zu: %3d\n", i, lanes[i]); + }*/ + + HWY_ASSERT(left % N1 == 0); + HWY_ASSERT(border % N1 == 0); + HWY_ASSERT(right % N1 == 0); + const bool asc = typename Traits::Order().IsAscending(); + for (size_t i = left; i < border; i += N1) { + if (st.Compare1(pivot, lanes + i)) { + HWY_ABORT( + "%s: asc %d left[%d] piv %.0f %.0f compares before %.0f %.0f " + "border %d", + st.KeyString().c_str(), asc, static_cast(i), + static_cast(pivot[1]), static_cast(pivot[0]), + static_cast(lanes[i + 1]), static_cast(lanes[i + 0]), + static_cast(border)); + } + } + for (size_t i = border; i < right; i += N1) { + if (!st.Compare1(pivot, lanes + i)) { + HWY_ABORT( + "%s: asc %d right[%d] piv %.0f %.0f compares after %.0f %.0f " + "border %d", + st.KeyString().c_str(), asc, static_cast(i), + static_cast(pivot[1]), static_cast(pivot[0]), + static_cast(lanes[i + 1]), static_cast(lanes[i]), + static_cast(border)); + } + } +} + +template +static HWY_NOINLINE void TestPartition() { + using LaneType = typename Traits::LaneType; + const SortTag d; + SharedTraits st; + const bool asc = typename Traits::Order().IsAscending(); + const size_t N = Lanes(d); + constexpr int kDebug = 0; + const size_t base_case_num = SortConstants::BaseCaseNum(N); + // left + len + align + const size_t total = 32 + (base_case_num + 4 * HWY_MAX(N, 4)) + 2 * N; + auto aligned_lanes = hwy::AllocateAligned(total); + auto buf = hwy::AllocateAligned(SortConstants::PartitionBufNum(N)); + + const size_t N1 = st.LanesPerKey(); + for (bool in_asc : {false, true}) { + for (int left_i : {0, 1, 4, 6, 7, 8, 12, 15, 22, 28, 30, 31}) { + const size_t left = static_cast(left_i) & ~(N1 - 1); + for (size_t ofs : {N, N + 1, N + 3, 2 * N, 2 * N + 2, 2 * N + 3, + 3 * N - 1, 4 * N - 3, 4 * N - 2}) { + const size_t len = (base_case_num + ofs) & ~(N1 - 1); + for (LaneType pivot1 : + {LaneType(0), LaneType(len / 3), LaneType(len / 2), + LaneType(2 * len / 3), LaneType(len)}) { + const LaneType pivot2[2] = {pivot1, 0}; + const auto pivot = st.SetKey(d, pivot2); + for (size_t misalign = 0; misalign < N; + misalign += st.LanesPerKey()) { + LaneType* HWY_RESTRICT lanes = aligned_lanes.get() + misalign; + const size_t right = left + len; + if (kDebug) { + printf( + "=========%s asc %d left %d len %d right %d piv %.0f %.0f\n", + st.KeyString().c_str(), asc, static_cast(left), + static_cast(len), static_cast(right), + static_cast(pivot2[1]), + static_cast(pivot2[0])); + } + + for (size_t i = 0; i < misalign; ++i) { + aligned_lanes[i] = hwy::LowestValue(); + } + for (size_t i = 0; i < left; ++i) { + lanes[i] = hwy::LowestValue(); + } + std::unordered_map counts; + for (size_t i = left; i < right; ++i) { + lanes[i] = static_cast( + in_asc ? LaneType(i + 1) - static_cast(left) + : static_cast(right) - LaneType(i)); + ++counts[lanes[i]]; + if (kDebug >= 2) { + printf("%3zu: %f\n", i, static_cast(lanes[i])); + } + } + for (size_t i = right; i < total - misalign; ++i) { + lanes[i] = hwy::LowestValue(); + } + + size_t border = + left + detail::Partition(d, st, lanes + left, right - left, + pivot, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = left; i < right; ++i) { + printf("%3zu: %f\n", i, static_cast(lanes[i])); + } + for (size_t i = right; i < total - misalign; ++i) { + printf("%3zu: sentinel %f\n", i, static_cast(lanes[i])); + } + } + for (size_t i = left; i < right; ++i) { + --counts[lanes[i]]; + } + for (auto kv : counts) { + if (kv.second != 0) { + PrintValue(kv.first); + HWY_ABORT("Incorrect count %d\n", kv.second); + } + } + VerifyPartition(st, lanes, left, border, right, N1, pivot2); + for (size_t i = 0; i < misalign; ++i) { + if (aligned_lanes[i] != hwy::LowestValue()) + HWY_ABORT("Overrun misalign at %d\n", static_cast(i)); + } + for (size_t i = 0; i < left; ++i) { + if (lanes[i] != hwy::LowestValue()) + HWY_ABORT("Overrun left at %d\n", static_cast(i)); + } + for (size_t i = right; i < total - misalign; ++i) { + if (lanes[i] != hwy::LowestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // misalign + } // pivot + } // len + } // left + } // asc +} + +HWY_NOINLINE void TestAllPartition() { + TestPartition > >(); + TestPartition >(); + +#if !HWY_IS_DEBUG_BUILD + TestPartition > >(); + TestPartition > >(); + TestPartition > >(); +#if HWY_HAVE_FLOAT64 + TestPartition > >(); +#endif + TestPartition >(); +#endif +} + +// (used for sample selection for choosing a pivot) +template +static HWY_NOINLINE void TestRandomGenerator() { + static_assert(!hwy::IsSigned(), ""); + SortTag du; + const size_t N = Lanes(du); + + detail::Generator rng(&N, N); + + const size_t lanes_per_block = HWY_MAX(64 / sizeof(TU), N); // power of two + + for (uint32_t num_blocks = 2; num_blocks < 100000; + num_blocks = 3 * num_blocks / 2) { + // Generate some numbers and ensure all are in range + uint64_t sum = 0; + constexpr size_t kReps = 10000; + for (size_t rep = 0; rep < kReps; ++rep) { + const uint32_t bits = rng() & 0xFFFFFFFF; + const size_t index = detail::RandomChunkIndex(num_blocks, bits); + HWY_ASSERT(((index + 1) * lanes_per_block) <= + num_blocks * lanes_per_block); + + sum += index; + } + + // Also ensure the mean is near the middle of the range + const double expected = (num_blocks - 1) / 2.0; + const double actual = static_cast(sum) / kReps; + HWY_ASSERT(0.9 * expected <= actual && actual <= 1.1 * expected); + } +} + +HWY_NOINLINE void TestAllGenerator() { + TestRandomGenerator(); + TestRandomGenerator(); +} + +#else +static void TestAllMedian() {} +static void TestAllBaseCase() {} +static void TestAllPartition() {} +static void TestAllGenerator() {} +#endif // VQSORT_ENABLED + +// Remembers input, and compares results to that of a reference algorithm. +template +class CompareResults { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + + public: + CompareResults(const LaneType* in, size_t num_lanes) { + copy_.resize(num_lanes); + memcpy(copy_.data(), in, num_lanes * sizeof(LaneType)); + } + + bool Verify(const LaneType* output) { +#if HAVE_PDQSORT + const Algo reference = Algo::kPDQ; +#else + const Algo reference = Algo::kStd; +#endif + SharedState shared; + using Order = typename Traits::Order; + const Traits st; + const size_t num_keys = copy_.size() / st.LanesPerKey(); + Run(reference, reinterpret_cast(copy_.data()), num_keys, + shared, /*thread=*/0); +#if VQSORT_PRINT >= 3 + fprintf(stderr, "\nExpected:\n"); + for (size_t i = 0; i < copy_.size(); ++i) { + PrintValue(copy_[i]); + } + fprintf(stderr, "\n"); +#endif + for (size_t i = 0; i < copy_.size(); ++i) { + if (copy_[i] != output[i]) { + if (sizeof(KeyType) == 16) { + fprintf(stderr, + "%s Asc %d mismatch at %d of %d: %" PRIu64 " %" PRIu64 "\n", + st.KeyString().c_str(), Order().IsAscending(), + static_cast(i), static_cast(copy_.size()), + static_cast(copy_[i]), + static_cast(output[i])); + } else { + fprintf(stderr, "Type %s Asc %d mismatch at %d of %d: ", + st.KeyString().c_str(), Order().IsAscending(), + static_cast(i), static_cast(copy_.size())); + PrintValue(copy_[i]); + PrintValue(output[i]); + fprintf(stderr, "\n"); + } + return false; + } + } + return true; + } + + private: + std::vector copy_; +}; + +std::vector AlgoForTest() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif + Algo::kHeap, Algo::kVQSort, + }; +} + +template +void TestSort(size_t num_lanes) { +// Workaround for stack overflow on clang-cl (/F 8388608 does not help). +#if defined(_MSC_VER) + return; +#endif + using Order = typename Traits::Order; + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + SharedState shared; + SharedTraits st; + + // Round up to a whole number of keys. + num_lanes += (st.Is128() && (num_lanes & 1)); + const size_t num_keys = num_lanes / st.LanesPerKey(); + + constexpr size_t kMaxMisalign = 16; + auto aligned = + hwy::AllocateAligned(kMaxMisalign + num_lanes + kMaxMisalign); + for (Algo algo : AlgoForTest()) { + for (Dist dist : AllDist()) { + for (size_t misalign : {size_t{0}, size_t{st.LanesPerKey()}, + size_t{3 * st.LanesPerKey()}, kMaxMisalign / 2}) { + LaneType* lanes = aligned.get() + misalign; + + // Set up red zones before/after the keys to sort + for (size_t i = 0; i < misalign; ++i) { + aligned[i] = hwy::LowestValue(); + } + for (size_t i = 0; i < kMaxMisalign; ++i) { + lanes[num_lanes + i] = hwy::HighestValue(); + } +#if HWY_IS_MSAN + __msan_poison(aligned.get(), misalign * sizeof(LaneType)); + __msan_poison(lanes + num_lanes, kMaxMisalign * sizeof(LaneType)); +#endif + InputStats input_stats = + GenerateInput(dist, lanes, num_lanes); + + CompareResults compare(lanes, num_lanes); + Run(algo, reinterpret_cast(lanes), num_keys, shared, + /*thread=*/0); + HWY_ASSERT(compare.Verify(lanes)); + HWY_ASSERT(VerifySort(st, input_stats, lanes, num_lanes, "TestSort")); + + // Check red zones +#if HWY_IS_MSAN + __msan_unpoison(aligned.get(), misalign * sizeof(LaneType)); + __msan_unpoison(lanes + num_lanes, kMaxMisalign * sizeof(LaneType)); +#endif + for (size_t i = 0; i < misalign; ++i) { + if (aligned[i] != hwy::LowestValue()) + HWY_ABORT("Overrun left at %d\n", static_cast(i)); + } + for (size_t i = num_lanes; i < num_lanes + kMaxMisalign; ++i) { + if (lanes[i] != hwy::HighestValue()) + HWY_ABORT("Overrun right at %d\n", static_cast(i)); + } + } // misalign + } // dist + } // algo +} + +void TestAllSort() { + for (int num : {129, 504, 3 * 1000, 34567}) { + const size_t num_lanes = AdjustedReps(static_cast(num)); + TestSort > >(num_lanes); + TestSort > >(num_lanes); + + TestSort > >(num_lanes); + TestSort > >(num_lanes); + + TestSort > >(num_lanes); + TestSort > >(num_lanes); + + // WARNING: for float types, SIMD comparisons will flush denormals to + // zero, causing mismatches with scalar sorts. In this test, we avoid + // generating denormal inputs. + TestSort > >(num_lanes); +#if HWY_HAVE_FLOAT64 // protects algo-inl's GenerateRandom + if (Sorter::HaveFloat64()) { + TestSort > >(num_lanes); + } +#endif + +// Our HeapSort does not support 128-bit keys. +#if VQSORT_ENABLED + TestSort >(num_lanes); + TestSort >(num_lanes); + + TestSort >(num_lanes); + TestSort >(num_lanes); + + TestSort >(num_lanes); + TestSort >(num_lanes); +#endif + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(SortTest); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllMedian); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllBaseCase); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllPartition); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllGenerator); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllSort); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/sorting_networks-inl.h b/media/highway/src/hwy/contrib/sort/sorting_networks-inl.h new file mode 100644 index 000000000..3cc545b7a --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/sorting_networks-inl.h @@ -0,0 +1,695 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#endif + +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED + +using Constants = hwy::SortConstants; + +// ------------------------------ SharedTraits + +// Code shared between all traits. It's unclear whether these can profitably be +// specialized for Lane vs Block, or optimized like SortPairsDistance1 using +// Compare/DupOdd. +template +struct SharedTraits : public Base { + // Conditionally swaps lane 0 with 2, 1 with 3 etc. + template + HWY_INLINE Vec SortPairsDistance2(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentPairs(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse8(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys8(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse16(D d, Vec v) const { + const Base* base = static_cast(this); + static_assert(Constants::kMaxCols <= 16, "Need actual Reverse16"); + Vec swapped = base->ReverseKeys(d, v); + base->Sort2(d, v, swapped); + return ConcatUpperLower(d, swapped, v); // 8 = half of the vector + } +}; + +// ------------------------------ Sorting network + +// (Green's irregular) sorting network for independent columns in 16 vectors. +template > +HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + st.Sort2(d, v8, va); + st.Sort2(d, v9, vb); + st.Sort2(d, vc, ve); + st.Sort2(d, vd, vf); + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + st.Sort2(d, v8, vc); + st.Sort2(d, v9, vd); + st.Sort2(d, va, ve); + st.Sort2(d, vb, vf); + st.Sort2(d, v0, v8); + st.Sort2(d, v1, v9); + st.Sort2(d, v2, va); + st.Sort2(d, v3, vb); + st.Sort2(d, v4, vc); + st.Sort2(d, v5, vd); + st.Sort2(d, v6, ve); + st.Sort2(d, v7, vf); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v3, vc); + st.Sort2(d, v7, vb); + st.Sort2(d, vd, ve); + st.Sort2(d, v4, v8); + st.Sort2(d, v1, v2); + st.Sort2(d, v1, v4); + st.Sort2(d, v7, vd); + st.Sort2(d, v2, v8); + st.Sort2(d, vb, ve); + st.Sort2(d, v2, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vd); + st.Sort2(d, v3, v8); + st.Sort2(d, v7, vc); + st.Sort2(d, v3, v5); + st.Sort2(d, v6, v8); + st.Sort2(d, v7, v9); + st.Sort2(d, va, vc); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v7, v8); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vc); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); +} + +// ------------------------------ Merging networks + +// Blacher's hybrid bitonic/odd-even networks, generated by print_network.cc. + +template > +HWY_INLINE void Merge2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys2(d, v8); + v9 = st.ReverseKeys2(d, v9); + va = st.ReverseKeys2(d, va); + vb = st.ReverseKeys2(d, vb); + vc = st.ReverseKeys2(d, vc); + vd = st.ReverseKeys2(d, vd); + ve = st.ReverseKeys2(d, ve); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys2(d, v4); + vc = st.ReverseKeys2(d, vc); + v5 = st.ReverseKeys2(d, v5); + vd = st.ReverseKeys2(d, vd); + v6 = st.ReverseKeys2(d, v6); + ve = st.ReverseKeys2(d, ve); + v7 = st.ReverseKeys2(d, v7); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys2(d, v2); + v3 = st.ReverseKeys2(d, v3); + v6 = st.ReverseKeys2(d, v6); + v7 = st.ReverseKeys2(d, v7); + va = st.ReverseKeys2(d, va); + vb = st.ReverseKeys2(d, vb); + ve = st.ReverseKeys2(d, ve); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + v9 = st.ReverseKeys2(d, v9); + vb = st.ReverseKeys2(d, vb); + vd = st.ReverseKeys2(d, vd); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template > +HWY_INLINE void Merge4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys4(d, v8); + v9 = st.ReverseKeys4(d, v9); + va = st.ReverseKeys4(d, va); + vb = st.ReverseKeys4(d, vb); + vc = st.ReverseKeys4(d, vc); + vd = st.ReverseKeys4(d, vd); + ve = st.ReverseKeys4(d, ve); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys4(d, v4); + vc = st.ReverseKeys4(d, vc); + v5 = st.ReverseKeys4(d, v5); + vd = st.ReverseKeys4(d, vd); + v6 = st.ReverseKeys4(d, v6); + ve = st.ReverseKeys4(d, ve); + v7 = st.ReverseKeys4(d, v7); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys4(d, v2); + v3 = st.ReverseKeys4(d, v3); + v6 = st.ReverseKeys4(d, v6); + v7 = st.ReverseKeys4(d, v7); + va = st.ReverseKeys4(d, va); + vb = st.ReverseKeys4(d, vb); + ve = st.ReverseKeys4(d, ve); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + v9 = st.ReverseKeys4(d, v9); + vb = st.ReverseKeys4(d, vb); + vd = st.ReverseKeys4(d, vd); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + v8 = st.SortPairsReverse4(d, v8); + v9 = st.SortPairsReverse4(d, v9); + va = st.SortPairsReverse4(d, va); + vb = st.SortPairsReverse4(d, vb); + vc = st.SortPairsReverse4(d, vc); + vd = st.SortPairsReverse4(d, vd); + ve = st.SortPairsReverse4(d, ve); + vf = st.SortPairsReverse4(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template > +HWY_INLINE void Merge8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys8(d, v8); + v9 = st.ReverseKeys8(d, v9); + va = st.ReverseKeys8(d, va); + vb = st.ReverseKeys8(d, vb); + vc = st.ReverseKeys8(d, vc); + vd = st.ReverseKeys8(d, vd); + ve = st.ReverseKeys8(d, ve); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys8(d, v4); + vc = st.ReverseKeys8(d, vc); + v5 = st.ReverseKeys8(d, v5); + vd = st.ReverseKeys8(d, vd); + v6 = st.ReverseKeys8(d, v6); + ve = st.ReverseKeys8(d, ve); + v7 = st.ReverseKeys8(d, v7); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys8(d, v2); + v3 = st.ReverseKeys8(d, v3); + v6 = st.ReverseKeys8(d, v6); + v7 = st.ReverseKeys8(d, v7); + va = st.ReverseKeys8(d, va); + vb = st.ReverseKeys8(d, vb); + ve = st.ReverseKeys8(d, ve); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys8(d, v1); + v3 = st.ReverseKeys8(d, v3); + v5 = st.ReverseKeys8(d, v5); + v7 = st.ReverseKeys8(d, v7); + v9 = st.ReverseKeys8(d, v9); + vb = st.ReverseKeys8(d, vb); + vd = st.ReverseKeys8(d, vd); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse8(d, v0); + v1 = st.SortPairsReverse8(d, v1); + v2 = st.SortPairsReverse8(d, v2); + v3 = st.SortPairsReverse8(d, v3); + v4 = st.SortPairsReverse8(d, v4); + v5 = st.SortPairsReverse8(d, v5); + v6 = st.SortPairsReverse8(d, v6); + v7 = st.SortPairsReverse8(d, v7); + v8 = st.SortPairsReverse8(d, v8); + v9 = st.SortPairsReverse8(d, v9); + va = st.SortPairsReverse8(d, va); + vb = st.SortPairsReverse8(d, vb); + vc = st.SortPairsReverse8(d, vc); + vd = st.SortPairsReverse8(d, vd); + ve = st.SortPairsReverse8(d, ve); + vf = st.SortPairsReverse8(d, vf); + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +// Unused on MSVC, see below +#if !HWY_COMPILER_MSVC + +template > +HWY_INLINE void Merge16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, + V& vd, V& ve, V& vf) { + v8 = st.ReverseKeys16(d, v8); + v9 = st.ReverseKeys16(d, v9); + va = st.ReverseKeys16(d, va); + vb = st.ReverseKeys16(d, vb); + vc = st.ReverseKeys16(d, vc); + vd = st.ReverseKeys16(d, vd); + ve = st.ReverseKeys16(d, ve); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys16(d, v4); + vc = st.ReverseKeys16(d, vc); + v5 = st.ReverseKeys16(d, v5); + vd = st.ReverseKeys16(d, vd); + v6 = st.ReverseKeys16(d, v6); + ve = st.ReverseKeys16(d, ve); + v7 = st.ReverseKeys16(d, v7); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys16(d, v2); + v3 = st.ReverseKeys16(d, v3); + v6 = st.ReverseKeys16(d, v6); + v7 = st.ReverseKeys16(d, v7); + va = st.ReverseKeys16(d, va); + vb = st.ReverseKeys16(d, vb); + ve = st.ReverseKeys16(d, ve); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys16(d, v1); + v3 = st.ReverseKeys16(d, v3); + v5 = st.ReverseKeys16(d, v5); + v7 = st.ReverseKeys16(d, v7); + v9 = st.ReverseKeys16(d, v9); + vb = st.ReverseKeys16(d, vb); + vd = st.ReverseKeys16(d, vd); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse16(d, v0); + v1 = st.SortPairsReverse16(d, v1); + v2 = st.SortPairsReverse16(d, v2); + v3 = st.SortPairsReverse16(d, v3); + v4 = st.SortPairsReverse16(d, v4); + v5 = st.SortPairsReverse16(d, v5); + v6 = st.SortPairsReverse16(d, v6); + v7 = st.SortPairsReverse16(d, v7); + v8 = st.SortPairsReverse16(d, v8); + v9 = st.SortPairsReverse16(d, v9); + va = st.SortPairsReverse16(d, va); + vb = st.SortPairsReverse16(d, vb); + vc = st.SortPairsReverse16(d, vc); + vd = st.SortPairsReverse16(d, vd); + ve = st.SortPairsReverse16(d, ve); + vf = st.SortPairsReverse16(d, vf); + v0 = st.SortPairsDistance4(d, v0); + v1 = st.SortPairsDistance4(d, v1); + v2 = st.SortPairsDistance4(d, v2); + v3 = st.SortPairsDistance4(d, v3); + v4 = st.SortPairsDistance4(d, v4); + v5 = st.SortPairsDistance4(d, v5); + v6 = st.SortPairsDistance4(d, v6); + v7 = st.SortPairsDistance4(d, v7); + v8 = st.SortPairsDistance4(d, v8); + v9 = st.SortPairsDistance4(d, v9); + va = st.SortPairsDistance4(d, va); + vb = st.SortPairsDistance4(d, vb); + vc = st.SortPairsDistance4(d, vc); + vd = st.SortPairsDistance4(d, vd); + ve = st.SortPairsDistance4(d, ve); + vf = st.SortPairsDistance4(d, vf); + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +#endif // !HWY_COMPILER_MSVC + +// Reshapes `buf` into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// `st` is SharedTraits>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. +// `buf` ensures full vectors are aligned, and enables loads/stores without +// bounds checks. +// +// NOINLINE because this is large and called twice from vqsort-inl.h. +// +// References: +// https://drops.dagstuhl.de/opus/volltexte/2021/13775/pdf/LIPIcs-SEA-2021-3.pdf +// https://github.com/simd-sorting/fast-and-robust/blob/master/avx2_sort_demo/avx2sort.h +// "Entwurf und Implementierung vektorisierter Sortieralgorithmen" (M. Blacher) +template +HWY_NOINLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) { + const CappedTag d; + using V = decltype(Zero(d)); + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // The network width depends on the number of keys, not lanes. + constexpr size_t kLanesPerKey = st.LanesPerKey(); + const size_t keys = cols / kLanesPerKey; + constexpr size_t kMaxKeys = MaxLanes(d) / kLanesPerKey; + + // These are aligned iff cols == Lanes(d). We prefer unaligned/non-constexpr + // offsets to duplicating this code for every value of cols. + static_assert(Constants::kMaxRows == 16, "Update loads/stores/args"); + V v0 = LoadU(d, buf + 0x0 * cols); + V v1 = LoadU(d, buf + 0x1 * cols); + V v2 = LoadU(d, buf + 0x2 * cols); + V v3 = LoadU(d, buf + 0x3 * cols); + V v4 = LoadU(d, buf + 0x4 * cols); + V v5 = LoadU(d, buf + 0x5 * cols); + V v6 = LoadU(d, buf + 0x6 * cols); + V v7 = LoadU(d, buf + 0x7 * cols); + V v8 = LoadU(d, buf + 0x8 * cols); + V v9 = LoadU(d, buf + 0x9 * cols); + V va = LoadU(d, buf + 0xa * cols); + V vb = LoadU(d, buf + 0xb * cols); + V vc = LoadU(d, buf + 0xc * cols); + V vd = LoadU(d, buf + 0xd * cols); + V ve = LoadU(d, buf + 0xe * cols); + V vf = LoadU(d, buf + 0xf * cols); + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Checking MaxLanes avoids generating HWY_ASSERT code for the unreachable + // code paths: if MaxLanes < 2, then keys <= cols < 2. + if (HWY_LIKELY(keys >= 2 && kMaxKeys >= 2)) { + Merge2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, + vf); + + if (HWY_LIKELY(keys >= 4 && kMaxKeys >= 4)) { + Merge4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, + vf); + + if (HWY_LIKELY(keys >= 8 && kMaxKeys >= 8)) { + Merge8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, + ve, vf); + + // Avoids build timeout. Must match #if condition in kMaxCols. +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + if (HWY_LIKELY(keys >= 16 && kMaxKeys >= 16)) { + Merge16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, + ve, vf); + + static_assert(Constants::kMaxCols <= 16, "Add more branches"); + } +#endif + } + } + } + + StoreU(v0, d, buf + 0x0 * cols); + StoreU(v1, d, buf + 0x1 * cols); + StoreU(v2, d, buf + 0x2 * cols); + StoreU(v3, d, buf + 0x3 * cols); + StoreU(v4, d, buf + 0x4 * cols); + StoreU(v5, d, buf + 0x5 * cols); + StoreU(v6, d, buf + 0x6 * cols); + StoreU(v7, d, buf + 0x7 * cols); + StoreU(v8, d, buf + 0x8 * cols); + StoreU(v9, d, buf + 0x9 * cols); + StoreU(va, d, buf + 0xa * cols); + StoreU(vb, d, buf + 0xb * cols); + StoreU(vc, d, buf + 0xc * cols); + StoreU(vd, d, buf + 0xd * cols); + StoreU(ve, d, buf + 0xe * cols); + StoreU(vf, d, buf + 0xf * cols); +} + +#else +template +struct SharedTraits : public Base {}; +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/traits-inl.h b/media/highway/src/hwy/contrib/sort/traits-inl.h new file mode 100644 index 000000000..8b87c8262 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/traits-inl.h @@ -0,0 +1,527 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#endif + +#include + +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/contrib/sort/vqsort.h" // SortDescending +#include "hwy/highway.h" +#include "hwy/print.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED || HWY_IDE + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +template +struct KeyLane { + static constexpr bool Is128() { return false; } + constexpr size_t LanesPerKey() const { return 1; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = T; + // What type to pass to Sorter::operator(). + using KeyType = T; + + std::string KeyString() const { + char string100[100]; + hwy::detail::TypeName(hwy::detail::MakeTypeInfo(), 1, string100); + return string100; + } + + // For HeapSort + HWY_INLINE void Swap(T* a, T* b) const { + const T temp = *a; + *a = *b; + *b = temp; + } + + template + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressNot(keys, mask); + } + + // Broadcasts one key into a vector + template + HWY_INLINE Vec SetKey(D d, const T* key) const { + return Set(d, *key); + } + + template + HWY_INLINE Mask EqualKeys(D /*tag*/, Vec a, Vec b) const { + return Eq(a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D /*tag*/, Vec a, Vec b) const { + return Ne(a, b); + } + + HWY_INLINE bool Equal1(const T* a, const T* b) { return *a == *b; } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return Reverse(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D d, Vec v) const { + return Reverse2(d, v); + } + + template + HWY_INLINE Vec ReverseKeys4(D d, Vec v) const { + return Reverse4(d, v); + } + + template + HWY_INLINE Vec ReverseKeys8(D d, Vec v) const { + return Reverse8(d, v); + } + + template + HWY_INLINE Vec ReverseKeys16(D d, Vec v) const { + static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit"); + return ReverseKeys(d, v); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEven(odd, even); + } + + template + HWY_INLINE Vec SwapAdjacentPairs(D d, const Vec v) const { + const Repartition du32; + return BitCast(d, Shuffle2301(BitCast(du32, v))); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return Shuffle1032(v); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return SwapAdjacentBlocks(v); + } + + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide > dw; +#endif + return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v))); + } + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { + // Assumes max vector size = 512 + return ConcatLowerUpper(d, v, v); + } + + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide > dw; +#endif + return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenPairs(D /* tag */, Vec odd, Vec even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide > dw; +#endif + return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { + return ConcatUpperLower(d, odd, even); + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +template +struct OrderAscending : public KeyLane { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *a < *b; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(a, b); + } + + // Two halves of Sort2, used in ScanMinMax. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Sub(v, Set(d, hwy::Epsilon())); + } +}; + +template +struct OrderDescending : public KeyLane { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *b < *a; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Add(v, Set(d, hwy::Epsilon())); + } +}; + +struct OrderAscendingKV64 : public KeyLane { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (*a >> 32) < (*b >> 32); + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + // Same as for regular lanes. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Sub(v, Set(d, 1)); + } +}; + +struct OrderDescendingKV64 : public KeyLane { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (*b >> 32) < (*a >> 32); + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(ShiftRight<32>(b), ShiftRight<32>(a)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Add(v, Set(d, 1)); + } +}; + +// Shared code that depends on Order. +template +struct TraitsLane : public Base { + // For each lane i: replaces a[i] with the first and b[i] with the second + // according to Base. + // Corresponds to a conditional swap, which is one "node" of a sorting + // network. Min/Max are cheaper than compare + blend at least for integers. + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 + // instructions. We can reduce it to a compare + 2 IfThenElse. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + if (sizeof(TFromD) == 8) { + const Mask cmp = base->Compare(d, a, b); + a = IfThenElse(cmp, a, b); + b = IfThenElse(cmp, b, a_copy); + return; + } +#endif + a = base->First(d, a, b); + b = base->Last(d, a_copy, b); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + // Further to the above optimization, Sort2+OddEvenKeys compile to four + // instructions; we can save one by combining two blends. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); + return IfVecThenElse(DupOdd(cmp), swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // (See above - we use Sort2 for non-64-bit types.) + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentQuads(d, v); + // Only used in Merge16, so this will not be used on AVX2 (which only has 4 + // u64 lanes), so skip the above optimization for 64-bit AVX2. + Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } +}; + +#else + +// Base class shared between OrderAscending, OrderDescending. +template +struct KeyLane { + constexpr bool Is128() const { return false; } + constexpr size_t LanesPerKey() const { return 1; } + + using LaneType = T; + using KeyType = T; + + std::string KeyString() const { + char string100[100]; + hwy::detail::TypeName(hwy::detail::MakeTypeInfo(), 1, string100); + return string100; + } +}; + +template +struct OrderAscending : public KeyLane { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *a < *b; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) { + return Lt(a, b); + } +}; + +template +struct OrderDescending : public KeyLane { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *b < *a; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) { + return Lt(b, a); + } +}; + +template +struct TraitsLane : public Order { + // For HeapSort + template // MSVC doesn't find typename Order::LaneType. + HWY_INLINE void Swap(T* a, T* b) const { + const T temp = *a; + *a = *b; + *b = temp; + } + + template + HWY_INLINE Vec SetKey(D d, const TFromD* key) const { + return Set(d, *key); + } +}; + +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/traits128-inl.h b/media/highway/src/hwy/contrib/sort/traits128-inl.h new file mode 100644 index 000000000..c69206440 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/traits128-inl.h @@ -0,0 +1,492 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#endif + +#include + +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/contrib/sort/vqsort.h" // SortDescending +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED || HWY_IDE + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +struct KeyAny128 { + static constexpr bool Is128() { return true; } + constexpr size_t LanesPerKey() const { return 2; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = uint64_t; + // KeyType and KeyString are defined by derived classes. + + HWY_INLINE void Swap(LaneType* a, LaneType* b) const { + const FixedTag d; + const auto temp = LoadU(d, a); + StoreU(LoadU(d, b), d, a); + StoreU(temp, d, b); + } + + template + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressBlocksNot(keys, mask); + } + + template + HWY_INLINE Vec SetKey(D d, const TFromD* key) const { + return LoadDup128(d, key); + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return ReverseBlocks(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D /* tag */, const Vec v) const { + return SwapAdjacentBlocks(v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec ReverseKeys4(D d, const Vec v) const { + HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD)); + return ReverseKeys(d, v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { + HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD)); + return ConcatUpperLower(d, odd, even); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec ReverseKeys8(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 1024-bit vectors + } + + template + HWY_INLINE Vec ReverseKeys16(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 2048-bit vectors + } + + // This is only called for 8/16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentPairs(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentQuads(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 8 col networks (not supported). + template + HWY_INLINE Vec OddEvenQuads(D, Vec, Vec) const { + HWY_ASSERT(0); + } +}; + +// Base class shared between OrderAscending128, OrderDescending128. +struct Key128 : public KeyAny128 { + // What type to pass to Sorter::operator(). + using KeyType = hwy::uint128_t; + + std::string KeyString() const { return "U128"; } + + template + HWY_INLINE Mask EqualKeys(D d, Vec a, Vec b) const { + return Eq128(d, a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D d, Vec a, Vec b) const { + return Ne128(d, a, b); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) { + return a[0] == b[0] && a[1] == b[1]; + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +struct OrderAscending128 : public Key128 { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, a, b); + } + + // Used by CompareTop + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k0 = Zero(d); + const Vec k1 = OddEven(k0, Set(d, 1)); + const Mask borrow = Eq(v, k0); // don't-care, lo == 0 + // lo == 0? 1 : 0, 0 + const Vec adjust = ShiftLeftLanes<1>(IfThenElseZero(borrow, k1)); + return Sub(Sub(v, k1), adjust); + } +}; + +struct OrderDescending128 : public Key128 { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, b, a); + } + + // Used by CompareTop + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Zero(d), Set(d, 1)); + const Vec added = Add(v, k1); + const Mask overflowed = Lt(added, v); // false, overflowed + // overflowed? 1 : 0, 0 + const Vec adjust = ShiftLeftLanes<1>(IfThenElseZero(overflowed, k1)); + return Add(added, adjust); + } +}; + +// Base class shared between OrderAscendingKV128, OrderDescendingKV128. +struct KeyValue128 : public KeyAny128 { + // What type to pass to Sorter::operator(). + using KeyType = K64V64; + + std::string KeyString() const { return "KV128"; } + + template + HWY_INLINE Mask EqualKeys(D d, Vec a, Vec b) const { + return Eq128Upper(d, a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D d, Vec a, Vec b) const { + return Ne128Upper(d, a, b); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) { + return a[1] == b[1]; + } +}; + +struct OrderAscendingKV128 : public KeyValue128 { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128Upper(d, a, b); + } + + // Used by CompareTop + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128Upper(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128Upper(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Set(d, 1), Zero(d)); + return Sub(v, k1); + } +}; + +struct OrderDescendingKV128 : public KeyValue128 { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128Upper(d, b, a); + } + + // Used by CompareTop + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128Upper(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128Upper(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Set(d, 1), Zero(d)); + return Add(v, k1); + } +}; + +// Shared code that depends on Order. +template +class Traits128 : public Base { + // Special case for >= 256 bit vectors +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256 + // Returns vector with only the top u64 lane valid. Useful when the next step + // is to replicate the mask anyway. + template + HWY_INLINE HWY_MAYBE_UNUSED Vec CompareTop(D d, Vec a, Vec b) const { + const Base* base = static_cast(this); + const Mask eqHL = Eq(a, b); + const Vec ltHL = VecFromMask(d, base->CompareLanes(a, b)); +#if HWY_TARGET == HWY_SVE_256 + return IfThenElse(eqHL, DupEven(ltHL), ltHL); +#else + const Vec ltLX = ShiftLeftLanes<1>(ltHL); + return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX); +#endif + } + + // We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in + // the most-significant of those lanes (the result of CompareTop), so + // replicate it 4x. Only called for >= 256-bit vectors. + template + HWY_INLINE V ReplicateTop4x(V v) const { +#if HWY_TARGET == HWY_SVE_256 + return svdup_lane_u64(v, 3); +#elif HWY_TARGET <= HWY_AVX3 + return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +#else // AVX2 + return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +#endif + } +#endif // HWY_TARGET + + public: + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const Base* base = static_cast(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->First(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const Base* base = static_cast(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->Last(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + const auto lt = base->Compare(d, a, b); + a = IfThenElse(lt, a, b); + b = IfThenElse(lt, b, a_copy); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256 + const Vec select = ReplicateTop4x(CompareTop(d, v, swapped)); + return IfVecThenElse(select, swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + + // Only specialize for AVX3 because this requires 512-bit vectors. +#if HWY_TARGET <= HWY_AVX3 + const Vec512 outHx = CompareTop(d, v, swapped); + // Similar to ReplicateTop4x, we want to gang together 2 comparison results + // (4 lanes). They are not contiguous, so use permute to replicate 4x. + alignas(64) uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7}; + const Vec512 select = + TableLookupLanes(outHx, SetTableIndices(d, kIndices)); + return IfVecThenElse(select, swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); +#endif + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D, Vec) const { + // Only used by Merge16, which would require 2048 bit vectors (unsupported). + HWY_ASSERT(0); + } +}; + +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/vqsort-inl.h b/media/highway/src/hwy/contrib/sort/vqsort-inl.h new file mode 100644 index 000000000..10584d246 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort-inl.h @@ -0,0 +1,1443 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +#ifndef VQSORT_PRINT +#define VQSORT_PRINT 0 +#endif + +// Makes it harder for adversaries to predict our sampling locations, at the +// cost of 1-2% increased runtime. +#ifndef VQSORT_SECURE_RNG +#define VQSORT_SECURE_RNG 0 +#endif + +#if VQSORT_SECURE_RNG +#include "third_party/absl/random/random.h" +#endif + +#include // unconditional #include so we can use if(VQSORT_PRINT). +#include // memcpy + +#include "hwy/cache_control.h" // Prefetch +#include "hwy/contrib/sort/vqsort.h" // Fill24Bytes + +#if HWY_IS_MSAN +#include +#endif + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#endif + +#if VQSORT_PRINT +#include "hwy/print-inl.h" +#endif + +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" +// Placeholder for internal instrumentation. Do not remove. +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +using Constants = hwy::SortConstants; + +// Wrappers to avoid #if in user code (interferes with code folding) + +HWY_INLINE void UnpoisonIfMemorySanitizer(void* p, size_t bytes) { +#if HWY_IS_MSAN + __msan_unpoison(p, bytes); +#else + (void)p; + (void)bytes; +#endif +} + +template +HWY_INLINE void MaybePrintVector(D d, const char* label, Vec v, + size_t start = 0, size_t max_lanes = 16) { +#if VQSORT_PRINT >= 2 // Print is only defined #if + Print(d, label, v, start, max_lanes); +#else + (void)d; + (void)label; + (void)v; + (void)start; + (void)max_lanes; +#endif +} + +// ------------------------------ HeapSort + +template +void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + size_t start) { + constexpr size_t N1 = st.LanesPerKey(); + const FixedTag d; + + while (start < num_lanes) { + const size_t left = 2 * start + N1; + const size_t right = 2 * start + 2 * N1; + if (left >= num_lanes) break; + size_t idx_larger = start; + const auto key_j = st.SetKey(d, lanes + start); + if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) { + idx_larger = left; + } + if (right < num_lanes && + AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger), + st.SetKey(d, lanes + right)))) { + idx_larger = right; + } + if (idx_larger == start) break; + st.Swap(lanes + start, lanes + idx_larger); + start = idx_larger; + } +} + +// Heapsort: O(1) space, O(N*logN) worst-case comparisons. +// Based on LLVM sanitizer_common.h, licensed under Apache-2.0. +template +void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) { + constexpr size_t N1 = st.LanesPerKey(); + + if (num_lanes < 2 * N1) return; + + // Build heap. + for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { + SiftDown(st, lanes, num_lanes, i); + } + + for (size_t i = num_lanes - N1; i != 0; i -= N1) { + // Swap root with last + st.Swap(lanes + 0, lanes + i); + + // Sift down the new root. + SiftDown(st, lanes, i, 0); + } +} + +#if VQSORT_ENABLED || HWY_IDE + +// ------------------------------ BaseCase + +// Sorts `keys` within the range [0, num) via sorting network. +template +HWY_NOINLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys, + T* HWY_RESTRICT keys_end, size_t num, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + using V = decltype(Zero(d)); + + // _Nonzero32 requires num - 1 != 0. + if (HWY_UNLIKELY(num <= 1)) return; + + // Reshape into a matrix with kMaxRows rows, and columns limited by the + // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum). + const size_t num_pow2 = size_t{1} + << (32 - Num0BitsAboveMS1Bit_Nonzero32( + static_cast(num - 1))); + HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N)); + const size_t cols = + HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2); + HWY_DASSERT(cols <= N); + + // We can avoid padding and load/store directly to `keys` after checking the + // original input array has enough space. Except at the right border, it's OK + // to sort more than the current sub-array. Even if we sort across a previous + // partition point, we know that keys will not migrate across it. However, we + // must use the maximum size of the sorting network, because the StoreU of its + // last vector would otherwise write invalid data starting at kMaxRows * cols. + const size_t N_sn = Lanes(CappedTag()); + if (HWY_LIKELY(keys + N_sn * Constants::kMaxRows <= keys_end)) { + SortingNetwork(st, keys, N_sn); + return; + } + + // Copy `keys` to `buf`. + size_t i; + for (i = 0; i + N <= num; i += N) { + Store(LoadU(d, keys + i), d, buf + i); + } + SafeCopyN(num - i, d, keys + i, buf + i); + i = num; + + // Fill with padding - last in sort order, not copied to keys. + const V kPadding = st.LastValue(d); + // Initialize an extra vector because SortingNetwork loads full vectors, + // which may exceed cols*kMaxRows. + for (; i < (cols * Constants::kMaxRows + N); i += N) { + StoreU(kPadding, d, buf + i); + } + + SortingNetwork(st, buf, cols); + + for (i = 0; i + N <= num; i += N) { + StoreU(Load(d, buf + i), d, keys + i); + } + SafeCopyN(num - i, d, buf + i, keys + i); +} + +// ------------------------------ Partition + +// Consumes from `keys` until a multiple of kUnroll*N remains. +// Temporarily stores the right side into `buf`, then moves behind `num`. +// Returns the number of keys consumed from the left side. +template +HWY_NOINLINE size_t PartitionToMultipleOfUnroll(D d, Traits st, + T* HWY_RESTRICT keys, + size_t& num, const Vec pivot, + T* HWY_RESTRICT buf) { + constexpr size_t kUnroll = Constants::kPartitionUnroll; + const size_t N = Lanes(d); + size_t readL = 0; + T* HWY_RESTRICT posL = keys; + size_t bufR = 0; + // Partition requires both a multiple of kUnroll*N and at least + // 2*kUnroll*N for the initial loads. If less, consume all here. + const size_t num_rem = + (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1)); + size_t i = 0; + for (; i + N <= num_rem; i += N) { + const Vec vL = LoadU(d, keys + readL); + readL += N; + + const auto comp = st.Compare(d, pivot, vL); + posL += CompressBlendedStore(vL, Not(comp), d, posL); + bufR += CompressStore(vL, comp, d, buf + bufR); + } + // Last iteration: only use valid lanes. + if (HWY_LIKELY(i != num_rem)) { + const auto mask = FirstN(d, num_rem - i); + const Vec vL = LoadU(d, keys + readL); + + const auto comp = st.Compare(d, pivot, vL); + posL += CompressBlendedStore(vL, AndNot(comp, mask), d, posL); + bufR += CompressStore(vL, And(comp, mask), d, buf + bufR); + } + + // MSAN seems not to understand CompressStore. buf[0, bufR) are valid. + UnpoisonIfMemorySanitizer(buf, bufR * sizeof(T)); + + // Everything we loaded was put into buf, or behind the current `posL`, after + // which there is space for bufR items. First move items from `keys + num` to + // `posL` to free up space, then copy `buf` into the vacated `keys + num`. + // A loop with masked loads from `buf` is insufficient - we would also need to + // mask from `keys + num`. Combining a loop with memcpy for the remainders is + // slower than just memcpy, so we use that for simplicity. + num -= bufR; + memcpy(posL, keys + num, bufR * sizeof(T)); + memcpy(keys + num, buf, bufR * sizeof(T)); + return static_cast(posL - keys); // caller will shrink num by this. +} + +template +V OrXor(const V o, const V x1, const V x2) { + // TODO(janwas): add op so we can benefit from AVX-512 ternlog? + return Or(o, Xor(x1, x2)); +} + +// Note: we could track the OrXor of v and pivot to see if the entire left +// partition is equal, but that happens rarely and thus is a net loss. +template +HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec v, + const Vec pivot, T* HWY_RESTRICT keys, + size_t& writeL, size_t& remaining) { + const size_t N = Lanes(d); + + const auto comp = st.Compare(d, pivot, v); + + remaining -= N; + if (hwy::HWY_NAMESPACE::CompressIsPartition::value || + (HWY_MAX_BYTES == 16 && st.Is128())) { + // Non-native Compress (e.g. AVX2): we are able to partition a vector using + // a single Compress+two StoreU instead of two Compress[Blended]Store. The + // latter are more expensive. Because we store entire vectors, the contents + // between the updated writeL and writeR are ignored and will be overwritten + // by subsequent calls. This works because writeL and writeR are at least + // two vectors apart. + const auto lr = st.CompressKeys(v, comp); + const size_t num_left = N - CountTrue(d, comp); + StoreU(lr, d, keys + writeL); + // Now write the right-side elements (if any), such that the previous writeR + // is one past the end of the newly written right elements, then advance. + StoreU(lr, d, keys + remaining + writeL); + writeL += num_left; + } else { + // Native Compress[Store] (e.g. AVX3), which only keep the left or right + // side, not both, hence we require two calls. + const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL); + writeL += num_left; + + (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL); + } +} + +template +HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec v0, + const Vec v1, const Vec v2, + const Vec v3, const Vec pivot, + T* HWY_RESTRICT keys, size_t& writeL, + size_t& remaining) { + StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining); +} + +// Moves "<= pivot" keys to the front, and others to the back. pivot is +// broadcasted. Time-critical! +// +// Aligned loads do not seem to be worthwhile (not bottlenecked by load ports). +template +HWY_NOINLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + const Vec pivot, T* HWY_RESTRICT buf) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all + // lanes happen to be in the right-side partition, this will overrun `keys`, + // which triggers asan errors. Avoid by special-casing the last vector. + HWY_DASSERT(num > 2 * N); // ensured by HandleSpecialCases + num -= N; + size_t last = num; + const V vlast = LoadU(d, keys + last); + + const size_t consumedL = + PartitionToMultipleOfUnroll(d, st, keys, num, pivot, buf); + keys += consumedL; + last -= consumedL; + num -= consumedL; + constexpr size_t kUnroll = Constants::kPartitionUnroll; + + // Partition splits the vector into 3 sections, left to right: Elements + // smaller or equal to the pivot, unpartitioned elements and elements larger + // than the pivot. To write elements unconditionally on the loop body without + // overwriting existing data, we maintain two regions of the loop where all + // elements have been copied elsewhere (e.g. vector registers.). I call these + // bufferL and bufferR, for left and right respectively. + // + // These regions are tracked by the indices (writeL, writeR, left, right) as + // presented in the diagram below. + // + // writeL writeR + // \/ \/ + // | <= pivot | bufferL | unpartitioned | bufferR | > pivot | + // \/ \/ + // left right + // + // In the main loop body below we choose a side, load some elements out of the + // vector and move either `left` or `right`. Next we call into StoreLeftRight + // to partition the data, and the partitioned elements will be written either + // to writeR or writeL and the corresponding index will be moved accordingly. + // + // Note that writeR is not explicitly tracked as an optimization for platforms + // with conditional operations. Instead we track writeL and the number of + // elements left to process (`remaining`). From the diagram above we can see + // that: + // writeR - writeL = remaining => writeR = remaining + writeL + // + // Tracking `remaining` is advantageous because each iteration reduces the + // number of unpartitioned elements by a fixed amount, so we can compute + // `remaining` without data dependencies. + // + size_t writeL = 0; + size_t remaining = num; + + const T* HWY_RESTRICT readL = keys; + const T* HWY_RESTRICT readR = keys + num; + // Cannot load if there were fewer than 2 * kUnroll * N. + if (HWY_LIKELY(num != 0)) { + HWY_DASSERT(num >= 2 * kUnroll * N); + HWY_DASSERT((num & (kUnroll * N - 1)) == 0); + + // Make space for writing in-place by reading from readL/readR. + const V vL0 = LoadU(d, readL + 0 * N); + const V vL1 = LoadU(d, readL + 1 * N); + const V vL2 = LoadU(d, readL + 2 * N); + const V vL3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + readR -= kUnroll * N; + const V vR0 = LoadU(d, readR + 0 * N); + const V vR1 = LoadU(d, readR + 1 * N); + const V vR2 = LoadU(d, readR + 2 * N); + const V vR3 = LoadU(d, readR + 3 * N); + + // readL/readR changed above, so check again before the loop. + while (readL != readR) { + V v0, v1, v2, v3; + + // Data-dependent but branching is faster than forcing branch-free. + const size_t capacityL = + static_cast((readL - keys) - static_cast(writeL)); + HWY_DASSERT(capacityL <= num); // >= 0 + // Load data from the end of the vector with less data (front or back). + // The next paragraphs explain how this works. + // + // let block_size = (kUnroll * N) + // On the loop prelude we load block_size elements from the front of the + // vector and an additional block_size elements from the back. On each + // iteration k elements are written to the front of the vector and + // (block_size - k) to the back. + // + // This creates a loop invariant where the capacity on the front + // (capacityL) and on the back (capacityR) always add to 2 * block_size. + // In other words: + // capacityL + capacityR = 2 * block_size + // capacityR = 2 * block_size - capacityL + // + // This means that: + // capacityL < capacityR <=> + // capacityL < 2 * block_size - capacityL <=> + // 2 * capacityL < 2 * block_size <=> + // capacityL < block_size + // + // Thus the check on the next line is equivalent to capacityL > capacityR. + // + if (kUnroll * N < capacityL) { + readR -= kUnroll * N; + v0 = LoadU(d, readR + 0 * N); + v1 = LoadU(d, readR + 1 * N); + v2 = LoadU(d, readR + 2 * N); + v3 = LoadU(d, readR + 3 * N); + hwy::Prefetch(readR - 3 * kUnroll * N); + } else { + v0 = LoadU(d, readL + 0 * N); + v1 = LoadU(d, readL + 1 * N); + v2 = LoadU(d, readL + 2 * N); + v3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + hwy::Prefetch(readL + 3 * kUnroll * N); + } + + StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining); + } + + // Now finish writing the saved vectors to the middle. + StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining); + StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, remaining); + } + + // We have partitioned [left, right) such that writeL is the boundary. + HWY_DASSERT(remaining == 0); + // Make space for inserting vlast: move up to N of the first right-side keys + // into the unused space starting at last. If we have fewer, ensure they are + // the last items in that vector by subtracting from the *load* address, + // which is safe because we have at least two vectors (checked above). + const size_t totalR = last - writeL; + const size_t startR = totalR < N ? writeL + totalR - N : writeL; + StoreU(LoadU(d, keys + startR), d, keys + last); + + // Partition vlast: write L, then R, into the single-vector gap at writeL. + const auto comp = st.Compare(d, pivot, vlast); + writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL); + (void)CompressBlendedStore(vlast, comp, d, keys + writeL); + + return consumedL + writeL; +} + +// Returns true and partitions if [keys, keys + num) contains only {valueL, +// valueR}. Otherwise, sets third to the first differing value; keys may have +// been reordered and a regular Partition is still necessary. +template +HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec valueL, + const Vec valueR, Vec& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + + size_t i = 0; + size_t writeL = 0; + + // As long as all lanes are equal to L or R, we can overwrite with valueL. + // This is faster than first counting, then backtracking to fill L and R. + for (; i + N <= num; i += N) { + const Vec v = LoadU(d, keys + i); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to AVX3. + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = st.EqualKeys(d, v, valueR); + // At least one other value present; will require a regular partition. + // On AVX-512, Or + AllTrue are folded into a single kortest if we are + // careful with the FindKnownFirstTrue argument, see below. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the + // loop, which is a pessimization because this if-true branch is cold. + // We can defeat this via Not(Xor), which is equivalent because eqL and + // eqR cannot be true at the same time. Can we elide the additional Not? + // FindFirstFalse instructions are generally unavailable, but we can + // fuse Not and Xor/Or into one ExclusiveNeither. + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i, writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + for (; writeL + N <= i; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL); + return false; + } + StoreU(valueL, d, keys + writeL); + writeL += CountTrue(d, eqL); + } + + // Final vector, masked comparison (no effect if i == num) + const size_t remaining = num - i; + SafeCopyN(remaining, d, keys + i, buf); + const Vec v = Load(d, buf); + const Mask valid = FirstN(d, remaining); + const Mask eqL = And(st.EqualKeys(d, v, valueL), valid); + const Mask eqR = st.EqualKeys(d, v, valueR); + // Invalid lanes are considered equal. + const Mask eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i, + writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + for (; writeL + N <= i; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL); + return false; + } + BlendedStore(valueL, valid, d, keys + writeL); + writeL += CountTrue(d, eqL); + + // Fill right side + i = writeL; + for (; i + N <= num; i += N) { + StoreU(valueR, d, keys + i); + } + BlendedStore(valueR, FirstN(d, num - i), d, keys + i); + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Successful MaybePartitionTwoValue\n"); + } + return true; +} + +// Same as above, except that the pivot equals valueR, so scan right to left. +template +HWY_NOINLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec valueL, + const Vec valueR, Vec& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + + HWY_DASSERT(num >= N); + size_t pos = num - N; // current read/write position + size_t countR = 0; // number of valueR found + + // For whole vectors, in descending address order: as long as all lanes are + // equal to L or R, overwrite with valueR. This is faster than counting, then + // filling both L and R. Loop terminates after unsigned wraparound. + for (; pos < num; pos -= N) { + const Vec v = LoadU(d, keys + pos); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to AVX3. + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = st.EqualKeys(d, v, valueR); + // If there is a third value, stop and undo what we've done. On AVX-512, + // Or + AllTrue are folded into a single kortest, but only if we are + // careful with the FindKnownFirstTrue argument - see prior comment on that. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + pos + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + for (; pos + N <= endL; pos += N) { + StoreU(valueL, d, keys + pos); + } + BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos); + return false; + } + StoreU(valueR, d, keys + pos); + countR += CountTrue(d, eqR); + } + + // Final partial (or empty) vector, masked comparison. + const size_t remaining = pos + N; + HWY_DASSERT(remaining <= N); + const Vec v = LoadU(d, keys); // Safe because num >= N. + const Mask valid = FirstN(d, remaining); + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = And(st.EqualKeys(d, v, valueR), valid); + // Invalid lanes are considered equal. + const Mask eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + for (; pos + N <= endL; pos += N) { + StoreU(valueL, d, keys + pos); + } + BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos); + return false; + } + const size_t lastR = CountTrue(d, eqR); + countR += lastR; + + // First finish writing valueR - [0, N) lanes were not yet written. + StoreU(valueR, d, keys); // Safe because num >= N. + + // Fill left side (ascending order for clarity) + const size_t endL = num - countR; + size_t i = 0; + for (; i + N <= endL; i += N) { + StoreU(valueL, d, keys + i); + } + Store(valueL, d, buf); + SafeCopyN(endL - i, d, buf, keys + i); // avoids asan overrun + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, + "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n", + countR, pos, i, endL); + } + + return true; +} + +// `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the +// second key. This is the first path into `MaybePartitionTwoValue`, called +// when all samples are equal. Returns false if there are at least a third +// value and sets `third`. Otherwise, partitions the array and returns true. +template +HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec pivot, + T* HWY_RESTRICT keys, size_t num, + const size_t idx_second, const Vec second, + Vec& third, T* HWY_RESTRICT buf) { + // True if second comes before pivot. + const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second)); + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second, + is_pivotR); + } + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot))); + + // If pivot is R, we scan backwards over the entire array. Otherwise, + // we already scanned up to idx_second and can leave those in place. + return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot, + third, buf) + : MaybePartitionTwoValue(d, st, keys + idx_second, + num - idx_second, pivot, second, + third, buf); +} + +// Second path into `MaybePartitionTwoValue`, called when not all samples are +// equal. `samples` is sorted. +template +HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + constexpr size_t N1 = st.LanesPerKey(); + const Vec valueL = st.SetKey(d, samples); + const Vec valueR = st.SetKey(d, samples + kSampleLanes - N1); + HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR))); + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR))); + const Vec prev = st.PrevValue(d, valueR); + // If the sample has more than two values, then the keys have at least that + // many, and thus this special case is inapplicable. + if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) { + return false; + } + + // Must not overwrite samples because if this returns false, caller wants to + // read the original samples again. + T* HWY_RESTRICT buf = samples + kSampleLanes; + Vec third; // unused + return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf); +} + +// ------------------------------ Pivot sampling + +template +HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) { + const DFromV d; + // Slightly faster for 128-bit, apparently because not serially dependent. + if (st.Is128()) { + // Median = XOR-sum 'minus' the first and last. Calling First twice is + // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR. + const auto sum = Xor(Xor(v0, v1), v2); + const auto first = st.First(d, st.First(d, v0, v1), v2); + const auto last = st.Last(d, st.Last(d, v0, v1), v2); + return Xor(Xor(sum, first), last); + } + st.Sort2(d, v0, v2); + v1 = st.Last(d, v0, v1); + v1 = st.First(d, v1, v2); + return v1; +} + +#if VQSORT_SECURE_RNG +using Generator = absl::BitGen; +#else +// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028 +#pragma pack(push, 1) +class Generator { + public: + Generator(const void* heap, size_t num) { + Sorter::Fill24Bytes(heap, num, &a_); + k_ = 1; // stream index: must be odd + } + + explicit Generator(uint64_t seed) { + a_ = b_ = w_ = seed; + k_ = 1; + } + + uint64_t operator()() { + const uint64_t b = b_; + w_ += k_; + const uint64_t next = a_ ^ w_; + a_ = (b + (b << 3)) ^ (b >> 11); + const uint64_t rot = (b << 24) | (b >> 40); + b_ = rot + next; + return next; + } + + private: + uint64_t a_; + uint64_t b_; + uint64_t w_; + uint64_t k_; // increment +}; +#pragma pack(pop) + +#endif // !VQSORT_SECURE_RNG + +// Returns slightly biased random index of a chunk in [0, num_chunks). +// See https://www.pcg-random.org/posts/bounded-rands.html. +HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) { + const uint64_t chunk_index = (static_cast(bits) * num_chunks) >> 32; + HWY_DASSERT(chunk_index < num_chunks); + return static_cast(chunk_index); +} + +// Writes samples from `keys[0, num)` into `buf`. +template +HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf, Generator& rng) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // Power of two + const size_t lanes_per_chunk = Constants::LanesPerChunk(sizeof(T), N); + + // Align start of keys to chunks. We always have at least 2 chunks because the + // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks. + HWY_DASSERT(num >= 2 * lanes_per_chunk); + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (lanes_per_chunk - 1); + if (misalign != 0) { + const size_t consume = lanes_per_chunk - misalign; + keys += consume; + num -= consume; + } + + // Generate enough random bits for 9 uint32 + uint64_t* bits64 = reinterpret_cast(buf); + for (size_t i = 0; i < 5; ++i) { + bits64[i] = rng(); + } + const uint32_t* bits = reinterpret_cast(buf); + + const uint32_t lpc32 = static_cast(lanes_per_chunk); + // Avoid division + const size_t log2_lpc = Num0BitsBelowLS1Bit_Nonzero32(lpc32); + const size_t num_chunks64 = num >> log2_lpc; + // Clamp to uint32 for RandomChunkIndex + const uint32_t num_chunks = + static_cast(HWY_MIN(num_chunks64, 0xFFFFFFFFull)); + + const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc; + const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc; + const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc; + const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc; + const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc; + const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc; + const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc; + const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc; + const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc; + for (size_t i = 0; i < lanes_per_chunk; i += N) { + const V v0 = Load(d, keys + offset0 + i); + const V v1 = Load(d, keys + offset1 + i); + const V v2 = Load(d, keys + offset2 + i); + const V medians0 = MedianOf3(st, v0, v1, v2); + Store(medians0, d, buf + i); + + const V v3 = Load(d, keys + offset3 + i); + const V v4 = Load(d, keys + offset4 + i); + const V v5 = Load(d, keys + offset5 + i); + const V medians1 = MedianOf3(st, v3, v4, v5); + Store(medians1, d, buf + i + lanes_per_chunk); + + const V v6 = Load(d, keys + offset6 + i); + const V v7 = Load(d, keys + offset7 + i); + const V v8 = Load(d, keys + offset8 + i); + const V medians2 = MedianOf3(st, v6, v7, v8); + Store(medians2, d, buf + i + lanes_per_chunk * 2); + } +} + +// For detecting inputs where (almost) all keys are equal. +template +HWY_INLINE bool UnsortedSampleEqual(D d, Traits st, + const TFromD* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = 3 * 64 / sizeof(TFromD); + const size_t N = Lanes(d); + using V = Vec; + + const V first = st.SetKey(d, samples); + // OR of XOR-difference may be faster than comparison. + V diff = Zero(d); + size_t i = 0; + for (; i + N <= kSampleLanes; i += N) { + const V v = Load(d, samples + i); + diff = OrXor(diff, first, v); + } + // Remainder, if any. + const V v = Load(d, samples + i); + const auto valid = FirstN(d, kSampleLanes - i); + diff = IfThenElse(valid, OrXor(diff, first, v), diff); + + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); +} + +template +HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) { + // buf contains 192 bytes, so 16 128-bit vectors are necessary and sufficient. + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + const CappedTag d128; + const size_t N128 = Lanes(d128); + constexpr size_t kCols = HWY_MIN(16 / sizeof(T), Constants::kMaxCols); + constexpr size_t kBytes = kCols * Constants::kMaxRows * sizeof(T); + static_assert(192 <= kBytes, ""); + // Fill with padding - last in sort order. + const auto kPadding = st.LastValue(d128); + // Initialize an extra vector because SortingNetwork loads full vectors, + // which may exceed cols*kMaxRows. + for (size_t i = kSampleLanes; i <= kBytes / sizeof(T); i += N128) { + StoreU(kPadding, d128, buf + i); + } + + SortingNetwork(st, buf, kCols); + + if (VQSORT_PRINT >= 2) { + const size_t N = Lanes(d); + fprintf(stderr, "Samples:\n"); + for (size_t i = 0; i < kSampleLanes; i += N) { + MaybePrintVector(d, "", Load(d, buf + i), 0, N); + } + } +} + +// ------------------------------ Pivot selection + +enum class PivotResult { + kDone, // stop without partitioning (all equal, or two-value partition) + kNormal, // partition and recurse left and right + kIsFirst, // partition but skip left recursion + kWasLast, // partition but skip right recursion +}; + +HWY_INLINE const char* PivotResultString(PivotResult result) { + switch (result) { + case PivotResult::kDone: + return "done"; + case PivotResult::kNormal: + return "normal"; + case PivotResult::kIsFirst: + return "first"; + case PivotResult::kWasLast: + return "last"; + } + return "unknown"; +} + +template +HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + constexpr size_t N1 = st.LanesPerKey(); + + constexpr size_t kRankMid = kSampleLanes / 2; + static_assert(kRankMid % N1 == 0, "Mid is not an aligned key"); + + // Find the previous value not equal to the median. + size_t rank_prev = kRankMid - N1; + for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) { + // All previous samples are equal to the median. + if (rank_prev == 0) return 0; + } + + size_t rank_next = rank_prev + N1; + for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) { + // The median is also the largest sample. If it is also the largest key, + // we'd end up with an empty right partition, so choose the previous key. + if (rank_next == kSampleLanes - N1) return rank_prev; + } + + // If we choose the median as pivot, the ratio of keys ending in the left + // partition will likely be rank_next/kSampleLanes (if the sample is + // representative). This is because equal-to-pivot values also land in the + // left - it's infeasible to do an in-place vectorized 3-way partition. + // Check whether prev would lead to a more balanced partition. + const size_t excess_if_median = rank_next - kRankMid; + const size_t excess_if_prev = kRankMid - rank_prev; + return excess_if_median < excess_if_prev ? kRankMid : rank_prev; +} + +// Returns pivot chosen from `samples`. It will never be the largest key +// (thus the right partition will never be empty). +template +HWY_INLINE Vec ChoosePivotByRank(D d, Traits st, + const T* HWY_RESTRICT samples) { + const size_t pivot_rank = PivotRank(st, samples); + const Vec pivot = st.SetKey(d, samples + pivot_rank); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, " Pivot rank %zu = %f\n", pivot_rank, + static_cast(GetLane(pivot))); + } + return pivot; +} + +// Returns true if all keys equal `pivot`, otherwise returns false and sets +// `*first_mismatch' to the index of the first differing key. +template +HWY_NOINLINE bool AllEqual(D d, Traits st, const Vec pivot, + const T* HWY_RESTRICT keys, size_t num, + size_t* HWY_RESTRICT first_mismatch) { + const size_t N = Lanes(d); + // Ensures we can use overlapping loads for the tail; see HandleSpecialCases. + HWY_DASSERT(num >= N); + const Vec zero = Zero(d); + + // Vector-align keys + i. + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (N - 1); + HWY_DASSERT(misalign % st.LanesPerKey() == 0); + const size_t consume = N - misalign; + { + const Vec v = LoadU(d, keys); + // Only check masked lanes; consider others to be equal. + const Mask diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot)); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = lane; + return false; + } + } + size_t i = consume; + HWY_DASSERT(((reinterpret_cast(keys + i) / sizeof(T)) & (N - 1)) == + 0); + + // Sticky bits registering any difference between `keys` and the first key. + // We use vector XOR because it may be cheaper than comparisons, especially + // for 128-bit. 2x unrolled for more ILP. + Vec diff0 = zero; + Vec diff1 = zero; + + // We want to stop once a difference has been found, but without slowing + // down the loop by comparing during each iteration. The compromise is to + // compare after a 'group', which consists of kLoops times two vectors. + constexpr size_t kLoops = 8; + const size_t lanes_per_group = kLoops * 2 * N; + + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec v0 = Load(d, keys + i + loop * 2 * N); + const Vec v1 = Load(d, keys + i + loop * 2 * N + N); + diff0 = OrXor(diff0, v0, pivot); + diff1 = OrXor(diff1, v1, pivot); + } + diff0 = Or(diff0, diff1); + + // If there was a difference in the entire group: (use du because we must + // avoid floating-point comparisons for -0) + const RebindToUnsigned du; + if (HWY_UNLIKELY(!AllTrue(du, Eq(BitCast(du, diff0), Zero(du))))) { + // .. then loop until the first one, with termination guarantee. + for (;; i += N) { + const Vec v = Load(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + } + } + + // Whole vectors, no unrolling, compare directly + for (; i + N <= num; i += N) { + const Vec v = Load(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + // Always re-check the last (unaligned) vector to reduce branching. + i = num - N; + const Vec v = LoadU(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "All keys equal\n"); + } + return true; // all equal +} + +template +HWY_NOINLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for before\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec first = pivot; + + // Whole group, unrolled + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec curr = LoadU(d, keys + i + loop * N); + first = st.First(d, first, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + // Whole vectors, no unrolling + for (; i + N <= num; i += N) { + const Vec curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the first +} + +template +HWY_NOINLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for after\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec last = pivot; + + // Whole group, unrolled + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec curr = LoadU(d, keys + i + loop * N); + last = st.Last(d, last, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + // Whole vectors, no unrolling + for (; i + N <= num; i += N) { + const Vec curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the last +} + +// Returns pivot chosen from `keys[0, num)`. It will never be the largest key +// (thus the right partition will never be empty). +template +HWY_INLINE Vec ChoosePivotForEqualSamples(D d, Traits st, + T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT samples, + Vec second, Vec third, + PivotResult& result) { + const Vec pivot = st.SetKey(d, samples); // the single unique sample + + // Early out for mostly-0 arrays, where pivot is often FirstValue. + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) { + result = PivotResult::kIsFirst; + return pivot; + } + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) { + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Check if pivot is between two known values. If so, it is not the first nor + // the last and we can avoid scanning. + st.Sort2(d, second, third); + HWY_DASSERT(AllTrue(d, st.Compare(d, second, third))); + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + const bool after = !AllFalse(d, st.Compare(d, pivot, third)); + // Only reached if there are three keys, which means pivot is either first, + // last, or in between. Thus there is another key that comes before or after. + HWY_DASSERT(before || after); + if (HWY_UNLIKELY(before)) { + // Neither first nor last. + if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would be + // empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Has after, and we found one before: in the middle. + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // Pivot is first. We could consider a special partition mode that only + // reads from and writes to the right side, and later fills in the left + // side, which we know is equal to the pivot. However, that leads to more + // cache misses if the array is large, and doesn't save much, hence is a + // net loss. + result = PivotResult::kIsFirst; + return pivot; +} + +// ------------------------------ Quicksort recursion + +template +HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 2) { + const size_t N = Lanes(d); + if (num < N) return; + + Vec first = st.LastValue(d); + Vec last = st.FirstValue(d); + + size_t i = 0; + for (; i + N <= num; i += N) { + const Vec v = LoadU(d, keys + i); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + if (HWY_LIKELY(i != num)) { + HWY_DASSERT(num >= N); // See HandleSpecialCases + const Vec v = LoadU(d, keys + num - N); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + + first = st.FirstOfLanes(d, first, buf); + last = st.LastOfLanes(d, last, buf); + MaybePrintVector(d, "first", first, 0, st.LanesPerKey()); + MaybePrintVector(d, "last", last, 0, st.LanesPerKey()); + } +} + +// keys_end is the end of the entire user input, not just the current subarray +// [keys, keys + num). +template +HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys, + T* HWY_RESTRICT keys_end, const size_t num, + T* HWY_RESTRICT buf, Generator& rng, + size_t remaining_levels) { + HWY_DASSERT(num != 0); + + if (HWY_UNLIKELY(num <= Constants::BaseCaseNum(Lanes(d)))) { + BaseCase(d, st, keys, keys_end, num, buf); + return; + } + + // Move after BaseCase so we skip printing for small subarrays. + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu\n", remaining_levels, + num); + PrintMinMax(d, st, keys, num, buf); + } + + DrawSamples(d, st, keys, num, buf, rng); + + Vec pivot; + PivotResult result = PivotResult::kNormal; + if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) { + pivot = st.SetKey(d, buf); + size_t idx_second = 0; + if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) { + return; + } + HWY_DASSERT(idx_second % st.LanesPerKey() == 0); + // Must capture the value before PartitionIfTwoKeys may overwrite it. + const Vec second = st.SetKey(d, keys + idx_second); + MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey()); + MaybePrintVector(d, "second", second, 0, st.LanesPerKey()); + + Vec third; + if (HWY_UNLIKELY(PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second, + second, third, buf))) { + return; // Done, skip recursion because each side has all-equal keys. + } + + // We can no longer start scanning from idx_second because + // PartitionIfTwoKeys may have reordered keys. + pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third, + result); + // If kNormal, `pivot` is very common but not the first/last. It is + // tempting to do a 3-way partition (to avoid moving the =pivot keys a + // second time), but that is a net loss due to the extra comparisons. + } else { + SortSamples(d, st, buf); + + if (HWY_UNLIKELY(PartitionIfTwoSamples(d, st, keys, num, buf))) { + return; + } + + pivot = ChoosePivotByRank(d, st, buf); + } + + // Too many recursions. This is unlikely to happen because we select pivots + // from large (though still O(1)) samples. + if (HWY_UNLIKELY(remaining_levels == 0)) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "HeapSort reached, size=%zu\n", num); + } + HeapSort(st, keys, num); // Slow but N*logN. + return; + } + + const size_t bound = Partition(d, st, keys, num, pivot, buf); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "bound %zu num %zu result %s\n", bound, num, + PivotResultString(result)); + } + if (HWY_LIKELY(result != PivotResult::kIsFirst)) { + // The left partition is not empty because the pivot is one of the keys. + HWY_DASSERT(0 != bound && bound != num); + Recurse(d, st, keys, keys_end, bound, buf, rng, remaining_levels - 1); + } + if (HWY_LIKELY(result != PivotResult::kWasLast)) { + // ChoosePivot* ensure pivot != last, so the right partition is never empty. + HWY_DASSERT(bound != num); + Recurse(d, st, keys + bound, keys_end, num - bound, buf, rng, + remaining_levels - 1); + } +} + +// Returns true if sorting is finished. +template +HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, + size_t num) { + const size_t N = Lanes(d); + const size_t base_case_num = Constants::BaseCaseNum(N); + + // 128-bit keys require vectors with at least two u64 lanes, which is always + // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the + // hardware vector width is less than 128bit / fraction. + const bool partial_128 = !IsFull(d) && N < 2 && st.Is128(); + // Partition assumes its input is at least two vectors. If vectors are huge, + // base_case_num may actually be smaller. If so, which is only possible on + // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of + // HWY_LANES to account for the largest possible LMUL. + constexpr bool kPotentiallyHuge = + HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols; + const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num); + if (partial_128 || huge_vec) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "WARNING: using slow HeapSort: partial %d huge %d\n", + partial_128, huge_vec); + } + HeapSort(st, keys, num); + return true; + } + + // Small arrays are already handled by Recurse. + + // We could also check for already sorted/reverse/equal, but that's probably + // counterproductive if vqsort is used as a base case. + + return false; // not finished sorting +} + +#endif // VQSORT_ENABLED +} // namespace detail + +// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`. +// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons. +// Non-stable (order of equal keys may change), except for the common case where +// the upper bits of T are the key, and the lower bits are a sequential or at +// least unique ID. +// There is no upper limit on `num`, but note that pivots may be chosen by +// sampling only from the first 256 GiB. +// +// `d` is typically SortTag (chooses between full and partial vectors). +// `st` is SharedTraits>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. +template +void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "=============== Sort num %zu\n", num); + } + +#if VQSORT_ENABLED || HWY_IDE +#if !HWY_HAVE_SCALABLE + // On targets with fixed-size vectors, avoid _using_ the allocated memory. + // We avoid (potentially expensive for small input sizes) allocations on + // platforms where no targets are scalable. For 512-bit vectors, this fits on + // the stack (several KiB). + HWY_ALIGN T storage[SortConstants::BufNum(HWY_LANES(T))] = {}; + static_assert(sizeof(storage) <= 8192, "Unexpectedly large, check size"); + buf = storage; +#endif // !HWY_HAVE_SCALABLE + + if (detail::HandleSpecialCases(d, st, keys, num)) return; + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return Sort(CappedTag(), st, keys, num, buf); + } +#endif // HWY_MAX_BYTES > 64 + + detail::Generator rng(keys, num); + + // Introspection: switch to worst-case N*logN heapsort after this many. + const size_t max_levels = 2 * hwy::CeilLog2(num) + 4; + detail::Recurse(d, st, keys, keys + num, num, buf, rng, max_levels); +#else + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n"); + } + return detail::HeapSort(st, keys, num); +#endif // VQSORT_ENABLED +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE diff --git a/media/highway/src/hwy/contrib/sort/vqsort.cc b/media/highway/src/hwy/contrib/sort/vqsort.cc new file mode 100644 index 000000000..b3bac0720 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort.cc @@ -0,0 +1,184 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#include // memset + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/shared-inl.h" + +// Architectures for which we know HWY_HAVE_SCALABLE == 0. This opts into an +// optimization that replaces dynamic allocation with stack storage. +#ifndef VQSORT_STACK +#if HWY_ARCH_X86 || HWY_ARCH_WASM +#define VQSORT_STACK 1 +#else +#define VQSORT_STACK 0 +#endif +#endif // VQSORT_STACK + +#if !VQSORT_STACK +#include "hwy/aligned_allocator.h" +#endif + +// Check if we have sys/random.h. First skip some systems on which the check +// itself (features.h) might be problematic. +#if defined(ANDROID) || defined(__ANDROID__) || HWY_ARCH_RVV +#define VQSORT_GETRANDOM 0 +#endif + +#if !defined(VQSORT_GETRANDOM) && HWY_OS_LINUX +#include + +// ---- which libc +#if defined(__UCLIBC__) +#define VQSORT_GETRANDOM 1 // added Mar 2015, before uclibc-ng 1.0 + +#elif defined(__GLIBC__) && defined(__GLIBC_PREREQ) +#if __GLIBC_PREREQ(2, 25) +#define VQSORT_GETRANDOM 1 +#else +#define VQSORT_GETRANDOM 0 +#endif + +#else +// Assume MUSL, which has getrandom since 2018. There is no macro to test, see +// https://www.openwall.com/lists/musl/2013/03/29/13. +#define VQSORT_GETRANDOM 1 + +#endif // ---- which libc +#endif // linux + +#if !defined(VQSORT_GETRANDOM) +#define VQSORT_GETRANDOM 0 +#endif + +// Seed source for SFC generator: 1=getrandom, 2=CryptGenRandom +// (not all Android support the getrandom wrapper) +#ifndef VQSORT_SECURE_SEED + +#if VQSORT_GETRANDOM +#define VQSORT_SECURE_SEED 1 +#elif defined(_WIN32) || defined(_WIN64) +#define VQSORT_SECURE_SEED 2 +#else +#define VQSORT_SECURE_SEED 0 +#endif + +#endif // VQSORT_SECURE_SEED + +#if !VQSORT_SECURE_RNG + +#include +#if VQSORT_SECURE_SEED == 1 +#include +#elif VQSORT_SECURE_SEED == 2 +#include +#pragma comment(lib, "advapi32.lib") +// Must come after windows.h. +#include +#endif // VQSORT_SECURE_SEED + +#endif // !VQSORT_SECURE_RNG + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +size_t VectorSize() { return Lanes(ScalableTag()); } +bool HaveFloat64() { return HWY_HAVE_FLOAT64; } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(VectorSize); +HWY_EXPORT(HaveFloat64); + +} // namespace + +Sorter::Sorter() { +#if VQSORT_STACK + ptr_ = nullptr; // Sort will use stack storage instead +#else + // Determine the largest buffer size required for any type by trying them all. + // (The capping of N in BaseCaseNum means that smaller N but larger sizeof_t + // may require a larger buffer.) + const size_t vector_size = HWY_DYNAMIC_DISPATCH(VectorSize)(); + const size_t max_bytes = + HWY_MAX(HWY_MAX(SortConstants::BufBytes(vector_size), + SortConstants::BufBytes(vector_size)), + SortConstants::BufBytes(vector_size)); + ptr_ = hwy::AllocateAlignedBytes(max_bytes, nullptr, nullptr); + + // Prevent msan errors by initializing. + memset(ptr_, 0, max_bytes); +#endif +} + +void Sorter::Delete() { +#if !VQSORT_STACK + FreeAlignedBytes(ptr_, nullptr, nullptr); + ptr_ = nullptr; +#endif +} + +#if !VQSORT_SECURE_RNG + +void Sorter::Fill24Bytes(const void* seed_heap, size_t seed_num, void* bytes) { +#if VQSORT_SECURE_SEED == 1 + // May block if urandom is not yet initialized. + const ssize_t ret = getrandom(bytes, 24, /*flags=*/0); + if (ret == 24) return; +#elif VQSORT_SECURE_SEED == 2 + HCRYPTPROV hProvider{}; + if (CryptAcquireContextA(&hProvider, nullptr, nullptr, PROV_RSA_FULL, + CRYPT_VERIFYCONTEXT)) { + const BOOL ok = + CryptGenRandom(hProvider, 24, reinterpret_cast(bytes)); + CryptReleaseContext(hProvider, 0); + if (ok) return; + } +#endif + + // VQSORT_SECURE_SEED == 0, or one of the above failed. Get some entropy from + // stack/heap/code addresses and the clock() timer. + uint64_t* words = reinterpret_cast(bytes); + uint64_t** seed_stack = &words; + void (*seed_code)(const void*, size_t, void*) = &Fill24Bytes; + const uintptr_t bits_stack = reinterpret_cast(seed_stack); + const uintptr_t bits_heap = reinterpret_cast(seed_heap); + const uintptr_t bits_code = reinterpret_cast(seed_code); + const uint64_t bits_time = static_cast(clock()); + words[0] = bits_stack ^ bits_time ^ seed_num; + words[1] = bits_heap ^ bits_time ^ seed_num; + words[2] = bits_code ^ bits_time ^ seed_num; +} + +#endif // !VQSORT_SECURE_RNG + +bool Sorter::HaveFloat64() { return HWY_DYNAMIC_DISPATCH(HaveFloat64)(); } + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort.h b/media/highway/src/hwy/contrib/sort/vqsort.h new file mode 100644 index 000000000..88d78ac7f --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort.h @@ -0,0 +1,108 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Interface to vectorized quicksort with dynamic dispatch. +// Blog post: https://tinyurl.com/vqsort-blog +// Paper with measurements: https://arxiv.org/abs/2205.05982 +// +// To ensure the overhead of using wide vectors (e.g. AVX2 or AVX-512) is +// worthwhile, we recommend using this code for sorting arrays whose size is at +// least 512 KiB. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Tag arguments that determine the sort order. +struct SortAscending { + constexpr bool IsAscending() const { return true; } +}; +struct SortDescending { + constexpr bool IsAscending() const { return false; } +}; + +// Allocates O(1) space. Type-erased RAII wrapper over hwy/aligned_allocator.h. +// This allows amortizing the allocation over multiple sorts. +class HWY_CONTRIB_DLLEXPORT Sorter { + public: + Sorter(); + ~Sorter() { Delete(); } + + // Move-only + Sorter(const Sorter&) = delete; + Sorter& operator=(const Sorter&) = delete; + Sorter(Sorter&& other) { + Delete(); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + } + Sorter& operator=(Sorter&& other) { + Delete(); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + return *this; + } + + // Sorts keys[0, n). Dispatches to the best available instruction set, + // and does not allocate memory. + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // For internal use only + static void Fill24Bytes(const void* seed_heap, size_t seed_num, void* bytes); + static bool HaveFloat64(); + + private: + void Delete(); + + template + T* Get() const { + return static_cast(ptr_); + } + + void* ptr_ = nullptr; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ diff --git a/media/highway/src/hwy/contrib/sort/vqsort_128a.cc b/media/highway/src/hwy/contrib/sort/vqsort_128a.cc new file mode 100644 index 000000000..40daea85c --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_128a.cc @@ -0,0 +1,62 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void Sort128Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Asc); +} // namespace + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(Sort128Asc) + (reinterpret_cast(keys), n * 2, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_128d.cc b/media/highway/src/hwy/contrib/sort/vqsort_128d.cc new file mode 100644 index 000000000..357da840c --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_128d.cc @@ -0,0 +1,62 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void Sort128Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Desc); +} // namespace + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(Sort128Desc) + (reinterpret_cast(keys), n * 2, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_f32a.cc b/media/highway/src/hwy/contrib/sort/vqsort_f32a.cc new file mode 100644 index 000000000..3856eea5d --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_f32a.cc @@ -0,0 +1,53 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF32Asc(float* HWY_RESTRICT keys, size_t num, float* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Asc); +} // namespace + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortF32Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_f32d.cc b/media/highway/src/hwy/contrib/sort/vqsort_f32d.cc new file mode 100644 index 000000000..7f5f97cdf --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_f32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF32Desc(float* HWY_RESTRICT keys, size_t num, + float* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Desc); +} // namespace + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortF32Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_f64a.cc b/media/highway/src/hwy/contrib/sort/vqsort_f64a.cc new file mode 100644 index 000000000..287d5214e --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_f64a.cc @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF64Asc(double* HWY_RESTRICT keys, size_t num, + double* HWY_RESTRICT buf) { +#if HWY_HAVE_FLOAT64 + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +#else + (void)keys; + (void)num; + (void)buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Asc); +} // namespace + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortF64Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_f64d.cc b/media/highway/src/hwy/contrib/sort/vqsort_f64d.cc new file mode 100644 index 000000000..74d40c1ed --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_f64d.cc @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF64Desc(double* HWY_RESTRICT keys, size_t num, + double* HWY_RESTRICT buf) { +#if HWY_HAVE_FLOAT64 + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +#else + (void)keys; + (void)num; + (void)buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Desc); +} // namespace + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortF64Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_i16a.cc b/media/highway/src/hwy/contrib/sort/vqsort_i16a.cc new file mode 100644 index 000000000..ef4bb75bc --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_i16a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI16Asc(int16_t* HWY_RESTRICT keys, size_t num, + int16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Asc); +} // namespace + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI16Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_i16d.cc b/media/highway/src/hwy/contrib/sort/vqsort_i16d.cc new file mode 100644 index 000000000..6507ed608 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_i16d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI16Desc(int16_t* HWY_RESTRICT keys, size_t num, + int16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Desc); +} // namespace + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI16Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_i32a.cc b/media/highway/src/hwy/contrib/sort/vqsort_i32a.cc new file mode 100644 index 000000000..ae65be997 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_i32a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI32Asc(int32_t* HWY_RESTRICT keys, size_t num, + int32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Asc); +} // namespace + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI32Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_i32d.cc b/media/highway/src/hwy/contrib/sort/vqsort_i32d.cc new file mode 100644 index 000000000..3ce276ee9 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_i32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI32Desc(int32_t* HWY_RESTRICT keys, size_t num, + int32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Desc); +} // namespace + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI32Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_i64a.cc b/media/highway/src/hwy/contrib/sort/vqsort_i64a.cc new file mode 100644 index 000000000..901b8ead8 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_i64a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI64Asc(int64_t* HWY_RESTRICT keys, size_t num, + int64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Asc); +} // namespace + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI64Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_i64d.cc b/media/highway/src/hwy/contrib/sort/vqsort_i64d.cc new file mode 100644 index 000000000..7713f2eb8 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_i64d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI64Desc(int64_t* HWY_RESTRICT keys, size_t num, + int64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Desc); +} // namespace + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI64Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_kv128a.cc b/media/highway/src/hwy/contrib/sort/vqsort_kv128a.cc new file mode 100644 index 000000000..1e02742ef --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_kv128a.cc @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128a.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV128Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV128Asc); +} // namespace + +void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortKV128Asc) + (reinterpret_cast(keys), n * 2, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_kv128d.cc b/media/highway/src/hwy/contrib/sort/vqsort_kv128d.cc new file mode 100644 index 000000000..3dd53b5da --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_kv128d.cc @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128d.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV128Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV128Desc); +} // namespace + +void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortKV128Desc) + (reinterpret_cast(keys), n * 2, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_kv64a.cc b/media/highway/src/hwy/contrib/sort/vqsort_kv64a.cc new file mode 100644 index 000000000..c513e3c4c --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_kv64a.cc @@ -0,0 +1,65 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64a.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV64Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV64Asc); +} // namespace + +void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortKV64Asc) + (reinterpret_cast(keys), n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_kv64d.cc b/media/highway/src/hwy/contrib/sort/vqsort_kv64d.cc new file mode 100644 index 000000000..c6c5fdcf7 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_kv64d.cc @@ -0,0 +1,65 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64d.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV64Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag d; + detail::SharedTraits> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV64Desc); +} // namespace + +void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortKV64Desc) + (reinterpret_cast(keys), n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_u16a.cc b/media/highway/src/hwy/contrib/sort/vqsort_u16a.cc new file mode 100644 index 000000000..0a97ffa92 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_u16a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU16Asc(uint16_t* HWY_RESTRICT keys, size_t num, + uint16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Asc); +} // namespace + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU16Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_u16d.cc b/media/highway/src/hwy/contrib/sort/vqsort_u16d.cc new file mode 100644 index 000000000..286ebbba6 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_u16d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU16Desc(uint16_t* HWY_RESTRICT keys, size_t num, + uint16_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> + st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Desc); +} // namespace + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU16Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_u32a.cc b/media/highway/src/hwy/contrib/sort/vqsort_u32a.cc new file mode 100644 index 000000000..b6a69e6e2 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_u32a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU32Asc(uint32_t* HWY_RESTRICT keys, size_t num, + uint32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Asc); +} // namespace + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU32Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_u32d.cc b/media/highway/src/hwy/contrib/sort/vqsort_u32d.cc new file mode 100644 index 000000000..38fc1e1bf --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_u32d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU32Desc(uint32_t* HWY_RESTRICT keys, size_t num, + uint32_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> + st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Desc); +} // namespace + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU32Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_u64a.cc b/media/highway/src/hwy/contrib/sort/vqsort_u64a.cc new file mode 100644 index 000000000..a29824a6f --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_u64a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU64Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Asc); +} // namespace + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU64Asc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/contrib/sort/vqsort_u64d.cc b/media/highway/src/hwy/contrib/sort/vqsort_u64d.cc new file mode 100644 index 000000000..d69245862 --- /dev/null +++ b/media/highway/src/hwy/contrib/sort/vqsort_u64d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU64Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag d; + detail::SharedTraits>> + st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Desc); +} // namespace + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU64Desc)(keys, n, Get()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/detect_compiler_arch.h b/media/highway/src/hwy/detect_compiler_arch.h new file mode 100644 index 000000000..98c6a55b0 --- /dev/null +++ b/media/highway/src/hwy/detect_compiler_arch.h @@ -0,0 +1,234 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ +#define HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ + +// Detects compiler and arch from predefined macros. Zero dependencies for +// inclusion by foreach_target.h. + +// Add to #if conditions to prevent IDE from graying out code. +#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \ + (defined Q_CREATOR_RUN) || (defined(__CLANGD__)) +#define HWY_IDE 1 +#else +#define HWY_IDE 0 +#endif + +//------------------------------------------------------------------------------ +// Compiler + +// Actual MSVC, not clang-cl, which defines _MSC_VER but doesn't behave like +// MSVC in other aspects (e.g. HWY_DIAGNOSTICS). +#if defined(_MSC_VER) && !defined(__clang__) +#define HWY_COMPILER_MSVC _MSC_VER +#else +#define HWY_COMPILER_MSVC 0 +#endif + +#if defined(_MSC_VER) && defined(__clang__) +#define HWY_COMPILER_CLANGCL _MSC_VER +#else +#define HWY_COMPILER_CLANGCL 0 +#endif + +#ifdef __INTEL_COMPILER +#define HWY_COMPILER_ICC __INTEL_COMPILER +#else +#define HWY_COMPILER_ICC 0 +#endif + +#ifdef __INTEL_LLVM_COMPILER +#define HWY_COMPILER_ICX __INTEL_LLVM_COMPILER +#else +#define HWY_COMPILER_ICX 0 +#endif + +// HWY_COMPILER_GCC is a generic macro for all compilers implementing the GNU +// compiler extensions (eg. Clang, Intel...) +#ifdef __GNUC__ +#define HWY_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define HWY_COMPILER_GCC 0 +#endif + +// Clang or clang-cl, not GCC. +#ifdef __clang__ +// In case of Apple LLVM (whose version number is unrelated to that of LLVM) or +// an invalid version number, deduce it from the presence of warnings. +// Adapted from https://github.com/simd-everywhere/simde/ simde-detect-clang.h. +#if defined(__apple_build_version__) || __clang_major__ >= 999 +#if __has_warning("-Wbitwise-instead-of-logical") +#define HWY_COMPILER_CLANG 1400 +#elif __has_warning("-Wreserved-identifier") +#define HWY_COMPILER_CLANG 1300 +#elif __has_warning("-Wformat-insufficient-args") +#define HWY_COMPILER_CLANG 1200 +#elif __has_warning("-Wimplicit-const-int-float-conversion") +#define HWY_COMPILER_CLANG 1100 +#elif __has_warning("-Wmisleading-indentation") +#define HWY_COMPILER_CLANG 1000 +#elif defined(__FILE_NAME__) +#define HWY_COMPILER_CLANG 900 +#elif __has_warning("-Wextra-semi-stmt") || \ + __has_builtin(__builtin_rotateleft32) +#define HWY_COMPILER_CLANG 800 +// For reasons unknown, XCode 10.3 (Apple LLVM version 10.0.1) is apparently +// based on Clang 7, but does not support the warning we test. +// See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions and +// https://trac.macports.org/wiki/XcodeVersionInfo. +#elif __has_warning("-Wc++98-compat-extra-semi") || \ + (defined(__apple_build_version__) && __apple_build_version__ >= 10010000) +#define HWY_COMPILER_CLANG 700 +#else // Anything older than 7.0 is not recommended for Highway. +#define HWY_COMPILER_CLANG 600 +#endif // __has_warning chain +#else // use normal version +#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +#endif +#else // Not clang +#define HWY_COMPILER_CLANG 0 +#endif + +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG +#define HWY_COMPILER_GCC_ACTUAL HWY_COMPILER_GCC +#else +#define HWY_COMPILER_GCC_ACTUAL 0 +#endif + +// More than one may be nonzero, but we want at least one. +#if 0 == (HWY_COMPILER_MSVC + HWY_COMPILER_CLANGCL + HWY_COMPILER_ICC + \ + HWY_COMPILER_GCC + HWY_COMPILER_CLANG) +#error "Unsupported compiler" +#endif + +// We should only detect one of these (only clang/clangcl overlap) +#if 1 < \ + (!!HWY_COMPILER_MSVC + !!HWY_COMPILER_ICC + !!HWY_COMPILER_GCC_ACTUAL + \ + !!(HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG)) +#error "Detected multiple compilers" +#endif + +#ifdef __has_builtin +#define HWY_HAS_BUILTIN(name) __has_builtin(name) +#else +#define HWY_HAS_BUILTIN(name) 0 +#endif + +#ifdef __has_attribute +#define HWY_HAS_ATTRIBUTE(name) __has_attribute(name) +#else +#define HWY_HAS_ATTRIBUTE(name) 0 +#endif + +#ifdef __has_feature +#define HWY_HAS_FEATURE(name) __has_feature(name) +#else +#define HWY_HAS_FEATURE(name) 0 +#endif + +//------------------------------------------------------------------------------ +// Architecture + +#if defined(__i386__) || defined(_M_IX86) +#define HWY_ARCH_X86_32 1 +#else +#define HWY_ARCH_X86_32 0 +#endif + +#if defined(__x86_64__) || defined(_M_X64) +#define HWY_ARCH_X86_64 1 +#else +#define HWY_ARCH_X86_64 0 +#endif + +#if HWY_ARCH_X86_32 && HWY_ARCH_X86_64 +#error "Cannot have both x86-32 and x86-64" +#endif + +#if HWY_ARCH_X86_32 || HWY_ARCH_X86_64 +#define HWY_ARCH_X86 1 +#else +#define HWY_ARCH_X86 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) +#define HWY_ARCH_PPC 1 +#else +#define HWY_ARCH_PPC 0 +#endif + +#if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64) +#define HWY_ARCH_ARM_A64 1 +#else +#define HWY_ARCH_ARM_A64 0 +#endif + +#if (defined(__ARM_ARCH) && __ARM_ARCH == 7) || (defined(_M_ARM) && _M_ARM == 7) +#define HWY_ARCH_ARM_V7 1 +#else +#define HWY_ARCH_ARM_V7 0 +#endif + +#if HWY_ARCH_ARM_A64 && HWY_ARCH_ARM_V7 +#error "Cannot have both A64 and V7" +#endif + +// Any *supported* version of Arm, i.e. 7 or later +#if HWY_ARCH_ARM_A64 || HWY_ARCH_ARM_V7 +#define HWY_ARCH_ARM 1 +#else +#define HWY_ARCH_ARM 0 +#endif + +// Older than v7 (e.g. armel aka Arm v5), in which case we do not support SIMD. +#if (defined(__arm__) || defined(_M_ARM)) && !HWY_ARCH_ARM +#define HWY_ARCH_ARM_OLD 1 +#else +#define HWY_ARCH_ARM_OLD 0 +#endif + +#if defined(__EMSCRIPTEN__) || defined(__wasm__) || defined(__WASM__) +#define HWY_ARCH_WASM 1 +#else +#define HWY_ARCH_WASM 0 +#endif + +#ifdef __riscv +#define HWY_ARCH_RVV 1 +#else +#define HWY_ARCH_RVV 0 +#endif + +// It is an error to detect multiple architectures at the same time, but OK to +// detect none of the above. +#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \ + HWY_ARCH_WASM + HWY_ARCH_RVV) > 1 +#error "Must not detect more than one architecture" +#endif + +#if defined(_WIN32) || defined(_WIN64) +#define HWY_OS_WIN 1 +#else +#define HWY_OS_WIN 0 +#endif + +#if defined(linux) || defined(__linux__) +#define HWY_OS_LINUX 1 +#else +#define HWY_OS_LINUX 0 +#endif + +#endif // HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ diff --git a/media/highway/src/hwy/detect_targets.h b/media/highway/src/hwy/detect_targets.h new file mode 100644 index 000000000..7f7e179b3 --- /dev/null +++ b/media/highway/src/hwy/detect_targets.h @@ -0,0 +1,478 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_TARGETS_H_ +#define HIGHWAY_HWY_DETECT_TARGETS_H_ + +// Defines targets and chooses which to enable. + +#include "hwy/detect_compiler_arch.h" + +//------------------------------------------------------------------------------ +// Optional configuration + +// See g3doc/quick_reference.md for documentation of these macros. + +// Uncomment to override the default baseline determined from predefined macros: +// #define HWY_BASELINE_TARGETS (HWY_SSE4 | HWY_SCALAR) + +// Uncomment to override the default blocklist: +// #define HWY_BROKEN_TARGETS HWY_AVX3 + +// Uncomment to definitely avoid generating those target(s): +// #define HWY_DISABLED_TARGETS HWY_SSE4 + +// Uncomment to avoid emitting BMI/BMI2/FMA instructions (allows generating +// AVX2 target for VMs which support AVX2 but not the other instruction sets) +// #define HWY_DISABLE_BMI2_FMA + +// Uncomment to enable SSSE3/SSE4 on MSVC even if AVX is not enabled +// #define HWY_WANT_SSSE3 +// #define HWY_WANT_SSE4 + +//------------------------------------------------------------------------------ +// Targets + +// Unique bit value for each target. A lower value is "better" (e.g. more lanes) +// than a higher value within the same group/platform - see HWY_STATIC_TARGET. +// +// All values are unconditionally defined so we can test HWY_TARGETS without +// first checking the HWY_ARCH_*. +// +// The C99 preprocessor evaluates #if expressions using intmax_t types. This +// holds at least 64 bits in practice (verified 2022-07-18 via Godbolt on +// 32-bit clang/GCC/MSVC compilers for x86/Arm7/AArch32/RISC-V/WASM). We now +// avoid overflow when computing HWY_TARGETS (subtracting one instead of +// left-shifting 2^62), but still do not use bit 63 because it is the sign bit. + +// --------------------------- x86: 15 targets (+ one fallback) +// Bits 0..6 reserved (7 targets) +// Currently satisfiable by Ice Lake (VNNI, VPCLMULQDQ, VPOPCNTDQ, VBMI, VBMI2, +// VAES, BITALG). Later to be added: BF16 (Cooper Lake). VP2INTERSECT is only in +// Tiger Lake? We do not yet have uses for GFNI. +#define HWY_AVX3_DL (1LL << 7) // see HWY_WANT_AVX3_DL below +#define HWY_AVX3 (1LL << 8) +#define HWY_AVX2 (1LL << 9) +// Bit 10: reserved for AVX +#define HWY_SSE4 (1LL << 11) +#define HWY_SSSE3 (1LL << 12) +// Bits 13..14 reserved for SSE3 or SSE2 (2 targets) +// The highest bit in the HWY_TARGETS mask that a x86 target can have. Used for +// dynamic dispatch. All x86 target bits must be lower or equal to +// (1 << HWY_HIGHEST_TARGET_BIT_X86) and they can only use +// HWY_MAX_DYNAMIC_TARGETS in total. +#define HWY_HIGHEST_TARGET_BIT_X86 14 + +// --------------------------- Arm: 15 targets (+ one fallback) +// Bits 15..23 reserved (9 targets) +#define HWY_SVE2_128 (1LL << 24) // specialized target (e.g. Arm N2) +#define HWY_SVE_256 (1LL << 25) // specialized target (e.g. Arm V1) +#define HWY_SVE2 (1LL << 26) +#define HWY_SVE (1LL << 27) +#define HWY_NEON (1LL << 28) // On A64, includes/requires AES +// Bit 29 reserved (Helium?) +#define HWY_HIGHEST_TARGET_BIT_ARM 29 + +// --------------------------- RISC-V: 9 targets (+ one fallback) +// Bits 30..36 reserved (7 targets) +#define HWY_RVV (1LL << 37) +// Bit 38 reserved +#define HWY_HIGHEST_TARGET_BIT_RVV 38 + +// --------------------------- Future expansion: 4 targets +// Bits 39..42 reserved + + +// --------------------------- IBM Power: 9 targets (+ one fallback) +// Bits 43..48 reserved (6 targets) +#define HWY_PPC8 (1LL << 49) // v2.07 or 3 +// Bits 50..51 reserved for prior VSX/AltiVec (2 targets) +#define HWY_HIGHEST_TARGET_BIT_PPC 51 + +// --------------------------- WebAssembly: 9 targets (+ one fallback) +// Bits 52..57 reserved (6 targets) +#define HWY_WASM_EMU256 (1LL << 58) // Experimental +#define HWY_WASM (1LL << 59) +// Bits 60 reserved +#define HWY_HIGHEST_TARGET_BIT_WASM 60 + +// --------------------------- Emulation: 2 targets + +#define HWY_EMU128 (1LL << 61) +// We do not add/left-shift, so this will not overflow to a negative number. +#define HWY_SCALAR (1LL << 62) +#define HWY_HIGHEST_TARGET_BIT_SCALAR 62 + +// Do not use bit 63 - would be confusing to have negative numbers. + +//------------------------------------------------------------------------------ +// Set default blocklists + +// Disabled means excluded from enabled at user's request. A separate config +// macro allows disabling without deactivating the blocklist below. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS 0 +#endif + +// Broken means excluded from enabled due to known compiler issues. Allow the +// user to override this blocklist without any guarantee of success. +#ifndef HWY_BROKEN_TARGETS + +// x86 clang-6: we saw multiple AVX2/3 compile errors and in one case invalid +// SSE4 codegen (possibly only for msan), so disable all those targets. +#if HWY_ARCH_X86 && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_BROKEN_TARGETS (HWY_SSE4 | HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL) +// This entails a major speed reduction, so warn unless the user explicitly +// opts in to scalar-only. +#if !defined(HWY_COMPILE_ONLY_SCALAR) +#pragma message("x86 Clang <= 6: define HWY_COMPILE_ONLY_SCALAR or upgrade.") +#endif + +// 32-bit may fail to compile AVX2/3. +#elif HWY_ARCH_X86_32 +#define HWY_BROKEN_TARGETS (HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL) + +// MSVC AVX3 support is buggy: https://github.com/Mysticial/Flops/issues/16 +#elif HWY_COMPILER_MSVC != 0 +#define HWY_BROKEN_TARGETS (HWY_AVX3 | HWY_AVX3_DL) + +// armv7be has not been tested and is not yet supported. +#elif HWY_ARCH_ARM_V7 && \ + (defined(__ARM_BIG_ENDIAN) || \ + (defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN)) +#define HWY_BROKEN_TARGETS (HWY_NEON) + +// SVE[2] require recent clang or gcc versions. +#elif (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) +#define HWY_BROKEN_TARGETS (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128) + +#else +#define HWY_BROKEN_TARGETS 0 +#endif + +#endif // HWY_BROKEN_TARGETS + +// Enabled means not disabled nor blocklisted. +#define HWY_ENABLED(targets) \ + ((targets) & ~((HWY_DISABLED_TARGETS) | (HWY_BROKEN_TARGETS))) + +// Opt-out for EMU128 (affected by a GCC bug on multiple arches, fixed in 12.3: +// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=106322). This is separate +// from HWY_BROKEN_TARGETS because it affects the fallback target, which must +// always be enabled. If 1, we instead choose HWY_SCALAR even without +// HWY_COMPILE_ONLY_SCALAR being set. +#if !defined(HWY_BROKEN_EMU128) // allow overriding +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1203 +#define HWY_BROKEN_EMU128 1 +#else +#define HWY_BROKEN_EMU128 0 +#endif +#endif // HWY_BROKEN_EMU128 + +//------------------------------------------------------------------------------ +// Detect baseline targets using predefined macros + +// Baseline means the targets for which the compiler is allowed to generate +// instructions, implying the target CPU would have to support them. This does +// not take the blocklist into account. + +#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128 +#define HWY_BASELINE_SCALAR HWY_SCALAR +#else +#define HWY_BASELINE_SCALAR HWY_EMU128 +#endif + +// Also check HWY_ARCH to ensure that simulating unknown platforms ends up with +// HWY_TARGET == HWY_BASELINE_SCALAR. + +#if HWY_ARCH_WASM && defined(__wasm_simd128__) +#if defined(HWY_WANT_WASM2) +#define HWY_BASELINE_WASM HWY_WASM_EMU256 +#else +#define HWY_BASELINE_WASM HWY_WASM +#endif // HWY_WANT_WASM2 +#else +#define HWY_BASELINE_WASM 0 +#endif + +// Avoid choosing the PPC target until we have an implementation. +#if HWY_ARCH_PPC && defined(__VSX__) && 0 +#define HWY_BASELINE_PPC8 HWY_PPC8 +#else +#define HWY_BASELINE_PPC8 0 +#endif + +#define HWY_BASELINE_SVE2 0 +#define HWY_BASELINE_SVE 0 +#define HWY_BASELINE_NEON 0 + +#if HWY_ARCH_ARM + +#if defined(__ARM_FEATURE_SVE2) +#undef HWY_BASELINE_SVE2 // was 0, will be re-defined +// If user specified -msve-vector-bits=128, they assert the vector length is +// 128 bits and we should use the HWY_SVE2_128 (more efficient for some ops). +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 128 +#define HWY_BASELINE_SVE2 HWY_SVE2_128 +// Otherwise we're not sure what the vector length will be. The baseline must be +// unconditionally valid, so we can only assume HWY_SVE2. However, when running +// on a CPU with 128-bit vectors, user code that supports dynamic dispatch will +// still benefit from HWY_SVE2_128 because we add it to HWY_ATTAINABLE_TARGETS. +#else +#define HWY_BASELINE_SVE2 HWY_SVE2 +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE2 + +#if defined(__ARM_FEATURE_SVE) +#undef HWY_BASELINE_SVE // was 0, will be re-defined +// See above. If user-specified vector length matches our optimization, use it. +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 256 +#define HWY_BASELINE_SVE HWY_SVE_256 +#else +#define HWY_BASELINE_SVE HWY_SVE +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE + +// GCC 4.5.4 only defines __ARM_NEON__; 5.4 defines both. +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#undef HWY_BASELINE_NEON +#define HWY_BASELINE_NEON HWY_NEON +#endif + +#endif // HWY_ARCH_ARM + +// Special handling for MSVC because it has fewer predefined macros: +#if HWY_COMPILER_MSVC + +// 1) We can only be sure SSSE3/SSE4 are enabled if AVX is: +// https://stackoverflow.com/questions/18563978/. +#if defined(__AVX__) +#define HWY_CHECK_SSSE3 1 +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSSE3 0 +#define HWY_CHECK_SSE4 0 +#endif + +// 2) Cannot check for PCLMUL/AES and BMI2/FMA/F16C individually; we assume +// PCLMUL/AES are available if SSE4 is, and BMI2/FMA/F16C if AVX2 is. +#define HWY_CHECK_PCLMUL_AES 1 +#define HWY_CHECK_BMI2_FMA 1 +#define HWY_CHECK_F16C 1 + +#else // non-MSVC + +#if defined(__SSSE3__) +#define HWY_CHECK_SSSE3 1 +#else +#define HWY_CHECK_SSSE3 0 +#endif + +#if defined(__SSE4_1__) && defined(__SSE4_2__) +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSE4 0 +#endif + +// If these are disabled, they should not gate the availability of SSE4/AVX2. +#if defined(HWY_DISABLE_PCLMUL_AES) || (defined(__PCLMUL__) && defined(__AES__)) +#define HWY_CHECK_PCLMUL_AES 1 +#else +#define HWY_CHECK_PCLMUL_AES 0 +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) || (defined(__BMI2__) && defined(__FMA__)) +#define HWY_CHECK_BMI2_FMA 1 +#else +#define HWY_CHECK_BMI2_FMA 0 +#endif + +#if defined(HWY_DISABLE_F16C) || defined(__F16C__) +#define HWY_CHECK_F16C 1 +#else +#define HWY_CHECK_F16C 0 +#endif + +#endif // non-MSVC + +#if HWY_ARCH_X86 && (HWY_WANT_SSSE3 || HWY_CHECK_SSSE3) +#define HWY_BASELINE_SSSE3 HWY_SSSE3 +#else +#define HWY_BASELINE_SSSE3 0 +#endif + +#if HWY_ARCH_X86 && (HWY_WANT_SSE4 || (HWY_CHECK_SSE4 && HWY_CHECK_PCLMUL_AES)) +#define HWY_BASELINE_SSE4 HWY_SSE4 +#else +#define HWY_BASELINE_SSE4 0 +#endif + +#if HWY_BASELINE_SSE4 != 0 && HWY_CHECK_BMI2_FMA && HWY_CHECK_F16C && \ + defined(__AVX2__) +#define HWY_BASELINE_AVX2 HWY_AVX2 +#else +#define HWY_BASELINE_AVX2 0 +#endif + +// Require everything in AVX2 plus AVX-512 flags (also set by MSVC) +#if HWY_BASELINE_AVX2 != 0 && defined(__AVX512F__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__) && defined(__AVX512VL__) +#define HWY_BASELINE_AVX3 HWY_AVX3 +#else +#define HWY_BASELINE_AVX3 0 +#endif + +// TODO(janwas): not yet known whether these will be set by MSVC +#if HWY_BASELINE_AVX3 != 0 && defined(__AVXVNNI__) && defined(__VAES__) && \ + defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \ + defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \ + defined(__AVX512BITALG__) +#define HWY_BASELINE_AVX3_DL HWY_AVX3_DL +#else +#define HWY_BASELINE_AVX3_DL 0 +#endif + +#if HWY_ARCH_RVV && defined(__riscv_vector) +#define HWY_BASELINE_RVV HWY_RVV +#else +#define HWY_BASELINE_RVV 0 +#endif + +// Allow the user to override this without any guarantee of success. +#ifndef HWY_BASELINE_TARGETS +#define HWY_BASELINE_TARGETS \ + (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \ + HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | HWY_BASELINE_NEON | \ + HWY_BASELINE_SSSE3 | HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | \ + HWY_BASELINE_AVX3 | HWY_BASELINE_AVX3_DL | HWY_BASELINE_RVV) +#endif // HWY_BASELINE_TARGETS + +//------------------------------------------------------------------------------ +// Choose target for static dispatch + +#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS) +#if HWY_ENABLED_BASELINE == 0 +#error "At least one baseline target must be defined and enabled" +#endif + +// Best baseline, used for static dispatch. This is the least-significant 1-bit +// within HWY_ENABLED_BASELINE and lower bit values imply "better". +#define HWY_STATIC_TARGET (HWY_ENABLED_BASELINE & -HWY_ENABLED_BASELINE) + +// Start by assuming static dispatch. If we later use dynamic dispatch, this +// will be defined to other targets during the multiple-inclusion, and finally +// return to the initial value. Defining this outside begin/end_target ensures +// inl headers successfully compile by themselves (required by Bazel). +#define HWY_TARGET HWY_STATIC_TARGET + +//------------------------------------------------------------------------------ +// Choose targets for dynamic dispatch according to one of four policies + +#if 1 < (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_EMU128) + \ + defined(HWY_COMPILE_ONLY_STATIC)) +#error "Can only define one of HWY_COMPILE_ONLY_{SCALAR|EMU128|STATIC} - bug?" +#endif +// Defining one of HWY_COMPILE_ONLY_* will trump HWY_COMPILE_ALL_ATTAINABLE. + +// Clang, GCC and MSVC allow runtime dispatch on x86. +#if HWY_ARCH_X86 +#define HWY_HAVE_RUNTIME_DISPATCH 1 +// On Arm, currently only GCC does, and we require Linux to detect CPU +// capabilities. +#elif HWY_ARCH_ARM && HWY_COMPILER_GCC_ACTUAL && HWY_OS_LINUX +#define HWY_HAVE_RUNTIME_DISPATCH 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH 0 +#endif + +// AVX3_DL is not widely available yet. To reduce code size and compile time, +// only include it in the set of attainable targets (for dynamic dispatch) if +// the user opts in, OR it is in the baseline (we check whether enabled below). +#if defined(HWY_WANT_AVX3_DL) || (HWY_BASELINE & HWY_AVX3_DL) +#define HWY_ATTAINABLE_AVX3_DL HWY_AVX3_DL +#else +#define HWY_ATTAINABLE_AVX3_DL 0 +#endif + +#if HWY_ARCH_ARM_A64 && (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256))) +#define HWY_ATTAINABLE_SVE HWY_ENABLED(HWY_SVE | HWY_SVE_256) +#else +#define HWY_ATTAINABLE_SVE 0 +#endif + +#if HWY_ARCH_ARM_A64 && (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128))) +#define HWY_ATTAINABLE_SVE2 HWY_ENABLED(HWY_SVE2 | HWY_SVE2_128) +#else +#define HWY_ATTAINABLE_SVE2 0 +#endif + +// Attainable means enabled and the compiler allows intrinsics (even when not +// allowed to autovectorize). Used in 3 and 4. +#if HWY_ARCH_X86 +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSSE3 | HWY_SSE4 | HWY_AVX2 | \ + HWY_AVX3 | HWY_ATTAINABLE_AVX3_DL) +#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_NEON | HWY_ATTAINABLE_SVE | \ + HWY_ATTAINABLE_SVE2) +#else +#define HWY_ATTAINABLE_TARGETS \ + (HWY_ENABLED_BASELINE | HWY_ATTAINABLE_SVE | HWY_ATTAINABLE_SVE2) +#endif + +// 1) For older compilers: avoid SIMD intrinsics, but still support all ops. +#if defined(HWY_COMPILE_ONLY_EMU128) && !HWY_BROKEN_EMU128 +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_EMU128 // override baseline +#define HWY_TARGETS HWY_EMU128 + +// 1b) HWY_SCALAR is less capable than HWY_EMU128 (which supports all ops), but +// we currently still support it for backwards compatibility. +#elif defined(HWY_COMPILE_ONLY_SCALAR) || \ + (defined(HWY_COMPILE_ONLY_EMU128) && HWY_BROKEN_EMU128) +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_SCALAR // override baseline +#define HWY_TARGETS HWY_SCALAR + +// 2) For forcing static dispatch without code changes (removing HWY_EXPORT) +#elif defined(HWY_COMPILE_ONLY_STATIC) +#define HWY_TARGETS HWY_STATIC_TARGET + +// 3) For tests: include all attainable targets (in particular: scalar) +#elif defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST) +#define HWY_TARGETS HWY_ATTAINABLE_TARGETS + +// 4) Default: attainable WITHOUT non-best baseline. This reduces code size by +// excluding superseded targets, in particular scalar. Note: HWY_STATIC_TARGET +// may be 2^62 (HWY_SCALAR), so we must not left-shift/add it. Subtracting one +// sets all lower bits (better targets), then we also include the static target. +#else +#define HWY_TARGETS \ + (HWY_ATTAINABLE_TARGETS & ((HWY_STATIC_TARGET - 1LL) | HWY_STATIC_TARGET)) + +#endif // target policy + +// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being +// one of the dynamic targets. This also implies HWY_TARGETS != 0 and +// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0. +#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0 +#error "Logic error: best baseline should be included in dynamic targets" +#endif + +#endif // HIGHWAY_HWY_DETECT_TARGETS_H_ diff --git a/media/highway/src/hwy/examples/benchmark.cc b/media/highway/src/hwy/examples/benchmark.cc new file mode 100644 index 000000000..8ab810894 --- /dev/null +++ b/media/highway/src/hwy/examples/benchmark.cc @@ -0,0 +1,254 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include +#include +#include + +#include +#include // iota + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/examples/benchmark.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// Must come after foreach_target.h to avoid redefinition errors. +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" +#include "hwy/nanobenchmark.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::CombineShiftRightLanes; +#endif + +class TwoArray { + public: + // Must be a multiple of the vector lane count * 8. + static size_t NumItems() { return 3456; } + + TwoArray() + : a_(AllocateAligned(NumItems() * 2)), b_(a_.get() + NumItems()) { + // = 1, but compiler doesn't know + const float init = static_cast(Unpredictable1()); + std::iota(a_.get(), a_.get() + NumItems(), init); + std::iota(b_, b_ + NumItems(), init); + } + + protected: + AlignedFreeUniquePtr a_; + float* b_; +}; + +// Measures durations, verifies results, prints timings. +template +void RunBenchmark(const char* caption) { + printf("%10s: ", caption); + const size_t kNumInputs = 1; + const size_t num_items = Benchmark::NumItems() * size_t(Unpredictable1()); + const FuncInput inputs[kNumInputs] = {num_items}; + Result results[kNumInputs]; + + Benchmark benchmark; + + Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&benchmark](const FuncInput input) { return benchmark(input); }, inputs, + kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + } + + benchmark.Verify(num_items); + + for (size_t i = 0; i < num_results; ++i) { + const double cycles_per_item = + results[i].ticks / static_cast(results[i].input); + const double mad = results[i].variability * cycles_per_item; + printf("%6" PRIu64 ": %6.3f (+/- %5.3f)\n", + static_cast(results[i].input), cycles_per_item, mad); + } +} + +void Intro() { + const float in[16] = {1, 2, 3, 4, 5, 6}; + float out[16]; + const ScalableTag d; // largest possible vector + for (size_t i = 0; i < 16; i += Lanes(d)) { + const auto vec = LoadU(d, in + i); // no alignment requirement + auto result = Mul(vec, vec); + result = Add(result, result); // can update if not const + StoreU(result, d, out + i); + } + printf("\nF(x)->2*x^2, F(%.0f) = %.1f\n", in[2], out[2]); +} + +// BEGINNER: dot product +// 0.4 cyc/float = bronze, 0.25 = silver, 0.15 = gold! +class BenchmarkDot : public TwoArray { + public: + BenchmarkDot() : dot_{-1.0f} {} + + FuncOutput operator()(const size_t num_items) { + const ScalableTag d; + const size_t N = Lanes(d); + using V = decltype(Zero(d)); + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // We cannot use an array because V might be a sizeless type. For reasonable + // code, we unroll 4x, but 8x might help (2 FMA ports * 4 cycle latency). + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + const float* const HWY_RESTRICT pa = &a_[0]; + const float* const HWY_RESTRICT pb = b_; + for (size_t i = 0; i < num_items; i += 4 * N) { + const auto a0 = Load(d, pa + i + 0 * N); + const auto b0 = Load(d, pb + i + 0 * N); + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = Load(d, pa + i + 1 * N); + const auto b1 = Load(d, pb + i + 1 * N); + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = Load(d, pa + i + 2 * N); + const auto b2 = Load(d, pb + i + 2 * N); + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = Load(d, pa + i + 3 * N); + const auto b3 = Load(d, pb + i + 3 * N); + sum3 = MulAdd(a3, b3, sum3); + } + // Reduction tree: sum of all accumulators by pairs into sum0. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + dot_ = GetLane(SumOfLanes(d, sum0)); + return static_cast(dot_); + } + void Verify(size_t num_items) { + if (dot_ == -1.0f) { + fprintf(stderr, "Dot: must call Verify after benchmark"); + abort(); + } + + const float expected = + std::inner_product(a_.get(), a_.get() + num_items, b_, 0.0f); + const float rel_err = std::abs(expected - dot_) / expected; + if (rel_err > 1.1E-6f) { + fprintf(stderr, "Dot: expected %e actual %e (%e)\n", expected, dot_, + rel_err); + abort(); + } + } + + private: + float dot_; // for Verify +}; + +// INTERMEDIATE: delta coding +// 1.0 cycles/float = bronze, 0.7 = silver, 0.4 = gold! +struct BenchmarkDelta : public TwoArray { + FuncOutput operator()(const size_t num_items) const { +#if HWY_TARGET == HWY_SCALAR + b_[0] = a_[0]; + for (size_t i = 1; i < num_items; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } +#elif HWY_CAP_GE256 + // Larger vectors are split into 128-bit blocks, easiest to use the + // unaligned load support to shift between them. + const ScalableTag df; + const size_t N = Lanes(df); + size_t i; + b_[0] = a_[0]; + for (i = 1; i < N; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } + for (; i < num_items; i += N) { + const auto a = Load(df, &a_[i]); + const auto shifted = LoadU(df, &a_[i - 1]); + Store(a - shifted, df, &b_[i]); + } +#else // 128-bit + // Slightly better than unaligned loads + const HWY_CAPPED(float, 4) df; + const size_t N = Lanes(df); + size_t i; + b_[0] = a_[0]; + for (i = 1; i < N; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } + auto prev = Load(df, &a_[0]); + for (; i < num_items; i += Lanes(df)) { + const auto a = Load(df, &a_[i]); + const auto shifted = CombineShiftRightLanes<3>(df, a, prev); + prev = a; + Store(Sub(a, shifted), df, &b_[i]); + } +#endif + return static_cast(b_[num_items - 1]); + } + + void Verify(size_t num_items) { + for (size_t i = 0; i < num_items; ++i) { + const float expected = (i == 0) ? a_[0] : a_[i] - a_[i - 1]; + const float err = std::abs(expected - b_[i]); + if (err > 1E-6f) { + fprintf(stderr, "Delta: expected %e, actual %e\n", expected, b_[i]); + } + } + } +}; + +void RunBenchmarks() { + Intro(); + printf("------------------------ %s\n", TargetName(HWY_TARGET)); + RunBenchmark("dot"); + RunBenchmark("delta"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_EXPORT(RunBenchmarks); + +void Run() { + for (int64_t target : SupportedAndGeneratedTargets()) { + SetSupportedTargetsForTest(target); + HWY_DYNAMIC_DISPATCH(RunBenchmarks)(); + } + SetSupportedTargetsForTest(0); // Reset the mask afterwards. +} + +} // namespace hwy + +int main(int /*argc*/, char** /*argv*/) { + hwy::Run(); + return 0; +} +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/examples/skeleton-inl.h b/media/highway/src/hwy/examples/skeleton-inl.h new file mode 100644 index 000000000..8aec33e66 --- /dev/null +++ b/media/highway/src/hwy/examples/skeleton-inl.h @@ -0,0 +1,66 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo of functions that might be called from multiple SIMD modules (either +// other -inl.h files, or a .cc file between begin/end_target-inl). This is +// optional - all SIMD code can reside in .cc files. However, this allows +// splitting code into different files while still inlining instead of requiring +// calling through function pointers. + +// Per-target include guard. This is only required when using dynamic dispatch, +// i.e. including foreach_target.h. For static dispatch, a normal include +// guard would be fine because the header is only compiled once. +#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#else +#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#endif + +// It is fine to #include normal or *-inl headers. +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +// Highway ops reside here; ADL does not find templates nor builtins. +namespace hn = hwy::HWY_NAMESPACE; + +// Example of a type-agnostic (caller-specified lane type) and width-agnostic +// (uses best available instruction set) function in a header. +// +// Computes x[i] = mul_array[i] * x_array[i] + add_array[i] for i < size. +template +HWY_MAYBE_UNUSED void MulAddLoop(const D d, const T* HWY_RESTRICT mul_array, + const T* HWY_RESTRICT add_array, + const size_t size, T* HWY_RESTRICT x_array) { + for (size_t i = 0; i < size; i += hn::Lanes(d)) { + const auto mul = hn::Load(d, mul_array + i); + const auto add = hn::Load(d, add_array + i); + auto x = hn::Load(d, x_array + i); + x = hn::MulAdd(mul, x, add); + hn::Store(x, d, x_array + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#endif // include guard diff --git a/media/highway/src/hwy/examples/skeleton.cc b/media/highway/src/hwy/examples/skeleton.cc new file mode 100644 index 000000000..2e820b6a9 --- /dev/null +++ b/media/highway/src/hwy/examples/skeleton.cc @@ -0,0 +1,121 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/examples/skeleton.h" + +#include + +// >>>> for dynamic dispatch only, skip if you want static dispatch + +// First undef to prevent error when re-included. +#undef HWY_TARGET_INCLUDE +// For dynamic dispatch, specify the name of the current file (unfortunately +// __FILE__ is not reliable) so that foreach_target.h can re-include it. +#define HWY_TARGET_INCLUDE "hwy/examples/skeleton.cc" +// Generates code for each enabled target by re-including this source file. +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// <<<< end of dynamic dispatch + +// Must come after foreach_target.h to avoid redefinition errors. +#include "hwy/highway.h" + +// Optional, can instead add HWY_ATTR to all functions. +HWY_BEFORE_NAMESPACE(); + +namespace skeleton { +// This namespace name is unique per target, which allows code for multiple +// targets to co-exist in the same translation unit. Required when using dynamic +// dispatch, otherwise optional. +namespace HWY_NAMESPACE { + +// Highway ops reside here; ADL does not find templates nor builtins. +namespace hn = hwy::HWY_NAMESPACE; + +// Computes log2 by converting to a vector of floats. Compiled once per target. +template +HWY_ATTR_NO_MSAN void OneFloorLog2(const DF df, + const uint8_t* HWY_RESTRICT values, + uint8_t* HWY_RESTRICT log2) { + // Type tags for converting to other element types (Rebind = same count). + const hn::RebindToSigned d32; + const hn::Rebind d8; + + const auto u8 = hn::Load(d8, values); + const auto bits = hn::BitCast(d32, hn::ConvertTo(df, hn::PromoteTo(d32, u8))); + const auto exponent = hn::Sub(hn::ShiftRight<23>(bits), hn::Set(d32, 127)); + hn::Store(hn::DemoteTo(d8, exponent), d8, log2); +} + +void CodepathDemo() { + // Highway defaults to portability, but per-target codepaths may be selected + // via #if HWY_TARGET == HWY_SSE4 or by testing capability macros: +#if HWY_HAVE_INTEGER64 + const char* gather = "Has int64"; +#else + const char* gather = "No int64"; +#endif + printf("Target %s: %s\n", hwy::TargetName(HWY_TARGET), gather); +} + +void FloorLog2(const uint8_t* HWY_RESTRICT values, size_t count, + uint8_t* HWY_RESTRICT log2) { + CodepathDemo(); + + const hn::ScalableTag df; + const size_t N = hn::Lanes(df); + size_t i = 0; + for (; i + N <= count; i += N) { + OneFloorLog2(df, values + i, log2 + i); + } + for (; i < count; ++i) { + hn::CappedTag d1; + OneFloorLog2(d1, values + i, log2 + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +// The table of pointers to the various implementations in HWY_NAMESPACE must +// be compiled only once (foreach_target #includes this file multiple times). +// HWY_ONCE is true for only one of these 'compilation passes'. +#if HWY_ONCE + +namespace skeleton { + +// This macro declares a static array used for dynamic dispatch; it resides in +// the same outer namespace that contains FloorLog2. +HWY_EXPORT(FloorLog2); + +// This function is optional and only needed in the case of exposing it in the +// header file. Otherwise using HWY_DYNAMIC_DISPATCH(FloorLog2) in this module +// is equivalent to inlining this function. +HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, + const size_t count, + uint8_t* HWY_RESTRICT out) { + // This must reside outside of HWY_NAMESPACE because it references (calls the + // appropriate one from) the per-target implementations there. + // For static dispatch, use HWY_STATIC_DISPATCH. + return HWY_DYNAMIC_DISPATCH(FloorLog2)(in, count, out); +} + +// Optional: anything to compile only once, e.g. non-SIMD implementations of +// public functions provided by this module, can go inside #if HWY_ONCE. + +} // namespace skeleton +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/examples/skeleton.h b/media/highway/src/hwy/examples/skeleton.h new file mode 100644 index 000000000..381ac69af --- /dev/null +++ b/media/highway/src/hwy/examples/skeleton.h @@ -0,0 +1,36 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo interface to target-specific code in skeleton.cc + +// Normal header with include guard and namespace. +#ifndef HIGHWAY_HWY_EXAMPLES_SKELETON_H_ +#define HIGHWAY_HWY_EXAMPLES_SKELETON_H_ + +#include + +// Platform-specific definitions used for declaring an interface, independent of +// the SIMD instruction set. +#include "hwy/base.h" // HWY_RESTRICT + +namespace skeleton { + +// Computes base-2 logarithm by converting to float. Supports dynamic dispatch. +HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, + const size_t count, uint8_t* HWY_RESTRICT out); + +} // namespace skeleton + +#endif // HIGHWAY_HWY_EXAMPLES_SKELETON_H_ diff --git a/media/highway/src/hwy/examples/skeleton_test.cc b/media/highway/src/hwy/examples/skeleton_test.cc new file mode 100644 index 000000000..c7c26bf5b --- /dev/null +++ b/media/highway/src/hwy/examples/skeleton_test.cc @@ -0,0 +1,110 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Example of unit test for the "skeleton" library. + +#include "hwy/examples/skeleton.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "examples/skeleton_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// Must come after foreach_target.h to avoid redefinition errors. +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Optional: factor out parts of the implementation into *-inl.h +// (must also come after foreach_target.h to avoid redefinition errors) +#include "hwy/examples/skeleton-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +// Calls function defined in skeleton.cc. +struct TestFloorLog2 { + template + HWY_NOINLINE void operator()(T /*unused*/, DF df) { + const size_t count = 5 * hn::Lanes(df); + auto in = hwy::AllocateAligned(count); + auto expected = hwy::AllocateAligned(count); + + hwy::RandomState rng; + for (size_t i = 0; i < count; ++i) { + expected[i] = Random32(&rng) & 7; + in[i] = static_cast(1u << expected[i]); + } + auto out = hwy::AllocateAligned(count); + CallFloorLog2(in.get(), count, out.get()); + int sum = 0; + for (size_t i = 0; i < count; ++i) { + HWY_ASSERT_EQ(expected[i], out[i]); + sum += out[i]; + } + hwy::PreventElision(sum); + } +}; + +HWY_NOINLINE void TestAllFloorLog2() { + hn::ForPartialVectors()(float()); +} + +// Calls function defined in skeleton-inl.h. +struct TestSumMulAdd { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + hwy::RandomState rng; + const size_t count = 4096; + EXPECT_EQ(0, count % hn::Lanes(d)); + auto mul = hwy::AllocateAligned(count); + auto x = hwy::AllocateAligned(count); + auto add = hwy::AllocateAligned(count); + for (size_t i = 0; i < count; ++i) { + mul[i] = static_cast(Random32(&rng) & 0xF); + x[i] = static_cast(Random32(&rng) & 0xFF); + add[i] = static_cast(Random32(&rng) & 0xFF); + } + double expected_sum = 0.0; + for (size_t i = 0; i < count; ++i) { + expected_sum += mul[i] * x[i] + add[i]; + } + + MulAddLoop(d, mul.get(), add.get(), count, x.get()); + HWY_ASSERT_EQ(4344240.0, expected_sum); + } +}; + +HWY_NOINLINE void TestAllSumMulAdd() { + hn::ForFloatTypes(hn::ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace skeleton { +HWY_BEFORE_TEST(SkeletonTest); +HWY_EXPORT_AND_TEST_P(SkeletonTest, TestAllFloorLog2); +HWY_EXPORT_AND_TEST_P(SkeletonTest, TestAllSumMulAdd); +} // namespace skeleton + +#endif diff --git a/media/highway/src/hwy/foreach_target.h b/media/highway/src/hwy/foreach_target.h new file mode 100644 index 000000000..3929905ca --- /dev/null +++ b/media/highway/src/hwy/foreach_target.h @@ -0,0 +1,261 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_FOREACH_TARGET_H_ +#define HIGHWAY_HWY_FOREACH_TARGET_H_ + +// Re-includes the translation unit zero or more times to compile for any +// targets except HWY_STATIC_TARGET. Defines unique HWY_TARGET each time so that +// highway.h defines the corresponding macro/namespace. + +#include "hwy/detect_targets.h" + +// *_inl.h may include other headers, which requires include guards to prevent +// repeated inclusion. The guards must be reset after compiling each target, so +// the header is again visible. This is done by flipping HWY_TARGET_TOGGLE, +// defining it if undefined and vice versa. This macro is initially undefined +// so that IDEs don't gray out the contents of each header. +#ifdef HWY_TARGET_TOGGLE +#error "This macro must not be defined outside foreach_target.h" +#endif + +#ifdef HWY_HIGHWAY_INCLUDED // highway.h include guard +// Trigger fixup at the bottom of this header. +#define HWY_ALREADY_INCLUDED + +// The next highway.h must re-include set_macros-inl.h because the first +// highway.h chose the static target instead of what we will set below. +#undef HWY_SET_MACROS_PER_TARGET +#endif + +// Disable HWY_EXPORT in user code until we have generated all targets. Note +// that a subsequent highway.h will not override this definition. +#undef HWY_ONCE +#define HWY_ONCE (0 || HWY_IDE) + +// Avoid warnings on #include HWY_TARGET_INCLUDE by hiding them from the IDE; +// also skip if only 1 target defined (no re-inclusion will be necessary). +#if !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +#if !defined(HWY_TARGET_INCLUDE) +#error ">1 target enabled => define HWY_TARGET_INCLUDE before foreach_target.h" +#endif + +#if (HWY_TARGETS & HWY_EMU128) && (HWY_STATIC_TARGET != HWY_EMU128) +#undef HWY_TARGET +#define HWY_TARGET HWY_EMU128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SCALAR) && (HWY_STATIC_TARGET != HWY_SCALAR) +#undef HWY_TARGET +#define HWY_TARGET HWY_SCALAR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON) && (HWY_STATIC_TARGET != HWY_NEON) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_RVV) && (HWY_STATIC_TARGET != HWY_RVV) +#undef HWY_TARGET +#define HWY_TARGET HWY_RVV +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2) && (HWY_STATIC_TARGET != HWY_SVE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE_256) && (HWY_STATIC_TARGET != HWY_SVE_256) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE_256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2_128) && (HWY_STATIC_TARGET != HWY_SVE2_128) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2_128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSSE3) && (HWY_STATIC_TARGET != HWY_SSSE3) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSSE3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSE4) && (HWY_STATIC_TARGET != HWY_SSE4) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX2) && (HWY_STATIC_TARGET != HWY_AVX2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3) && (HWY_STATIC_TARGET != HWY_AVX3) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_DL) && (HWY_STATIC_TARGET != HWY_AVX3_DL) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_DL +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM_EMU256) && (HWY_STATIC_TARGET != HWY_WASM_EMU256) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM_EMU256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM) && (HWY_STATIC_TARGET != HWY_WASM) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC8) && (HWY_STATIC_TARGET != HWY_PPC8) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC8 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#endif // !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +// Now that all but the static target have been generated, re-enable HWY_EXPORT. +#undef HWY_ONCE +#define HWY_ONCE 1 + +// If we re-include once per enabled target, the translation unit's +// implementation would have to be skipped via #if to avoid redefining symbols. +// We instead skip the re-include for HWY_STATIC_TARGET, and generate its +// implementation when resuming compilation of the translation unit. +#undef HWY_TARGET +#define HWY_TARGET HWY_STATIC_TARGET + +#ifdef HWY_ALREADY_INCLUDED +// Revert the previous toggle to prevent redefinitions for the static target. +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif + +// Force re-inclusion of set_macros-inl.h now that HWY_TARGET is restored. +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif +#endif + +#endif // HIGHWAY_HWY_FOREACH_TARGET_H_ diff --git a/media/highway/src/hwy/highway.h b/media/highway/src/hwy/highway.h new file mode 100644 index 000000000..4640f31e8 --- /dev/null +++ b/media/highway/src/hwy/highway.h @@ -0,0 +1,378 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This include guard is checked by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included +// after/outside this include guard. +#ifndef HWY_HIGHWAY_INCLUDED +#define HWY_HIGHWAY_INCLUDED + +// Main header required before using vector types. + +#include "hwy/base.h" +#include "hwy/targets.h" + +namespace hwy { + +// API version (https://semver.org/); keep in sync with CMakeLists.txt. +#define HWY_MAJOR 1 +#define HWY_MINOR 0 +#define HWY_PATCH 2 + +//------------------------------------------------------------------------------ +// Shorthand for tags (defined in shared-inl.h) used to select overloads. +// Note that ScalableTag is preferred over HWY_FULL, and CappedTag over +// HWY_CAPPED(T, N). + +// HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of +// registers in the group, and is ignored on targets that do not support groups. +#define HWY_FULL1(T) hwy::HWY_NAMESPACE::ScalableTag +#define HWY_FULL2(T, LMUL) \ + hwy::HWY_NAMESPACE::ScalableTag +#define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3 +// Workaround for MSVC grouping __VA_ARGS__ into a single argument +#define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren +// Trailing comma avoids -pedantic false alarm +#define HWY_CHOOSE_FULL(...) \ + HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, )) +#define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__) + +// Vector of up to MAX_N lanes. It's better to use full vectors where possible. +#define HWY_CAPPED(T, MAX_N) hwy::HWY_NAMESPACE::CappedTag + +//------------------------------------------------------------------------------ +// Export user functions for static/dynamic dispatch + +// Evaluates to 0 inside a translation unit if it is generating anything but the +// static target (the last one if multiple targets are enabled). Used to prevent +// redefinitions of HWY_EXPORT. Unless foreach_target.h is included, we only +// compile once anyway, so this is 1 unless it is or has been included. +#ifndef HWY_ONCE +#define HWY_ONCE 1 +#endif + +// HWY_STATIC_DISPATCH(FUNC_NAME) is the namespace-qualified FUNC_NAME for +// HWY_STATIC_TARGET (the only defined namespace unless HWY_TARGET_INCLUDE is +// defined), and can be used to deduce the return type of Choose*. +#if HWY_STATIC_TARGET == HWY_SCALAR +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SCALAR::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_EMU128 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_EMU128::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_RVV +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_RVV::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM_EMU256 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM_EMU256::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE_256 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE_256::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE2_128 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2_128::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_PPC8 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC8::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSSE3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSSE3::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSE4 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE4::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3_DL +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_DL::FUNC_NAME +#endif + +// HWY_CHOOSE_*(FUNC_NAME) expands to the function pointer for that target or +// nullptr is that target was not compiled. +#if HWY_TARGETS & HWY_EMU128 +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_EMU128::FUNC_NAME +#elif HWY_TARGETS & HWY_SCALAR +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_SCALAR::FUNC_NAME +#else +// When HWY_SCALAR/HWY_EMU128 are not present and other targets were disabled at +// runtime, fall back to the baseline with HWY_STATIC_DISPATCH(). +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#endif + +#if HWY_TARGETS & HWY_WASM_EMU256 +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) &N_WASM_EMU256::FUNC_NAME +#else +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_WASM +#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME +#else +#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_RVV +#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME +#else +#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_NEON +#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME +#else +#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE +#define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME +#else +#define HWY_CHOOSE_SVE(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE2 +#define HWY_CHOOSE_SVE2(FUNC_NAME) &N_SVE2::FUNC_NAME +#else +#define HWY_CHOOSE_SVE2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE_256 +#define HWY_CHOOSE_SVE_256(FUNC_NAME) &N_SVE_256::FUNC_NAME +#else +#define HWY_CHOOSE_SVE_256(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE2_128 +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) &N_SVE2_128::FUNC_NAME +#else +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_PPC8 +#define HWY_CHOOSE_PCC8(FUNC_NAME) &N_PPC8::FUNC_NAME +#else +#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSSE3 +#define HWY_CHOOSE_SSSE3(FUNC_NAME) &N_SSSE3::FUNC_NAME +#else +#define HWY_CHOOSE_SSSE3(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSE4 +#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME +#else +#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX2 +#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME +#else +#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3 +#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3_DL +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) &N_AVX3_DL::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) nullptr +#endif + +// MSVC 2017 workaround: the non-type template parameter to ChooseAndCall +// apparently cannot be an array. Use a function pointer instead, which has the +// disadvantage that we call the static (not best) target on the first call to +// any HWY_DYNAMIC_DISPATCH. +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915 +#define HWY_DISPATCH_WORKAROUND 1 +#else +#define HWY_DISPATCH_WORKAROUND 0 +#endif + +// Provides a static member function which is what is called during the first +// HWY_DYNAMIC_DISPATCH, where GetIndex is still zero, and instantiations of +// this function are the first entry in the tables created by HWY_EXPORT. +template +struct FunctionCache { + public: + typedef RetType(FunctionType)(Args...); + +#if HWY_DISPATCH_WORKAROUND + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (*func)(args...); + } +#else + // A template function that when instantiated has the same signature as the + // function being called. This function initializes the bit array of targets + // supported by the current CPU and then calls the appropriate entry within + // the HWY_EXPORT table. Subsequent calls via HWY_DYNAMIC_DISPATCH to any + // exported functions, even those defined by different translation units, + // will dispatch directly to the best available target. + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // HWY_DISPATCH_WORKAROUND +}; + +// Used to deduce the template parameters RetType and Args from a function. +template +FunctionCache DeduceFunctionCache(RetType (*)(Args...)) { + return FunctionCache(); +} + +#define HWY_DISPATCH_TABLE(FUNC_NAME) \ + HWY_CONCAT(FUNC_NAME, HighwayDispatchTable) + +// HWY_EXPORT(FUNC_NAME); expands to a static array that is used by +// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. This +// static array must be defined at the same namespace level as the function +// it is exporting. +// After being exported, it can be called from other parts of the same source +// file using HWY_DYNAMIC_DISPATCH(), in particular from a function wrapper +// like in the following example: +// +// #include "hwy/highway.h" +// HWY_BEFORE_NAMESPACE(); +// namespace skeleton { +// namespace HWY_NAMESPACE { +// +// void MyFunction(int a, char b, const char* c) { ... } +// +// // NOLINTNEXTLINE(google-readability-namespace-comments) +// } // namespace HWY_NAMESPACE +// } // namespace skeleton +// HWY_AFTER_NAMESPACE(); +// +// namespace skeleton { +// HWY_EXPORT(MyFunction); // Defines the dispatch table in this scope. +// +// void MyFunction(int a, char b, const char* c) { +// return HWY_DYNAMIC_DISPATCH(MyFunction)(a, b, c); +// } +// } // namespace skeleton +// + +#if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// Simplified version for IDE or the dynamic dispatch case with only one target. +// This case still uses a table, although of a single element, to provide the +// same compile error conditions as with the dynamic dispatch case when multiple +// targets are being compiled. +#define HWY_EXPORT(FUNC_NAME) \ + HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const \ + HWY_DISPATCH_TABLE(FUNC_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)} +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) + +#else + +// Simplified version for MSVC 2017: function pointer instead of table. +#if HWY_DISPATCH_WORKAROUND + +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the function from HWY_STATIC_TARGET. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH( \ + FUNC_NAME)))::ChooseAndCall<&HWY_STATIC_DISPATCH(FUNC_NAME)>, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } + +#else + +// Dynamic dispatch case with one entry per dynamic target plus the fallback +// target and the initialization wrapper. +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH( \ + FUNC_NAME)))::ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } + +#endif // HWY_DISPATCH_WORKAROUND + +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \ + (*(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()])) + +#endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// DEPRECATED names; please use HWY_HAVE_* instead. +#define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64 +#define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16 +#define HWY_CAP_FLOAT64 HWY_HAVE_FLOAT64 + +} // namespace hwy + +#endif // HWY_HIGHWAY_INCLUDED + +//------------------------------------------------------------------------------ + +// NOTE: the following definitions and ops/*.h depend on HWY_TARGET, so we want +// to include them once per target, which is ensured by the toggle check. +// Because ops/*.h are included under it, they do not need their own guard. +#if defined(HWY_HIGHWAY_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_HIGHWAY_PER_TARGET +#undef HWY_HIGHWAY_PER_TARGET +#else +#define HWY_HIGHWAY_PER_TARGET +#endif + +// These define ops inside namespace hwy::HWY_NAMESPACE. +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 +#include "hwy/ops/x86_128-inl.h" +#elif HWY_TARGET == HWY_AVX2 +#include "hwy/ops/x86_256-inl.h" +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL +#include "hwy/ops/x86_512-inl.h" +#elif HWY_TARGET == HWY_PPC8 +#error "PPC is not yet supported" +#elif HWY_TARGET == HWY_NEON +#include "hwy/ops/arm_neon-inl.h" +#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || \ + HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#include "hwy/ops/arm_sve-inl.h" +#elif HWY_TARGET == HWY_WASM_EMU256 +#include "hwy/ops/wasm_256-inl.h" +#elif HWY_TARGET == HWY_WASM +#include "hwy/ops/wasm_128-inl.h" +#elif HWY_TARGET == HWY_RVV +#include "hwy/ops/rvv-inl.h" +#elif HWY_TARGET == HWY_EMU128 +#include "hwy/ops/emu128-inl.h" +#elif HWY_TARGET == HWY_SCALAR +#include "hwy/ops/scalar-inl.h" +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +#include "hwy/ops/generic_ops-inl.h" + +#endif // HWY_HIGHWAY_PER_TARGET diff --git a/media/highway/src/hwy/highway_export.h b/media/highway/src/hwy/highway_export.h new file mode 100644 index 000000000..30edc17d0 --- /dev/null +++ b/media/highway/src/hwy/highway_export.h @@ -0,0 +1,74 @@ +// Pseudo-generated file to handle both cmake & bazel build system. + +// Initial generation done using cmake code: +// include(GenerateExportHeader) +// generate_export_header(hwy EXPORT_MACRO_NAME HWY_DLLEXPORT EXPORT_FILE_NAME +// hwy/highway_export.h) +// code reformatted using clang-format --style=Google + +#ifndef HWY_DLLEXPORT_H +#define HWY_DLLEXPORT_H + +#if !defined(HWY_SHARED_DEFINE) +#define HWY_DLLEXPORT +#define HWY_CONTRIB_DLLEXPORT +#define HWY_TEST_DLLEXPORT +#else // !HWY_SHARED_DEFINE + +#ifndef HWY_DLLEXPORT +#if defined(hwy_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllexport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllimport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_EXPORTS) +#endif // HWY_DLLEXPORT + +#ifndef HWY_CONTRIB_DLLEXPORT +#if defined(hwy_contrib_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllexport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_contrib_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllimport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_contrib_EXPORTS) +#endif // HWY_CONTRIB_DLLEXPORT + +#ifndef HWY_TEST_DLLEXPORT +#if defined(hwy_test_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllexport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_test_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllimport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_test_EXPORTS) +#endif // HWY_TEST_DLLEXPORT + +#endif // !HWY_SHARED_DEFINE + +#endif /* HWY_DLLEXPORT_H */ diff --git a/media/highway/src/hwy/highway_test.cc b/media/highway/src/hwy/highway_test.cc new file mode 100644 index 000000000..4838e72f4 --- /dev/null +++ b/media/highway/src/hwy/highway_test.cc @@ -0,0 +1,485 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "highway_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/nanobenchmark.h" // Unpredictable1 +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +HWY_NOINLINE void TestCappedLimit(T /* tag */) { + CappedTag d; + // Ensure two ops compile + HWY_ASSERT_VEC_EQ(d, Zero(d), Set(d, T{0})); + + // Ensure we do not write more than kLimit lanes + const size_t N = Lanes(d); + if (kLimit < N) { + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T{0}); + Store(Set(d, T{1}), d, lanes.get()); + for (size_t i = kLimit; i < N; ++i) { + HWY_ASSERT_EQ(lanes[i], T{0}); + } + } +} + +// Adapter for ForAllTypes - we are constructing our own Simd<> and thus do not +// use ForPartialVectors etc. +struct TestCapped { + template + void operator()(T t) const { + TestCappedLimit<1>(t); + TestCappedLimit<3>(t); + TestCappedLimit<5>(t); + TestCappedLimit<1ull << 15>(t); + } +}; + +HWY_NOINLINE void TestAllCapped() { ForAllTypes(TestCapped()); } + +// For testing that ForPartialVectors reaches every possible size: +using NumLanesSet = std::bitset; + +// Monostate pattern because ForPartialVectors takes a template argument, not a +// functor by reference. +static NumLanesSet* NumLanesForSize(size_t sizeof_t) { + HWY_ASSERT(sizeof_t <= sizeof(uint64_t)); + static NumLanesSet num_lanes[sizeof(uint64_t) + 1]; + return num_lanes + sizeof_t; +} +static size_t* MaxLanesForSize(size_t sizeof_t) { + HWY_ASSERT(sizeof_t <= sizeof(uint64_t)); + static size_t num_lanes[sizeof(uint64_t) + 1] = {0}; + return num_lanes + sizeof_t; +} + +struct TestMaxLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const size_t kMax = MaxLanes(d); // for RVV, includes LMUL + HWY_ASSERT(N <= kMax); + HWY_ASSERT(kMax <= (HWY_MAX_BYTES / sizeof(T))); + + NumLanesForSize(sizeof(T))->set(N); + *MaxLanesForSize(sizeof(T)) = HWY_MAX(*MaxLanesForSize(sizeof(T)), N); + } +}; + +HWY_NOINLINE void TestAllMaxLanes() { + ForAllTypes(ForPartialVectors()); + + // Ensure ForPartialVectors visited all powers of two [1, N]. + for (size_t sizeof_t : {sizeof(uint8_t), sizeof(uint16_t), sizeof(uint32_t), + sizeof(uint64_t)}) { + const size_t N = *MaxLanesForSize(sizeof_t); + for (size_t i = 1; i <= N; i += i) { + if (!NumLanesForSize(sizeof_t)->test(i)) { + fprintf(stderr, "T=%d: did not visit for N=%d, max=%d\n", + static_cast(sizeof_t), static_cast(i), + static_cast(N)); + HWY_ASSERT(false); + } + } + } +} + +struct TestSet { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Zero + const auto v0 = Zero(d); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + std::fill(expected.get(), expected.get() + N, T(0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), v0); + + // Set + const auto v2 = Set(d, T(2)); + for (size_t i = 0; i < N; ++i) { + expected[i] = 2; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), v2); + + // Iota + const auto vi = Iota(d, T(5)); + for (size_t i = 0; i < N; ++i) { + expected[i] = T(5 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), vi); + + // Undefined + const auto vu = Undefined(d); + Store(vu, d, expected.get()); + } +}; + +HWY_NOINLINE void TestAllSet() { ForAllTypes(ForPartialVectors()); } + +// Ensures wraparound (mod 2^bits) +struct TestOverflow { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(1)); + const auto vmax = Set(d, LimitsMax()); + const auto vmin = Set(d, LimitsMin()); + // Unsigned underflow / negative -> positive + HWY_ASSERT_VEC_EQ(d, vmax, Sub(vmin, v1)); + // Unsigned overflow / positive -> negative + HWY_ASSERT_VEC_EQ(d, vmin, Add(vmax, v1)); + } +}; + +HWY_NOINLINE void TestAllOverflow() { + ForIntegerTypes(ForPartialVectors()); +} + +struct TestClamp { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto v2 = Set(d, 2); + + HWY_ASSERT_VEC_EQ(d, v1, Clamp(v2, v0, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Clamp(v0, v1, v2)); + } +}; + +HWY_NOINLINE void TestAllClamp() { + ForAllTypes(ForPartialVectors()); +} + +struct TestSignBitInteger { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto all = VecFromMask(d, Eq(v0, v0)); + const auto vs = SignBit(d); + const auto other = Sub(vs, Set(d, 1)); + + // Shifting left by one => overflow, equal zero + HWY_ASSERT_VEC_EQ(d, v0, Add(vs, vs)); + // Verify the lower bits are zero (only +/- and logical ops are available + // for all types) + HWY_ASSERT_VEC_EQ(d, all, Add(vs, other)); + } +}; + +struct TestSignBitFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vs = SignBit(d); + const auto vp = Set(d, 2.25); + const auto vn = Set(d, -2.25); + HWY_ASSERT_VEC_EQ(d, Or(vp, vs), vn); + HWY_ASSERT_VEC_EQ(d, AndNot(vs, vn), vp); + HWY_ASSERT_VEC_EQ(d, v0, vs); + } +}; + +HWY_NOINLINE void TestAllSignBit() { + ForIntegerTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +// inline to work around incorrect SVE codegen (only first 128 bits used). +template +HWY_INLINE void AssertNaN(D d, VecArg v, const char* file, int line) { + using T = TFromD; + const size_t N = Lanes(d); + if (!AllTrue(d, IsNaN(v))) { + Print(d, "not all NaN", v, 0, N); + Print(d, "mask", VecFromMask(d, IsNaN(v)), 0, N); + const std::string type_name = TypeName(T(), N); + // RVV lacks PRIu64 and MSYS still has problems with %zu, so print bytes to + // avoid truncating doubles. + uint8_t bytes[HWY_MAX(sizeof(T), 8)] = {0}; + const T lane = GetLane(v); + CopyBytes(&lane, bytes); + Abort(file, line, + "Expected %s NaN, got %E (bytes %02x %02x %02x %02x %02x %02x %02x " + "%02x)", + type_name.c_str(), lane, bytes[0], bytes[1], bytes[2], bytes[3], + bytes[4], bytes[5], bytes[6], bytes[7]); + } +} + +#define HWY_ASSERT_NAN(d, v) AssertNaN(d, v, __FILE__, __LINE__) + +struct TestNaN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + HWY_ASSERT_NAN(d, nan); + + // Arithmetic + HWY_ASSERT_NAN(d, Add(nan, v1)); + HWY_ASSERT_NAN(d, Add(v1, nan)); + HWY_ASSERT_NAN(d, Sub(nan, v1)); + HWY_ASSERT_NAN(d, Sub(v1, nan)); + HWY_ASSERT_NAN(d, Mul(nan, v1)); + HWY_ASSERT_NAN(d, Mul(v1, nan)); + HWY_ASSERT_NAN(d, Div(nan, v1)); + HWY_ASSERT_NAN(d, Div(v1, nan)); + + // FMA + HWY_ASSERT_NAN(d, MulAdd(nan, v1, v1)); + HWY_ASSERT_NAN(d, MulAdd(v1, nan, v1)); + HWY_ASSERT_NAN(d, MulAdd(v1, v1, nan)); + HWY_ASSERT_NAN(d, MulSub(nan, v1, v1)); + HWY_ASSERT_NAN(d, MulSub(v1, nan, v1)); + HWY_ASSERT_NAN(d, MulSub(v1, v1, nan)); + HWY_ASSERT_NAN(d, NegMulAdd(nan, v1, v1)); + HWY_ASSERT_NAN(d, NegMulAdd(v1, nan, v1)); + HWY_ASSERT_NAN(d, NegMulAdd(v1, v1, nan)); + HWY_ASSERT_NAN(d, NegMulSub(nan, v1, v1)); + HWY_ASSERT_NAN(d, NegMulSub(v1, nan, v1)); + HWY_ASSERT_NAN(d, NegMulSub(v1, v1, nan)); + + // Rcp/Sqrt + HWY_ASSERT_NAN(d, Sqrt(nan)); + + // Sign manipulation + HWY_ASSERT_NAN(d, Abs(nan)); + HWY_ASSERT_NAN(d, Neg(nan)); + HWY_ASSERT_NAN(d, CopySign(nan, v1)); + HWY_ASSERT_NAN(d, CopySignToAbs(nan, v1)); + + // Rounding + HWY_ASSERT_NAN(d, Ceil(nan)); + HWY_ASSERT_NAN(d, Floor(nan)); + HWY_ASSERT_NAN(d, Round(nan)); + HWY_ASSERT_NAN(d, Trunc(nan)); + + // Logical (And/AndNot/Xor will clear NaN!) + HWY_ASSERT_NAN(d, Or(nan, v1)); + + // Comparison + HWY_ASSERT(AllFalse(d, Eq(nan, v1))); + HWY_ASSERT(AllFalse(d, Gt(nan, v1))); + HWY_ASSERT(AllFalse(d, Lt(nan, v1))); + HWY_ASSERT(AllFalse(d, Ge(nan, v1))); + HWY_ASSERT(AllFalse(d, Le(nan, v1))); + + // Reduction + HWY_ASSERT_NAN(d, SumOfLanes(d, nan)); +// TODO(janwas): re-enable after QEMU/Spike are fixed +#if HWY_TARGET != HWY_RVV + HWY_ASSERT_NAN(d, MinOfLanes(d, nan)); + HWY_ASSERT_NAN(d, MaxOfLanes(d, nan)); +#endif + + // Min +#if HWY_ARCH_X86 && (HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_EMU128) + // x86 SIMD returns the second operand if any input is NaN. + HWY_ASSERT_VEC_EQ(d, v1, Min(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Max(nan, v1)); + HWY_ASSERT_NAN(d, Min(v1, nan)); + HWY_ASSERT_NAN(d, Max(v1, nan)); +#elif HWY_ARCH_WASM + // Should return NaN if any input is NaN, but does not for scalar. + // TODO(janwas): remove once this is fixed. +#elif HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7 + // ARMv7 NEON returns NaN if any input is NaN. + HWY_ASSERT_NAN(d, Min(v1, nan)); + HWY_ASSERT_NAN(d, Max(v1, nan)); + HWY_ASSERT_NAN(d, Min(nan, v1)); + HWY_ASSERT_NAN(d, Max(nan, v1)); +#else + // IEEE 754-2019 minimumNumber is defined as the other argument if exactly + // one is NaN, and qNaN if both are. + HWY_ASSERT_VEC_EQ(d, v1, Min(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Max(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, nan)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, nan)); +#endif + HWY_ASSERT_NAN(d, Min(nan, nan)); + HWY_ASSERT_NAN(d, Max(nan, nan)); + } +}; + +// For functions only available for float32 +struct TestF32NaN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + HWY_ASSERT_NAN(d, ApproximateReciprocal(nan)); + HWY_ASSERT_NAN(d, ApproximateReciprocalSqrt(nan)); + HWY_ASSERT_NAN(d, AbsDiff(nan, v1)); + HWY_ASSERT_NAN(d, AbsDiff(v1, nan)); + } +}; + +HWY_NOINLINE void TestAllNaN() { + ForFloatTypes(ForPartialVectors()); + ForPartialVectors()(float()); +} + +struct TestIsNaN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto inf = IfThenElse(Eq(v1, Set(d, T(1))), Inf(d), v1); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + const auto neg = Set(d, T{-1}); + HWY_ASSERT_NAN(d, nan); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(inf)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(CopySign(inf, neg))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsNaN(nan)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsNaN(CopySign(nan, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(v1)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(Zero(d))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(Set(d, hwy::LowestValue()))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(Set(d, hwy::HighestValue()))); + } +}; + +HWY_NOINLINE void TestAllIsNaN() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestIsInf { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto inf = IfThenElse(Eq(v1, Set(d, T(1))), Inf(d), v1); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + const auto neg = Set(d, T{-1}); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsInf(inf)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsInf(CopySign(inf, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(nan)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(CopySign(nan, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(v1)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(Zero(d))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(Set(d, hwy::LowestValue()))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(Set(d, hwy::HighestValue()))); + } +}; + +HWY_NOINLINE void TestAllIsInf() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestIsFinite { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto inf = IfThenElse(Eq(v1, Set(d, T(1))), Inf(d), v1); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + const auto neg = Set(d, T{-1}); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(inf)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(CopySign(inf, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(nan)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(CopySign(nan, neg))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsFinite(v1)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsFinite(Zero(d))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsFinite(Set(d, hwy::LowestValue()))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), + IsFinite(Set(d, hwy::HighestValue()))); + } +}; + +HWY_NOINLINE void TestAllIsFinite() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestCopyAndAssign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // copy V + const auto v3 = Iota(d, 3); + auto v3b(v3); + HWY_ASSERT_VEC_EQ(d, v3, v3b); + + // assign V + auto v3c = Undefined(d); + v3c = v3; + HWY_ASSERT_VEC_EQ(d, v3, v3c); + } +}; + +HWY_NOINLINE void TestAllCopyAndAssign() { + ForAllTypes(ForPartialVectors()); +} + +struct TestGetLane { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + HWY_ASSERT_EQ(T(0), GetLane(Zero(d))); + HWY_ASSERT_EQ(T(1), GetLane(Set(d, 1))); + } +}; + +HWY_NOINLINE void TestAllGetLane() { + ForAllTypes(ForPartialVectors()); +} + +struct TestDFromV { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + using D0 = DFromV; // not necessarily same as D + const auto v0b = And(v0, Set(D0(), 1)); // but vectors can interoperate + HWY_ASSERT_VEC_EQ(d, v0, v0b); + } +}; + +HWY_NOINLINE void TestAllDFromV() { + ForAllTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HighwayTest); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllCapped); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllMaxLanes); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSet); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllOverflow); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllClamp); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSignBit); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllNaN); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllIsNaN); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllIsInf); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllIsFinite); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllCopyAndAssign); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllGetLane); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllDFromV); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/hwy.version b/media/highway/src/hwy/hwy.version new file mode 100644 index 000000000..9ff6be6a2 --- /dev/null +++ b/media/highway/src/hwy/hwy.version @@ -0,0 +1,19 @@ +HWY_0 { + global: + extern "C++" { + *hwy::*; + }; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/media/highway/src/hwy/nanobenchmark.cc b/media/highway/src/hwy/nanobenchmark.cc new file mode 100644 index 000000000..e03ed4cf6 --- /dev/null +++ b/media/highway/src/hwy/nanobenchmark.cc @@ -0,0 +1,762 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include +#include +#include +#include // clock_gettime + +#include // sort +#include +#include +#include //NOLINT +#include +#include // iota +#include +#include +#include + +#if defined(_WIN32) || defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#include +#endif + +#if defined(__APPLE__) +#include +#include +#endif + +#if defined(__HAIKU__) +#include +#endif + +#include "hwy/base.h" +#if HWY_ARCH_PPC && defined(__GLIBC__) +#include // NOLINT __ppc_get_timebase_freq +#elif HWY_ARCH_X86 + +#if HWY_COMPILER_MSVC +#include +#else +#include // NOLINT +#endif // HWY_COMPILER_MSVC + +#endif // HWY_ARCH_X86 + +namespace hwy { +namespace { +namespace timer { + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. +using Ticks = uint64_t; + +// Start/Stop return absolute timestamps and must be placed immediately before +// and after the region to measure. We provide separate Start/Stop functions +// because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final RDTSCP would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional RDTSCP half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than RDTSCP because it does not read TSC_AUX. In summary, +// we define Start = LFENCE/RDTSC/LFENCE; Stop = RDTSCP/LFENCE. +// +// Using Start+Start leads to higher variance and overhead than Stop+Stop. +// However, Stop+Stop includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Start+Stop +// is faster than Start+Start and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +inline Ticks Start() { + Ticks t; +#if HWY_ARCH_PPC && defined(__GLIBC__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = __rdtsc(); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#elif HWY_ARCH_RVV + asm volatile("rdtime %0" : "=r"(t)); +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + t = counter.QuadPart; +#elif defined(__APPLE__) + t = mach_absolute_time(); +#elif defined(__HAIKU__) + t = system_time_nsecs(); // since boot +#else // POSIX + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = static_cast(ts.tv_sec * 1000000000LL + ts.tv_nsec); +#endif + return t; +} + +// WARNING: on x86, caller must check HasRDTSCP before using this! +inline Ticks Stop() { + uint64_t t; +#if HWY_ARCH_PPC && defined(__GLIBC__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + unsigned aux; + t = __rdtscp(&aux); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = Start(); +#endif + return t; +} + +} // namespace timer + +namespace robust_statistics { + +// Sorts integral values in ascending order (e.g. for Mode). About 3x faster +// than std::sort for input distributions with very few unique values. +template +void CountingSort(T* values, size_t num_values) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair; + std::vector unique; + for (size_t i = 0; i < num_values; ++i) { + const T value = values[i]; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(value, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* HWY_RESTRICT p = values; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + NANOBENCHMARK_CHECK(p == values + num_values); +} + +// @return i in [idx_begin, idx_begin + half_count) that minimizes +// sorted[i + half_count] - sorted[i]. +template +size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin, + const size_t half_count) { + T min_range = std::numeric_limits::max(); + size_t min_idx = 0; + + for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { + NANOBENCHMARK_CHECK(sorted[idx] <= sorted[idx + half_count]); + const T range = sorted[idx + half_count] - sorted[idx]; + if (range < min_range) { + min_range = range; + min_idx = idx; + } + } + + return min_idx; +} + +// Returns an estimate of the mode by calling MinRange on successively +// halved intervals. "sorted" must be in ascending order. This is the +// Half Sample Mode estimator proposed by Bickel in "On a fast, robust +// estimator of the mode", with complexity O(N log N). The mode is less +// affected by outliers in highly-skewed distributions than the median. +// The averaging operation below assumes "T" is an unsigned integer type. +template +T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) { + size_t idx_begin = 0; + size_t half_count = num_values / 2; + while (half_count > 1) { + idx_begin = MinRange(sorted, idx_begin, half_count); + half_count >>= 1; + } + + const T x = sorted[idx_begin + 0]; + if (half_count == 0) { + return x; + } + NANOBENCHMARK_CHECK(half_count == 1); + const T average = (x + sorted[idx_begin + 1] + 1) / 2; + return average; +} + +// Returns the mode. Side effect: sorts "values". +template +T Mode(T* values, const size_t num_values) { + CountingSort(values, num_values); + return ModeOfSorted(values, num_values); +} + +template +T Mode(T (&values)[N]) { + return Mode(&values[0], N); +} + +// Returns the median value. Side effect: sorts "values". +template +T Median(T* values, const size_t num_values) { + NANOBENCHMARK_CHECK(!values->empty()); + std::sort(values, values + num_values); + const size_t half = num_values / 2; + // Odd count: return middle + if (num_values % 2) { + return values[half]; + } + // Even count: return average of middle two. + return (values[half] + values[half - 1] + 1) / 2; +} + +// Returns a robust measure of variability. +template +T MedianAbsoluteDeviation(const T* values, const size_t num_values, + const T median) { + NANOBENCHMARK_CHECK(num_values != 0); + std::vector abs_deviations; + abs_deviations.reserve(num_values); + for (size_t i = 0; i < num_values; ++i) { + const int64_t abs = std::abs(static_cast(values[i]) - + static_cast(median)); + abs_deviations.push_back(static_cast(abs)); + } + return Median(abs_deviations.data(), num_values); +} + +} // namespace robust_statistics +} // namespace +namespace platform { +namespace { + +// Prevents the compiler from eliding the computations that led to "output". +template +inline void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC == 0 + // Works by indicating to the compiler that "output" is being read and + // modified. The +r constraint avoids unnecessary writes to memory, but only + // works for built-in types (typically FuncOutput). + asm volatile("" : "+r"(output) : : "memory"); +#else + // MSVC does not support inline assembly anymore (and never supported GCC's + // RTL constraints). Self-assignment with #pragma optimize("off") might be + // expected to prevent elision, but it does not with MSVC 2015. Type-punning + // with volatile pointers generates inefficient code on MSVC 2017. + static std::atomic dummy(T{}); + dummy.store(output, std::memory_order_relaxed); +#endif +} + +// Measures the actual current frequency of Ticks. We cannot rely on the nominal +// frequency encoded in x86 BrandString because it is misleading on M1 Rosetta, +// and not reported by AMD. CPUID 0x15 is also not yet widely supported. Also +// used on RISC-V and ARM64. +HWY_MAYBE_UNUSED double MeasureNominalClockRate() { + double max_ticks_per_sec = 0.0; + // Arbitrary, enough to ignore 2 outliers without excessive init time. + for (int rep = 0; rep < 3; ++rep) { + auto time0 = std::chrono::steady_clock::now(); + using Time = decltype(time0); + const timer::Ticks ticks0 = timer::Start(); + const Time time_min = time0 + std::chrono::milliseconds(10); + + Time time1; + timer::Ticks ticks1; + for (;;) { + time1 = std::chrono::steady_clock::now(); + // Ideally this would be Stop, but that requires RDTSCP on x86. To avoid + // another codepath, just use Start instead. now() presumably has its own + // fence-like behavior. + ticks1 = timer::Start(); // Do not use Stop, see comment above + if (time1 >= time_min) break; + } + + const double dticks = static_cast(ticks1 - ticks0); + std::chrono::duration> dtime = time1 - time0; + const double ticks_per_sec = dticks / dtime.count(); + max_ticks_per_sec = std::max(max_ticks_per_sec, ticks_per_sec); + } + return max_ticks_per_sec; +} + +#if HWY_ARCH_X86 + +void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif +} + +bool HasRDTSCP() { + uint32_t abcd[4]; + Cpuid(0x80000001U, 0, abcd); // Extended feature flags + return (abcd[3] & (1u << 27)) != 0; // RDTSCP +} + +std::string BrandString() { + char brand_string[49]; + std::array abcd; + + // Check if brand string is supported (it is on all reasonable Intel/AMD) + Cpuid(0x80000000U, 0, abcd.data()); + if (abcd[0] < 0x80000004U) { + return std::string(); + } + + for (size_t i = 0; i < 3; ++i) { + Cpuid(static_cast(0x80000002U + i), 0, abcd.data()); + CopyBytes(&abcd[0], brand_string + i * 16); // not same size + } + brand_string[48] = 0; + return brand_string; +} + +#endif // HWY_ARCH_X86 + +} // namespace + +HWY_DLLEXPORT double InvariantTicksPerSecond() { +#if HWY_ARCH_PPC && defined(__GLIBC__) + return static_cast(__ppc_get_timebase_freq()); +#elif HWY_ARCH_X86 || HWY_ARCH_RVV || (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) + // We assume the x86 TSC is invariant; it is on all recent Intel/AMD CPUs. + static const double freq = MeasureNominalClockRate(); + return freq; +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER freq; + (void)QueryPerformanceFrequency(&freq); + return static_cast(freq.QuadPart); +#elif defined(__APPLE__) + // https://developer.apple.com/library/mac/qa/qa1398/_index.html + mach_timebase_info_data_t timebase; + (void)mach_timebase_info(&timebase); + return static_cast(timebase.denom) / timebase.numer * 1E9; +#else + return 1E9; // Haiku and clock_gettime return nanoseconds. +#endif +} + +HWY_DLLEXPORT double Now() { + static const double mul = 1.0 / InvariantTicksPerSecond(); + return static_cast(timer::Start()) * mul; +} + +HWY_DLLEXPORT uint64_t TimerResolution() { +#if HWY_ARCH_X86 + bool can_use_stop = platform::HasRDTSCP(); +#else + constexpr bool can_use_stop = true; +#endif + + // Nested loop avoids exceeding stack/L1 capacity. + timer::Ticks repetitions[Params::kTimerSamples]; + for (size_t rep = 0; rep < Params::kTimerSamples; ++rep) { + timer::Ticks samples[Params::kTimerSamples]; + if (can_use_stop) { + for (size_t i = 0; i < Params::kTimerSamples; ++i) { + const timer::Ticks t0 = timer::Start(); + const timer::Ticks t1 = timer::Stop(); // we checked HasRDTSCP above + samples[i] = t1 - t0; + } + } else { + for (size_t i = 0; i < Params::kTimerSamples; ++i) { + const timer::Ticks t0 = timer::Start(); + const timer::Ticks t1 = timer::Start(); // do not use Stop, see above + samples[i] = t1 - t0; + } + } + repetitions[rep] = robust_statistics::Mode(samples); + } + return robust_statistics::Mode(repetitions); +} + +} // namespace platform +namespace { + +static const timer::Ticks timer_resolution = platform::TimerResolution(); + +// Estimates the expected value of "lambda" values with a variable number of +// samples until the variability "rel_mad" is less than "max_rel_mad". +template +timer::Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad, + const Params& p, const Lambda& lambda) { + // Choose initial samples_per_eval based on a single estimated duration. + timer::Ticks t0 = timer::Start(); + lambda(); + timer::Ticks t1 = timer::Stop(); // Caller checks HasRDTSCP + timer::Ticks est = t1 - t0; + static const double ticks_per_second = platform::InvariantTicksPerSecond(); + const size_t ticks_per_eval = + static_cast(ticks_per_second * p.seconds_per_eval); + size_t samples_per_eval = est == 0 + ? p.min_samples_per_eval + : static_cast(ticks_per_eval / est); + samples_per_eval = HWY_MAX(samples_per_eval, p.min_samples_per_eval); + + std::vector samples; + samples.reserve(1 + samples_per_eval); + samples.push_back(est); + + // Percentage is too strict for tiny differences, so also allow a small + // absolute "median absolute deviation". + const timer::Ticks max_abs_mad = (timer_resolution + 99) / 100; + *rel_mad = 0.0; // ensure initialized + + for (size_t eval = 0; eval < p.max_evals; ++eval, samples_per_eval *= 2) { + samples.reserve(samples.size() + samples_per_eval); + for (size_t i = 0; i < samples_per_eval; ++i) { + t0 = timer::Start(); + lambda(); + t1 = timer::Stop(); // Caller checks HasRDTSCP + samples.push_back(t1 - t0); + } + + if (samples.size() >= p.min_mode_samples) { + est = robust_statistics::Mode(samples.data(), samples.size()); + } else { + // For "few" (depends also on the variance) samples, Median is safer. + est = robust_statistics::Median(samples.data(), samples.size()); + } + NANOBENCHMARK_CHECK(est != 0); + + // Median absolute deviation (mad) is a robust measure of 'variability'. + const timer::Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation( + samples.data(), samples.size(), est); + *rel_mad = static_cast(abs_mad) / static_cast(est); + + if (*rel_mad <= max_rel_mad || abs_mad <= max_abs_mad) { + if (p.verbose) { + printf("%6" PRIu64 " samples => %5" PRIu64 " (abs_mad=%4" PRIu64 + ", rel_mad=%4.2f%%)\n", + static_cast(samples.size()), + static_cast(est), static_cast(abs_mad), + *rel_mad * 100.0); + } + return est; + } + } + + if (p.verbose) { + printf("WARNING: rel_mad=%4.2f%% still exceeds %4.2f%% after %6" PRIu64 + " samples.\n", + *rel_mad * 100.0, max_rel_mad * 100.0, + static_cast(samples.size())); + } + return est; +} + +using InputVec = std::vector; + +// Returns vector of unique input values. +InputVec UniqueInputs(const FuncInput* inputs, const size_t num_inputs) { + InputVec unique(inputs, inputs + num_inputs); + std::sort(unique.begin(), unique.end()); + unique.erase(std::unique(unique.begin(), unique.end()), unique.end()); + return unique; +} + +// Returns how often we need to call func for sufficient precision. +size_t NumSkip(const Func func, const uint8_t* arg, const InputVec& unique, + const Params& p) { + // Min elapsed ticks for any input. + timer::Ticks min_duration = ~timer::Ticks(0); + + for (const FuncInput input : unique) { + double rel_mad; + const timer::Ticks total = SampleUntilStable( + p.target_rel_mad, &rel_mad, p, + [func, arg, input]() { platform::PreventElision(func(arg, input)); }); + min_duration = HWY_MIN(min_duration, total - timer_resolution); + } + + // Number of repetitions required to reach the target resolution. + const size_t max_skip = p.precision_divisor; + // Number of repetitions given the estimated duration. + const size_t num_skip = + min_duration == 0 + ? 0 + : static_cast((max_skip + min_duration - 1) / min_duration); + if (p.verbose) { + printf("res=%" PRIu64 " max_skip=%" PRIu64 " min_dur=%" PRIu64 + " num_skip=%" PRIu64 "\n", + static_cast(timer_resolution), + static_cast(max_skip), static_cast(min_duration), + static_cast(num_skip)); + } + return num_skip; +} + +// Replicates inputs until we can omit "num_skip" occurrences of an input. +InputVec ReplicateInputs(const FuncInput* inputs, const size_t num_inputs, + const size_t num_unique, const size_t num_skip, + const Params& p) { + InputVec full; + if (num_unique == 1) { + full.assign(p.subset_ratio * num_skip, inputs[0]); + return full; + } + + full.reserve(p.subset_ratio * num_skip * num_inputs); + for (size_t i = 0; i < p.subset_ratio * num_skip; ++i) { + full.insert(full.end(), inputs, inputs + num_inputs); + } + std::mt19937 rng; + std::shuffle(full.begin(), full.end(), rng); + return full; +} + +// Copies the "full" to "subset" in the same order, but with "num_skip" +// randomly selected occurrences of "input_to_skip" removed. +void FillSubset(const InputVec& full, const FuncInput input_to_skip, + const size_t num_skip, InputVec* subset) { + const size_t count = + static_cast(std::count(full.begin(), full.end(), input_to_skip)); + // Generate num_skip random indices: which occurrence to skip. + std::vector omit(count); + std::iota(omit.begin(), omit.end(), 0); + // omit[] is the same on every call, but that's OK because they identify the + // Nth instance of input_to_skip, so the position within full[] differs. + std::mt19937 rng; + std::shuffle(omit.begin(), omit.end(), rng); + omit.resize(num_skip); + std::sort(omit.begin(), omit.end()); + + uint32_t occurrence = ~0u; // 0 after preincrement + size_t idx_omit = 0; // cursor within omit[] + size_t idx_subset = 0; // cursor within *subset + for (const FuncInput next : full) { + if (next == input_to_skip) { + ++occurrence; + // Haven't removed enough already + if (idx_omit < num_skip) { + // This one is up for removal + if (occurrence == omit[idx_omit]) { + ++idx_omit; + continue; + } + } + } + if (idx_subset < subset->size()) { + (*subset)[idx_subset++] = next; + } + } + NANOBENCHMARK_CHECK(idx_subset == subset->size()); + NANOBENCHMARK_CHECK(idx_omit == omit.size()); + NANOBENCHMARK_CHECK(occurrence == count - 1); +} + +// Returns total ticks elapsed for all inputs. +timer::Ticks TotalDuration(const Func func, const uint8_t* arg, + const InputVec* inputs, const Params& p, + double* max_rel_mad) { + double rel_mad; + const timer::Ticks duration = + SampleUntilStable(p.target_rel_mad, &rel_mad, p, [func, arg, inputs]() { + for (const FuncInput input : *inputs) { + platform::PreventElision(func(arg, input)); + } + }); + *max_rel_mad = HWY_MAX(*max_rel_mad, rel_mad); + return duration; +} + +// (Nearly) empty Func for measuring timer overhead/resolution. +HWY_NOINLINE FuncOutput EmptyFunc(const void* /*arg*/, const FuncInput input) { + return input; +} + +// Returns overhead of accessing inputs[] and calling a function; this will +// be deducted from future TotalDuration return values. +timer::Ticks Overhead(const uint8_t* arg, const InputVec* inputs, + const Params& p) { + double rel_mad; + // Zero tolerance because repeatability is crucial and EmptyFunc is fast. + return SampleUntilStable(0.0, &rel_mad, p, [arg, inputs]() { + for (const FuncInput input : *inputs) { + platform::PreventElision(EmptyFunc(arg, input)); + } + }); +} + +} // namespace + +HWY_DLLEXPORT int Unpredictable1() { return timer::Start() != ~0ULL; } + +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p) { + NANOBENCHMARK_CHECK(num_inputs != 0); + +#if HWY_ARCH_X86 + if (!platform::HasRDTSCP()) { + fprintf(stderr, "CPU '%s' does not support RDTSCP, skipping benchmark.\n", + platform::BrandString().c_str()); + return 0; + } +#endif + + const InputVec& unique = UniqueInputs(inputs, num_inputs); + + const size_t num_skip = NumSkip(func, arg, unique, p); // never 0 + if (num_skip == 0) return 0; // NumSkip already printed error message + // (slightly less work on x86 to cast from signed integer) + const float mul = 1.0f / static_cast(static_cast(num_skip)); + + const InputVec& full = + ReplicateInputs(inputs, num_inputs, unique.size(), num_skip, p); + InputVec subset(full.size() - num_skip); + + const timer::Ticks overhead = Overhead(arg, &full, p); + const timer::Ticks overhead_skip = Overhead(arg, &subset, p); + if (overhead < overhead_skip) { + fprintf(stderr, "Measurement failed: overhead %" PRIu64 " < %" PRIu64 "\n", + static_cast(overhead), + static_cast(overhead_skip)); + return 0; + } + + if (p.verbose) { + printf("#inputs=%5" PRIu64 ",%5" PRIu64 " overhead=%5" PRIu64 ",%5" PRIu64 + "\n", + static_cast(full.size()), + static_cast(subset.size()), + static_cast(overhead), + static_cast(overhead_skip)); + } + + double max_rel_mad = 0.0; + const timer::Ticks total = TotalDuration(func, arg, &full, p, &max_rel_mad); + + for (size_t i = 0; i < unique.size(); ++i) { + FillSubset(full, unique[i], num_skip, &subset); + const timer::Ticks total_skip = + TotalDuration(func, arg, &subset, p, &max_rel_mad); + + if (total < total_skip) { + fprintf(stderr, "Measurement failed: total %" PRIu64 " < %" PRIu64 "\n", + static_cast(total), static_cast(total_skip)); + return 0; + } + + const timer::Ticks duration = + (total - overhead) - (total_skip - overhead_skip); + results[i].input = unique[i]; + results[i].ticks = static_cast(duration) * mul; + results[i].variability = static_cast(max_rel_mad); + } + + return unique.size(); +} + +} // namespace hwy diff --git a/media/highway/src/hwy/nanobenchmark.h b/media/highway/src/hwy/nanobenchmark.h new file mode 100644 index 000000000..f0910b4b9 --- /dev/null +++ b/media/highway/src/hwy/nanobenchmark.h @@ -0,0 +1,194 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_NANOBENCHMARK_H_ +#define HIGHWAY_HWY_NANOBENCHMARK_H_ + +// Benchmarks functions of a single integer argument with realistic branch +// prediction hit rates. Uses a robust estimator to summarize the measurements. +// The precision is about 0.2%. +// +// Examples: see nanobenchmark_test.cc. +// +// Background: Microbenchmarks such as http://github.com/google/benchmark +// can measure elapsed times on the order of a microsecond. Shorter functions +// are typically measured by repeating them thousands of times and dividing +// the total elapsed time by this count. Unfortunately, repetition (especially +// with the same input parameter!) influences the runtime. In time-critical +// code, it is reasonable to expect warm instruction/data caches and TLBs, +// but a perfect record of which branches will be taken is unrealistic. +// Unless the application also repeatedly invokes the measured function with +// the same parameter, the benchmark is measuring something very different - +// a best-case result, almost as if the parameter were made a compile-time +// constant. This may lead to erroneous conclusions about branch-heavy +// algorithms outperforming branch-free alternatives. +// +// Our approach differs in three ways. Adding fences to the timer functions +// reduces variability due to instruction reordering, improving the timer +// resolution to about 40 CPU cycles. However, shorter functions must still +// be invoked repeatedly. For more realistic branch prediction performance, +// we vary the input parameter according to a user-specified distribution. +// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the +// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the +// central tendency of the measurement samples with the "half sample mode", +// which is more robust to outliers and skewed data than the mean or median. + +#include +#include + +#include "hwy/highway_export.h" + +// Enables sanity checks that verify correct operation at the cost of +// longer benchmark runs. +#ifndef NANOBENCHMARK_ENABLE_CHECKS +#define NANOBENCHMARK_ENABLE_CHECKS 0 +#endif + +#define NANOBENCHMARK_CHECK_ALWAYS(condition) \ + while (!(condition)) { \ + fprintf(stderr, "Nanobenchmark check failed at line %d\n", __LINE__); \ + abort(); \ + } + +#if NANOBENCHMARK_ENABLE_CHECKS +#define NANOBENCHMARK_CHECK(condition) NANOBENCHMARK_CHECK_ALWAYS(condition) +#else +#define NANOBENCHMARK_CHECK(condition) +#endif + +namespace hwy { + +namespace platform { + +// Returns tick rate, useful for converting measurements to seconds. Invariant +// means the tick counter frequency is independent of CPU throttling or sleep. +// This call may be expensive, callers should cache the result. +HWY_DLLEXPORT double InvariantTicksPerSecond(); + +// Returns current timestamp [in seconds] relative to an unspecified origin. +// Features: monotonic (no negative elapsed time), steady (unaffected by system +// time changes), high-resolution (on the order of microseconds). +HWY_DLLEXPORT double Now(); + +// Returns ticks elapsed in back to back timer calls, i.e. a function of the +// timer resolution (minimum measurable difference) and overhead. +// This call is expensive, callers should cache the result. +HWY_DLLEXPORT uint64_t TimerResolution(); + +} // namespace platform + +// Returns 1, but without the compiler knowing what the value is. This prevents +// optimizing out code. +HWY_DLLEXPORT int Unpredictable1(); + +// Input influencing the function being measured (e.g. number of bytes to copy). +using FuncInput = size_t; + +// "Proof of work" returned by Func to ensure the compiler does not elide it. +using FuncOutput = uint64_t; + +// Function to measure: either 1) a captureless lambda or function with two +// arguments or 2) a lambda with capture, in which case the first argument +// is reserved for use by MeasureClosure. +using Func = FuncOutput (*)(const void*, FuncInput); + +// Internal parameters that determine precision/resolution/measuring time. +struct Params { + // For measuring timer overhead/resolution. Used in a nested loop => + // quadratic time, acceptable because we know timer overhead is "low". + // constexpr because this is used to define array bounds. + static constexpr size_t kTimerSamples = 256; + + // Best-case precision, expressed as a divisor of the timer resolution. + // Larger => more calls to Func and higher precision. + size_t precision_divisor = 1024; + + // Ratio between full and subset input distribution sizes. Cannot be less + // than 2; larger values increase measurement time but more faithfully + // model the given input distribution. + size_t subset_ratio = 2; + + // Together with the estimated Func duration, determines how many times to + // call Func before checking the sample variability. Larger values increase + // measurement time, memory/cache use and precision. + double seconds_per_eval = 4E-3; + + // The minimum number of samples before estimating the central tendency. + size_t min_samples_per_eval = 7; + + // The mode is better than median for estimating the central tendency of + // skewed/fat-tailed distributions, but it requires sufficient samples + // relative to the width of half-ranges. + size_t min_mode_samples = 64; + + // Maximum permissible variability (= median absolute deviation / center). + double target_rel_mad = 0.002; + + // Abort after this many evals without reaching target_rel_mad. This + // prevents infinite loops. + size_t max_evals = 9; + + // Whether to print additional statistics to stdout. + bool verbose = true; +}; + +// Measurement result for each unique input. +struct Result { + FuncInput input; + + // Robust estimate (mode or median) of duration. + float ticks; + + // Measure of variability (median absolute deviation relative to "ticks"). + float variability; +}; + +// Precisely measures the number of ticks elapsed when calling "func" with the +// given inputs, shuffled to ensure realistic branch prediction hit rates. +// +// "func" returns a 'proof of work' to ensure its computations are not elided. +// "arg" is passed to Func, or reserved for internal use by MeasureClosure. +// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to +// "func". The values should be chosen to maximize coverage of "func". This +// represents a distribution, so a value's frequency should reflect its +// probability in the real application. Order does not matter; for example, a +// uniform distribution over [0, 4) could be represented as {3,0,2,1}. +// Returns how many Result were written to "results": one per unique input, or +// zero if the measurement failed (an error message goes to stderr). +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p = Params()); + +// Calls operator() of the given closure (lambda function). +template +static FuncOutput CallClosure(const Closure* f, const FuncInput input) { + return (*f)(input); +} + +// Same as Measure, except "closure" is typically a lambda function of +// FuncInput -> FuncOutput with a capture list. +template +static inline size_t MeasureClosure(const Closure& closure, + const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()) { + return Measure(reinterpret_cast(&CallClosure), + reinterpret_cast(&closure), inputs, num_inputs, + results, p); +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_NANOBENCHMARK_H_ diff --git a/media/highway/src/hwy/nanobenchmark_test.cc b/media/highway/src/hwy/nanobenchmark_test.cc new file mode 100644 index 000000000..0d153a14c --- /dev/null +++ b/media/highway/src/hwy/nanobenchmark_test.cc @@ -0,0 +1,94 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include +#include + +#include + +#include "hwy/tests/test_util-inl.h" + +namespace hwy { +namespace { + +// Governs duration of test; avoid timeout in debug builds. +#if HWY_IS_DEBUG_BUILD +constexpr size_t kMaxEvals = 3; +#else +constexpr size_t kMaxEvals = 4; +#endif + +FuncOutput Div(const void*, FuncInput in) { + // Here we're measuring the throughput because benchmark invocations are + // independent. Any dividend will do; the divisor is nonzero. + return 0xFFFFF / in; +} + +template +void MeasureDiv(const FuncInput (&inputs)[N]) { + printf("Measuring integer division (output on final two lines)\n"); + Result results[N]; + Params params; + params.max_evals = kMaxEvals; + const size_t num_results = Measure(&Div, nullptr, inputs, N, results, params); + for (size_t i = 0; i < num_results; ++i) { + printf("%5" PRIu64 ": %6.2f ticks; MAD=%4.2f%%\n", + static_cast(results[i].input), results[i].ticks, + results[i].variability * 100.0); + } +} + +std::mt19937 rng; + +// A function whose runtime depends on rng. +FuncOutput Random(const void* /*arg*/, FuncInput in) { + const size_t r = rng() & 0xF; + FuncOutput ret = static_cast(in); + for (size_t i = 0; i < r; ++i) { + ret /= ((rng() & 1) + 2); + } + return ret; +} + +// Ensure the measured variability is high. +template +void MeasureRandom(const FuncInput (&inputs)[N]) { + Result results[N]; + Params p; + p.max_evals = kMaxEvals; + p.verbose = false; + const size_t num_results = Measure(&Random, nullptr, inputs, N, results, p); + for (size_t i = 0; i < num_results; ++i) { + NANOBENCHMARK_CHECK(results[i].variability > 1E-3); + } +} + +TEST(NanobenchmarkTest, RunAll) { + const int unpredictable = Unpredictable1(); // == 1, unknown to compiler. + static const FuncInput inputs[] = {static_cast(unpredictable) + 2, + static_cast(unpredictable + 9)}; + + MeasureDiv(inputs); + MeasureRandom(inputs); +} + +} // namespace +} // namespace hwy diff --git a/media/highway/src/hwy/ops/arm_neon-inl.h b/media/highway/src/hwy/ops/arm_neon-inl.h new file mode 100644 index 000000000..f85fcf8f5 --- /dev/null +++ b/media/highway/src/hwy/ops/arm_neon-inl.h @@ -0,0 +1,6664 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit ARM64 NEON vectors and operations. +// External include guard in highway.h - see comment there. + +// ARM NEON intrinsics are documented at: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] + +#include +#include + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); + +// Must come after HWY_BEFORE_NAMESPACE so that the intrinsics are compiled with +// the same target attribute as our code, see #834. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#include +HWY_DIAGNOSTICS(pop) + +// Must come after arm_neon.h. +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { // for code folding and Raw128 + +// Macros used to define single and double function calls for multiple types +// for full and half vectors. These macros are undefined at the end of the file. + +// HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. +#define HWY_NEON_BUILD_TPL_1 +#define HWY_NEON_BUILD_TPL_2 +#define HWY_NEON_BUILD_TPL_3 + +// HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can +// extend it to int32x4x2_t packs. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_2(type, size) Vec128 +#define HWY_NEON_BUILD_RET_3(type, size) Vec128 + +// HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a +#define HWY_NEON_BUILD_PARAM_2(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128 a, const Vec128 b, \ + const Vec128 c + +// HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying +// function. +#define HWY_NEON_BUILD_ARG_1 a.raw +#define HWY_NEON_BUILD_ARG_2 a.raw, b.raw +#define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw + +// We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after +// the __VA_ARGS__ have been expanded. This allows "func" to be a macro on +// itself like with some of the library "functions" such as vshlq_u8. For +// example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as +// "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. +// Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro +// expects two arguments. +#define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) + +// Main macro definition that defines a single function for the given type and +// size of vector, using the underlying (prefix##infix##suffix) function and +// the template, return type, parameters and arguments defined by the "args" +// parameters passed here (see HWY_NEON_BUILD_* macros defined before). +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ + HWY_API HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ + name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ + return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ + HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ + } + +// The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function +// called "name" using the set of neon functions starting with the given +// "prefix" for all the variants of certain types, as specified next to each +// macro. For example, the prefix "vsub" can be used to define the operator- +// using args=2. + +// uint8_t +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) + +// int8_t +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) + +// uint16_t +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) + +// int16_t +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) + +// uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) + +// int32_t +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) + +// uint64_t +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) + +// int64_t +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) + +// float +#define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) + +// double +#if HWY_ARCH_ARM_A64 +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) +#endif + +// float and double + +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// Helper macros to define for more than one type. +// uint8_t, uint16_t and uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) + +// int8_t, int16_t and int32_t +#define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) + +// uint8_t, uint16_t, uint32_t and uint64_t +#define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) + +// int8_t, int16_t, int32_t and int64_t +#define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) + +// All int*_t and uint*_t up to 64 +#define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) + +// All previous types. +#define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF81632(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) + +// Emulation of some intrinsics on armv7. +#if HWY_ARCH_ARM_V7 +#define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] +#define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] +#define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] +#define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] +#define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] +#define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] +#define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] +#define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] +#define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] +#define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] +#define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] +#define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] +#define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] +#define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] +#define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] +#define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] +#define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] +#define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] +#define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] +#define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] +#define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] +#define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] +#define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] +#define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] +#define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] +#define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] +#define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] +#define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] +#define vzip1_s8(x, y) vzip_s8(x, y).val[0] +#define vzip1_u8(x, y) vzip_u8(x, y).val[0] +#define vzip1_s16(x, y) vzip_s16(x, y).val[0] +#define vzip1_u16(x, y) vzip_u16(x, y).val[0] +#define vzip1_f32(x, y) vzip_f32(x, y).val[0] +#define vzip1_u32(x, y) vzip_u32(x, y).val[0] +#define vzip1_s32(x, y) vzip_s32(x, y).val[0] +#define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] +#define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] +#define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] +#define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] +#define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] +#define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] +#define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] +#define vzip2_s8(x, y) vzip_s8(x, y).val[1] +#define vzip2_u8(x, y) vzip_u8(x, y).val[1] +#define vzip2_s16(x, y) vzip_s16(x, y).val[1] +#define vzip2_u16(x, y) vzip_u16(x, y).val[1] +#define vzip2_s32(x, y) vzip_s32(x, y).val[1] +#define vzip2_u32(x, y) vzip_u32(x, y).val[1] +#define vzip2_f32(x, y) vzip_f32(x, y).val[1] +#define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] +#define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] +#define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] +#define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] +#define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] +#define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] +#define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] +#endif + +// Wrappers over uint8x16x2_t etc. so we can define StoreInterleaved2 overloads +// for all vector types, even those (bfloat16_t) where the underlying vector is +// the same as others (uint16_t). +template +struct Tuple2; +template +struct Tuple3; +template +struct Tuple4; + +template <> +struct Tuple2 { + uint8x16x2_t raw; +}; +template +struct Tuple2 { + uint8x8x2_t raw; +}; +template <> +struct Tuple2 { + int8x16x2_t raw; +}; +template +struct Tuple2 { + int8x8x2_t raw; +}; +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; +template <> +struct Tuple2 { + int16x8x2_t raw; +}; +template +struct Tuple2 { + int16x4x2_t raw; +}; +template <> +struct Tuple2 { + uint32x4x2_t raw; +}; +template +struct Tuple2 { + uint32x2x2_t raw; +}; +template <> +struct Tuple2 { + int32x4x2_t raw; +}; +template +struct Tuple2 { + int32x2x2_t raw; +}; +template <> +struct Tuple2 { + uint64x2x2_t raw; +}; +template +struct Tuple2 { + uint64x1x2_t raw; +}; +template <> +struct Tuple2 { + int64x2x2_t raw; +}; +template +struct Tuple2 { + int64x1x2_t raw; +}; + +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; + +template <> +struct Tuple2 { + float32x4x2_t raw; +}; +template +struct Tuple2 { + float32x2x2_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple2 { + float64x2x2_t raw; +}; +template +struct Tuple2 { + float64x1x2_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <> +struct Tuple3 { + uint8x16x3_t raw; +}; +template +struct Tuple3 { + uint8x8x3_t raw; +}; +template <> +struct Tuple3 { + int8x16x3_t raw; +}; +template +struct Tuple3 { + int8x8x3_t raw; +}; +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; +template <> +struct Tuple3 { + int16x8x3_t raw; +}; +template +struct Tuple3 { + int16x4x3_t raw; +}; +template <> +struct Tuple3 { + uint32x4x3_t raw; +}; +template +struct Tuple3 { + uint32x2x3_t raw; +}; +template <> +struct Tuple3 { + int32x4x3_t raw; +}; +template +struct Tuple3 { + int32x2x3_t raw; +}; +template <> +struct Tuple3 { + uint64x2x3_t raw; +}; +template +struct Tuple3 { + uint64x1x3_t raw; +}; +template <> +struct Tuple3 { + int64x2x3_t raw; +}; +template +struct Tuple3 { + int64x1x3_t raw; +}; + +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; + +template <> +struct Tuple3 { + float32x4x3_t raw; +}; +template +struct Tuple3 { + float32x2x3_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple3 { + float64x2x3_t raw; +}; +template +struct Tuple3 { + float64x1x3_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <> +struct Tuple4 { + uint8x16x4_t raw; +}; +template +struct Tuple4 { + uint8x8x4_t raw; +}; +template <> +struct Tuple4 { + int8x16x4_t raw; +}; +template +struct Tuple4 { + int8x8x4_t raw; +}; +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; +template <> +struct Tuple4 { + int16x8x4_t raw; +}; +template +struct Tuple4 { + int16x4x4_t raw; +}; +template <> +struct Tuple4 { + uint32x4x4_t raw; +}; +template +struct Tuple4 { + uint32x2x4_t raw; +}; +template <> +struct Tuple4 { + int32x4x4_t raw; +}; +template +struct Tuple4 { + int32x2x4_t raw; +}; +template <> +struct Tuple4 { + uint64x2x4_t raw; +}; +template +struct Tuple4 { + uint64x1x4_t raw; +}; +template <> +struct Tuple4 { + int64x2x4_t raw; +}; +template +struct Tuple4 { + int64x1x4_t raw; +}; + +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; + +template <> +struct Tuple4 { + float32x4x4_t raw; +}; +template +struct Tuple4 { + float32x2x4_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple4 { + float64x2x4_t raw; +}; +template +struct Tuple4 { + float64x1x4_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template +struct Raw128; + +// 128 +template <> +struct Raw128 { + using type = uint8x16_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = uint32x4_t; +}; + +template <> +struct Raw128 { + using type = uint64x2_t; +}; + +template <> +struct Raw128 { + using type = int8x16_t; +}; + +template <> +struct Raw128 { + using type = int16x8_t; +}; + +template <> +struct Raw128 { + using type = int32x4_t; +}; + +template <> +struct Raw128 { + using type = int64x2_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = float32x4_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128 { + using type = float64x2_t; +}; +#endif + +// 64 +template <> +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint32x2_t; +}; + +template <> +struct Raw128 { + using type = uint64x1_t; +}; + +template <> +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = int32x2_t; +}; + +template <> +struct Raw128 { + using type = int64x1_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = float32x2_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128 { + using type = float64x1_t; +}; +#endif + +// 32 (same as 64) +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +// 16 (same as 64) +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +// 8 (same as 64) +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + HWY_INLINE Vec128() {} + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +// FF..FF or 0. +template +class Mask128 { + // ARM C Language Extensions return and expect unsigned type. + using Raw = typename detail::Raw128, N>::type; + + public: + HWY_INLINE Mask128() {} + Mask128(const Mask128&) = default; + Mask128& operator=(const Mask128&) = default; + HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} + + Raw raw; +}; + +template +using Mask64 = Mask128; + +namespace detail { + +// Deduce Simd from Vec128 +struct DeduceD { + template + Simd operator()(Vec128) const { + return Simd(); + } +}; + +} // namespace detail + +template +using DFromV = decltype(detail::DeduceD()(V())); + +template +using TFromV = TFromD>; + +// ------------------------------ BitCast + +namespace detail { + +// Converts from Vec128 to Vec128 using the +// vreinterpret*_u8_*() set of functions. +#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ + Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw + +// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return v; +} + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) + +// Special cases for [b]float16_t, which have the same Raw as uint16_t. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} + +#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return v; +} + +// 64-bit or less: + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s8_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_f32_u8(v.raw)); +} +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_u64_u8(v.raw)); +} +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_s64_u8(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec64 BitCastFromByte(Full64 /* tag */, + Vec128 v) { + return Vec64(vreinterpret_f64_u8(v.raw)); +} +#endif + +// 128-bit full: + +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s8_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u16_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s16_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_f32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u64_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s64_u8(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_f64_u8(v.raw)); +} +#endif + +// Special cases for [b]float16_t, which have the same Raw as uint16_t. +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(BitCastFromByte(Simd(), v).raw); +} +template +HWY_INLINE Vec128 BitCastFromByte( + Simd /* tag */, Vec128 v) { + return Vec128(BitCastFromByte(Simd(), v).raw); +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns a vector with all lanes set to "t". +#define HWY_NEON_BUILD_TPL_HWY_SET1 +#define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_SET1(type, size) \ + Simd /* tag */, const type##_t t +#define HWY_NEON_BUILD_ARG_HWY_SET1 t + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) + +#undef HWY_NEON_BUILD_TPL_HWY_SET1 +#undef HWY_NEON_BUILD_RET_HWY_SET1 +#undef HWY_NEON_BUILD_PARAM_HWY_SET1 +#undef HWY_NEON_BUILD_ARG_HWY_SET1 + +// Returns an all-zero vector. +template +HWY_API Vec128 Zero(Simd d) { + return Set(d, 0); +} + +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128(Zero(Simd()).raw); +} + +template +using VFromD = decltype(Zero(D())); + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd /*d*/) { + typename detail::Raw128::type a; + return Vec128(a); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ GetLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_GET template +#define HWY_NEON_BUILD_RET_HWY_GET(type, size) type##_t +#define HWY_NEON_BUILD_PARAM_HWY_GET(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) + +#undef HWY_NEON_BUILD_TPL_HWY_GET +#undef HWY_NEON_BUILD_RET_HWY_GET +#undef HWY_NEON_BUILD_PARAM_HWY_GET +#undef HWY_NEON_BUILD_ARG_HWY_GET + +} // namespace detail + +template +HWY_API TFromV GetLane(const V v) { + return detail::GetLane<0>(v); +} + +// ------------------------------ ExtractLane + +// Requires one overload per vector length because GetLane<3> is a compile error +// if v is a uint32x2_t. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::GetLane<0>(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + case 8: + return detail::GetLane<8>(v); + case 9: + return detail::GetLane<9>(v); + case 10: + return detail::GetLane<10>(v); + case 11: + return detail::GetLane<11>(v); + case 12: + return detail::GetLane<12>(v); + case 13: + return detail::GetLane<13>(v); + case 14: + return detail::GetLane<14>(v); + case 15: + return detail::GetLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_INSERT template +#define HWY_NEON_BUILD_RET_HWY_INSERT(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_INSERT(type, size) \ + Vec128 v, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) + +#undef HWY_NEON_BUILD_TPL_HWY_INSERT +#undef HWY_NEON_BUILD_RET_HWY_INSERT +#undef HWY_NEON_BUILD_PARAM_HWY_INSERT +#undef HWY_NEON_BUILD_ARG_HWY_INSERT + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) + +// ------------------------------ Subtraction +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) + +// ------------------------------ SumsOf8 + +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); +} + +// ------------------------------ SaturatedAdd +// Only defined for uint8_t, uint16_t and their signed versions, as in other +// architectures. + +// Returns a + b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedAdd, vqadd, _, 2) + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedSub, vqsub, _, 2) + +// Not part of API, used in implementation. +namespace detail { +HWY_NEON_DEF_FUNCTION_UINT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_64(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_64(SaturatedSub, vqsub, _, 2) +} // namespace detail + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 +HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) + +// ------------------------------ Neg + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below + +HWY_API Vec64 Neg(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vneg_s64(v.raw)); +#else + return Zero(Full64()) - v; +#endif +} + +HWY_API Vec128 Neg(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vnegq_s64(v.raw)); +#else + return Zero(Full128()) - v; +#endif +} + +// ------------------------------ ShiftLeft + +// Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + template \ + HWY_API Vec128 name(const Vec128 v) { \ + return kBits == 0 ? v \ + : Vec128(HWY_NEON_EVAL( \ + prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ + } + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) + +HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ RotateRight (ShiftRight, Or) + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// NOTE: vxarq_u64 can be applied to uint64_t, but we do not yet have a +// mechanism for checking for extensions to ARMv8. + +// ------------------------------ Shl + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); +} +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s8(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s8(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s16(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s16(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s32(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s32(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s64(v.raw, bits.raw)); +} +HWY_API Vec64 operator<<(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_s64(v.raw, bits.raw)); +} + +// ------------------------------ Shr (Neg) + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int8x16_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int8x8_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int16x8_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int16x4_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int32x4_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int32x2_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int64x2_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + const int64x1_t neg_bits = Neg(BitCast(Full64(), bits)).raw; + return Vec64(vshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64(vshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ ShiftLeftSame (Shl) + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { + return v << Set(Simd(), static_cast(bits)); +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { + return v >> Set(Simd(), static_cast(bits)); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_u16(a.raw, b.raw)); +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_u32(a.raw, b.raw)); +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_u16(a.raw, b.raw)); +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_u32(a.raw, b.raw)); +} + +// Signed +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_s16(a.raw, b.raw)); +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_s32(a.raw, b.raw)); +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_s16(a.raw, b.raw)); +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_s32(a.raw, b.raw)); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); +#if HWY_ARCH_ARM_A64 + int32x4_t rhi = vmull_high_s16(a.raw, b.raw); +#else + int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); +#endif + return Vec128( + vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); +} +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); +#if HWY_ARCH_ARM_A64 + uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); +#else + uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); +#endif + return Vec128( + vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); +} + +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); + return Vec128(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); + return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); +} + +HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { + return Vec128(vqrdmulhq_s16(a.raw, b.raw)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128(vqrdmulh_s16(a.raw, b.raw)); +} + +// ------------------------------ Floating-point mul / div + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) + +// Approximate reciprocal +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128(vrecpeq_f32(v.raw)); +} +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128(vrecpe_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) +#else +// Not defined on armv7: approximate +namespace detail { + +HWY_INLINE Vec128 ReciprocalNewtonRaphsonStep( + const Vec128 recip, const Vec128 divisor) { + return Vec128(vrecpsq_f32(recip.raw, divisor.raw)); +} +template +HWY_INLINE Vec128 ReciprocalNewtonRaphsonStep( + const Vec128 recip, Vec128 divisor) { + return Vec128(vrecps_f32(recip.raw, divisor.raw)); +} + +} // namespace detail + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + auto x = ApproximateReciprocal(b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + return a * x; +} +#endif + +// ------------------------------ Absolute value of difference. + +HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { + return Vec128(vabdq_f32(a.raw, b.raw)); +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Vec128(vabd_f32(a.raw, b.raw)); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns add + mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfma_f32(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmaq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return mul * x + add; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 MulAdd(const Vec64 mul, const Vec64 x, + const Vec64 add) { + return Vec64(vfma_f64(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 MulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmaq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns add - mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfms_f32(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 NegMulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmsq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return add - mul * x; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 NegMulAdd(const Vec64 mul, const Vec64 x, + const Vec64 add) { + return Vec64(vfms_f64(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfmsq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} + +#if HWY_ARCH_ARM_A64 +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} +#endif + +// ------------------------------ Floating-point square root (IfThenZeroElse) + +// Approximate reciprocal square root +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128(vrsqrteq_f32(v.raw)); +} +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128(vrsqrte_f32(v.raw)); +} + +// Full precision square root +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) +#else +namespace detail { + +HWY_INLINE Vec128 ReciprocalSqrtStep(const Vec128 root, + const Vec128 recip) { + return Vec128(vrsqrtsq_f32(root.raw, recip.raw)); +} +template +HWY_INLINE Vec128 ReciprocalSqrtStep(const Vec128 root, + Vec128 recip) { + return Vec128(vrsqrts_f32(root.raw, recip.raw)); +} + +} // namespace detail + +// Not defined on armv7: approximate +template +HWY_API Vec128 Sqrt(const Vec128 v) { + auto recip = ApproximateReciprocalSqrt(v); + + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + + const auto root = v * recip; + return IfThenZeroElse(v == Zero(Simd()), root); +} +#endif + +// ================================================== LOGICAL + +// ------------------------------ Not + +// There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. +template +HWY_API Vec128 Not(const Vec128 v) { + const Full128 d; + const Repartition d8; + return BitCast(d, Vec128(vmvnq_u8(BitCast(d8, v).raw))); +} +template +HWY_API Vec128 Not(const Vec128 v) { + const Simd d; + const Repartition d8; + using V8 = decltype(Zero(d8)); + return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); +} + +// ------------------------------ And +HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 And(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) & BitCast(du, b)); +} + +// ------------------------------ AndNot + +namespace detail { +// reversed_andnot returns a & ~b. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) +} // namespace detail + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return detail::reversed_andnot(mask, not_mask); +} + +// Uses the u32/64 defined above. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + VFromD ret = + detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(d, ret); +} + +// ------------------------------ Or + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) | BitCast(du, b)); +} + +// ------------------------------ Xor + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { + const Full128 d8; + return Vec128(vcntq_u8(BitCast(d8, v).raw)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + const Simd d8; + return Vec128(vcnt_u8(BitCast(d8, v).raw)); +} + +// ARM lacks popcount for lane sizes > 1, so take pairwise sums of the bytes. +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u8(bytes)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u8(bytes)); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u16(vpaddlq_u8(bytes))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u16(vpaddl_u8(bytes))); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(bytes)))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +// ================================================== SIGN + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s8(v.raw)); +} +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s16(v.raw)); +} +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s32(v.raw)); +} +// i64 is implemented after BroadcastSignBit. +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_f32(v.raw)); +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s8(v.raw)); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s16(v.raw)); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s32(v.raw)); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabs_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_f64(v.raw)); +} + +HWY_API Vec64 Abs(const Vec64 v) { + return Vec64(vabs_f64(v.raw)); +} +#endif + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} + +// ================================================== MASK + +// ------------------------------ To/from vector + +// Mask and Vec have the same representation (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const Simd, N, 0> du; + return Mask128(BitCast(du, v).raw); +} + +template +HWY_API Vec128 VecFromMask(Simd d, const Mask128 v) { + return BitCast(d, Vec128, N>(v.raw)); +} + +// ------------------------------ RebindMask + +template +HWY_API Mask128 RebindMask(Simd dto, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(dto, VecFromMask(Simd(), m))); +} + +// ------------------------------ IfThenElse(mask, yes, no) = mask ? b : a. + +#define HWY_NEON_BUILD_TPL_HWY_IF +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128 mask, const Vec128 yes, \ + const Vec128 no +#define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) + +#undef HWY_NEON_BUILD_TPL_HWY_IF +#undef HWY_NEON_BUILD_RET_HWY_IF +#undef HWY_NEON_BUILD_PARAM_HWY_IF +#undef HWY_NEON_BUILD_ARG_HWY_IF + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(const Mask128 mask, + const Vec128 yes) { + return yes & VecFromMask(Simd(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(const Mask128 mask, + const Vec128 no) { + return AndNot(VecFromMask(Simd(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Simd d; + const RebindToSigned di; + + Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + const auto zero = Zero(Simd()); + return Max(zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// ------------------------------ Shuffle2301 (for i64 compares) + +// Swap 32-bit halves in 64-bits +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_u32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_s32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_f32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_u32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_s32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_f32(v.raw)); +} + +#define HWY_NEON_BUILD_TPL_HWY_COMPARE +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw + +// ------------------------------ Equality +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) +#else +// No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) +#endif + +// ------------------------------ Strict inequality (signed, float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<, vclt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<, vclt, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) + +// ------------------------------ Weak inequality (float) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) + +#undef HWY_NEON_BUILD_TPL_HWY_COMPARE +#undef HWY_NEON_BUILD_RET_HWY_COMPARE +#undef HWY_NEON_BUILD_PARAM_HWY_COMPARE +#undef HWY_NEON_BUILD_ARG_HWY_COMPARE + +// ------------------------------ ARMv7 i64 compare (Shuffle2301, Eq) + +#if HWY_ARCH_ARM_V7 + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const int64x2_t sub = vqsubq_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128(sub))); +} +HWY_API Mask128 operator<(const Vec64 a, + const Vec64 b) { + const int64x1_t sub = vqsub_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec64(sub))); +} + +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = AndNot(a, b) | AndNot(a ^ b, a - b); + return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); +} + +#endif + +// ------------------------------ operator!= (operator==) + +// Customize HWY_NEON_DEF_FUNCTION to call 2 functions. +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +// This cannot have _any_ template argument (in x86_128 we can at least have N +// as an argument), otherwise it is not more specialized than rewritten +// operator== in C++20, leading to compile errors. +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_API Mask128 name(Vec128 a, \ + Vec128 b) { \ + return Not(a == b); \ + } + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator!=, ignored, ignored, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return operator<(b, a); +} +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return operator<=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ------------------------------ TestBit (Eq) + +#define HWY_NEON_BUILD_TPL_HWY_TESTBIT +#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ + Vec128 v, Vec128 bit +#define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(TestBit, vtst, _, HWY_TESTBIT) +#else +// No 64-bit versions on armv7 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) + +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} + +#endif +#undef HWY_NEON_BUILD_TPL_HWY_TESTBIT +#undef HWY_NEON_BUILD_RET_HWY_TESTBIT +#undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT +#undef HWY_NEON_BUILD_ARG_HWY_TESTBIT + +// ------------------------------ Abs i64 (IfThenElse, BroadcastSignBit) +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vabsq_s64(v.raw)); +#else + const auto zero = Zero(Full128()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} +HWY_API Vec64 Abs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vabs_s64(v.raw)); +#else + const auto zero = Zero(Full64()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +// ------------------------------ Min (IfThenElse, BroadcastSignBit) + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, a) - BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const Vec128 sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); +#endif +} + +// Float: IEEE minimumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vminnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) +#endif + +// ------------------------------ Max (IfThenElse, BroadcastSignBit) + +// Unsigned (no u64) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, b) + BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed (no i64) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const Vec128 sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); +#endif +} + +// Float: IEEE maximumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmaxnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) +#endif + +// ================================================== MEMORY + +// ------------------------------ Load 128 + +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u8(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u16(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u32(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const uint64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u64(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s8(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s16(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s32(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const int64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s64(unaligned)); +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f32(unaligned)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f64(unaligned)); +} +#endif + +// ------------------------------ Load 64 + +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint8_t* HWY_RESTRICT p) { + return Vec64(vld1_u8(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint16_t* HWY_RESTRICT p) { + return Vec64(vld1_u16(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint32_t* HWY_RESTRICT p) { + return Vec64(vld1_u32(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const uint64_t* HWY_RESTRICT p) { + return Vec64(vld1_u64(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int8_t* HWY_RESTRICT p) { + return Vec64(vld1_s8(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int16_t* HWY_RESTRICT p) { + return Vec64(vld1_s16(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int32_t* HWY_RESTRICT p) { + return Vec64(vld1_s32(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const int64_t* HWY_RESTRICT p) { + return Vec64(vld1_s64(p)); +} +HWY_API Vec64 LoadU(Full64 /* tag */, + const float* HWY_RESTRICT p) { + return Vec64(vld1_f32(p)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 LoadU(Full64 /* tag */, + const double* HWY_RESTRICT p) { + return Vec64(vld1_f64(p)); +} +#endif +// ------------------------------ Load 32 + +// Actual 32-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +HWY_API Vec32 LoadU(Full32 /*tag*/, + const uint32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_u32(p)); +} +HWY_API Vec32 LoadU(Full32 /*tag*/, + const int32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_s32(p)); +} +HWY_API Vec32 LoadU(Full32 /*tag*/, const float* HWY_RESTRICT p) { + return Vec32(vld1_dup_f32(p)); +} + +template +HWY_API Vec32 LoadU(Full32 d, const T* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} + +// ------------------------------ Load 16 + +// Actual 16-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +HWY_API Vec128 LoadU(Simd /*tag*/, + const uint16_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_u16(p)); +} +HWY_API Vec128 LoadU(Simd /*tag*/, + const int16_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_s16(p)); +} + +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + const Repartition d16; + uint16_t buf; + CopyBytes<2>(p, &buf); + return BitCast(d, LoadU(d16, &buf)); +} + +// ------------------------------ Load 8 + +HWY_API Vec128 LoadU(Simd, + const uint8_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_u8(p)); +} + +HWY_API Vec128 LoadU(Simd, + const int8_t* HWY_RESTRICT p) { + return Vec128(vld1_dup_s8(p)); +} + +// [b]float16_t use the same Raw as uint16_t, so forward to that. +template +HWY_API Vec128 LoadU(Simd d, + const float16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return Vec128(LoadU(du16, pu16).raw); +} +template +HWY_API Vec128 LoadU(Simd d, + const bfloat16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return Vec128(LoadU(du16, pu16).raw); +} + +// On ARM, Load is the same as LoadU. +template +HWY_API Vec128 Load(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, + const T* const HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store 128 + +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint8_t* HWY_RESTRICT unaligned) { + vst1q_u8(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint16_t* HWY_RESTRICT unaligned) { + vst1q_u16(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint32_t* HWY_RESTRICT unaligned) { + vst1q_u32(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + uint64_t* HWY_RESTRICT unaligned) { + vst1q_u64(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int8_t* HWY_RESTRICT unaligned) { + vst1q_s8(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int16_t* HWY_RESTRICT unaligned) { + vst1q_s16(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int32_t* HWY_RESTRICT unaligned) { + vst1q_s32(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + int64_t* HWY_RESTRICT unaligned) { + vst1q_s64(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT unaligned) { + vst1q_f32(unaligned, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT unaligned) { + vst1q_f64(unaligned, v.raw); +} +#endif + +// ------------------------------ Store 64 + +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint8_t* HWY_RESTRICT p) { + vst1_u8(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint16_t* HWY_RESTRICT p) { + vst1_u16(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint32_t* HWY_RESTRICT p) { + vst1_u32(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + uint64_t* HWY_RESTRICT p) { + vst1_u64(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int8_t* HWY_RESTRICT p) { + vst1_s8(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int16_t* HWY_RESTRICT p) { + vst1_s16(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int32_t* HWY_RESTRICT p) { + vst1_s32(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + int64_t* HWY_RESTRICT p) { + vst1_s64(p, v.raw); +} +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + float* HWY_RESTRICT p) { + vst1_f32(p, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_API void StoreU(const Vec64 v, Full64 /* tag */, + double* HWY_RESTRICT p) { + vst1_f64(p, v.raw); +} +#endif + +// ------------------------------ Store 32 + +HWY_API void StoreU(const Vec32 v, Full32, + uint32_t* HWY_RESTRICT p) { + vst1_lane_u32(p, v.raw, 0); +} +HWY_API void StoreU(const Vec32 v, Full32, + int32_t* HWY_RESTRICT p) { + vst1_lane_s32(p, v.raw, 0); +} +HWY_API void StoreU(const Vec32 v, Full32, + float* HWY_RESTRICT p) { + vst1_lane_f32(p, v.raw, 0); +} + +template +HWY_API void StoreU(const Vec32 v, Full32 d, T* HWY_RESTRICT p) { + const Repartition d32; + const uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} + +// ------------------------------ Store 16 + +HWY_API void StoreU(const Vec128 v, Simd, + uint16_t* HWY_RESTRICT p) { + vst1_lane_u16(p, v.raw, 0); +} +HWY_API void StoreU(const Vec128 v, Simd, + int16_t* HWY_RESTRICT p) { + vst1_lane_s16(p, v.raw, 0); +} + +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + const Repartition d16; + const uint16_t buf = GetLane(BitCast(d16, v)); + CopyBytes<2>(&buf, p); +} + +// ------------------------------ Store 8 + +HWY_API void StoreU(const Vec128 v, Simd, + uint8_t* HWY_RESTRICT p) { + vst1_lane_u8(p, v.raw, 0); +} +HWY_API void StoreU(const Vec128 v, Simd, + int8_t* HWY_RESTRICT p) { + vst1_lane_s8(p, v.raw, 0); +} + +// [b]float16_t use the same Raw as uint16_t, so forward to that. +template +HWY_API void StoreU(Vec128 v, Simd d, + float16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return StoreU(Vec128(v.raw), du16, pu16); +} +template +HWY_API void StoreU(Vec128 v, Simd d, + bfloat16_t* HWY_RESTRICT p) { + const RebindToUnsigned du16; + const auto pu16 = reinterpret_cast(p); + return StoreU(Vec128(v.raw), du16, pu16); +} + +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// On ARM, Store is the same as StoreU. +template +HWY_API void Store(Vec128 v, Simd d, T* HWY_RESTRICT aligned) { + StoreU(v, d, aligned); +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + // Treat as unsigned so that we correctly support float16. + const RebindToUnsigned du; + const auto blended = + IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); + StoreU(BitCast(d, blended), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(const Vec128 v, Simd d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend to full vector. +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_u8(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vmovl_u16(vget_low_u16(a))); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_u16(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_u32(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { + return BitCast(d, Vec128(vmovl_u8(v.raw))); +} +HWY_API Vec128 PromoteTo(Full128 d, const Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return BitCast(d, Vec128(vmovl_u16(vget_low_u16(a)))); +} +HWY_API Vec128 PromoteTo(Full128 d, const Vec64 v) { + return BitCast(d, Vec128(vmovl_u16(v.raw))); +} + +// Unsigned: zero-extend to half vector. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u16(vmovl_u8(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vget_low_u32(vmovl_u16(vget_low_u16(a)))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u32(vmovl_u16(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u64(vmovl_u32(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd d, + const Vec128 v) { + return BitCast(d, Vec128(vget_low_u16(vmovl_u8(v.raw)))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + uint32x4_t b = vmovl_u16(vget_low_u16(a)); + return Vec128(vget_low_s32(vreinterpretq_s32_u32(b))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint32x4_t a = vmovl_u16(v.raw); + return Vec128(vget_low_s32(vreinterpretq_s32_u32(a))); +} + +// Signed: replicate sign bit to full vector. +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_s8(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec32 v) { + int16x8_t a = vmovl_s8(v.raw); + return Vec128(vmovl_s16(vget_low_s16(a))); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_s16(v.raw)); +} +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vmovl_s32(v.raw)); +} + +// Signed: replicate sign bit to half vector. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s16(vmovl_s8(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + int16x8_t a = vmovl_s8(v.raw); + int32x4_t b = vmovl_s16(vget_low_s16(a)); + return Vec128(vget_low_s32(b)); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s32(vmovl_s16(v.raw))); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s64(vmovl_s32(v.raw))); +} + +#if __ARM_FP & 2 + +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); + return Vec128(f32); +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); + return Vec128(vget_low_f32(f32)); +} + +#else + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +#endif + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + return Vec128(vcvt_f64_f32(v.raw)); +} + +HWY_API Vec64 PromoteTo(Full64 /* tag */, + const Vec32 v) { + return Vec64(vget_low_f64(vcvt_f64_f32(v.raw))); +} + +HWY_API Vec128 PromoteTo(Full128 /* tag */, + const Vec64 v) { + const int64x2_t i64 = vmovl_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +HWY_API Vec64 PromoteTo(Full64 /* tag */, + const Vec32 v) { + const int64x1_t i64 = vget_low_s64(vmovl_s32(v.raw)); + return Vec64(vcvt_f64_s64(i64)); +} + +#endif + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +// From full vector to half or quarter +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovun_s32(v.raw)); +} +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovn_s32(v.raw)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec128 v) { + const uint16x4_t a = vqmovun_s32(v.raw); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); +} +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovun_s16(v.raw)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec128 v) { + const int16x4_t a = vqmovn_s32(v.raw); + return Vec32(vqmovn_s16(vcombine_s16(a, a))); +} +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec64(vqmovn_s16(v.raw)); +} + +// From half vector to partial half +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); + return Vec128(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); + return Vec128(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s16(vcombine_s16(v.raw, v.raw))); +} + +#if __ARM_FP & 2 + +HWY_API Vec128 DemoteTo(Full64 /* tag */, + const Vec128 v) { + return Vec128{vreinterpret_u16_f16(vcvt_f16_f32(v.raw))}; +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const float16x4_t f16 = vcvt_f16_f32(vcombine_f32(v.raw, v.raw)); + return Vec128(vreinterpret_u16_f16(f16)); +} + +#else + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128(DemoteTo(du16, bits16).raw); +} + +#endif + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec64 DemoteTo(Full64 /* tag */, const Vec128 v) { + return Vec64(vcvt_f32_f64(v.raw)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, const Vec64 v) { + return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +} + +HWY_API Vec64 DemoteTo(Full64 /* tag */, + const Vec128 v) { + const int64x2_t i64 = vcvtq_s64_f64(v.raw); + return Vec64(vqmovn_s64(i64)); +} +HWY_API Vec32 DemoteTo(Full32 /* tag */, + const Vec64 v) { + const int64x1_t i64 = vcvt_s64_f64(v.raw); + // There is no i64x1 -> i32x1 narrow, so expand to int64x2_t first. + const int64x2_t i64x2 = vcombine_s64(i64, i64); + return Vec32(vqmovn_s64(i64x2)); +} + +#endif + +HWY_API Vec32 U8FromU32(const Vec128 v) { + const uint8x16_t org_v = detail::BitCastToByte(v).raw; + const uint8x16_t w = vuzp1q_u8(org_v, org_v); + return Vec32(vget_low_u8(vuzp1q_u8(w, w))); +} +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const uint8x8_t org_v = detail::BitCastToByte(v).raw; + const uint8x8_t w = vuzp1_u8(org_v, org_v); + return Vec128(vuzp1_u8(w, w)); +} + +// In the following DemoteTo functions, |b| is purposely undefined. +// The value a needs to be extended to 128 bits so that vqmovn can be +// used and |b| is undefined so that no extra overhead is introduced. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 a = DemoteTo(Simd(), v); + Vec128 b; + uint16x8_t c = vcombine_u16(a.raw, b.raw); + return Vec128(vqmovn_u16(c)); +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 a = DemoteTo(Simd(), v); + Vec128 b; + int16x8_t c = vcombine_s16(a.raw, b.raw); + return Vec128(vqmovn_s16(c)); +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Convert integer <=> floating-point + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f32_s32(v.raw)); +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_s32(v.raw)); +} + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f32_u32(v.raw)); +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_u32(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_s32_f32(v.raw)); +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_s32_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f64_s64(v.raw)); +} +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_f64_s64(v.raw)); +} + +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f64_u64(v.raw)); +} +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_f64_u64(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_API Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_s64_f64(v.raw)); +} +HWY_API Vec64 ConvertTo(Full64 /* tag */, + const Vec64 v) { + return Vec64(vcvt_s64_f64(v.raw)); +} + +#endif + +// ------------------------------ Round (IfThenElse, mask, logical) + +#if HWY_ARCH_ARM_A64 +// Toward nearest integer +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) + +// Toward zero, aka truncate +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) + +// Toward +infinity, aka ceiling +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) + +// Toward -infinity, aka floor +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) +#else + +// ------------------------------ Trunc + +// ARMv7 only supports truncation to integer. We can either convert back to +// float (3 floating-point and 2 logic operations) or manipulate the binary32 +// representation, clearing the lowest 23-exp mantissa bits. This requires 9 +// integer operations and 3 constants, which is likely more expensive. + +namespace detail { + +// The original value is already the desired result if NaN or the magnitude is +// large (i.e. the value is already an integer). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +template +HWY_API Vec128 Trunc(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), int_f, v); +} + +template +HWY_API Vec128 Round(const Vec128 v) { + const DFromV df; + + // ARMv7 also lacks a native NearestInt, but we can instead rely on rounding + // (we assume the current mode is nearest-even) after addition with a large + // value such that no mantissa bits remain. We may need a compiler flag for + // precise floating-point to prevent this from being "optimized" out. + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +template +HWY_API Vec128 Ceil(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +template +HWY_API Vec128 Floor(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#endif + +// ------------------------------ NearestInt (Round) + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s32_f32(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s32_f32(v.raw)); +} + +#else + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const RebindToSigned> di; + return ConvertTo(di, Round(v)); +} + +#endif + +// ------------------------------ Floating-point classification +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +// <= 64 bit: just return different type +template +HWY_API Vec128 LowerHalf(const Vec128 v) { + return Vec128(v.raw); +} + +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u8(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u16(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u32(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_u64(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s8(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s16(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s32(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_s64(v.raw)); +} +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 LowerHalf(const Vec128 v) { + return Vec64(vget_low_f64(v.raw)); +} +#endif +HWY_API Vec64 LowerHalf(const Vec128 v) { + const Full128 du; + const Full64 dbh; + return BitCast(dbh, LowerHalf(BitCast(du, v))); +} + +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return LowerHalf(v); +} + +// ------------------------------ CombineShiftRightBytes + +// 128-bit +template > +HWY_API V128 CombineShiftRightBytes(Full128 d, V128 hi, V128 lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); + const Repartition d8; + uint8x16_t v8 = vextq_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, Vec128(v8)); +} + +// 64-bit +template +HWY_API Vec64 CombineShiftRightBytes(Full64 d, Vec64 hi, Vec64 lo) { + static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); + const Repartition d8; + uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, VFromD(v8)); +} + +// <= 32-bit defined after ShiftLeftBytes. + +// ------------------------------ Shift vector by constant #bytes + +namespace detail { + +// Partially specialize because kBytes = 0 and >= size are compile errors; +// callers replace the latter with 0xFF for easier specialization. +template +struct ShiftLeftBytesT { + // Full + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + const Full128 d; + return CombineShiftRightBytes<16 - kBytes>(d, v, Zero(d)); + } + + // Partial + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + // Expand to 64-bit so we only use the native EXT instruction. + const Full64 d64; + const auto zero64 = Zero(d64); + const decltype(zero64) v64(v.raw); + return Vec128( + CombineShiftRightBytes<8 - kBytes>(d64, v64, zero64).raw); + } +}; +template <> +struct ShiftLeftBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftLeftBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 /* v */) { + return Zero(Simd()); + } +}; + +template +struct ShiftRightBytesT { + template + HWY_INLINE Vec128 operator()(Vec128 v) { + const Simd d; + // For < 64-bit vectors, zero undefined lanes so we shift in zeros. + if (N * sizeof(T) < 8) { + constexpr size_t kReg = N * sizeof(T) == 16 ? 16 : 8; + const Simd dreg; + v = Vec128( + IfThenElseZero(FirstN(dreg, N), VFromD(v.raw)).raw); + } + return CombineShiftRightBytes(d, Zero(d), v); + } +}; +template <> +struct ShiftRightBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftRightBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 /* v */) { + return Zero(Simd()); + } +}; + +} // namespace detail + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + return detail::ShiftLeftBytesT < kBytes >= N * sizeof(T) ? 0xFF + : kBytes > ()(v); +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(Simd(), v); +} + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(Simd(), v); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + return detail::ShiftRightBytesT < kBytes >= N * sizeof(T) ? 0xFF + : kBytes > ()(v); +} + +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// Calls ShiftLeftBytes +template +HWY_API Vec128 CombineShiftRightBytes(Simd d, Vec128 hi, + Vec128 lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full64 d_full8; + const Repartition d_full; + using V64 = VFromD; + const V64 hi64(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V64 lo64 = ShiftLeftBytes<8 - kSize>(V64(BitCast(d8, lo).raw)); + const V64 r = CombineShiftRightBytes<8 - kSize + kBytes>(d_full8, hi64, lo64); + // After casting to full 64-bit vector of correct type, shrink to 32-bit + return Vec128(BitCast(d_full, r).raw); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u8(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u16(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u32(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_u64(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s8(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s16(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s32(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_s64(v.raw)); +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64(vget_high_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64 UpperHalf(Full64 /* tag */, + const Vec128 v) { + return Vec64(vget_high_f64(v.raw)); +} +#endif + +HWY_API Vec64 UpperHalf(Full64 dbh, + const Vec128 v) { + const RebindToUnsigned duh; + const Twice du; + return BitCast(dbh, UpperHalf(duh, BitCast(du, v))); +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128(upper.raw); +} + +// ------------------------------ Broadcast/splat any lane + +#if HWY_ARCH_ARM_A64 +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_u64(v.raw, kLane)); +} +// Vec64 is defined below. + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_s64(v.raw, kLane)); +} +// Vec64 is defined below. + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_f64(v.raw, kLane)); +} +template +HWY_API Vec64 Broadcast(const Vec64 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +#else +// No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); +} +// Vec64 is defined below. + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); +} +// Vec64 is defined below. + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} + +#endif + +template +HWY_API Vec64 Broadcast(const Vec64 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} +template +HWY_API Vec64 Broadcast(const Vec64 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + typename detail::Raw128::type raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#endif + + const Repartition d8; + using V8 = VFromD; + const Repartition d16; + + // Broadcast each lane index to all bytes of T and shift to bytes + static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); + if (sizeof(T) == 4) { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); + return Indices128{BitCast(d, sum).raw}; + } else { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); + return Indices128{BitCast(d, sum).raw}; + } +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const RebindToSigned di; + return BitCast( + d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128{idx.raw}))); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return Vec128(Shuffle2301(v)); +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32_u16(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u16(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64_u32(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse2(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u32(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64_u16(BitCast(du, v).raw))); +} +template +HWY_API Vec128 Reverse4(Full128 d, const Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u16(BitCast(du, v).raw))); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { + return Reverse(d, v); +} + +template +HWY_API Vec128 Reverse8(Simd, const Vec128) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ Other shuffles (TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return CombineShiftRightBytes<8>(Full128(), v, v); +} +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + return CombineShiftRightBytes<8>(Full128(), v, v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return CombineShiftRightBytes<4>(Full128(), v, v); +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return CombineShiftRightBytes<12>(Full128(), v, v); +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveLower, vzip1, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveLower, vzip1, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_u64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_s64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_f64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), b, Shuffle01(a)); +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), b, Shuffle01(a)); +} +#endif + +// Floats +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_f32(a.raw, b.raw)); +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1_f32(a.raw, b.raw)); +} + +// < 64 bit parts +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128(InterleaveLower(Vec64(a.raw), Vec64(b.raw)).raw); +} + +// Additional overload for the optional Simd<> tag. +template > +HWY_API V InterleaveLower(Simd /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveUpper, vzip2, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveUpper, vzip2, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vzip2q_u64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return Vec128(vzip2q_s64(a.raw, b.raw)); +} +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return Vec128(vzip2q_f64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), Shuffle01(b), a); +} +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return CombineShiftRightBytes<8>(Full128(), Shuffle01(b), a); +} +#endif + +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + return Vec128(vzip2q_f32(a.raw, b.raw)); +} +HWY_API Vec64 InterleaveUpper(const Vec64 a, + const Vec64 b) { + return Vec64(vzip2_f32(a.raw, b.raw)); +} + +} // namespace detail + +// Full register +template > +HWY_API V InterleaveUpper(Simd /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V(UpperHalf(d2, a).raw), V(UpperHalf(d2, b).raw)); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + const Repartition du16; + const RebindToUnsigned du32; + const Vec128 zero = Zero(du16); + const Vec128 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec128 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec128 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec128 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +HWY_API Vec128 ReorderWidenMulAccumulate(Full128 /*d32*/, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmlal_high_s16(sum1.raw, a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128( + vmlal_s16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128( + vmlal_s16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +HWY_API Vec64 ReorderWidenMulAccumulate(Full64 d32, + Vec64 a, + Vec64 b, + const Vec64 sum0, + Vec64& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64 mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +HWY_API Vec32 ReorderWidenMulAccumulate(Full32 d32, + Vec32 a, + Vec32 b, + const Vec32 sum0, + Vec32& sum1) { + const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// Full result +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u8(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, + Vec64 hi, Vec64 lo) { + return Vec128(vcombine_u16(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, + Vec64 hi, Vec64 lo) { + return Vec128(vcombine_u32(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, + Vec64 hi, Vec64 lo) { + return Vec128(vcombine_u64(lo.raw, hi.raw)); +} + +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s8(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s16(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s32(lo.raw, hi.raw)); +} +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s64(lo.raw, hi.raw)); +} + +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_f32(lo.raw, hi.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec128 Combine(Full128 /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_f64(lo.raw, hi.raw)); +} +#endif + +// < 64bit input, <= 64 bit result +template +HWY_API Vec128 Combine(Simd d, Vec128 hi, + Vec128 lo) { + // First double N (only lower halves will be used). + const Vec128 hi2(hi.raw); + const Vec128 lo2(lo.raw); + // Repartition to two unsigned lanes (each the size of the valid input). + const Simd, 2, 0> du; + return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); +} + +// ------------------------------ ZeroExtendVector (Combine) + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ ConcatLowerLower + +// 64 or 128-bit input: just interleave +template +HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); +} + +namespace detail { +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveOdd, vtrn2, _, 2) +#else + +// vtrn returns a struct with even and odd result. +#define HWY_NEON_BUILD_TPL_HWY_TRN +#define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t +// Pass raw args so we can accept uint16x2 args, for which there is no +// corresponding uint16x2x2 return type. +#define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ + Raw128::type a, Raw128::type b +#define HWY_NEON_BUILD_ARG_HWY_TRN a, b + +// Cannot use UINT8 etc. type macros because the x2_t tuples are only defined +// for full and half vectors. +HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) +#endif +} // namespace detail + +// <= 32-bit input/output +template +HWY_API Vec128 ConcatLowerLower(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[0])); +#endif +} + +// ------------------------------ ConcatUpperUpper + +// 64 or 128-bit input: just interleave +template +HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); +} + +// <= 32-bit input/output +template +HWY_API Vec128 ConcatUpperUpper(const Simd d, Vec128 hi, + Vec128 lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[1])); +#endif +} + +// ------------------------------ ConcatLowerUpper (ShiftLeftBytes) + +// 64 or 128-bit input: extract from concatenated +template +HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, + Vec128 lo) { + return CombineShiftRightBytes(d, hi, lo); +} + +// <= 32-bit input/output +template +HWY_API Vec128 ConcatLowerUpper(const Simd d, Vec128 hi, + Vec128 lo) { + constexpr size_t kSize = N * sizeof(T); + const Repartition d8; + const Full64 d8x8; + const Full64 d64; + using V8x8 = VFromD; + const V8x8 hi8x8(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V8x8 lo8x8 = ShiftLeftBytes<8 - kSize>(V8x8(BitCast(d8, lo).raw)); + const V8x8 r = CombineShiftRightBytes<8 - kSize / 2>(d8x8, hi8x8, lo8x8); + // Back to original lane type, then shrink N. + return Vec128(BitCast(d64, r).raw); +} + +// ------------------------------ ConcatUpperLower + +// Works for all N. +template +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, + Vec128 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd (InterleaveUpper) + +namespace detail { +// There is no vuzpq_u64. +HWY_NEON_DEF_FUNCTION_UIF81632(ConcatEven, vuzp1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF81632(ConcatOdd, vuzp2, _, 2) +} // namespace detail + +// Full/half vector +template = 8>* = nullptr> +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + return detail::ConcatOdd(lo, hi); +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx1Lx1 = BitCast(dw2, ConcatOdd(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec128(BitCast(d2, ConcatEven(dw2, Hx1Lx1, Hx1Lx1)).raw); +} + +// Any type x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// Full/half vector +template = 8>* = nullptr> +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + return detail::ConcatEven(lo, hi); +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx0Lx0 = BitCast(dw2, ConcatEven(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec128(BitCast(d2, ConcatEven(dw2, Hx0Lx0, Hx0Lx0)).raw); +} + +// Any type x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); +#endif +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(Simd(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); +#endif +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(Simd(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + const Simd d; + const Repartition d8; + alignas(16) constexpr uint8_t kBytes[16] = { + ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, + ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, + ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, + ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, + ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, + ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, + ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, + ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, + }; + const auto vec = BitCast(d, Load(d8, kBytes)); + return IfThenElse(MaskFromVec(vec), b, a); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ------------------------------ ReorderDemote2To (OddEven) + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec128 ReorderDemote2To(Full128 d16, + Vec128 a, Vec128 b) { + const Vec64 a16(vqmovn_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovn_high_s32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovn_s32(b.raw)); + return Combine(d16, a16, b16); +#endif +} + +HWY_API Vec64 ReorderDemote2To(Full64 /*d16*/, + Vec64 a, Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, a, b); + return Vec64(vqmovn_s32(ab.raw)); +} + +HWY_API Vec32 ReorderDemote2To(Full32 /*d16*/, + Vec32 a, Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_s32(a.raw, b.raw)); + return Vec32(vqmovn_s32(Combine(d32, ab, ab).raw)); +} + +// ================================================== CRYPTO + +#if defined(__ARM_FEATURE_AES) || \ + (HWY_HAVE_RUNTIME_DISPATCH && HWY_ARCH_ARM_A64) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + // NOTE: it is important that AESE and AESMC be consecutive instructions so + // they can be fused. AESE includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128(vaesmcq_u8(vaeseq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { + return Vec128((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); +} + +HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { + return Vec128( + (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); +} + +#endif // __ARM_FEATURE_AES + +// ================================================== MISC + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + const auto v4 = detail::ConcatEven(v3, v3); + return LowerHalf(LowerHalf(LowerHalf(v4))); +} + +HWY_API Vec32 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +HWY_API Vec64 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +// ------------------------------ MulEven (ConcatEven) + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const Full128 d; + int32x4_t a_packed = ConcatEven(d, a, a).raw; + int32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const Full128 d; + uint32x4_t a_packed = ConcatEven(d, a, a).raw; + uint32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + int32x2_t a_packed = ConcatEven(d, a, a).raw; + int32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + uint32x2_t a_packed = ConcatEven(d, a, a).raw; + uint32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { + uint64_t hi; + uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 0), vgetq_lane_u64(b.raw, 0), &hi); + return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +} + +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { + uint64_t hi; + uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 1), vgetq_lane_u64(b.raw, 1), &hi); + return Vec128(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +} + +// ------------------------------ TableLookupBytes (Combine, LowerHalf) + +// Both full +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d; + const Repartition d8; +#if HWY_ARCH_ARM_A64 + return BitCast(d, Vec128(vqtbl1q_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +#else + uint8x16_t table0 = BitCast(d8, bytes).raw; + uint8x8x2_t table; + table.val[0] = vget_low_u8(table0); + table.val[1] = vget_high_u8(table0); + uint8x16_t idx = BitCast(d8, from).raw; + uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); + return BitCast(d, Vec128(vcombine_u8(low, hi))); +#endif +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d_full; + const Vec64 from64(from.raw); + const auto idx_full = Combine(d_full, from64, from64); + const auto out_full = TableLookupBytes(bytes, idx_full); + return Vec128(LowerHalf(Half(), out_full).raw); +} + +// Partial table vector +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d_full; + return TableLookupBytes(Combine(d_full, bytes, bytes), from); +} + +// Partial both +template +HWY_API VFromD>> TableLookupBytes( + Vec128 bytes, Vec128 from) { + const Simd d; + const Simd d_idx; + const Repartition d_idx8; + // uint8x8 + const auto bytes8 = BitCast(Repartition(), bytes); + const auto from8 = BitCast(d_idx8, from); + const VFromD v8(vtbl1_u8(bytes8.raw, from8.raw)); + return BitCast(d_idx, v8); +} + +// For all vector widths; ARM anyway zeroes if >= 0x10. +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// u32/i32/f32: N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Shuffle2301(v10); +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Shuffle2301(v10)); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Shuffle2301(v10)); +} + +// full vectors +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + return Vec128(vdupq_n_u32(vaddvq_u32(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + return Vec128(vdupq_n_s32(vaddvq_s32(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + return Vec128(vdupq_n_f32(vaddvq_f32(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return Vec128(vdupq_n_u64(vaddvq_u64(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return Vec128(vdupq_n_s64(vaddvq_s64(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return Vec128(vdupq_n_f64(vaddvq_f64(v.raw))); +} +#else +// ARMv7 version for everything except doubles. +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + uint32x4x2_t v0 = vuzpq_u32(v.raw, v.raw); + uint32x4_t c0 = vaddq_u32(v0.val[0], v0.val[1]); + uint32x4x2_t v1 = vuzpq_u32(c0, c0); + return Vec128(vaddq_u32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + int32x4x2_t v0 = vuzpq_s32(v.raw, v.raw); + int32x4_t c0 = vaddq_s32(v0.val[0], v0.val[1]); + int32x4x2_t v1 = vuzpq_s32(c0, c0); + return Vec128(vaddq_s32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v) { + float32x4x2_t v0 = vuzpq_f32(v.raw, v.raw); + float32x4_t c0 = vaddq_f32(v0.val[0], v0.val[1]); + float32x4x2_t v1 = vuzpq_f32(c0, c0); + return Vec128(vaddq_f32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return v + Shuffle01(v); +} +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v) { + return v + Shuffle01(v); +} +#endif + +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// For u64/i64[/f64]. +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// Helper function to set 64 bits and potentially return a smaller vector. The +// overload is required to call the q vs non-q intrinsics. Note that 8-bit +// LoadMaskBits only requires 16 bits, but 64 avoids casting. +template +HWY_INLINE Vec128 Set64(Simd /* tag */, uint64_t mask_bits) { + const auto v64 = Vec64(vdup_n_u64(mask_bits)); + return Vec128(BitCast(Full64(), v64).raw); +} +template +HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { + return BitCast(d, Vec128(vdupq_n_u64(mask_bits))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const auto vmask_bits = Set64(du, mask_bits); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vmask_bits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns mask[i]? 0xF : 0 in each nibble. This is more efficient than +// BitsFromMask for use in (partial) CountTrue, FindFirstTrue and AllFalse. +template +HWY_INLINE uint64_t NibblesFromMask(const Full128 d, Mask128 mask) { + const Full128 du16; + const Vec128 vu16 = BitCast(du16, VecFromMask(d, mask)); + const Vec64 nib(vshrn_n_u16(vu16.raw, 4)); + return GetLane(BitCast(Full64(), nib)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(const Full64 d, Mask64 mask) { + // There is no vshrn_n_u16 for uint16x4, so zero-extend. + const Twice d2; + const Vec128 v128 = ZeroExtendVector(d2, VecFromMask(d, mask)); + // No need to mask, upper half is zero thanks to ZeroExtendVector. + return NibblesFromMask(d2, MaskFromVec(v128)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(Simd /*d*/, Mask128 mask) { + const Mask64 mask64(mask.raw); + const uint64_t nib = NibblesFromMask(Full64(), mask64); + // Clear nibbles from upper half of 64-bits + constexpr size_t kBytes = sizeof(T) * N; + return nib & ((1ull << (kBytes * 4)) - 1); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint8_t kSliceLanes[16] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, + }; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(Full128(), mask)) & Load(du, kSliceLanes); + +#if HWY_ARCH_ARM_A64 + // Can't vaddv - we need two separate bytes (16 bits). + const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); + const uint8x8_t x4 = vpadd_u8(x2, x2); + const uint8x8_t x8 = vpadd_u8(x4, x4); + return vget_lane_u64(vreinterpret_u64_u8(x8), 0); +#else + // Don't have vpaddq, so keep doubling lane size. + const uint16x8_t x2 = vpaddlq_u8(values.raw); + const uint32x4_t x4 = vpaddlq_u16(x2); + const uint64x2_t x8 = vpaddlq_u32(x4); + return (vgetq_lane_u64(x8, 1) << 8) | vgetq_lane_u64(x8, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; + +#if HWY_ARCH_ARM_A64 + return vaddv_u8(values.raw); +#else + const uint16x4_t x2 = vpaddl_u8(values.raw); + const uint32x2_t x4 = vpaddl_u16(x2); + const uint64x1_t x8 = vpaddl_u32(x4); + return vget_lane_u64(x8, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint16_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u16(values.raw); +#else + const uint32x4_t x2 = vpaddlq_u16(values.raw); + const uint64x2_t x4 = vpaddlq_u32(x2); + return vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u16(values.raw); +#else + const uint32x2_t x2 = vpaddl_u16(values.raw); + const uint64x1_t x4 = vpaddl_u32(x2); + return vget_lane_u64(x4, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u32(values.raw); +#else + const uint64x2_t x2 = vpaddlq_u32(values.raw); + return vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint32_t kSliceLanes[2] = {1, 2}; + const Simd d; + const RebindToUnsigned du; + const Vec128 slice(Load(Full64(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u32(values.raw); +#else + const uint64x1_t x2 = vpaddl_u32(values.raw); + return vget_lane_u64(x2, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + alignas(16) constexpr uint64_t kSliceLanes[2] = {1, 2}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, m)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u64(values.raw); +#else + return vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 m) { + const Full64 d; + const Full64 du; + const Vec64 values = BitCast(du, VecFromMask(d, m)) & Set(du, 1); + return vget_lane_u64(values.raw, 0); +} + +// Returns the lowest N for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) >= 8) ? bits : (bits & ((1ull << N) - 1)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +// Returns number of lanes whose mask is set. +// +// Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op +// ("vsubv"). ANDing with 1 would work but requires a constant. Negating also +// changes each lane to 1 (if mask set) or 0. +// NOTE: PopCount also operates on vectors, so we still have to do horizontal +// sums separately. We specialize CountTrue for full vectors (negating instead +// of PopCount because it avoids an extra shift), and use PopCount of +// NibblesFromMask for partial vectors. + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { + const Full128 di; + const int8x16_t ones = + vnegq_s8(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s8(ones)); +#else + const int16x8_t x2 = vpaddlq_s8(ones); + const int32x4_t x4 = vpaddlq_s16(x2); + const int64x2_t x8 = vpaddlq_s32(x4); + return static_cast(vgetq_lane_s64(x8, 0) + vgetq_lane_s64(x8, 1)); +#endif +} +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, const Mask128 mask) { + const Full128 di; + const int16x8_t ones = + vnegq_s16(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s16(ones)); +#else + const int32x4_t x2 = vpaddlq_s16(ones); + const int64x2_t x4 = vpaddlq_s32(x2); + return static_cast(vgetq_lane_s64(x4, 0) + vgetq_lane_s64(x4, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { + const Full128 di; + const int32x4_t ones = + vnegq_s32(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s32(ones)); +#else + const int64x2_t x2 = vpaddlq_s32(ones); + return static_cast(vgetq_lane_s64(x2, 0) + vgetq_lane_s64(x2, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 mask) { +#if HWY_ARCH_ARM_A64 + const Full128 di; + const int64x2_t ones = + vnegq_s64(BitCast(di, VecFromMask(Full128(), mask)).raw); + return static_cast(vaddvq_s64(ones)); +#else + const Full128 du; + const auto mask_u = VecFromMask(du, RebindMask(du, mask)); + const uint64x2_t ones = vshrq_n_u64(mask_u.raw, 63); + return static_cast(vgetq_lane_u64(ones, 0) + vgetq_lane_u64(ones, 1)); +#endif +} + +} // namespace detail + +// Full +template +HWY_API size_t CountTrue(Full128 /* tag */, const Mask128 mask) { + return detail::CountTrue(hwy::SizeTag(), mask); +} + +// Partial +template +HWY_API size_t CountTrue(Simd d, const Mask128 mask) { + constexpr int kDiv = 4 * sizeof(T); + return PopCount(detail::NibblesFromMask(d, mask)) / kDiv; +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd d, + const Mask128 mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(T); + return Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv; +} + +template +HWY_API intptr_t FindFirstTrue(const Simd d, + const Mask128 mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr int kDiv = 4 * sizeof(T); + return static_cast(Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, + uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API bool AllFalse(const Simd d, const Mask128 m) { + return detail::NibblesFromMask(d, m) == 0; +} + +// Full +template +HWY_API bool AllTrue(const Full128 d, const Mask128 m) { + return detail::NibblesFromMask(d, m) == ~0ull; +} +// Partial +template +HWY_API bool AllTrue(const Simd d, const Mask128 m) { + constexpr size_t kBytes = sizeof(T) * N; + return detail::NibblesFromMask(d, m) == (1ull << (kBytes * 4)) - 1; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +namespace detail { + +// Load 8 bytes, replicate into upper half so ZipLower can use the lower half. +HWY_INLINE Vec128 Load8Bytes(Full128 /*d*/, + const uint8_t* bytes) { + return Vec128(vreinterpretq_u8_u64( + vld1q_dup_u64(reinterpret_cast(bytes)))); +} + +// Load 8 bytes and return half-reg with N <= 8 bytes. +template +HWY_INLINE Vec128 Load8Bytes(Simd d, + const uint8_t* bytes) { + return Load(d, bytes); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#endif + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. +template +HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromBits(hwy::SizeTag(), mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromNotBits(hwy::SizeTag(), mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Simd d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNot(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits + +template +HWY_INLINE Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + using TU = TFromD; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Mask128 store_mask = RebindMask(d, FirstN(du, count)); + const Vec128 compressed = detail::Compress(BitCast(du, v), mask_bits); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_LOAD_INT(T, N) HWY_IF_GE64(T, N) +#define HWY_NEON_DEF_FUNCTION_LOAD_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64 +#define HWY_IF_LOAD_INT(T, N) \ + hwy::EnableIf= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +// Must return raw tuple because Tuple2 lack a ctor, and we cannot use +// brace-initialization in HWY_NEON_DEF_FUNCTION because some functions return +// void. +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple2().raw) +// Tuple tag arg allows overloading (cannot just overload on return type) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple2 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved2, vld2, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple3().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple3 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved3, vld3, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple4().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple4 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#undef HWY_NEON_DEF_FUNCTION_LOAD_INT +#undef HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#undef HWY_NEON_BUILD_ARG_HWY_LOAD_INT +} // namespace detail + +template +HWY_API void LoadInterleaved2(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1) { + auto raw = detail::LoadInterleaved2(unaligned, detail::Tuple2()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); +} + +// <= 32 bits: avoid loading more than N bytes by copying to buffer +template +HWY_API void LoadInterleaved2(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1) { + // The smallest vector registers are 64-bits and we want space for two. + alignas(16) T buf[2 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved2(buf, detail::Tuple2()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void LoadInterleaved2(Full128 d, T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1) { + const Half dh; + VFromD v00, v10, v01, v11; + LoadInterleaved2(dh, unaligned, v00, v10); + LoadInterleaved2(dh, unaligned + 2, v01, v11); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved3 + +template +HWY_API void LoadInterleaved3(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2) { + auto raw = detail::LoadInterleaved3(unaligned, detail::Tuple3()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void LoadInterleaved3(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2) { + // The smallest vector registers are 64-bits and we want space for three. + alignas(16) T buf[3 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved3(buf, detail::Tuple3()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void LoadInterleaved3(Full128 d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2) { + const Half dh; + VFromD v00, v10, v20, v01, v11, v21; + LoadInterleaved3(dh, unaligned, v00, v10, v20); + LoadInterleaved3(dh, unaligned + 3, v01, v11, v21); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved4 + +template +HWY_API void LoadInterleaved4(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2, + Vec128& v3) { + auto raw = detail::LoadInterleaved4(unaligned, detail::Tuple4()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); + v3 = Vec128(raw.val[3]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void LoadInterleaved4(Simd /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1, Vec128& v2, + Vec128& v3) { + alignas(16) T buf[4 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved4(buf, detail::Tuple4()); + v0 = Vec128(raw.val[0]); + v1 = Vec128(raw.val[1]); + v2 = Vec128(raw.val[2]); + v3 = Vec128(raw.val[3]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void LoadInterleaved4(Full128 d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2, + Vec128& v3) { + const Half dh; + VFromD v00, v10, v20, v30, v01, v11, v21, v31; + LoadInterleaved4(dh, unaligned, v00, v10, v20, v30); + LoadInterleaved4(dh, unaligned + 4, v01, v11, v21, v31); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); + v3 = Combine(d, v31, v30); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_LOAD_INT + +// ------------------------------ StoreInterleaved2 + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_STORE_INT +#define HWY_NEON_BUILD_RET_HWY_STORE_INT(type, size) void +#define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_STORE_INT(T, N) HWY_IF_GE64(T, N) +#define HWY_NEON_DEF_FUNCTION_STORE_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64 +#define HWY_IF_STORE_INT(T, N) \ + hwy::EnableIf= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple2 tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved2, vst2, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple3 tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved3, vst3, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple4 tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved4, vst4, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#undef HWY_NEON_DEF_FUNCTION_STORE_INT +#undef HWY_NEON_BUILD_TPL_HWY_STORE_INT +#undef HWY_NEON_BUILD_RET_HWY_STORE_INT +#undef HWY_NEON_BUILD_ARG_HWY_STORE_INT +} // namespace detail + +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[2 * 8 / sizeof(T)]; + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, buf); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Full128 d, T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, unaligned); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, unaligned + 2); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved3 + +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[3 * 8 / sizeof(T)]; + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, buf); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Full128 d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, + unaligned); + StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, + unaligned + 3); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved4 + +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[4 * 8 / sizeof(T)]; + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, buf); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Full128 d, T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), + LowerHalf(dh, v3), dh, unaligned); + StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), + UpperHalf(dh, v3), dh, unaligned + 4); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_STORE_INT + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128 eqHL = Eq(a, b); + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const Vec128 ltLx = DupEven(ltHL); + const Vec128 outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE Mask128 Lt128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE Mask128 Eq128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE Mask128 Eq128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE Mask128 Ne128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE Mask128 Ne128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +namespace detail { // for code folding +#if HWY_ARCH_ARM_V7 +#undef vuzp1_s8 +#undef vuzp1_u8 +#undef vuzp1_s16 +#undef vuzp1_u16 +#undef vuzp1_s32 +#undef vuzp1_u32 +#undef vuzp1_f32 +#undef vuzp1q_s8 +#undef vuzp1q_u8 +#undef vuzp1q_s16 +#undef vuzp1q_u16 +#undef vuzp1q_s32 +#undef vuzp1q_u32 +#undef vuzp1q_f32 +#undef vuzp2_s8 +#undef vuzp2_u8 +#undef vuzp2_s16 +#undef vuzp2_u16 +#undef vuzp2_s32 +#undef vuzp2_u32 +#undef vuzp2_f32 +#undef vuzp2q_s8 +#undef vuzp2q_u8 +#undef vuzp2q_s16 +#undef vuzp2q_u16 +#undef vuzp2q_s32 +#undef vuzp2q_u32 +#undef vuzp2q_f32 +#undef vzip1_s8 +#undef vzip1_u8 +#undef vzip1_s16 +#undef vzip1_u16 +#undef vzip1_s32 +#undef vzip1_u32 +#undef vzip1_f32 +#undef vzip1q_s8 +#undef vzip1q_u8 +#undef vzip1q_s16 +#undef vzip1q_u16 +#undef vzip1q_s32 +#undef vzip1q_u32 +#undef vzip1q_f32 +#undef vzip2_s8 +#undef vzip2_u8 +#undef vzip2_s16 +#undef vzip2_u16 +#undef vzip2_s32 +#undef vzip2_u32 +#undef vzip2_f32 +#undef vzip2q_s8 +#undef vzip2q_u8 +#undef vzip2q_s16 +#undef vzip2q_u16 +#undef vzip2q_s32 +#undef vzip2q_u32 +#undef vzip2q_f32 +#endif + +#undef HWY_NEON_BUILD_ARG_1 +#undef HWY_NEON_BUILD_ARG_2 +#undef HWY_NEON_BUILD_ARG_3 +#undef HWY_NEON_BUILD_PARAM_1 +#undef HWY_NEON_BUILD_PARAM_2 +#undef HWY_NEON_BUILD_PARAM_3 +#undef HWY_NEON_BUILD_RET_1 +#undef HWY_NEON_BUILD_RET_2 +#undef HWY_NEON_BUILD_RET_3 +#undef HWY_NEON_BUILD_TPL_1 +#undef HWY_NEON_BUILD_TPL_2 +#undef HWY_NEON_BUILD_TPL_3 +#undef HWY_NEON_DEF_FUNCTION +#undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS +#undef HWY_NEON_DEF_FUNCTION_ALL_TYPES +#undef HWY_NEON_DEF_FUNCTION_FLOAT_64 +#undef HWY_NEON_DEF_FUNCTION_INTS +#undef HWY_NEON_DEF_FUNCTION_INTS_UINTS +#undef HWY_NEON_DEF_FUNCTION_INT_16 +#undef HWY_NEON_DEF_FUNCTION_INT_32 +#undef HWY_NEON_DEF_FUNCTION_INT_8 +#undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_TPL +#undef HWY_NEON_DEF_FUNCTION_UIF81632 +#undef HWY_NEON_DEF_FUNCTION_UINTS +#undef HWY_NEON_DEF_FUNCTION_UINT_16 +#undef HWY_NEON_DEF_FUNCTION_UINT_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_8 +#undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 +#undef HWY_NEON_EVAL +} // namespace detail + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/arm_sve-inl.h b/media/highway/src/hwy/ops/arm_sve-inl.h new file mode 100644 index 000000000..1ccac9e6e --- /dev/null +++ b/media/highway/src/hwy/ops/arm_sve-inl.h @@ -0,0 +1,3151 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM SVE[2] vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +// If running on hardware whose vector length is known to be a power of two, we +// can skip fixups for non-power of two sizes. +#undef HWY_SVE_IS_POW2 +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#define HWY_SVE_IS_POW2 1 +#else +#define HWY_SVE_IS_POW2 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// Unsigned: +#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 64, 32, NAME, OP) + +// Signed: +#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) + +// Float: +#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 16, 16, NAME, OP) +#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 64, 32, NAME, OP) + +// For all element sizes: +#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories for a given element size: +#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories: +#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +// Assemble types for use in x-macros +#define HWY_SVE_T(BASE, BITS) BASE##BITS##_t +#define HWY_SVE_D(BASE, BITS, N, POW2) Simd +#define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t + +} // namespace detail + +#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using type = ScalableTag; \ + }; + +HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) +#undef HWY_SPECIALIZE + +// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX +// instructions, and we anyway only use it when the predicate is ptrue. + +// vector = f(vector), e.g. Not +#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +// vector = f(vector, scalar), e.g. detail::AddN +#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +// ------------------------------ Lanes + +namespace detail { + +// Returns actual lanes of a hardware vector without rounding to a power of two. +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<1> /* tag */) { + return svcntb_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<2> /* tag */) { + return svcnth_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<4> /* tag */) { + return svcntw_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<8> /* tag */) { + return svcntd_pat(SV_ALL); +} + +// All-true mask from a macro +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) + +#if HWY_SVE_IS_POW2 +#define HWY_SVE_PTRUE(BITS) HWY_SVE_ALL_PTRUE(BITS) +#else +#define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) + +// Returns actual lanes of a hardware vector, rounded down to a power of two. +template +HWY_INLINE size_t HardwareLanes() { + return svcntb_pat(SV_POW2); +} +template +HWY_INLINE size_t HardwareLanes() { + return svcnth_pat(SV_POW2); +} +template +HWY_INLINE size_t HardwareLanes() { + return svcntw_pat(SV_POW2); +} +template +HWY_INLINE size_t HardwareLanes() { + return svcntd_pat(SV_POW2); +} + +#endif // HWY_SVE_IS_POW2 + +} // namespace detail + +// Returns actual number of lanes after capping by N and shifting. May return 0 +// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). +#if HWY_TARGET == HWY_SVE_256 +template +HWY_API constexpr size_t Lanes(Simd /* d */) { + return HWY_MIN(detail::ScaleByPower(32 / sizeof(T), kPow2), N); +} +#elif HWY_TARGET == HWY_SVE2_128 +template +HWY_API constexpr size_t Lanes(Simd /* d */) { + return HWY_MIN(detail::ScaleByPower(16 / sizeof(T), kPow2), N); +} +#else +template +HWY_API size_t Lanes(Simd d) { + const size_t actual = detail::HardwareLanes(); + // Common case of full vectors: avoid any extra instructions. + if (detail::IsFull(d)) return actual; + return HWY_MIN(detail::ScaleByPower(actual, kPow2), N); +} +#endif // HWY_TARGET + +// ================================================== MASK INIT + +// One mask bit per byte; only the one belonging to the lowest byte is valid. + +// ------------------------------ FirstN +#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ + const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ + return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(limit)); \ + } +HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) +#undef HWY_SVE_FIRSTN + +template +using MFromD = decltype(FirstN(D(), 0)); + +namespace detail { + +#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_PTRUE(BITS); \ + } \ + template \ + HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_ALL_PTRUE(BITS); \ + } + +HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true +#undef HWY_SVE_WRAP_PTRUE + +HWY_API svbool_t PFalse() { return svpfalse_b(); } + +// Returns all-true if d is HWY_FULL or FirstN(N) after capping N. +// +// This is used in functions that load/store memory; other functions (e.g. +// arithmetic) can ignore d and use PTrue instead. +template +svbool_t MakeMask(D d) { + return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d)); +} + +} // namespace detail + +// ================================================== INIT + +// ------------------------------ Set +// vector = f(d, scalar), e.g. Set +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) arg) { \ + return sv##OP##_##CHAR##BITS(arg); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) +#undef HWY_SVE_SET + +// Required for Zero and VFromD +template +svuint16_t Set(Simd d, bfloat16_t arg) { + return Set(RebindToUnsigned(), arg.bits); +} + +template +using VFromD = decltype(Set(D(), TFromD())); + +// ------------------------------ Zero + +template +VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) + +// ------------------------------ BitCast + +namespace detail { + +// u8: no change +#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } + +// All other types +#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u8_##CHAR##BITS(v); \ + } \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ + return sv##OP##_##CHAR##BITS##_u8(v); \ + } + +HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) +HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) + +#undef HWY_SVE_CAST_NOP +#undef HWY_SVE_CAST + +template +HWY_INLINE svuint16_t BitCastFromByte(Simd /* d */, + svuint8_t v) { + return BitCastFromByte(Simd(), v); +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ================================================== LOGICAL + +// detail::*N() functions accept a scalar argument to avoid extra Set(). + +// ------------------------------ Not +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT + +// ------------------------------ And + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) + +template +HWY_API V And(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, And(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Or + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) + +template +HWY_API V Or(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) + +template +HWY_API V Xor(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ AndNot + +namespace detail { +#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) +#undef HWY_SVE_RETV_ARGPVN_SWAP +} // namespace detail + +#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) +#undef HWY_SVE_RETV_ARGPVV_SWAP + +template +HWY_API V AndNot(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +// Need to return original type instead of unsigned. +#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return BitCast(DFromV(), \ + sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) +#undef HWY_SVE_POPCNT + +// ================================================== SIGN + +// ------------------------------ Neg +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) + +// ------------------------------ Abs +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) + +// ------------------------------ CopySign[ToAbs] + +template +HWY_API V CopySign(const V magn, const V sign) { + const auto msb = SignBit(DFromV()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + const auto msb = SignBit(DFromV()); + return Or(abs, And(msb, sign)); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) + +// ------------------------------ Sub + +namespace detail { +// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. +#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) +#undef HWY_SVE_RETV_ARGPVN_MASK +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) + +// ------------------------------ SumsOf8 +HWY_API svuint64_t SumsOf8(const svuint8_t v) { + const ScalableTag du32; + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + + const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); + // Compute pairwise sum of u32 and extend to u64. + // TODO(janwas): on SVE2, we can instead use svaddp. + const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) + const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); + return Add(hi, lo); +} + +// ------------------------------ SaturatedAdd + +HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) + +// ------------------------------ SaturatedSub + +HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) + +// ------------------------------ AbsDiff +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) + +// ------------------------------ ShiftLeft[Same] + +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) + +// ------------------------------ ShiftRight[Same] + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) + +#undef HWY_SVE_SHIFT_N + +// ------------------------------ RotateRight + +// TODO(janwas): svxar on SVE2 +template +HWY_API V RotateRight(const V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shl/r + +#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToUnsigned> du; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + BitCast(du, bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) + +#undef HWY_SVE_SHIFT + +// ------------------------------ Min/Max + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) +} // namespace detail + +// ------------------------------ Mul +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, Mul, mul) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_RETV_ARGPVV, Mul, mul) + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// ------------------------------ MulHigh +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +// Not part of API, used internally: +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) + +// ------------------------------ MulFixedPoint15 +HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { +#if HWY_TARGET == HWY_SVE2 + return svqrdmulh_s16(a, b); +#else + const DFromV d; + const RebindToUnsigned du; + + const svuint16_t lo = BitCast(du, Mul(a, b)); + const svint16_t hi = MulHigh(a, b); + // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must + // carry that into the result. Instead isolate the top two bits because only + // they can influence the result. + const svuint16_t lo_top2 = ShiftRight<14>(lo); + // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. + const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); + return Add(Add(hi, hi), BitCast(d, rounding)); +#endif +} + +// ------------------------------ Div +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) + +// ------------------------------ ApproximateReciprocal +HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) + +// ------------------------------ Sqrt +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) + +// ------------------------------ ApproximateReciprocalSqrt +HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) + +// ------------------------------ MulAdd +#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulAdd, mad) + +// ------------------------------ NegMulAdd +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulAdd, msb) + +// ------------------------------ MulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) + +// ------------------------------ NegMulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) + +#undef HWY_SVE_FMA + +// ------------------------------ Round etc. + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) + +// ================================================== MASK + +// ------------------------------ RebindMask +template +HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { + return mask; +} + +// ------------------------------ Mask logical + +HWY_API svbool_t Not(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Per ARM, such bits are ignored. + return svnot_b_z(HWY_SVE_PTRUE(8), m); +} +HWY_API svbool_t And(svbool_t a, svbool_t b) { + return svand_b_z(b, b, a); // same order as AndNot for consistency +} +HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { + return svbic_b_z(b, b, a); // reversed order like NEON +} +HWY_API svbool_t Or(svbool_t a, svbool_t b) { + return svsel_b(a, a, b); // a ? true : b +} +HWY_API svbool_t Xor(svbool_t a, svbool_t b) { + return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. +} + +HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { + return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. +} + +// ------------------------------ CountTrue + +#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ + return sv##OP##_b##BITS(detail::MakeMask(d), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) +#undef HWY_SVE_COUNT_TRUE + +// For 16-bit Compress: full vector, not limited to SV_POW2. +namespace detail { + +#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ + return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) +#undef HWY_SVE_COUNT_TRUE_FULL + +} // namespace detail + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, svbool_t m) { + return !svptest_any(detail::MakeMask(d), m); +} + +// ------------------------------ AllTrue +template +HWY_API bool AllTrue(D d, svbool_t m) { + return CountTrue(d, m) == Lanes(d); +} + +// ------------------------------ FindFirstTrue +template +HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast( + CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); +} + +// ------------------------------ FindKnownFirstTrue +template +HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { + return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); +} + +// ------------------------------ IfThenElse +#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(m, yes, no); \ + } + +HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +#undef HWY_SVE_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template +HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse +template +HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { + return IfThenElse(mask, Zero(DFromV()), no); +} + +// ================================================== COMPARE + +// mask = f(vector, vector) +#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } + +// ------------------------------ Eq +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) +} // namespace detail + +// ------------------------------ Ne +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) +} // namespace detail + +// ------------------------------ Lt +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) +} // namespace detail + +// ------------------------------ Le +HWY_SVE_FOREACH_F(HWY_SVE_COMPARE, Le, cmple) + +#undef HWY_SVE_COMPARE +#undef HWY_SVE_COMPARE_N + +// ------------------------------ Gt/Ge (swapped order) +template +HWY_API svbool_t Gt(const V a, const V b) { + return Lt(b, a); +} +template +HWY_API svbool_t Ge(const V a, const V b) { + return Le(b, a); +} + +// ------------------------------ TestBit +template +HWY_API svbool_t TestBit(const V a, const V bit) { + return detail::NeN(And(a, bit), 0); +} + +// ------------------------------ MaskFromVec (Ne) +template +HWY_API svbool_t MaskFromVec(const V v) { + return detail::NeN(v, static_cast>(0)); +} + +// ------------------------------ VecFromMask +template +HWY_API VFromD VecFromMask(const D d, svbool_t mask) { + const RebindToSigned di; + // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which + // requires an extra instruction plus M0 pipeline. + return BitCast(d, IfThenElseZero(mask, Set(di, -1))); +} + +// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) + +#if HWY_TARGET == HWY_SVE2 + +#define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ + HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(yes, no, mask); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) +#undef HWY_SVE_IF_VEC + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#else + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_TARGET == HWY_SVE2 + +// ------------------------------ Floating-point classification (Ne) + +template +HWY_API svbool_t IsNaN(const V v) { + return Ne(v, v); // could also use cmpuo +} + +template +HWY_API svbool_t IsInf(const V v) { + using T = TFromV; + const DFromV d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqN(Add(vi, vi), hwy::MaxExponentTimes2())); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API svbool_t IsFinite(const V v) { + using T = TFromV; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField())); +} + +// ================================================== MEMORY + +// ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream + +#define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(detail::MakeMask(d), p); \ + } + +#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(m, p); \ + } + +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \ + } + +#define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), p, v); \ + } + +#define HWY_SVE_BLENDED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(m, p, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1) +HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1) +HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDup128, ld1rq) +HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1) +HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1) +HWY_SVE_FOREACH(HWY_SVE_BLENDED_STORE, BlendedStore, st1) + +#undef HWY_SVE_LOAD +#undef HWY_SVE_MASKED_LOAD +#undef HWY_SVE_LOAD_DUP128 +#undef HWY_SVE_STORE +#undef HWY_SVE_BLENDED_STORE + +// BF16 is the same as svuint16_t because BF16 is optional before v8.6. +template +HWY_API svuint16_t Load(Simd d, + const bfloat16_t* HWY_RESTRICT p) { + return Load(RebindToUnsigned(), + reinterpret_cast(p)); +} + +template +HWY_API void Store(svuint16_t v, Simd d, + bfloat16_t* HWY_RESTRICT p) { + Store(v, RebindToUnsigned(), + reinterpret_cast(p)); +} + +// ------------------------------ Load/StoreU + +// SVE only requires lane alignment, not natural alignment of the entire +// vector. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ ScatterOffset/Index + +#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ + v); \ + } + +#define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, HWY_SVE_V(int, BITS) index) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, index, v); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter) +#undef HWY_SVE_SCATTER_OFFSET +#undef HWY_SVE_SCATTER_INDEX + +// ------------------------------ GatherOffset/Index + +#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ + offset); \ + } +#define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) index) { \ + return sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, \ + index); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather) +#undef HWY_SVE_GATHER_OFFSET +#undef HWY_SVE_GATHER_INDEX + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ + const sv##BASE##BITS##x2_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget2(tuple, 0); \ + v1 = svget2(tuple, 1); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) + +#undef HWY_SVE_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2) { \ + const sv##BASE##BITS##x3_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget3(tuple, 0); \ + v1 = svget3(tuple, 1); \ + v2 = svget3(tuple, 2); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) + +#undef HWY_SVE_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ + const sv##BASE##BITS##x4_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget4(tuple, 0); \ + v1 = svget4(tuple, 1); \ + v2 = svget4(tuple, 2); \ + v3 = svget4(tuple, 3); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) + +#undef HWY_SVE_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x2_t tuple = svcreate2##_##CHAR##BITS(v0, v1); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, tuple); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) + +#undef HWY_SVE_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, triple); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) + +#undef HWY_SVE_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x4_t quad = \ + svcreate4##_##CHAR##BITS(v0, v1, v2, v3); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, quad); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) + +#undef HWY_SVE_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// Same sign +#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) + +// 2x +template +HWY_API svuint32_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} + +// Sign change +template +HWY_API svint16_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RebindToUnsigned du; + return BitCast(dto, PromoteTo(du, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svuint16_t vfrom) { + const RebindToUnsigned du; + return BitCast(dto, PromoteTo(du, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svuint8_t vfrom) { + const Repartition> du16; + const Repartition di16; + return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom))); +} + +// ------------------------------ PromoteTo F + +// Unlike Highway's ZipLower, this returns the same type. +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) +} // namespace detail + +template +HWY_API svfloat32_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f32_f16_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svint32_t v) { + const svint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_s32_x(detail::PTrue(Simd()), vv); +} + +// For 16-bit Compress +namespace detail { +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +#undef HWY_SVE_PROMOTE_TO + +template +HWY_API svfloat32_t PromoteUpperTo(Simd df, svfloat16_t v) { + const RebindToUnsigned du; + const RepartitionToNarrow dn; + return BitCast(df, PromoteUpperTo(du, BitCast(dn, v))); +} + +} // namespace detail + +// ------------------------------ DemoteTo U + +namespace detail { + +// Saturates unsigned vectors to half/quarter-width TN. +template +VU SaturateU(VU v) { + return detail::MinN(v, static_cast>(LimitsMax())); +} + +// Saturates unsigned vectors to half/quarter-width TN. +template +VI SaturateI(VI v) { + return detail::MinN(detail::MaxN(v, LimitsMin()), LimitsMax()); +} + +} // namespace detail + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint16_t v) { + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint8_t vn = BitCast(dn, detail::SaturateU(clamped)); + return svuzp1_u8(vn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint32_t v) { + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint16_t vn = BitCast(dn, detail::SaturateU(clamped)); + return svuzp1_u16(vn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint32_t v) { + const DFromV di; + const RebindToUnsigned du; + const RepartitionToNarrow d2; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and quarter the width. + const svuint16_t cast16 = BitCast(d2, detail::SaturateU(clamped)); + const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); + return svuzp1_u8(x2, x2); +} + +HWY_API svuint8_t U8FromU32(const svuint32_t v) { + const DFromV du32; + const RepartitionToNarrow du16; + const RepartitionToNarrow du8; + + const svuint16_t cast16 = BitCast(du16, v); + const svuint16_t x2 = svuzp1_u16(cast16, cast16); + const svuint8_t cast8 = BitCast(du8, x2); + return svuzp1_u8(cast8, cast8); +} + +// ------------------------------ Truncations + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + const svuint8_t v3 = svuzp1_u8(v2, v2); + return svuzp1_u8(v3, v3); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + const svuint16_t v2 = svuzp1_u16(v1, v1); + return svuzp1_u16(v2, v2); +} + +template +HWY_API svuint32_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint32_t v1 = BitCast(d, v); + return svuzp1_u32(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + return svuzp1_u8(v2, v2); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + return svuzp1_u16(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint16_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + return svuzp1_u8(v1, v1); +} + +// ------------------------------ DemoteTo I + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint16_t v) { +#if HWY_TARGET == HWY_SVE2 + const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); +#else + using TN = TFromD; + const svint8_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s8(vn, vn); +} + +template +HWY_API svint16_t DemoteTo(Simd dn, const svint32_t v) { +#if HWY_TARGET == HWY_SVE2 + const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); +#else + using TN = TFromD; + const svint16_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s16(vn, vn); +} + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { + const RepartitionToWide d2; +#if HWY_TARGET == HWY_SVE2 + const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); +#else + using TN = TFromD; + const svint16_t cast16 = BitCast(d2, detail::SaturateI(v)); +#endif + const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); + return BitCast(dn, svuzp1_s8(v2, v2)); +} + +// ------------------------------ ConcatEven/ConcatOdd + +// WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the +// full vector length, not rounded down to a power of two as we require). +namespace detail { + +#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) +#endif +#undef HWY_SVE_CONCAT_EVERY_SECOND + +// Used to slide up / shift whole register left; mask indicates which range +// to take from lo, and the rest is filled from hi starting at its lowest. +#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) +#undef HWY_SVE_SPLICE + +} // namespace detail + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + (void)d; + return detail::ConcatOddFull(hi, lo); +#else + const VFromD hi_odd = detail::ConcatOddFull(hi, hi); + const VFromD lo_odd = detail::ConcatOddFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + (void)d; + return detail::ConcatEvenFull(hi, lo); +#else + const VFromD hi_odd = detail::ConcatEvenFull(hi, hi); + const VFromD lo_odd = detail::ConcatEvenFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +#endif +} + +// ------------------------------ DemoteTo F + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svuint16_t DemoteTo(Simd /* d */, svfloat32_t v) { + const svuint16_t in_even = BitCast(ScalableTag(), v); + return detail::ConcatOddFull(in_even, in_even); // lower half +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +// ------------------------------ ConvertTo F + +#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* signed integers */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* unsigned integers */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } + +// API only requires f32 but we provide f64 for use by Iota. +HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) +#undef HWY_SVE_CONVERT + +// ------------------------------ NearestInt (Round, ConvertTo) +template >> +HWY_API VFromD NearestInt(VF v) { + // No single instruction, round then truncate. + return ConvertTo(DI(), Round(v)); +} + +// ------------------------------ Iota (Add, ConvertTo) + +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) first) { \ + return sv##OP##_##CHAR##BITS(first, 1); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) +#undef HWY_SVE_IOTA + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToSigned di; + return detail::AddN(ConvertTo(d, Iota(di, 0)), first); +} + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipLowerSame(a, b); +#else + // Move lower halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatEvenFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper + +// Only use zip2 if vector are a powers of two, otherwise getting the actual +// "upper half" requires MaskUpperHalf. +#if HWY_TARGET == HWY_SVE2_128 +namespace detail { +// Unlike Highway's ZipUpper, this returns the same type. +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) +} // namespace detail +#endif + +// Full vector: guaranteed to have at least one block +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipUpperSame(a, b); +#else + // Move upper halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatOddFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +// Capped/fraction: need runtime check +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { + // Less than one block: treat as capped + if (Lanes(d) * sizeof(TFromD) < 16) { + const Half d2; + return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); + } + return InterleaveUpper(DFromV(), a, b); +} + +// ================================================== COMBINE + +namespace detail { + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 32: + return svptrue_pat_b8(SV_VL16); + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 16: + return svptrue_pat_b16(SV_VL8); + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 8: + return svptrue_pat_b32(SV_VL4); + case 4: + return svptrue_pat_b32(SV_VL2); + default: + return svptrue_pat_b32(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 4: + return svptrue_pat_b64(SV_VL2); + default: + return svptrue_pat_b64(SV_VL1); + } +} +#endif +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + return svptrue_pat_b32(Lanes(d) == 4 ? SV_VL2 : SV_VL1); +} +template +svbool_t MaskLowerHalf(D /*d*/) { + return svptrue_pat_b64(SV_VL1); +} +#endif // HWY_TARGET == HWY_SVE2_128 +#if HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128 +template +svbool_t MaskLowerHalf(D d) { + return FirstN(d, Lanes(d) / 2); +} +#endif + +template +svbool_t MaskUpperHalf(D d) { + // TODO(janwas): WHILEGE on pow2 SVE2 + if (HWY_SVE_IS_POW2 && IsFull(d)) { + return Not(MaskLowerHalf(d)); + } + + // For Splice to work as intended, make sure bits above Lanes(d) are zero. + return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); +} + +// Right-shift vector pair by constexpr; can be used to slide down (=N) or up +// (=Lanes()-N). +#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ + } +HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) +#undef HWY_SVE_EXT + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { + return IfThenElse(detail::MaskLowerHalf(d), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatEvenBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); +#endif + } + return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes + if (detail::IsFull(d)) { + return detail::Ext(hi, lo); + } +#endif + return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatOddBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); +#endif + } + const svbool_t mask_upper = detail::MaskUpperHalf(d); + const V lo_upper = detail::Splice(lo, lo, mask_upper); + return IfThenElse(mask_upper, hi, lo_upper); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(const D d, const V2 hi, const V2 lo) { + return ConcatLowerLower(d, hi, lo); +} + +// ------------------------------ ZeroExtendVector +template +HWY_API V ZeroExtendVector(const D d, const V lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ Lower/UpperHalf + +template +HWY_API V LowerHalf(D2 /* tag */, const V v) { + return v; +} + +template +HWY_API V LowerHalf(const V v) { + return v; +} + +template +HWY_API V UpperHalf(const DH dh, const V v) { + const Twice d; + // Cast so that we support bfloat16_t. + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes + return BitCast(d, detail::Ext(vu, vu)); +#else + const MFromD mask = detail::MaskUpperHalf(du); + return BitCast(d, detail::Splice(vu, vu, mask)); +#endif +} + +// ================================================== REDUCE + +// These return T, whereas the Highway op returns a broadcasted vector. +namespace detail { +#define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ + using T = HWY_SVE_T(BASE, BITS); \ + using TU = MakeUnsigned; \ + constexpr uint64_t kMask = LimitsMax(); \ + return static_cast(static_cast( \ + static_cast(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ + } + +#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(pg, v); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) +// NaN if all are +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) + +#undef HWY_SVE_REDUCE +#undef HWY_SVE_REDUCE_ADD +} // namespace detail + +template +V SumOfLanes(D d, V v) { + return Set(d, detail::SumOfLanesM(detail::MakeMask(d), v)); +} + +template +V MinOfLanes(D d, V v) { + return Set(d, detail::MinOfLanesM(detail::MakeMask(d), v)); +} + +template +V MaxOfLanes(D d, V v) { + return Set(d, detail::MaxOfLanesM(detail::MakeMask(d), v)); +} + + +// ================================================== SWIZZLE + +// ------------------------------ GetLane + +namespace detail { +#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_T(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) +#undef HWY_SVE_GET_LANE +} // namespace detail + +template +HWY_API TFromV GetLane(V v) { + return detail::GetLaneM(v, detail::PFalse()); +} + +// ------------------------------ ExtractLane +template +HWY_API TFromV ExtractLane(V v, size_t i) { + return detail::GetLaneM(v, FirstN(DFromV(), i)); +} + +// ------------------------------ InsertLane (IfThenElse) +template +HWY_API V InsertLane(const V v, size_t i, TFromV t) { + const DFromV d; + const auto is_i = detail::EqN(Iota(d, 0), static_cast>(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ DupEven + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) +} // namespace detail + +template +HWY_API V DupEven(const V v) { + return detail::InterleaveEven(v, v); +} + +// ------------------------------ DupOdd + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) +} // namespace detail + +template +HWY_API V DupOdd(const V v) { + return detail::InterleaveOdd(v, v); +} + +// ------------------------------ OddEven + +#if HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE2 + +#define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ + return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) +#undef HWY_SVE_ODD_EVEN + +template +HWY_API V OddEven(const V odd, const V even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); +} + +#else + +template +HWY_API V OddEven(const V odd, const V even) { + const auto odd_in_even = detail::Ext<1>(odd, odd); + return detail::InterleaveEven(even, odd_in_even); +} + +#endif // HWY_TARGET + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V odd, const V even) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperLower(d, odd, even); +#elif HWY_TARGET == HWY_SVE2_128 + (void)odd; + (void)d; + return even; +#else + const RebindToUnsigned du; + using TU = TFromD; + constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); + const auto idx_block = ShiftRight(Iota(du, 0)); + const auto lsb = detail::AndN(idx_block, static_cast(1)); + const svbool_t is_even = detail::EqN(lsb, static_cast(0)); + return IfThenElse(is_even, even, odd); +#endif +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + using TI = TFromV; + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size mismatch"); + const RebindToUnsigned du; + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllTrue(du, detail::LtN(indices, static_cast(Lanes(d))))); +#else + (void)d; +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +// <32bit are not part of Highway API, but used in Broadcast. +#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(v, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) +#undef HWY_SVE_TABLE + +// ------------------------------ SwapAdjacentBlocks (TableLookupLanes) + +namespace detail { + +template +constexpr size_t LanesPerBlock(Simd /* tag */) { + // We might have a capped vector smaller than a block, so honor that. + return HWY_MIN(16 / sizeof(T), detail::ScaleByPower(N, kPow2)); +} + +} // namespace detail + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerUpper(d, v, v); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#else + const RebindToUnsigned du; + constexpr auto kLanesPerBlock = + static_cast>(detail::LanesPerBlock(d)); + const VFromD idx = detail::XorN(Iota(du, 0), kLanesPerBlock); + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ Reverse + +namespace detail { + +#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) +#undef HWY_SVE_REVERSE + +} // namespace detail + +template +HWY_API V Reverse(D d, V v) { + using T = TFromD; + const auto reversed = detail::ReverseFull(v); + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; + // Shift right to remove extra (non-pow2 and remainder) lanes. + // TODO(janwas): on SVE2, use WHILEGE. + // Avoids FirstN truncating to the return vector size. Must also avoid Not + // because that is limited to SV_POW2. + const ScalableTag dfull; + const svbool_t all_true = detail::AllPTrue(dfull); + const size_t all_lanes = detail::AllHardwareLanes(hwy::SizeTag()); + const svbool_t mask = + svnot_b_z(all_true, FirstN(dfull, all_lanes - Lanes(d))); + return detail::Splice(reversed, reversed, mask); +} + +// ------------------------------ Reverse2 + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { // 3210 +#if HWY_TARGET == HWY_SVE2_128 + if (detail::IsFull(d)) { + return detail::Ext<1>(v, v); + } +#endif + (void)d; + const auto odd_in_even = detail::Ext<1>(v, v); // x321 + return detail::InterleaveEven(odd_in_even, v); // 2301 +} +// ------------------------------ Reverse4 (TableLookupLanes) +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + if (HWY_TARGET == HWY_SVE_256 && sizeof(TFromD) == 8 && + detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse8 (TableLookupLanes) +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 7); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Compress (PromoteTo) + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + // Optimization for 64-bit lanes (could also be applied to 32-bit, but that + // requires a larger table). + enum { value = (sizeof(T) == 8) }; +#else + enum { value = 0 }; +#endif // HWY_TARGET == HWY_SVE_256 +}; + +#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) +HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) +#else +HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) +#endif +#undef HWY_SVE_COMPRESS + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompress64x4Tables + 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, + 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, + 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +} +#endif // HWY_TARGET == HWY_SVE_256 +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 + // unchanged and map everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(maskLL, mask)); +} +#endif // HWY_TARGET == HWY_SVE_256 + +template +HWY_API V Compress(V v, svbool_t mask16) { + static_assert(!IsSame(), "Must use overload"); + const DFromV d16; + + // Promote vector and mask to 32-bit + const RepartitionToWide dw; + const auto v32L = PromoteTo(dw, v); + const auto v32H = detail::PromoteUpperTo(dw, v); + const svbool_t mask32L = svunpklo_b(mask16); + const svbool_t mask32H = svunpkhi_b(mask16); + + const auto compressedL = Compress(v32L, mask32L); + const auto compressedH = Compress(v32H, mask32H); + + // Demote to 16-bit (already in range) - separately so we can splice + const V evenL = BitCast(d16, compressedL); + const V evenH = BitCast(d16, compressedH); + const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half + const V v16H = detail::ConcatEvenFull(evenH, evenH); + + // We need to combine two vectors of non-constexpr length, so the only option + // is Splice, which requires us to synthesize a mask. NOTE: this function uses + // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. + const size_t countL = detail::CountTrueFull(dw, mask32L); + const auto compressed_maskL = FirstN(d16, countL); + return detail::Splice(v16H, v16L, compressed_maskL); +} + +// Must treat float16_t as integers so we can ConcatEven. +HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { + const DFromV df; + const RebindToSigned di; + return BitCast(df, Compress(BitCast(di, v), mask16)); +} + +// ------------------------------ CompressNot + +template +HWY_API V CompressNot(V v, const svbool_t mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API V CompressNot(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map + // 01 to 10, and everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(mask, maskLL)); +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompressNot64x4Tables + 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, + 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, + 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif // HWY_TARGET == HWY_SVE_256 + + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 + (void)mask; + return v; +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + uint64_t bits = 0; // predicate reg is 32-bit + CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient + // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. + const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); + // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. + alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, + 0, 1, 2, 3, 0, 1, 2, 3}; + const ScalableTag d; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif + + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + const svbool_t store_mask = FirstN(d, count); + BlendedStore(Compress(v, mask), store_mask, d, unaligned); + return count; +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes + +// Prevent accidentally using these for 128-bit vectors - should not be +// necessary. +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return detail::AndNotN(static_cast(LanesPerBlock(d) - 1), iota0); +} + +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint8_t idx_mod = + svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, + 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, + 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, + 15 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint16_t idx_mod = + svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint32_t idx_mod = + svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint64_t idx_mod = + svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, detail::Ext(hi8, lo8)); +#else + const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); + const auto lo_down = detail::Ext(lo8, lo8); + const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +#endif +} + +// ------------------------------ Shuffle2301 +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return Reverse2(d, v); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { +#if HWY_TARGET == HWY_SVE_256 + if (detail::IsFull(d)) { + return SwapAdjacentBlocks(v); + } else if (detail::IsFull(Twice())) { + return v; + } +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#endif + const Repartition du64; + return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); +} + +// ------------------------------ TableLookupBytes + +template +HWY_API VI TableLookupBytes(const V v, const VI idx) { + const DFromV d; + const Repartition du8; +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); +#else + const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); + const auto idx8 = Add(BitCast(du8, idx), offsets128); + return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); +#endif +} + +template +HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { + const DFromV d; + // Mask size must match vector type, so cast everything to this type. + const Repartition di8; + + auto idx8 = BitCast(di8, idx); + const auto msb = detail::LtN(idx8, 0); + + const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Broadcast + +#if HWY_TARGET == HWY_SVE2_128 +namespace detail { +#define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v, kLane); \ + } + +HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) +#undef HWY_SVE_BROADCAST +} // namespace detail +#endif + +template +HWY_API V Broadcast(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); +#if HWY_TARGET == HWY_SVE2_128 + return detail::BroadcastLane(v); +#else + auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); + if (kLane != 0) { + idx = detail::AddN(idx, kLane); + } + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(D d, const V v) { + const auto zero = Zero(d); + const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); +#if HWY_TARGET == HWY_SVE2_128 + return shifted; +#else + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + return IfThenElse(detail::FirstNPerBlock(d), zero, shifted); +#endif +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template > +HWY_API V ShiftRightLanes(D d, V v) { + // For capped/fractional vectors, clear upper lanes so we shift in zeros. + if (!detail::IsFull(d)) { + v = IfThenElseZero(detail::MakeMask(d), v); + } + +#if HWY_TARGET == HWY_SVE2_128 + return detail::Ext(Zero(d), v); +#else + const auto shifted = detail::Ext(v, v); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const svbool_t mask = detail::FirstNPerBlock(d); + return IfThenElseZero(mask, shifted); +#endif +} + +// ------------------------------ ShiftLeftBytes + +template > +HWY_API V ShiftLeftBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(const V a, const V b) { + return BitCast(DW(), InterleaveLower(D(), a, b)); +} + +// ------------------------------ ZipUpper +template >> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ PromoteTo bfloat16 (ZipLower) +template +HWY_API svfloat32_t PromoteTo(Simd df32, + const svuint16_t v) { + return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), v)); +} + +// ------------------------------ ReorderDemote2To (OddEven) + +template +HWY_API svuint16_t ReorderDemote2To(Simd dbf16, + svfloat32_t a, svfloat32_t b) { + const RebindToUnsigned du16; + const Repartition du32; + const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +template +HWY_API svint16_t ReorderDemote2To(Simd d16, svint32_t a, + svint32_t b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + (void)d16; + const svint16_t a_in_even = svqxtnb_s32(a); + return svqxtnt_s32(a_in_even, b); +#else + const Half dh; + const svint16_t a16 = BitCast(dh, detail::SaturateI(a)); + const svint16_t b16 = BitCast(dh, detail::SaturateI(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +// ------------------------------ ZeroIfNegative (Lt, IfThenElse) +template +HWY_API V ZeroIfNegative(const V v) { + return IfThenZeroElse(detail::LtN(v, 0), v); +} + +// ------------------------------ BroadcastSignBit (ShiftRight) +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + const svbool_t m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ AverageRound (ShiftRight) + +#if HWY_TARGET == HWY_SVE2 +HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +#else +template +V AverageRound(const V a, const V b) { + return ShiftRight<1>(detail::AddN(Add(a, b), 1)); +} +#endif // HWY_TARGET == HWY_SVE2 + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const svuint8_t iota = Iota(du, 0); + + // Load correct number of bytes (bits/8) with 7 zeros after each. + const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // There may be up to 128 bits; avoid reading past the end. + const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); + + // Replicate bytes 16x such that each lane contains the bit that governs it. + const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, rep16), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, + // so we can skip computing the actual length (Lanes(du)+7)/8. + const svuint8_t bytes = svld1(FirstN(du8, 8), bits); + + // Replicate bytes 32x such that each lane contains the bit that governs it. + const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); + + // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. + const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); + + return TestBit(BitCast(du, rep32), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + + // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. + // The "at least 8 byte" guarantee in quick_reference ensures this is safe. + uint32_t mask_bits; + CopyBytes<4>(bits, &mask_bits); // copy from bytes + const auto vbits = Set(du, mask_bits); + + // 2 ^ {0,1, .., 31}, will not have more lanes than that. + const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); + + return TestBit(vbits, bit); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return svdup_n_u8_z(m, 1); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d8; + const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); + return detail::ConcatEvenFull(b16, b16); // lower half +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return U8FromU32(svdup_n_u32_z(m, 1)); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d32; + const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); + return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half +} + +// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. +HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { + const ScalableTag d8; + const ScalableTag d16; + const ScalableTag d32; + const ScalableTag d64; + // TODO(janwas): could use SVE2 BDEP, but it's optional. + x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); + x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); + x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); + return BitCast(d64, x); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +// TODO(janwas): specialize for HWY_SVE_256 +template +HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { + svuint64_t bits_in_u64 = + detail::BitsFromBool(detail::BoolFromMask>(m)); + + const size_t num_bits = Lanes(d); + const size_t num_bytes = (num_bits + 8 - 1) / 8; // Round up, see below + + // Truncate each u64 to 8 bits and store to u8. + svst1b_u64(FirstN(ScalableTag(), num_bytes), bits, bits_in_u64); + + // Non-full byte, need to clear the undefined upper bits. Can happen for + // capped/fractional vectors or large T and small hardware vectors. + if (num_bits < 8) { + const int mask = static_cast((1ull << num_bits) - 1); + bits[0] = static_cast(bits[0] & mask); + } + // Else: we wrote full bytes because num_bits is a power of two >= 8. + + return num_bytes; +} + +// ------------------------------ CompressBits (LoadMaskBits) +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ MulEven (InterleaveEven) + +#if HWY_TARGET == HWY_SVE2 +namespace detail { +#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +#undef HWY_SVE_MUL_EVEN +} // namespace detail +#endif + +template >> +HWY_API VFromD MulEven(const V a, const V b) { +#if HWY_TARGET == HWY_SVE2 + return BitCast(DW(), detail::MulEvenNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveEven(lo, hi)); +#endif +} + +HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd df32, + svuint16_t a, svuint16_t b, + const svfloat32_t sum0, + svfloat32_t& sum1) { + // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16. + const Repartition du16; + const RebindToUnsigned du32; + const svuint16_t zero = Zero(du16); + const svuint32_t a0 = ZipLower(du32, zero, BitCast(du16, a)); + const svuint32_t a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const svuint32_t b0 = ZipLower(du32, zero, BitCast(du16, b)); + const svuint32_t b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +template +HWY_API svint32_t ReorderWidenMulAccumulate(Simd d32, + svint16_t a, svint16_t b, + const svint32_t sum0, + svint32_t& sum1) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + (void)d32; + sum1 = svmlalt_s32(sum1, a, b); + return svmlalb_s32(sum0, a, b); +#else + const svbool_t pg = detail::PTrue(d32); + const svint32_t a0 = svunpklo_s32(a); + const svint32_t b0 = svunpklo_s32(b); + svint32_t a1, b1; + if (detail::IsFull(d32)) { + a1 = svunpkhi_s32(a); + b1 = svunpkhi_s32(b); + } else { + const Rebind d16h; + a1 = svunpklo_s32(UpperHalf(d16h, a)); + b1 = svunpklo_s32(UpperHalf(d16h, b)); + } + sum1 = svmla_s32_x(pg, sum1, a1, b1); + return svmla_s32_x(pg, sum0, a0, b0); +#endif +} + +// ------------------------------ AESRound / CLMul + +#if defined(__ARM_FEATURE_SVE2_AES) || \ + ((HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128) && \ + HWY_HAVE_RUNTIME_DISPATCH) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { + // It is not clear whether E and MC fuse like they did on NEON. + const svuint8_t zero = svdup_n_u8(0); + return Xor(svaesmc_u8(svaese_u8(state, zero)), round_key); +} + +HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { + return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); +} + +HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { + return svpmullb_pair(a, b); +} + +HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { + return svpmullt_pair(a, b); +} + +#endif // __ARM_FEATURE_SVE2_AES + +// ------------------------------ Lt128 + +namespace detail { +#define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ + return sv##OP##_b##BITS(m, m); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool +#undef HWY_SVE_DUP + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. + const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); + // Duplicate upper lane into lower. + return DupOdd(ltHx); +} +#endif +} // namespace detail + +template +HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Lt128Vec(d, a, b)); +#else + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + const svbool_t ltHL = Lt(a, b); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); + // Duplicate upper lane into lower. + return detail::DupOddB(d, ltHx); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Lt128Upper + +template +HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t ltHL = Lt(a, b); + return detail::DupOddB(d, ltHL); +} + +// ------------------------------ Eq128, Ne128 + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +namespace detail { + +template +HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); + // Duplicate upper and lower. + const svuint64_t eqHH = DupOdd(eqHL); + const svuint64_t eqLL = DupEven(eqHL); + return And(eqLL, eqHH); +} + +template +HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t neHL = VecFromMask(d, Ne(a, b)); + // Duplicate upper and lower. + const svuint64_t neHH = DupOdd(neHL); + const svuint64_t neLL = DupEven(neHL); + return Or(neLL, neHH); +} + +} // namespace detail +#endif + +template +HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Eq128Vec(d, a, b)); +#else + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHL = Eq(a, b); + const svbool_t eqHH = detail::DupOddB(d, eqHL); + const svbool_t eqLL = detail::DupEvenB(d, eqHL); + return And(eqLL, eqHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +template +HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Ne128Vec(d, a, b)); +#else + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t neHL = Ne(a, b); + const svbool_t neHH = detail::DupOddB(d, neHL); + const svbool_t neLL = detail::DupEvenB(d, neHL); + return Or(neLL, neHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Eq128Upper, Ne128Upper + +template +HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t eqHL = Eq(a, b); + return detail::DupOddB(d, eqHL); +} + +template +HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const svbool_t neHL = Ne(a, b); + return detail::DupOddB(d, neHL); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +#else + return IfThenElse(Lt128(d, a, b), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +#else + return IfThenElse(Lt128(d, b, a), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_IF_FLOAT_V +#undef HWY_IF_LANE_SIZE_V +#undef HWY_SVE_ALL_PTRUE +#undef HWY_SVE_D +#undef HWY_SVE_FOREACH +#undef HWY_SVE_FOREACH_F +#undef HWY_SVE_FOREACH_F16 +#undef HWY_SVE_FOREACH_F32 +#undef HWY_SVE_FOREACH_F64 +#undef HWY_SVE_FOREACH_I +#undef HWY_SVE_FOREACH_I08 +#undef HWY_SVE_FOREACH_I16 +#undef HWY_SVE_FOREACH_I32 +#undef HWY_SVE_FOREACH_I64 +#undef HWY_SVE_FOREACH_IF +#undef HWY_SVE_FOREACH_U +#undef HWY_SVE_FOREACH_U08 +#undef HWY_SVE_FOREACH_U16 +#undef HWY_SVE_FOREACH_U32 +#undef HWY_SVE_FOREACH_U64 +#undef HWY_SVE_FOREACH_UI +#undef HWY_SVE_FOREACH_UI08 +#undef HWY_SVE_FOREACH_UI16 +#undef HWY_SVE_FOREACH_UI32 +#undef HWY_SVE_FOREACH_UI64 +#undef HWY_SVE_FOREACH_UIF3264 +#undef HWY_SVE_PTRUE +#undef HWY_SVE_RETV_ARGPV +#undef HWY_SVE_RETV_ARGPVN +#undef HWY_SVE_RETV_ARGPVV +#undef HWY_SVE_RETV_ARGV +#undef HWY_SVE_RETV_ARGVN +#undef HWY_SVE_RETV_ARGVV +#undef HWY_SVE_T +#undef HWY_SVE_UNDEFINED +#undef HWY_SVE_V + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/emu128-inl.h b/media/highway/src/hwy/ops/emu128-inl.h new file mode 100644 index 000000000..5063a6d95 --- /dev/null +++ b/media/highway/src/hwy/ops/emu128-inl.h @@ -0,0 +1,2511 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +using Full128 = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec128 { + HWY_INLINE Vec128() = default; + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h + // relies on this for LoadInterleaved*. CAVEAT: this method of padding + // prevents using range for, especially in SumOfLanes, where it would be + // incorrect. Moving padding to another field would require handling the case + // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. + T raw[16 / sizeof(T)] = {}; +}; + +// 0 or FF..FF, same size as Vec128. +template +struct Mask128 { + using Raw = hwy::MakeUnsigned; + static HWY_INLINE Raw FromBool(bool b) { + return b ? static_cast(~Raw{0}) : 0; + } + + // Must match the size of Vec128. + Raw bits[16 / sizeof(T)] = {}; +}; + +namespace detail { + +// Deduce Simd from Vec128 +struct Deduce128 { + template + Simd operator()(Vec128) const { + return Simd(); + } +}; + +} // namespace detail + +template +using DFromV = decltype(detail::Deduce128()(V())); + +template +using TFromV = TFromD>; + +// ------------------------------ BitCast + +template +HWY_API Vec128 BitCast(Simd /* tag */, Vec128 v) { + Vec128 to; + CopySameSize(&v, &to); + return to; +} + +// ------------------------------ Set + +template +HWY_API Vec128 Zero(Simd /* tag */) { + Vec128 v; + ZeroBytes(v.raw); + return v; +} + +template +using VFromD = decltype(Zero(D())); + +template +HWY_API Vec128 Set(Simd /* tag */, const T2 t) { + Vec128 v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(t); + } + return v; +} + +template +HWY_API Vec128 Undefined(Simd d) { + return Zero(d); +} + +namespace detail { + +template +HWY_INLINE constexpr T IncrementWithWraparound(hwy::FloatTag /*tag*/, T t) { + return t + T{1}; +} + +template +HWY_INLINE constexpr T IncrementWithWraparound(hwy::NonFloatTag /*tag*/, T t) { + using TU = MakeUnsigned; + return static_cast(static_cast(static_cast(t) + TU{1}) & + hwy::LimitsMax()); +} + +} // namespace detail + +template +HWY_API Vec128 Iota(const Simd /* tag */, T2 first) { + Vec128 v; + T counter = static_cast(first); + for (size_t i = 0; i < N; ++i) { + v.raw[i] = counter; + counter = detail::IncrementWithWraparound(hwy::IsFloatTag(), counter); + } + return v; +} + +// ================================================== LOGICAL + +// ------------------------------ Not +template +HWY_API Vec128 Not(const Vec128 v) { + const Simd d; + const RebindToUnsigned du; + using TU = TFromD; + VFromD vu = BitCast(du, v); + for (size_t i = 0; i < N; ++i) { + vu.raw[i] = static_cast(~vu.raw[i]); + } + return BitCast(d, vu); +} + +// ------------------------------ And +template +HWY_API Vec128 And(const Vec128 a, const Vec128 b) { + const Simd d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] &= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +// ------------------------------ AndNot +template +HWY_API Vec128 AndNot(const Vec128 a, const Vec128 b) { + return And(Not(a), b); +} + +// ------------------------------ Or +template +HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { + const Simd d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] |= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +// ------------------------------ Xor +template +HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { + const Simd d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] ^= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(const Vec128 o, const Vec128 a1, + const Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd()), sign)); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[i] < 0 ? T(-1) : T(0); + } + return v; +} + +// ------------------------------ Mask + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 mask) { + Mask128 to; + CopySameSize(&mask, &to); + return to; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + Mask128 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +Vec128 VecFromMask(const Mask128 mask) { + Vec128 v; + CopySameSize(&mask, &v); + return v; +} + +template +Vec128 VecFromMask(Simd /* tag */, const Mask128 mask) { + return VecFromMask(mask); +} + +template +HWY_API Mask128 FirstN(Simd /*tag*/, size_t n) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(i < n); + } + return m; +} + +// Returns mask ? yes : no. +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, const Vec128 no) { + return IfVecThenElse(VecFromMask(mask), yes, no); +} + +template +HWY_API Vec128 IfThenElseZero(const Mask128 mask, + const Vec128 yes) { + return IfVecThenElse(VecFromMask(mask), yes, Zero(Simd())); +} + +template +HWY_API Vec128 IfThenZeroElse(const Mask128 mask, + const Vec128 no) { + return IfVecThenElse(VecFromMask(mask), Zero(Simd()), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[i] < 0 ? yes.raw[i] : no.raw[i]; + } + return v; +} + +template +HWY_API Vec128 ZeroIfNegative(const Vec128 v) { + return IfNegativeThenElse(v, Zero(Simd()), v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec128 ShiftLeft(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) << kBits; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> kBits); + } +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU shifted = static_cast(static_cast(v.raw[i]) >> kBits); + const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast(sign << sign_shift); + v.raw[i] = static_cast(shifted | upper); + } + } else { // T is unsigned + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> kBits); + } + } +#endif + return v; +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template +struct RotateRight { + template + HWY_INLINE Vec128 operator()(const Vec128 v) const { + return Or(ShiftRight(v), ShiftLeft(v)); + } +}; + +template <> +struct RotateRight<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) const { + return v; + } +}; + +} // namespace detail + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight()(v); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) << bits; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits); + } +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU shifted = static_cast(static_cast(v.raw[i]) >> bits); + const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast(sign << sign_shift); + v.raw[i] = static_cast(shifted | upper); + } + } else { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits); // unsigned, logical shift + } + } +#endif + return v; +} + +// ------------------------------ Shl + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) + << bits.raw[i]; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); + } +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU shifted = + static_cast(static_cast(v.raw[i]) >> bits.raw[i]); + const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; + const size_t sign_shift = static_cast( + static_cast(sizeof(TU)) * 8 - 1 - bits.raw[i]); + const TU upper = static_cast(sign << sign_shift); + v.raw[i] = static_cast(shifted | upper); + } + } else { // T is unsigned + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); + } + } +#endif + return v; +} + +// ================================================== ARITHMETIC + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Add(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 + b64) & static_cast(~T(0))); + } + return a; +} +template +HWY_INLINE Vec128 Sub(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 - b64) & static_cast(~T(0))); + } + return a; +} + +template +HWY_INLINE Vec128 Add(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] += b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Sub(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] -= b.raw[i]; + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 operator-(Vec128 a, const Vec128 b) { + return detail::Sub(hwy::IsFloatTag(), a, b); +} +template +HWY_API Vec128 operator+(Vec128 a, const Vec128 b) { + return detail::Add(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ SumsOf8 + +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +// ------------------------------ SaturatedAdd +template +HWY_API Vec128 SaturatedAdd(Vec128 a, const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast( + HWY_MIN(HWY_MAX(hwy::LowestValue(), a.raw[i] + b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ SaturatedSub +template +HWY_API Vec128 SaturatedSub(Vec128 a, const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast( + HWY_MIN(HWY_MAX(hwy::LowestValue(), a.raw[i] - b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ AverageRound +template +HWY_API Vec128 AverageRound(Vec128 a, const Vec128 b) { + static_assert(!IsSigned(), "Only for unsigned"); + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((a.raw[i] + b.raw[i] + 1) / 2); + } + return a; +} + +// ------------------------------ Abs + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Abs(SignedTag /*tag*/, Vec128 a) { + for (size_t i = 0; i < N; ++i) { + const T s = a.raw[i]; + const T min = hwy::LimitsMin(); + a.raw[i] = static_cast((s >= 0 || s == min) ? a.raw[i] : -s); + } + return a; +} + +template +HWY_INLINE Vec128 Abs(hwy::FloatTag /*tag*/, Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = std::abs(v.raw[i]); + } + return v; +} + +} // namespace detail + +template +HWY_API Vec128 Abs(Vec128 a) { + return detail::Abs(hwy::TypeTag(), a); +} + +// ------------------------------ Min/Max + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Min(hwy::NonFloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::NonFloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + return a; +} + +template +HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (std::isnan(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (std::isnan(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (std::isnan(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (std::isnan(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 Min(Vec128 a, const Vec128 b) { + return detail::Min(hwy::IsFloatTag(), a, b); +} + +template +HWY_API Vec128 Max(Vec128 a, const Vec128 b) { + return detail::Max(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API Vec128 Neg(hwy::NonFloatTag /*tag*/, Vec128 v) { + return Zero(Simd()) - v; +} + +template +HWY_API Vec128 Neg(hwy::FloatTag /*tag*/, Vec128 v) { + return Xor(v, SignBit(Simd())); +} + +} // namespace detail + +template +HWY_API Vec128 Neg(Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Mul/Div + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Mul(hwy::FloatTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] *= b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Mul(SignedTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * b.raw[i]); + } + return a; +} + +template +HWY_INLINE Vec128 Mul(UnsignedTag /*tag*/, Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * b.raw[i]); + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 operator*(Vec128 a, const Vec128 b) { + return detail::Mul(hwy::TypeTag(), a, b); +} + +template +HWY_API Vec128 operator/(Vec128 a, const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] /= b.raw[i]; + } + return a; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((int32_t{a.raw[i]} * b.raw[i]) >> 16); + } + return a; +} +template +HWY_API Vec128 MulHigh(Vec128 a, + const Vec128 b) { + for (size_t i = 0; i < N; ++i) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the + // result is the same but this way it is also defined. + a.raw[i] = static_cast( + (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> + 16); + } + return a; +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((2 * a.raw[i] * b.raw[i] + 32768) >> 16); + } + return a; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const int64_t a64 = a.raw[i]; + mul.raw[i / 2] = a64 * b.raw[i]; + } + return mul; +} +template +HWY_API Vec128 MulEven(Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const uint64_t a64 = a.raw[i]; + mul.raw[i / 2] = a64 * b.raw[i]; + } + return mul; +} + +template +HWY_API Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const int64_t a64 = a.raw[i + 1]; + mul.raw[i / 2] = a64 * b.raw[i + 1]; + } + return mul; +} +template +HWY_API Vec128 MulOdd(Vec128 a, + const Vec128 b) { + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const uint64_t a64 = a.raw[i + 1]; + mul.raw[i / 2] = a64 * b.raw[i + 1]; + } + return mul; +} + +template +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The result is arbitrary. + v.raw[i] = (std::abs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; + } + return v; +} + +template +HWY_API Vec128 AbsDiff(Vec128 a, const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec128 MulAdd(Vec128 mul, const Vec128 x, + const Vec128 add) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, const Vec128 x, + const Vec128 add) { + return add - mul * x; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, const Vec128 x, + const Vec128 sub) { + return mul * x - sub; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, const Vec128 x, + const Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + const float half = v.raw[i] * 0.5f; + uint32_t bits; + CopySameSize(&v.raw[i], &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &v.raw[i]); + // One Newton-Raphson iteration + v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); + } + return v; +} + +template +HWY_API Vec128 Sqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = std::sqrt(v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec128 Round(Vec128 v) { + using TI = MakeSigned; + const Vec128 a = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN + continue; + } + const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw[i] + bias); + if (rounded == 0) { + v.raw[i] = v.raw[i] < 0 ? T{-0} : T{0}; + continue; + } + const T rounded_f = static_cast(rounded); + // Round to even + if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + v.raw[i] = static_cast(rounded - (v.raw[i] < T(0) ? -1 : 1)); + continue; + } + v.raw[i] = rounded_f; + } + return v; +} + +// Round-to-nearest even. +template +HWY_API Vec128 NearestInt(const Vec128 v) { + using T = float; + using TI = int32_t; + + const Vec128 abs = Abs(v); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = std::signbit(v.raw[i]); + + if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs.raw[i] <= static_cast(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + ret.raw[i] = static_cast(v.raw[i]); + continue; + } + const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = static_cast(rounded); + // Round to even + if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API Vec128 Trunc(Vec128 v) { + using TI = MakeSigned; + const Vec128 abs = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(abs.raw[i] <= MantissaEnd())) { // Huge or NaN + continue; + } + const TI truncated = static_cast(v.raw[i]); + if (truncated == 0) { + v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; + continue; + } + v.raw[i] = static_cast(truncated); + } + return v; +} + +// Toward +infinity, aka ceiling +template +Vec128 Ceil(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool positive = v.raw[i] > Float(0.0); + + Bits bits; + CopySameSize(&v.raw[i], &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => 0 or 1. + if (exponent < 0) { + v.raw[i] = positive ? Float{1} : Float{-0.0}; + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &v.raw[i]); + } + return v; +} + +// Toward -infinity, aka floor +template +Vec128 Floor(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool negative = v.raw[i] < Float(0.0); + + Bits bits; + CopySameSize(&v.raw[i], &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => -1 or 0. + if (exponent < 0) { + v.raw[i] = negative ? Float(-1.0) : Float(0.0); + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + Mask128 ret; + for (size_t i = 0; i < N; ++i) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned bits; + CopySameSize(&v.raw[i], &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + ret.bits[i] = Mask128::FromBool(bits > ExponentMask()); + } + return ret; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + using VI = VFromD; + using VU = VFromD; + const VU vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VI exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] == b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] != b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] < b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] > b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator<=(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] <= b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>=(const Vec128 a, const Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] >= b.raw[i]); + } + return m; +} + +// ------------------------------ Lt128 + +// Only makes sense for full vectors of u64. +HWY_API Mask128 Lt128(Simd /* tag */, + Vec128 a, const Vec128 b) { + const bool lt = + (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +HWY_API Mask128 Lt128Upper(Simd /* tag */, + Vec128 a, + const Vec128 b) { + const bool lt = a.raw[1] < b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +// ------------------------------ Eq128 + +// Only makes sense for full vectors of u64. +HWY_API Mask128 Eq128(Simd /* tag */, + Vec128 a, const Vec128 b) { + const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +HWY_API Mask128 Ne128(Simd /* tag */, + Vec128 a, const Vec128 b) { + const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +HWY_API Mask128 Eq128Upper(Simd /* tag */, + Vec128 a, + const Vec128 b) { + const bool eq = a.raw[1] == b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +HWY_API Mask128 Ne128Upper(Simd /* tag */, + Vec128 a, + const Vec128 b) { + const bool ne = a.raw[1] != b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +// ------------------------------ Min128, Max128 (Lt128) + +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Simd /* tag */, + const T* HWY_RESTRICT aligned) { + Vec128 v; + CopyBytes(aligned, v.raw); // copy from array + return v; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API Vec128 LoadDup128(Simd d, + const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + CopyBytes(v.raw, aligned); // copy to array +} + +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(const Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + for (size_t i = 0; i < N; ++i) { + if (m.bits[i]) p[i] = v.raw[i]; + } +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +// We implement those here because scalar code is likely faster than emulation +// via shuffles. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, + Vec128& v2) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + alignas(16) T buf2[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, + Vec128& v2, Vec128& v3) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + alignas(16) T buf2[N]; + alignas(16) T buf3[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + buf3[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); + v3 = Load(d, buf3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(const Vec128 v0, const Vec128 v1, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + } +} + +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + } +} + +template +HWY_API void StoreInterleaved4(const Vec128 v0, const Vec128 v1, + const Vec128 v2, const Vec128 v3, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + *unaligned++ = v3.raw[i]; + } +} + +// ------------------------------ Stream + +template +HWY_API void Stream(const Vec128 v, Simd d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, T* base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + for (size_t i = 0; i < N; ++i) { + uint8_t* const base8 = reinterpret_cast(base) + offset.raw[i]; + CopyBytes(&v.raw[i], base8); // copy to bytes + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT base, const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + for (size_t i = 0; i < N; ++i) { + base[index.raw[i]] = v.raw[i]; + } +} + +// ------------------------------ Gather + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, const T* base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + Vec128 v; + for (size_t i = 0; i < N; ++i) { + const uint8_t* base8 = + reinterpret_cast(base) + offset.raw[i]; + CopyBytes(base8, &v.raw[i]); // copy from bytes + } + return v; +} + +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + Vec128 v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = base[index.raw[i]]; + } + return v; +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + Vec128 from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // For bits Y > X, floatX->floatY and intX->intY are always representable. + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + Vec128 from) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // Prevent ubsan errors when converting float to narrower integer/float + if (std::isinf(from.raw[i]) || + std::fabs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() + : HighestValue(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + Vec128 from) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (std::isinf(from.raw[i]) || + std::fabs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue() + : HighestValue(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + Vec128 from) { + static_assert(!IsFloat(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw[i] = + HWY_MIN(HWY_MAX(LimitsMin(), from.raw[i]), LimitsMax()); + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + const Repartition du32; + const Vec128 b_in_lower = ShiftRight<16>(BitCast(du32, b)); + // Avoid OddEven - we want the upper half of `a` even on big-endian systems. + const Vec128 a_mask = Set(du32, 0xFFFF0000); + return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); +} + +template +HWY_API Vec128 ReorderDemote2To(Simd /*d16*/, + Vec128 a, + Vec128 b) { + const int16_t min = LimitsMin(); + const int16_t max = LimitsMax(); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); + } + for (size_t i = 0; i < N; ++i) { + ret.raw[N + i] = static_cast(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); + } + return ret; +} + +namespace detail { + +HWY_INLINE void StoreU16ToF16(const uint16_t val, + hwy::float16_t* HWY_RESTRICT to) { + CopySameSize(&val, to); +} + +HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { + uint16_t bits16; + CopySameSize(from, &bits16); + return bits16; +} + +} // namespace detail + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + const uint16_t bits16 = detail::U16FromF16(&v.raw[i]); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + ret.raw[i] = sign ? -subnormal : subnormal; + continue; + } + + // Normalized: convert the representation directly (faster than + // ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + CopySameSize(&bits32, &ret.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = F32FromBF16(v.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + uint32_t bits32; + CopySameSize(&v.raw[i], &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + if (exp < -24) { + ZeroBytes(&ret.raw[i]); + continue; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + detail::StoreU16ToF16(narrowed, &ret.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = BF16FromF32(v.raw[i]); + } + return ret; +} + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API Vec128 ConvertTo(hwy::FloatTag /*tag*/, + Simd /* tag */, + Vec128 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax in FromT, so use double. + const double f = static_cast(from.raw[i]); + if (std::isinf(from.raw[i]) || + std::fabs(f) > static_cast(LimitsMax())) { + ret.raw[i] = + std::signbit(from.raw[i]) ? LimitsMin() : LimitsMax(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 ConvertTo(hwy::NonFloatTag /*tag*/, + Simd /* tag */, + Vec128 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + // int## -> float##: no check needed + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +} // namespace detail + +template +HWY_API Vec128 ConvertTo(Simd d, Vec128 from) { + return detail::ConvertTo(hwy::IsFloatTag(), d, from); +} + +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + return DemoteTo(Simd(), v); +} + +// ------------------------------ Truncations + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFFFFFFu); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +// ================================================== COMBINE + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; +} + +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return LowerHalf(v); +} + +template +HWY_API Vec128 UpperHalf(Simd /* tag */, + Vec128 v) { + Vec128 ret; + CopyBytes(&v.raw[N / 2], ret.raw); + return ret; +} + +template +HWY_API Vec128 ZeroExtendVector(Simd /* tag */, + Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; +} + +template +HWY_API Vec128 Combine(Simd /* tag */, Vec128 hi_half, + Vec128 lo_half) { + Vec128 ret; + CopyBytes(lo_half.raw, &ret.raw[0]); + CopyBytes(hi_half.raw, &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatLowerLower(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatUpperUpper(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + CopyBytes(&lo.raw[N / 2], &ret.raw[0]); + CopyBytes(&hi.raw[N / 2], &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatLowerUpper(Simd /* tag */, + const Vec128 hi, + const Vec128 lo) { + Vec128 ret; + CopyBytes(&lo.raw[N / 2], &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatUpperLower(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(&hi.raw[N / 2], &ret.raw[N / 2]); + return ret; +} + +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[i] = lo.raw[2 * i]; + } + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[N / 2 + i] = hi.raw[2 * i]; + } + return ret; +} + +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[i] = lo.raw[2 * i + 1]; + } + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[N / 2 + i] = hi.raw[2 * i + 1]; + } + return ret; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Simd /* tag */, V hi, V lo) { + V ret; + const uint8_t* HWY_RESTRICT lo8 = + reinterpret_cast(lo.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(lo8 + kBytes, ret8); + CopyBytes(hi.raw, ret8 + sizeof(T) * N - kBytes); + return ret; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + Vec128 ret; + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + ZeroBytes(ret8); + CopyBytes(v.raw, ret8 + kBytes); + return ret; +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + Vec128 ret; + const uint8_t* HWY_RESTRICT v8 = + reinterpret_cast(v.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(v8 + kBytes, ret8); + ZeroBytes(ret8 + sizeof(T) * N - kBytes); + return ret; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec128 v) { + return v.raw[0]; +} + +template +HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { + v.raw[i] = t; + return v; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + return v.raw[i]; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i + 1] = v.raw[i]; + } + return v; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i] = v.raw[i + 1]; + } + return v; +} + +template +HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { + for (size_t i = 0; i < N; i += 2) { + odd.raw[i] = even.raw[i]; + } + return odd; +} + +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + MakeSigned raw[N]; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + Indices128 ret; + CopyBytes(vec.raw, ret.raw); + return ret; +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + return IndicesFromVec(d, LoadU(Simd(), idx)); +} + +template +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[idx.raw[i]]; + } + return ret; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Simd /* tag */, + const Vec128 v) { + return v; +} + +// ------------------------------ Reverse + +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[N - 1 - i]; + } + return ret; +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; i += 2) { + ret.raw[i + 0] = v.raw[i + 1]; + ret.raw[i + 1] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; i += 4) { + ret.raw[i + 0] = v.raw[i + 3]; + ret.raw[i + 1] = v.raw[i + 2]; + ret.raw[i + 2] = v.raw[i + 1]; + ret.raw[i + 3] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API Vec128 Reverse8(Simd /* tag */, const Vec128 v) { + Vec128 ret; + for (size_t i = 0; i < N; i += 8) { + ret.raw[i + 0] = v.raw[i + 7]; + ret.raw[i + 1] = v.raw[i + 6]; + ret.raw[i + 2] = v.raw[i + 5]; + ret.raw[i + 3] = v.raw[i + 4]; + ret.raw[i + 4] = v.raw[i + 3]; + ret.raw[i + 5] = v.raw[i + 2]; + ret.raw[i + 6] = v.raw[i + 1]; + ret.raw[i + 7] = v.raw[i + 0]; + } + return ret; +} + +// ================================================== BLOCKWISE + +// ------------------------------ Shuffle* + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Reverse2(DFromV(), v); +} + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + Vec128 ret; + ret.raw[3] = v.raw[1]; + ret.raw[2] = v.raw[0]; + ret.raw[1] = v.raw[3]; + ret.raw[0] = v.raw[2]; + return ret; +} +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit"); + return Reverse2(DFromV(), v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[0]; + ret.raw[2] = v.raw[3]; + ret.raw[1] = v.raw[2]; + ret.raw[0] = v.raw[1]; + return ret; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[2]; + ret.raw[2] = v.raw[1]; + ret.raw[1] = v.raw[0]; + ret.raw[0] = v.raw[3]; + return ret; +} + +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Reverse4(DFromV(), v); +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[kLane]; + } + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec128 TableLookupBytes(const Vec128 v, + const Vec128 indices) { + const uint8_t* HWY_RESTRICT v_bytes = + reinterpret_cast(v.raw); + const uint8_t* HWY_RESTRICT idx_bytes = + reinterpret_cast(indices.raw); + Vec128 ret; + uint8_t* HWY_RESTRICT ret_bytes = + reinterpret_cast(ret.raw); + for (size_t i = 0; i < NI * sizeof(TI); ++i) { + const size_t idx = idx_bytes[i]; + // Avoid out of bounds reads. + ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; + } + return ret; +} + +template +HWY_API Vec128 TableLookupBytesOr0(const Vec128 v, + const Vec128 indices) { + // Same as TableLookupBytes, which already returns 0 if out of bounds. + return TableLookupBytes(v, indices); +} + +// ------------------------------ InterleaveLower/InterleaveUpper + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[i]; + ret.raw[2 * i + 1] = b.raw[i]; + } + return ret; +} + +// Additional overload for the optional tag (also for 256/512). +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +template +HWY_API Vec128 InterleaveUpper(Simd /* tag */, + const Vec128 a, + const Vec128 b) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[N / 2 + i]; + ret.raw[2 * i + 1] = b.raw[N / 2 + i]; + } + return ret; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(Simd /* tag */, const Mask128 mask) { + typename Mask128::Raw or_sum = 0; + for (size_t i = 0; i < N; ++i) { + or_sum |= mask.bits[i]; + } + return or_sum == 0; +} + +template +HWY_API bool AllTrue(Simd /* tag */, const Mask128 mask) { + constexpr uint64_t kAll = LimitsMax::Raw>(); + uint64_t and_sum = kAll; + for (size_t i = 0; i < N; ++i) { + and_sum &= mask.bits[i]; + } + return and_sum == kAll; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd /* tag */, + const uint8_t* HWY_RESTRICT bits) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + m.bits[i] = Mask128::FromBool((bits[idx_byte] & bit) != 0); + } + return m; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Simd /* tag */, const Mask128 mask, + uint8_t* bits) { + bits[0] = 0; + if (N > 8) bits[1] = 0; // N <= 16, so max two bytes + for (size_t i = 0; i < N; ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + if (mask.bits[i]) { + bits[idx_byte] = static_cast(bits[idx_byte] | bit); + } + } + return N > 8 ? 2 : 1; +} + +template +HWY_API size_t CountTrue(Simd /* tag */, const Mask128 mask) { + size_t count = 0; + for (size_t i = 0; i < N; ++i) { + count += mask.bits[i] != 0; + } + return count; +} + +template +HWY_API size_t FindKnownFirstTrue(Simd /* tag */, + const Mask128 mask) { + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i] != 0) return i; + } + HWY_DASSERT(false); + return 0; +} + +template +HWY_API intptr_t FindFirstTrue(Simd /* tag */, + const Mask128 mask) { + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i] != 0) return static_cast(i); + } + return intptr_t{-1}; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressNot +template +HWY_API Vec128 CompressNot(Vec128 v, const Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Simd(), bits)); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + size_t count = 0; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + unaligned[count++] = v.raw[i]; + } + } + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec128 v, const Mask128 mask, + Simd d, + T* HWY_RESTRICT unaligned) { + return CompressStore(v, mask, d, unaligned); +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + const Mask128 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + const Rebind dbf16; + // Avoid ZipLower/Upper so this also works on big-endian systems. + const Vec128 a0 = PromoteTo(df32, LowerHalf(dbf16, a)); + const Vec128 a1 = PromoteTo(df32, UpperHalf(dbf16, a)); + const Vec128 b0 = PromoteTo(df32, LowerHalf(dbf16, b)); + const Vec128 b1 = PromoteTo(df32, UpperHalf(dbf16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd d32, Vec128 a, Vec128 b, + const Vec128 sum0, Vec128& sum1) { + const Rebind d16; + // Avoid ZipLower/Upper so this also works on big-endian systems. + const Vec128 a0 = PromoteTo(d32, LowerHalf(d16, a)); + const Vec128 a1 = PromoteTo(d32, UpperHalf(d16, a)); + const Vec128 b0 = PromoteTo(d32, LowerHalf(d16, b)); + const Vec128 b1 = PromoteTo(d32, UpperHalf(d16, b)); + sum1 = MulAdd(BitCast(d32, a1), BitCast(d32, b1), sum1); + return MulAdd(BitCast(d32, a0), BitCast(d32, b0), sum0); +} + +// ================================================== REDUCTIONS + +template +HWY_API Vec128 SumOfLanes(Simd d, const Vec128 v) { + T sum = T{0}; + for (size_t i = 0; i < N; ++i) { + sum += v.raw[i]; + } + return Set(d, sum); +} +template +HWY_API Vec128 MinOfLanes(Simd d, const Vec128 v) { + T min = HighestValue(); + for (size_t i = 0; i < N; ++i) { + min = HWY_MIN(min, v.raw[i]); + } + return Set(d, min); +} +template +HWY_API Vec128 MaxOfLanes(Simd d, const Vec128 v) { + T max = LowestValue(); + for (size_t i = 0; i < N; ++i) { + max = HWY_MAX(max, v.raw[i]); + } + return Set(d, max); +} + +// ================================================== OPS WITH DEPENDENCIES + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/generic_ops-inl.h b/media/highway/src/hwy/ops/generic_ops-inl.h new file mode 100644 index 000000000..b01c5de0f --- /dev/null +++ b/media/highway/src/hwy/ops/generic_ops-inl.h @@ -0,0 +1,1357 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-independent types/functions defined after target-specific ops. + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The lane type of a vector type, e.g. float for Vec>. +template +using LaneType = decltype(GetLane(V())); + +// Vector type, e.g. Vec128 for CappedTag. Useful as the return +// type of functions that do not take a vector argument, or as an argument type +// if the function only has a template argument for D, or for explicit type +// names instead of auto. This may be a built-in type. +template +using Vec = decltype(Zero(D())); + +// Mask type. Useful as the return type of functions that do not take a mask +// argument, or as an argument type if the function only has a template argument +// for D, or for explicit type names instead of auto. +template +using Mask = decltype(MaskFromVec(Zero(D()))); + +// Returns the closest value to v within [lo, hi]. +template +HWY_API V Clamp(const V v, const V lo, const V hi) { + return Min(Max(lo, v), hi); +} + +// CombineShiftRightBytes (and -Lanes) are not available for the scalar target, +// and RVV has its own implementation of -Lanes. +#if HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV + +template > +HWY_API V CombineShiftRightLanes(D d, const V hi, const V lo) { + constexpr size_t kBytes = kLanes * sizeof(LaneType); + static_assert(kBytes < 16, "Shift count is per-block"); + return CombineShiftRightBytes(d, hi, lo); +} + +#endif + +// Returns lanes with the most significant bit set and all other bits zero. +template +HWY_API Vec SignBit(D d) { + const RebindToUnsigned du; + return BitCast(d, Set(du, SignMask>())); +} + +// Returns quiet NaN. +template +HWY_API Vec NaN(D d) { + const RebindToSigned di; + // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus + // mantissa MSB (to indicate quiet) would be sufficient. + return BitCast(d, Set(di, LimitsMax>())); +} + +// Returns positive infinity. +template +HWY_API Vec Inf(D d) { + const RebindToUnsigned du; + using T = TFromD; + using TU = TFromD; + const TU max_x2 = static_cast(MaxExponentTimes2()); + return BitCast(d, Set(du, max_x2 >> 1)); +} + +// ------------------------------ SafeFillN + +template > +HWY_API void SafeFillN(const size_t num, const T value, D d, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = value; + } +#else + BlendedStore(Set(d, value), FirstN(d, num), d, to); +#endif +} + +// ------------------------------ SafeCopyN + +template > +HWY_API void SafeCopyN(const size_t num, D d, const T* HWY_RESTRICT from, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = from[i]; + } +#else + const Mask mask = FirstN(d, num); + BlendedStore(MaskedLoad(mask, d, from), mask, d, to); +#endif +} + +// "Include guard": skip if native instructions are available. The generic +// implementation is currently shared between x86_* and wasm_*, and is too large +// to duplicate. + +#if (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// ------------------------------ LoadInterleaved2 + +template +HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1) { + const V A = LoadU(d, unaligned + 0 * N); // v1[1] v0[1] v1[0] v0[0] + const V B = LoadU(d, unaligned + 1 * N); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +template +HWY_API void LoadInterleaved2(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void LoadTransposedBlocks3(Simd d, + const T* HWY_RESTRICT unaligned, V& A, V& B, + V& C) { + A = LoadU(d, unaligned + 0 * N); + B = LoadU(d, unaligned + 1 * N); + C = LoadU(d, unaligned + 2 * N); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned du; + // Compact notation so these fit on one line: 12 := v1[2]. + V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 + V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 + V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + alignas(16) constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, 9, 12, 15, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, Z, Z, Z, 2, 5, + 8, 11, 14, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, Z, 1, 4, 7, 10, 13}; + alignas(16) constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, 10, 13, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, Z, Z, 0, 3, 6, + 9, 12, 15, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, Z, 2, 5, 8, 11, 14}; + alignas(16) constexpr uint8_t kIdx_v2A[16] = {2, 5, 8, 11, 14, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2B[16] = {Z, Z, Z, Z, Z, 1, 4, 7, + 10, 13, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, 0, 3, 6, 9, 12, 15}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Or3(v0L, v0M, v0U); + v1 = Or3(v1L, v1M, v1U); + v2 = Or3(v2L, v2M, v2U); +} + +// 8-bit lanes x8 +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned du; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + alignas(16) constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, 1, 4, 7, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, 2, 5}; + alignas(16) constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, 2, 5, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, 0, 3, 6}; + alignas(16) constexpr uint8_t kIdx_v2A[16] = {2, 5, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2B[16] = {Z, Z, 0, 3, 6, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, 1, 4, 7}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Or3(v0L, v0M, v0U); + v1 = Or3(v1L, v1M, v1U); + v2 = Or3(v2L, v2M, v2U); +} + +// 16-bit lanes x8 +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned du; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. Same as above, + // but each element of the array contains two byte indices for a lane. + constexpr uint16_t Z = 0x8080; + alignas(16) constexpr uint16_t kIdx_v0A[8] = {0x0100, 0x0706, 0x0D0C, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v0B[8] = {Z, Z, Z, 0x0302, + 0x0908, 0x0F0E, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v0C[8] = {Z, Z, Z, Z, + Z, Z, 0x0504, 0x0B0A}; + alignas(16) constexpr uint16_t kIdx_v1A[8] = {0x0302, 0x0908, 0x0F0E, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v1B[8] = {Z, Z, Z, 0x0504, + 0x0B0A, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v1C[8] = {Z, Z, Z, Z, + Z, 0x0100, 0x0706, 0x0D0C}; + alignas(16) constexpr uint16_t kIdx_v2A[8] = {0x0504, 0x0B0A, Z, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v2B[8] = {Z, Z, 0x0100, 0x0706, + 0x0D0C, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v2C[8] = {Z, Z, Z, Z, + Z, 0x0302, 0x0908, 0x0F0E}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Or3(v0L, v0M, v0U); + v1 = Or3(v1L, v1M, v1U); + v2 = Or3(v2L, v2M, v2U); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + V A; // v0[1] v2[0] v1[0] v0[0] + V B; // v1[2] v0[2] v2[1] v1[1] + V C; // v2[3] v1[3] v0[3] v2[2] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + + const V vxx_02_03_xx = OddEven(C, B); + v0 = detail::Shuffle1230(A, vxx_02_03_xx); + + // Shuffle2301 takes the upper/lower halves of the output from one input, so + // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use + // OddEven because it may have higher throughput than Shuffle. + const V vxx_xx_10_11 = OddEven(A, B); + const V v12_13_xx_xx = OddEven(B, C); + v1 = detail::Shuffle2301(vxx_xx_10_11, v12_13_xx_xx); + + const V vxx_20_21_xx = OddEven(B, A); + v2 = detail::Shuffle3012(vxx_20_21_xx, C); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + V A; // v1[0] v0[0] + V B; // v0[1] v2[0] + V C; // v2[1] v1[1] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + v0 = OddEven(B, A); + v1 = CombineShiftRightBytes(d, C, A); + v2 = OddEven(C, B); +} + +template +HWY_API void LoadInterleaved3(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +// ------------------------------ LoadInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void LoadTransposedBlocks4(Simd d, + const T* HWY_RESTRICT unaligned, V& A, V& B, + V& C, V& D) { + A = LoadU(d, unaligned + 0 * N); + B = LoadU(d, unaligned + 1 * N); + C = LoadU(d, unaligned + 2 * N); + D = LoadU(d, unaligned + 3 * N); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + const Repartition d64; + using V64 = VFromD; + // 16 lanes per block; the lowest four blocks are at the bottom of A,B,C,D. + // Here int[i] means the four interleaved values of the i-th 4-tuple and + // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). + V A; // int[13..10] int[3..0] + V B; // int[17..14] int[7..4] + V C; // int[1b..18] int[b..8] + V D; // int[1f..1c] int[f..c] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + + // For brevity, the comments only list the lower block (upper = lower + 0x10) + const V v5140 = InterleaveLower(d, A, B); // int[5,1,4,0] + const V vd9c8 = InterleaveLower(d, C, D); // int[d,9,c,8] + const V v7362 = InterleaveUpper(d, A, B); // int[7,3,6,2] + const V vfbea = InterleaveUpper(d, C, D); // int[f,b,e,a] + + const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] + const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] + const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] + const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] + + const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] + const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] + const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] + const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] + + v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); + v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); + v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); + v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + // In the last step, we interleave by half of the block size, which is usually + // 8 bytes but half that for 8-bit x8 vectors. + using TW = hwy::UnsignedFromSize; + const Repartition dw; + using VW = VFromD; + + // (Comments are for 256-bit vectors.) + // 8 lanes per block; the lowest four blocks are at the bottom of A,B,C,D. + V A; // v3210[9]v3210[8] v3210[1]v3210[0] + V B; // v3210[b]v3210[a] v3210[3]v3210[2] + V C; // v3210[d]v3210[c] v3210[5]v3210[4] + V D; // v3210[f]v3210[e] v3210[7]v3210[6] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + + const V va820 = InterleaveLower(d, A, B); // v3210[a,8] v3210[2,0] + const V vec64 = InterleaveLower(d, C, D); // v3210[e,c] v3210[6,4] + const V vb931 = InterleaveUpper(d, A, B); // v3210[b,9] v3210[3,1] + const V vfd75 = InterleaveUpper(d, C, D); // v3210[f,d] v3210[7,5] + + const VW v10_b830 = // v10[b..8] v10[3..0] + BitCast(dw, InterleaveLower(d, va820, vb931)); + const VW v10_fc74 = // v10[f..c] v10[7..4] + BitCast(dw, InterleaveLower(d, vec64, vfd75)); + const VW v32_b830 = // v32[b..8] v32[3..0] + BitCast(dw, InterleaveUpper(d, va820, vb931)); + const VW v32_fc74 = // v32[f..c] v32[7..4] + BitCast(dw, InterleaveUpper(d, vec64, vfd75)); + + v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); + v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); + v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); + v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + V A; // v3210[4] v3210[0] + V B; // v3210[5] v3210[1] + V C; // v3210[6] v3210[2] + V D; // v3210[7] v3210[3] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + const V v10_ev = InterleaveLower(d, A, C); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] + const V v10_od = InterleaveLower(d, B, D); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] + const V v32_ev = InterleaveUpper(d, A, C); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] + const V v32_od = InterleaveUpper(d, B, D); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] + + v0 = InterleaveLower(d, v10_ev, v10_od); + v1 = InterleaveUpper(d, v10_ev, v10_od); + v2 = InterleaveLower(d, v32_ev, v32_od); + v3 = InterleaveUpper(d, v32_ev, v32_od); +} + +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + V A, B, C, D; + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + v0 = InterleaveLower(d, A, C); + v1 = InterleaveUpper(d, A, C); + v2 = InterleaveLower(d, B, D); + v3 = InterleaveUpper(d, B, D); +} + +// Any T x1 +template +HWY_API void LoadInterleaved4(Simd d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void StoreTransposedBlocks2(const V A, const V B, Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); +} + +} // namespace detail + +// >= 128 bit vector +template +HWY_API void StoreInterleaved2(const V v0, const V v1, Simd d, + T* HWY_RESTRICT unaligned) { + const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] + const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[N/2] v0[N/2] + detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +} + +// 64 bits +template +HWY_API void StoreInterleaved2(const Vec64 part0, const Vec64 part1, + Full64 /*tag*/, T* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const auto v10 = InterleaveLower(d_full, v0, v1); + StoreU(v10, d_full, unaligned); +} + +// <= 32 bits +template +HWY_API void StoreInterleaved2(const Vec128 part0, + const Vec128 part1, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const auto v10 = InterleaveLower(d_full, v0, v1); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(v10, d_full, buf); + CopyBytes<2 * N * sizeof(T)>(buf, unaligned); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void StoreTransposedBlocks3(const V A, const V B, const V C, + Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); + StoreU(C, d, unaligned + 2 * N); +} + +} // namespace detail + +// >= 128-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + const auto k5 = Set(du, 5); + const auto k6 = Set(du, 6); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = LoadDup128(du, tbl_v0); + const auto shuf_A1 = LoadDup128(du, tbl_v1); // cannot reuse shuf_A0 (has 5) + const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const V A = BitCast(d, A0 | A1 | A2); + + // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. + const auto shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 + const auto shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B = BitCast(d, B0 | B1 | B2); + + // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] + const auto shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. + const auto shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. + const auto shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A + const auto C0 = TableLookupBytesOr0(v0, shuf_C0); + const auto C1 = TableLookupBytesOr0(v1, shuf_C1); + const auto C2 = TableLookupBytesOr0(v2, shuf_C2); + const V C = BitCast(d, C0 | C1 | C2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const Repartition du8; + const auto k2 = Set(du8, 2 * sizeof(T)); + const auto k3 = Set(du8, 3 * sizeof(T)); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. Note that these are byte + // indices for 16-bit lanes. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A1 = LoadDup128(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const auto shuf_A2 = LoadDup128(du8, tbl_v2); // ..1..0.. + + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); + const V A = BitCast(d, A0 | A1 | A2); + + // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] + const auto shuf_B0 = shuf_A1 + k3; // 5..4..3. + const auto shuf_B1 = shuf_A2 + k3; // ..4..3.. + const auto shuf_B2 = shuf_A0 + k2; // .4..3..2 + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B = BitCast(d, B0 | B1 | B2); + + // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_C0 = shuf_B1 + k3; // ..7..6.. + const auto shuf_C1 = shuf_B2 + k3; // .7..6..5 + const auto shuf_C2 = shuf_B0 + k2; // 7..6..5. + const auto C0 = TableLookupBytesOr0(v0, shuf_C0); + const auto C1 = TableLookupBytesOr0(v1, shuf_C1); + const auto C2 = TableLookupBytesOr0(v2, shuf_C2); + const V C = BitCast(d, C0 | C1 | C2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + + const V v10_v00 = InterleaveLower(d, v0, v1); + const V v01_v20 = OddEven(v0, v2); + // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) + const V A = BitCast( + d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); + + const V v1_321 = ShiftRightLanes<1>(d, v1); + const V v0_32 = ShiftRightLanes<2>(d, v0); + const V v21_v11 = OddEven(v2, v1_321); + const V v12_v02 = OddEven(v1_321, v0_32); + // B: v1[2],v0[2], v2[1],v1[1] + const V B = BitCast( + d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); + + // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. + const V v23_v13 = OddEven(v2, v1_321); + const V v03_v22 = OddEven(v0, v2); + // C: v2[3],v1[3],v0[3], v2[2] + const V C = BitCast( + d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd d, T* HWY_RESTRICT unaligned) { + const V A = InterleaveLower(d, v0, v1); + const V B = OddEven(v0, v2); + const V C = InterleaveUpper(d, v1, v2); + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// 64-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(const Vec64 part0, const Vec64 part1, + const Vec64 part2, Full64 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors for the shuffles and first result. + const Full128 du; + const Full128 d_full; + const auto k5 = Set(du, 5); + const auto k6 = Set(du, 6); + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = Load(du, tbl_v0); + const auto shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) + const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned + 0 * N); + + // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_B0 = shuf_A2 + k6; // ..7..6.. + const auto shuf_B1 = shuf_A0 + k5; // .7..6..5 + const auto shuf_B2 = shuf_A1 + k5; // 7..6..5. + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const Vec64 B{(B0 | B1 | B2).raw}; + StoreU(B, d, unaligned + 1 * N); +} + +// 64-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(const Vec64 part0, const Vec64 part1, + const Vec64 part2, Full64 dh, + T* HWY_RESTRICT unaligned) { + const Full128 d; + const Full128 du8; + constexpr size_t N = 16 / sizeof(T); + const auto k2 = Set(du8, 2 * sizeof(T)); + const auto k3 = Set(du8, 3 * sizeof(T)); + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A1 = Load(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const auto shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); + const Vec128 A = BitCast(d, A0 | A1 | A2); + StoreU(A, d, unaligned + 0 * N); + + // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] + const auto shuf_B0 = shuf_A1 + k3; // ..3. + const auto shuf_B1 = shuf_A2 + k3; // .3.. + const auto shuf_B2 = shuf_A0 + k2; // 3..2 + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const Vec128 B = BitCast(d, B0 | B1 | B2); + StoreU(Vec64{B.raw}, dh, unaligned + 1 * N); +} + +// 64-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(const Vec64 v0, const Vec64 v1, + const Vec64 v2, Full64 d, + T* HWY_RESTRICT unaligned) { + // (same code as 128-bit vector, 64-bit lanes) + constexpr size_t N = 2; + const Vec64 v10_v00 = InterleaveLower(d, v0, v1); + const Vec64 v01_v20 = OddEven(v0, v2); + const Vec64 v21_v11 = InterleaveUpper(d, v1, v2); + StoreU(v10_v00, d, unaligned + 0 * N); + StoreU(v01_v20, d, unaligned + 1 * N); + StoreU(v21_v11, d, unaligned + 2 * N); +} + +// 64-bit lanes are handled by the N=1 case below. + +// <= 32-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(const Vec128 part0, + const Vec128 part1, + const Vec128 part2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du; + const Full128 d_full; + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, + 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = Load(du, tbl_v0); + const auto shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); + const auto shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. + const Vec128 A = BitCast(d_full, A0 | A1 | A2); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// 32-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(const Vec128 part0, + const Vec128 part1, + const Vec128 part2, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 4 / sizeof(T); + // Use full vectors for the shuffles and result. + const Full128 du8; + const Full128 d_full; + + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A2 = // ..1..0.. + Load(du8, tbl_v2); + const auto shuf_A1 = // ...1..0. + CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); + const auto shuf_A0 = // ....1..0 + CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// Single-element vector, any lane size: just store directly +template +HWY_API void StoreInterleaved3(const Vec128 v0, const Vec128 v1, + const Vec128 v2, Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +// ------------------------------ StoreInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_API void StoreTransposedBlocks4(const V A, const V B, const V C, const V D, + Simd d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); + StoreU(C, d, unaligned + 2 * N); + StoreU(D, d, unaligned + 3 * N); +} + +} // namespace detail + +// >= 128-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(const V v0, const V v1, const V v2, const V v3, + Simd d, T* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32L = ZipLower(dw, v2, v3); + const auto v10U = ZipUpper(dw, v0, v1); + const auto v32U = ZipUpper(dw, v2, v3); + // The interleaved vectors are A, B, C, D. + const auto A = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 + const auto B = BitCast(d, InterleaveUpper(dw, v10L, v32L)); + const auto C = BitCast(d, InterleaveLower(dw, v10U, v32U)); + const auto D = BitCast(d, InterleaveUpper(dw, v10U, v32U)); + detail::StoreTransposedBlocks4(A, B, C, D, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved4(const V v0, const V v1, const V v2, const V v3, + Simd d, T* HWY_RESTRICT unaligned) { + // The interleaved vectors are A, B, C, D. + const auto A = InterleaveLower(d, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d, v2, v3); + const auto C = InterleaveUpper(d, v0, v1); + const auto D = InterleaveUpper(d, v2, v3); + detail::StoreTransposedBlocks4(A, B, C, D, d, unaligned); +} + +// 64-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(const Vec64 part0, const Vec64 part1, + const Vec64 part2, const Vec64 part3, + Full64 /*tag*/, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const RepartitionToWide dw; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + const Vec128 v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); + const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); + StoreU(A, d_full, unaligned + 0 * N); + StoreU(B, d_full, unaligned + 1 * N); +} + +// 64-bit vector, 64-bit lane +template +HWY_API void StoreInterleaved4(const Vec64 part0, const Vec64 part1, + const Vec64 part2, const Vec64 part3, + Full64 /*tag*/, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + const Vec128 v3{part3.raw}; + const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d_full, v2, v3); + StoreU(A, d_full, unaligned + 0 * N); + StoreU(B, d_full, unaligned + 1 * N); +} + +// <= 32-bit vectors +template +HWY_API void StoreInterleaved4(const Vec128 part0, + const Vec128 part1, + const Vec128 part2, + const Vec128 part3, Simd /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128 d_full; + const RepartitionToWide dw; + const Vec128 v0{part0.raw}; + const Vec128 v1{part1.raw}; + const Vec128 v2{part2.raw}; + const Vec128 v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(v3210, d_full, buf); + CopyBytes<4 * N * sizeof(T)>(buf, unaligned); +} + +#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED + +// ------------------------------ AESRound + +// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. +#if HWY_TARGET != HWY_SCALAR + +// Define for white-box testing, even if native instructions are available. +namespace detail { + +// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with +// Vector Permute Instructions" and the accompanying assembly language +// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: +// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . +// +// A brute-force 256 byte table lookup can also be made constant-time, and +// possibly competitive on NEON, but this is more performance-portable +// especially for x86 and large vectors. +template // u8 +HWY_INLINE V SubBytes(V state) { + const DFromV du; + const auto mask = Set(du, 0xF); + + // Change polynomial basis to GF(2^4) + { + alignas(16) static constexpr uint8_t basisL[16] = { + 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, + 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA}; + alignas(16) static constexpr uint8_t basisU[16] = { + 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, + 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD}; + const auto sL = And(state, mask); + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto gf4L = TableLookupBytes(LoadDup128(du, basisL), sL); + const auto gf4U = TableLookupBytes(LoadDup128(du, basisU), sU); + state = Xor(gf4L, gf4U); + } + + // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and + // cause TableLookupBytesOr0 to return 0. + alignas(16) static constexpr uint8_t kZetaInv[16] = { + 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3}; + alignas(16) static constexpr uint8_t kInv[16] = { + 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4}; + const auto tbl = LoadDup128(du, kInv); + const auto sL = And(state, mask); // L=low nibble, U=upper + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto sX = Xor(sU, sL); + const auto invL = TableLookupBytes(LoadDup128(du, kZetaInv), sL); + const auto invU = TableLookupBytes(tbl, sU); + const auto invX = TableLookupBytes(tbl, sX); + const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); + const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); + + // Linear skew (cannot bake 0x63 bias into the table because out* indices + // may have the infinity flag set). + alignas(16) static constexpr uint8_t kAffineL[16] = { + 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, + 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15}; + alignas(16) static constexpr uint8_t kAffineU[16] = { + 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, + 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E}; + const auto affL = TableLookupBytesOr0(LoadDup128(du, kAffineL), outL); + const auto affU = TableLookupBytesOr0(LoadDup128(du, kAffineU), outU); + return Xor(Xor(affL, affU), Set(du, 0x63)); +} + +} // namespace detail + +#endif // HWY_TARGET != HWY_SCALAR + +// "Include guard": skip if native AES instructions are available. +#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR + +namespace detail { + +template // u8 +HWY_API V ShiftRows(const V state) { + const DFromV du; + alignas(16) static constexpr uint8_t kShiftRow[16] = { + 0, 5, 10, 15, // transposed: state is column major + 4, 9, 14, 3, // + 8, 13, 2, 7, // + 12, 1, 6, 11}; + const auto shift_row = LoadDup128(du, kShiftRow); + return TableLookupBytes(state, shift_row); +} + +template // u8 +HWY_API V MixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. + // 1 2 3 1 // d are on diagonal, no permutation needed. + // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. + // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). + alignas(16) static constexpr uint8_t k2301[16] = { + 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}; + alignas(16) static constexpr uint8_t k1230[16] = { + 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12}; + const RebindToSigned di; // can only do signed comparisons + const auto msb = Lt(BitCast(di, state), Zero(di)); + const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, 0x1B))); + const auto d = Xor(Add(state, state), overflow); // = state*2 in GF(2^8). + const auto s2301 = TableLookupBytes(state, LoadDup128(du, k2301)); + const auto d_s2301 = Xor(d, s2301); + const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} + const auto t1230_s3012 = TableLookupBytes(t_s2301, LoadDup128(du, k1230)); + return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms +} + +} // namespace detail + +template // u8 +HWY_API V AESRound(V state, const V round_key) { + // Intel docs swap the first two steps, but it does not matter because + // ShiftRows is a permutation and SubBytes is independent of lane index. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = detail::MixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +// Constant-time implementation inspired by +// https://www.bearssl.org/constanttime.html, but about half the cost because we +// use 64x64 multiplies and 128-bit XORs. +template +HWY_API V CLMulLower(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); + auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); + auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); + auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); + m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); + m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); + m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); + m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +template +HWY_API V CLMulUpper(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); + auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); + auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); + auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); + m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); + m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); + m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); + m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +#endif // HWY_NATIVE_AES +#endif // HWY_TARGET != HWY_SCALAR + +// "Include guard": skip if native POPCNT-related instructions are available. +#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +#undef HWY_MIN_POW2_FOR_128 +#if HWY_TARGET == HWY_RVV +#define HWY_MIN_POW2_FOR_128 1 +#else +// All other targets except HWY_SCALAR (which is excluded by HWY_IF_GE128_D) +// guarantee 128 bits anyway. +#define HWY_MIN_POW2_FOR_128 0 +#endif + +// This algorithm requires vectors to be at least 16 bytes, which is the case +// for LMUL >= 2. If not, use the fallback below. +template , HWY_IF_LANE_SIZE_D(D, 1), + HWY_IF_GE128_D(D), HWY_IF_POW2_GE(D, HWY_MIN_POW2_FOR_128)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint8_t>(), "V must be u8"); + const D d; + HWY_ALIGN constexpr uint8_t kLookup[16] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + }; + const auto lo = And(v, Set(d, 0xF)); + const auto hi = ShiftRight<4>(v); + const auto lookup = LoadDup128(d, kLookup); + return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); +} + +// RVV has a specialization that avoids the Set(). +#if HWY_TARGET != HWY_RVV +// Slower fallback for capped vectors. +template , HWY_IF_LANE_SIZE_D(D, 1), + HWY_IF_LT128_D(D)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint8_t>(), "V must be u8"); + const D d; + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, And(ShiftRight<1>(v), Set(d, 0x55))); + v = Add(And(ShiftRight<2>(v), Set(d, 0x33)), And(v, Set(d, 0x33))); + return And(Add(v, ShiftRight<4>(v)), Set(d, 0x0F)); +} +#endif // HWY_TARGET != HWY_RVV + +template , HWY_IF_LANE_SIZE_D(D, 2)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint16_t>(), "V must be u16"); + const D d; + const Repartition d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + return Add(ShiftRight<8>(vals), And(vals, Set(d, 0xFF))); +} + +template , HWY_IF_LANE_SIZE_D(D, 4)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint32_t>(), "V must be u32"); + const D d; + Repartition d16; + auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); + return Add(ShiftRight<16>(vals), And(vals, Set(d, 0xFF))); +} + +#if HWY_HAVE_INTEGER64 +template , HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame, uint64_t>(), "V must be u64"); + const D d; + Repartition d32; + auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); + return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFF))); +} +#endif + +#endif // HWY_NATIVE_POPCNT + +template , HWY_IF_LANE_SIZE_D(D, 8), + HWY_IF_LT128_D(D)> +HWY_API V operator*(V x, V y) { + return Set(D(), GetLane(x) * GetLane(y)); +} + +// "Include guard": skip if native 64-bit mul instructions are available. +#if (defined(HWY_NATIVE_I64MULLO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +template , typename T = LaneType, + HWY_IF_LANE_SIZE(T, 8), HWY_IF_UNSIGNED(T), HWY_IF_GE128_D(D64)> +HWY_API V operator*(V x, V y) { + RepartitionToNarrow d32; + auto x32 = BitCast(d32, x); + auto y32 = BitCast(d32, y); + auto lolo = BitCast(d32, MulEven(x32, y32)); + auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); + auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); + auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); + return BitCast(D64{}, lolo + hi); +} +template , typename T = LaneType, + HWY_IF_LANE_SIZE(T, 8), HWY_IF_SIGNED(T), HWY_IF_GE128_D(DI64)> +HWY_API V operator*(V x, V y) { + RebindToUnsigned du64; + return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); +} + +#endif // HWY_NATIVE_I64MULLO + +// ================================================== Operator wrapper + +// These targets currently cannot define operators and have already defined +// (only) the corresponding functions such as Add. +#if HWY_TARGET != HWY_RVV && HWY_TARGET != HWY_SVE && \ + HWY_TARGET != HWY_SVE2 && HWY_TARGET != HWY_SVE_256 && \ + HWY_TARGET != HWY_SVE2_128 + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Ne(V a, V b) -> decltype(a == b) { + return a != b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +#endif // HWY_TARGET for operators + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/rvv-inl.h b/media/highway/src/hwy/ops/rvv-inl.h new file mode 100644 index 000000000..2a8fb5243 --- /dev/null +++ b/media/highway/src/hwy/ops/rvv-inl.h @@ -0,0 +1,3405 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RISC-V V vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +// Enables the overload if Pow2 is in [min, max]. +#define HWY_RVV_IF_POW2_IN(D, min, max) \ + hwy::EnableIf<(min) <= Pow2(D()) && Pow2(D()) <= (max)>* = nullptr + +template +constexpr size_t MLenFromD(Simd /* tag */) { + // Returns divisor = type bits / LMUL. Folding *8 into the ScaleByPower + // argument enables fractional LMUL < 1. Limit to 64 because that is the + // largest value for which vbool##_t are defined. + return HWY_MIN(64, sizeof(T) * 8 * 8 / detail::ScaleByPower(8, kPow2)); +} + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// For all mask sizes MLEN: (1/Nth of a register, one bit per lane) +// The first two arguments are SEW and SHIFT such that SEW >> SHIFT = MLEN. +#define HWY_RVV_FOREACH_B(X_MACRO, NAME, OP) \ + X_MACRO(64, 0, 64, NAME, OP) \ + X_MACRO(32, 0, 32, NAME, OP) \ + X_MACRO(16, 0, 16, NAME, OP) \ + X_MACRO(8, 0, 8, NAME, OP) \ + X_MACRO(8, 1, 4, NAME, OP) \ + X_MACRO(8, 2, 2, NAME, OP) \ + X_MACRO(8, 3, 1, NAME, OP) + +// For given SEW, iterate over one of LMULS: _TRUNC, _EXT, _ALL. This allows +// reusing type lists such as HWY_RVV_FOREACH_U for _ALL (the usual case) or +// _EXT (for Combine). To achieve this, we HWY_CONCAT with the LMULS suffix. +// +// Precompute SEW/LMUL => MLEN to allow token-pasting the result. For the same +// reason, also pass the double-width and half SEW and LMUL (suffixed D and H, +// respectively). "__" means there is no corresponding LMUL (e.g. LMULD for m8). +// Args: BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP + +// LMULS = _TRUNC: truncatable (not the smallest LMUL) +#define HWY_RVV_FOREACH_08_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. +#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _LE2: <= 2 +#define HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) + +// LMULS = _EXT: not the largest LMUL +#define HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) + +// LMULS = _ALL (2^MinPow2() <= LMUL <= 8) +#define HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// 'Virtual' LMUL. This upholds the Highway guarantee that vectors are at least +// 128 bit and LowerHalf is defined whenever there are at least 2 lanes, even +// though RISC-V LMUL must be at least SEW/64 (notice that this rules out +// LMUL=1/2 for SEW=64). To bridge the gap, we add overloads for kPow2 equal to +// one less than should be supported, with all other parameters (vector type +// etc.) unchanged. For D with the lowest kPow2 ('virtual LMUL'), Lanes() +// returns half of what it usually would. +// +// Notice that we can only add overloads whenever there is a D argument: those +// are unique with respect to non-virtual-LMUL overloads because their kPow2 +// template argument differs. Otherwise, there is no actual vuint64mf2_t, and +// defining another overload with the same LMUL would be an error. Thus we have +// a separate _VIRT category for HWY_RVV_FOREACH*, and the common case is +// _ALL_VIRT (meaning the regular LMUL plus the VIRT overloads), used in most +// functions that take a D. + +#define HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, -1, /*MLEN=*/64, NAME, OP) + +// ALL + VIRT +#define HWY_RVV_FOREACH_08_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// LE2 + VIRT +#define HWY_RVV_FOREACH_08_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// EXT + VIRT +#define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// DEMOTE + VIRT +#define HWY_RVV_FOREACH_08_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// SEW for unsigned: +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, uint, u, NAME, OP) + +// SEW for signed: +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, int, i, NAME, OP) + +// SEW for float: +#if HWY_HAVE_FLOAT16 +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, float, f, NAME, OP) +#else +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#endif +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, float, f, NAME, OP) + +// Commonly used type/SEW groups: +#define HWY_RVV_FOREACH_UI08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) + +// For all combinations of SEW: +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) + +// Commonly used type categories: +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) + +// Assemble types for use in x-macros +#define HWY_RVV_T(BASE, SEW) BASE##SEW##_t +#define HWY_RVV_D(BASE, SEW, N, SHIFT) Simd +#define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##LMUL##_t +#define HWY_RVV_M(MLEN) vbool##MLEN##_t + +} // namespace detail + +// Until we have full intrinsic support for fractional LMUL, mixed-precision +// code can use LMUL 1..8 (adequate unless they need many registers). +#define HWY_SPECIALIZE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using Lane = HWY_RVV_T(BASE, SEW); \ + using type = ScalableTag; \ + }; + +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) +#undef HWY_SPECIALIZE + +// ------------------------------ Lanes + +// WARNING: we want to query VLMAX/sizeof(T), but this actually changes VL! +// vlenb is not exposed through intrinsics and vreadvl is not VLMAX. +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + size_t actual = v##OP##SEW##LMUL(); \ + /* Common case of full vectors: avoid any extra instructions. */ \ + /* actual includes LMUL, so do not shift again. */ \ + if (detail::IsFull(d)) return actual; \ + /* Check for virtual LMUL, e.g. "uint16mf8_t" (not provided by */ \ + /* intrinsics). In this case the actual LMUL is 1/4, so divide by */ \ + /* another factor of two. */ \ + if (detail::ScaleByPower(128 / SEW, SHIFT) == 1) actual >>= 1; \ + return HWY_MIN(actual, N); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL_VIRT) +#undef HWY_RVV_LANES + +template +HWY_API size_t Lanes(Simd /* tag*/) { + return Lanes(Simd()); +} + +// ------------------------------ Common x-macros + +// Last argument to most intrinsics. Use when the op has no d arg of its own, +// which means there is no user-specified cap. +#define HWY_RVV_AVL(SEW, SHIFT) \ + Lanes(ScalableTag()) + +// vector = f(vector), e.g. Not +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, scalar), e.g. detail::AddS +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask) +#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return vm##OP##_m_b##MLEN(m, ~0ull); \ + } + +// ================================================== INIT + +// ------------------------------ Set + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_T(BASE, SEW) arg) { \ + return v##OP##_##CHAR##SEW##LMUL(arg, Lanes(d)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL_VIRT) +#undef HWY_RVV_SET + +// Treat bfloat16_t as uint16_t (using the previously defined Set overloads); +// required for Zero and VFromD. +template +decltype(Set(Simd(), 0)) Set(Simd d, + bfloat16_t arg) { + return Set(RebindToUnsigned(), arg.bits); +} + +template +using VFromD = decltype(Set(D(), TFromD())); + +// ------------------------------ Zero + +template +HWY_API VFromD> Zero(Simd d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +// RVV vundefined is 'poisoned' such that even XORing a _variable_ initialized +// by it gives unpredictable results. It should only be used for maskoff, so +// keep it internal. For the Highway op, just use Zero (single instruction). +namespace detail { +#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) /* tag */) { \ + return v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ + } + +HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined, _ALL) +#undef HWY_RVV_UNDEFINED +} // namespace detail + +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ BitCast + +namespace detail { + +// Halves LMUL. (Use LMUL arg for the source so we can use _TRUNC.) +#define HWY_RVV_TRUNC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +#undef HWY_RVV_TRUNC + +// Doubles LMUL to `d2` (the arg is only necessary for _VIRT). +#define HWY_RVV_EXT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULD) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULD(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) +#undef HWY_RVV_EXT + +// For virtual LMUL e.g. 'uint32mf4_t', the return type should be mf2, which is +// the same as the actual input type. +#define HWY_RVV_EXT_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT_VIRT, Ext, lmul_ext, _VIRT) +#undef HWY_RVV_EXT_VIRT + +// For BitCastToByte, the D arg is only to prevent duplicate definitions caused +// by _ALL_VIRT. + +// There is no reinterpret from u8 <-> u8, so just return. +#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vuint8##LMUL##_t v) { \ + return v; \ + } \ + template \ + HWY_API vuint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v; \ + } + +// For i8, need a single reinterpret (HWY_RVV_CAST_IF does two). +#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vint8##LMUL##_t v) { \ + return vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API vint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ + } + +// Separate u/i because clang only provides signed <-> unsigned reinterpret for +// the same SEW. +#define HWY_RVV_CAST_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ + } + +// Additional versions for virtual LMUL using LMULH for byte vectors. +#define HWY_RVV_CAST_VIRT_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v2); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_VIRT_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v))); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v2)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_VIRT_U, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) + +#undef HWY_RVV_CAST_U8 +#undef HWY_RVV_CAST_I8 +#undef HWY_RVV_CAST_U +#undef HWY_RVV_CAST_IF +#undef HWY_RVV_CAST_VIRT_U +#undef HWY_RVV_CAST_VIRT_IF + +template +HWY_INLINE VFromD> BitCastFromByte( + Simd /* d */, VFromD> v) { + return BitCastFromByte(Simd(), v); +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(d, v)); +} + +namespace detail { + +template >> +HWY_INLINE VFromD BitCastToUnsigned(V v) { + return BitCast(DU(), v); +} + +} // namespace detail + +// ------------------------------ Iota + +namespace detail { + +#define HWY_RVV_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + return v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v, _ALL_VIRT) +#undef HWY_RVV_IOTA + +template > +HWY_INLINE VFromD Iota0(const D /*d*/) { + return BitCastToUnsigned(Iota0(DU())); +} + +} // namespace detail + +// ================================================== LOGICAL + +// ------------------------------ Not + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not, _ALL) + +template +HWY_API V Not(const V v) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Not(BitCast(DU(), v))); +} + +// ------------------------------ And + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and, _ALL) + +template +HWY_API V And(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), And(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Or + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or, _ALL) + +template +HWY_API V Or(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Or(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Xor + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor, _ALL) + +template +HWY_API V Xor(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Xor(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ AndNot + +template +HWY_API V AndNot(const V not_a, const V b) { + return And(Not(not_a), b); +} + +// ------------------------------ Or3 + +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ CopySign + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj, _ALL) + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + // RVV can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, ReverseSubS, rsub_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, ReverseSubS, frsub_vf, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd, _ALL) + +// ------------------------------ Sub +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) + +// ------------------------------ SaturatedAdd + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) + +// ------------------------------ SaturatedSub + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) + +// ------------------------------ AverageRound + +// TODO(janwas): check vxrm rounding mode +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) + +// ------------------------------ ShiftLeft[Same] + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(v, kBits, HWY_RVV_AVL(SEW, SHIFT)); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(v, static_cast(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll, _ALL) + +// ------------------------------ ShiftRight[Same] + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) + +#undef HWY_RVV_SHIFT + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API VFromD>> SumsOf8(const VU8 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = detail::AndS(BitCast(du16, v), 0xFF); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); +} + +// ------------------------------ RotateRight +template +HWY_API V RotateRight(const V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shl +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, bits, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll, _ALL) + +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, detail::BitCastToUnsigned(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll, _ALL) + +// ------------------------------ Shr + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) + +#undef HWY_RVV_SHIFT_II +#undef HWY_RVV_SHIFT_VV + +// ------------------------------ Min + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin, _ALL) + +// ------------------------------ Max + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax, _ALL) + +// ------------------------------ Mul + +HWY_RVV_FOREACH_UI163264(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// ------------------------------ MulHigh + +// Only for internal use (Highway only promises MulHigh for 16-bit inputs). +// Used by MulEven; vwmul does not work for m8. +namespace detail { +HWY_RVV_FOREACH_I32(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) +HWY_RVV_FOREACH_U32(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) + +// ------------------------------ MulFixedPoint15 +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulFixedPoint15, smul, _ALL) + +// ------------------------------ Div +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) + +// ------------------------------ ApproximateReciprocal +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7, _ALL) + +// ------------------------------ Sqrt +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt, _ALL) + +// ------------------------------ ApproximateReciprocalSqrt +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7, _ALL) + +// ------------------------------ MulAdd +// Note: op is still named vv, not vvv. +#define HWY_RVV_FMA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ + HWY_RVV_V(BASE, SEW, LMUL) add) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(add, mul, x, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc, _ALL) + +// ------------------------------ NegMulAdd +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac, _ALL) + +// ------------------------------ MulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac, _ALL) + +// ------------------------------ NegMulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) + +#undef HWY_RVV_FMA + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in +// vboolXX_t is a power of two divisor for vector bits. SLEN 8 / LMUL 1 = 1/8th +// of all bits; SLEN 8 / LMUL 4 = half of all bits. + +// mask = f(vector, vector) +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN(a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(vector, scalar) +#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL##_b##MLEN(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// ------------------------------ Eq +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL) +} // namespace detail + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL) +} // namespace detail + +// ------------------------------ Lt +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, LtS, msltu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL) +} // namespace detail + +// ------------------------------ Le +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL) + +#undef HWY_RVV_RETM_ARGVV +#undef HWY_RVV_RETM_ARGVS + +// ------------------------------ Gt/Ge + +template +HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) { + return Le(b, a); +} + +template +HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { + return Lt(b, a); +} + +// ------------------------------ TestBit +template +HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { + return detail::NeS(And(a, bit), 0); +} + +// ------------------------------ Not +// NOLINTNEXTLINE +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, Not, not ) + +// ------------------------------ And + +// mask = f(mask_a, mask_b) (note arg2,arg1 order!) +#define HWY_RVV_RETM_ARGMM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) a, HWY_RVV_M(MLEN) b) { \ + return vm##OP##_mm_b##MLEN(b, a, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, And, and) + +// ------------------------------ AndNot +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, AndNot, andn) + +// ------------------------------ Or +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Or, or) + +// ------------------------------ Xor +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) + +// ------------------------------ ExclusiveNeither +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, ExclusiveNeither, xnor) + +#undef HWY_RVV_RETM_ARGMM + +// ------------------------------ IfThenElse +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ + HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_vvm_##CHAR##SEW##LMUL(m, no, yes, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge, _ALL) + +#undef HWY_RVV_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template +HWY_API V IfThenElseZero(const M mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse + +#define HWY_RVV_IF_THEN_ZERO_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_##CHAR##SEW##LMUL(m, no, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, merge_vxm, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) + +#undef HWY_RVV_IF_THEN_ZERO_ELSE + +// ------------------------------ MaskFromVec + +template +HWY_API auto MaskFromVec(const V v) -> decltype(Eq(v, v)) { + return detail::NeS(v, 0); +} + +template +using MFromD = decltype(MaskFromVec(Zero(D()))); + +template +HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { + // No need to check lane size/LMUL are the same: if not, casting MFrom to + // MFromD would fail. + return mask; +} + +// ------------------------------ VecFromMask + +namespace detail { +#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_M(MLEN) m) { \ + return v##OP##_##CHAR##SEW##LMUL##_m(m, v0, v0, 1, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, SubS, sub_vx, _ALL) +#undef HWY_RVV_VEC_FROM_MASK +} // namespace detail + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return detail::SubS(Zero(d), mask); +} + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec) + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ ZeroIfNegative +template +HWY_API V ZeroIfNegative(const V v) { + return IfThenZeroElse(detail::LtS(v, 0), v); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + MFromD m = + MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ FindFirstTrue + +#define HWY_RVV_FIND_FIRST_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API intptr_t FindFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return vfirst_m_b##MLEN(m, Lanes(d)); \ + } \ + template \ + HWY_API size_t FindKnownFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return static_cast(vfirst_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_FIND_FIRST_TRUE, , _) +#undef HWY_RVV_FIND_FIRST_TRUE + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, MFromD m) { + return FindFirstTrue(d, m) < 0; +} + +// ------------------------------ AllTrue + +#define HWY_RVV_ALL_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API bool AllTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return AllFalse(d, vmnot_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) +#undef HWY_RVV_ALL_TRUE + +// ------------------------------ CountTrue + +#define HWY_RVV_COUNT_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t CountTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return vcpop_m_b##MLEN(m, Lanes(d)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) +#undef HWY_RVV_COUNT_TRUE + +// ================================================== MEMORY + +// ------------------------------ Load + +#define HWY_RVV_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL_VIRT) +#undef HWY_RVV_LOAD + +// There is no native BF16, treat as uint16_t. +template +HWY_API VFromD> Load( + Simd d, const bfloat16_t* HWY_RESTRICT p) { + return Load(RebindToUnsigned(), + reinterpret_cast(p)); +} + +template +HWY_API void Store(VFromD> v, + Simd d, bfloat16_t* HWY_RESTRICT p) { + Store(v, RebindToUnsigned(), + reinterpret_cast(p)); +} + +// ------------------------------ LoadU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ MaskedLoad + +#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, Zero(d), p, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) +#undef HWY_RVV_MASKED_LOAD + +// ------------------------------ Store + +#define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) +#undef HWY_RVV_STORE + +// ------------------------------ BlendedStore + +#define HWY_RVV_BLENDED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, p, v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_BLENDED_STORE, BlendedStore, se, _ALL_VIRT) +#undef HWY_RVV_BLENDED_STORE + +namespace detail { + +#define HWY_RVV_STOREN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, count); \ + } +HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) +#undef HWY_RVV_STOREN + +} // namespace detail + +// ------------------------------ StoreU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ Stream +template +HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ ScatterOffset + +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL_VIRT) +#undef HWY_RVV_SCATTER + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + const VFromD> index) { + return ScatterOffset(v, d, base, ShiftLeft<2>(index)); +} + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + const VFromD> index) { + return ScatterOffset(v, d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ GatherOffset + +#define HWY_RVV_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL_VIRT) +#undef HWY_RVV_GATHER + +// ------------------------------ GatherIndex + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + return GatherOffset(d, base, ShiftLeft<2>(index)); +} + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + return GatherOffset(d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_RVV_LOAD2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, unaligned, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD2, LoadInterleaved2, lseg2, _LE2_VIRT) +#undef HWY_RVV_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_RVV_LOAD3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, &v2, unaligned, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD3, LoadInterleaved3, lseg3, _LE2_VIRT) +#undef HWY_RVV_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_RVV_LOAD4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2, HWY_RVV_V(BASE, SEW, LMUL) & v3) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, &v2, &v3, aligned, \ + Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD4, LoadInterleaved4, lseg4, _LE2_VIRT) +#undef HWY_RVV_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_RVV_STORE2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(unaligned, v0, v1, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE2, StoreInterleaved2, sseg2, _LE2_VIRT) +#undef HWY_RVV_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(unaligned, v0, v1, v2, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE3, StoreInterleaved3, sseg3, _LE2_VIRT) +#undef HWY_RVV_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_RVV_STORE4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(aligned, v0, v1, v2, v3, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE4, StoreInterleaved4, sseg4, _LE2_VIRT) +#undef HWY_RVV_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// SEW is for the input so we can use F16 (no-op if not supported). +#define HWY_RVV_PROMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWD##LMULD(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I08(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_F16(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT_VIRT) +HWY_RVV_FOREACH_F32(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT_VIRT) +#undef HWY_RVV_PROMOTE + +// The above X-macro cannot handle 4x promotion nor type switching. +// TODO(janwas): use BASE2 arg to allow the latter. +#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN, \ + SHIFT, ADD) \ + template \ + HWY_API HWY_RVV_V(BASE, BITS, LMUL) \ + PromoteTo(HWY_RVV_D(BASE, BITS, N, SHIFT + ADD) d, \ + HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ + return OP##CHAR##BITS##LMUL(v, Lanes(d)); \ + } + +#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -2, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1, 0, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2, 1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4, 2, 1) + +#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, mf2, mf8, -3, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4, -2, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2, -1, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1, 0, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2, 1, 2) + +HWY_RVV_PROMOTE_X4(vzext_vf4_, uint, u, 32, uint, 8) +HWY_RVV_PROMOTE_X4(vsext_vf4_, int, i, 32, int, 8) + +// i32 to f64 +HWY_RVV_PROMOTE_X2(vfwcvt_f_x_v_, float, f, 64, int, 32) + +#undef HWY_RVV_PROMOTE_X4 +#undef HWY_RVV_PROMOTE_X2 +#undef HWY_RVV_PROMOTE + +// Unsigned to signed: cast for unsigned promote. +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + const RebindToSigned di32; + const Rebind du16; + return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ DemoteTo U + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 0, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME##Shr16( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + +// Unsigned -> unsigned (also used for bf16) +namespace detail { +HWY_RVV_FOREACH_U16(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE_VIRT) +} // namespace detail + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_I_TO_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(uint, SEWH, LMULH) NAME( \ + HWY_RVV_D(uint, SEWH, N, SHIFT - 1) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + /* First clamp negative numbers to zero to match x86 packus. */ \ + return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); \ + } +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_I_TO_U + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return vnclipu_wx_u8mf8(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return vnclipu_wx_u8mf4(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return vnclipu_wx_u8mf2(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return vnclipu_wx_u8m1(DemoteTo(Simd(), v), 0, Lanes(d)); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return vnclipu_wx_u8m2(DemoteTo(Simd(), v), 0, Lanes(d)); +} + +HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf8(vnclipu_wx_u16mf4(v, 0, avl), 0, avl); +} +HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf4(vnclipu_wx_u16mf2(v, 0, avl), 0, avl); +} +HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8mf2(vnclipu_wx_u16m1(v, 0, avl), 0, avl); +} +HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8m1(vnclipu_wx_u16m2(v, 0, avl), 0, avl); +} +HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) { + const size_t avl = Lanes(ScalableTag()); + return vnclipu_wx_u8m2(vnclipu_wx_u16m4(v, 0, avl), 0, avl); +} + +// ------------------------------ Truncations + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFF, avl); + const vuint32mf2_t v2 = vnclipu_wx_u32mf2(v1, 0, avl); + const vuint16mf4_t v3 = vnclipu_wx_u16mf4(v2, 0, avl); + return vnclipu_wx_u8mf8(v3, 0, avl); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFF, avl); + const vuint32m1_t v2 = vnclipu_wx_u32m1(v1, 0, avl); + const vuint16mf2_t v3 = vnclipu_wx_u16mf2(v2, 0, avl); + return vnclipu_wx_u8mf4(v3, 0, avl); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFF, avl); + const vuint32m2_t v2 = vnclipu_wx_u32m2(v1, 0, avl); + const vuint16m1_t v3 = vnclipu_wx_u16m1(v2, 0, avl); + return vnclipu_wx_u8mf2(v3, 0, avl); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFF, avl); + const vuint32m4_t v2 = vnclipu_wx_u32m4(v1, 0, avl); + const vuint16m2_t v3 = vnclipu_wx_u16m2(v2, 0, avl); + return vnclipu_wx_u8m1(v3, 0, avl); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = vnclipu_wx_u32mf2(v1, 0, avl); + return vnclipu_wx_u16mf4(v2, 0, avl); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFFFF, avl); + const vuint32m1_t v2 = vnclipu_wx_u32m1(v1, 0, avl); + return vnclipu_wx_u16mf2(v2, 0, avl); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFFFF, avl); + const vuint32m2_t v2 = vnclipu_wx_u32m2(v1, 0, avl); + return vnclipu_wx_u16m1(v2, 0, avl); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFFFF, avl); + const vuint32m4_t v2 = vnclipu_wx_u32m4(v1, 0, avl); + return vnclipu_wx_u16m2(v2, 0, avl); +} + +template +HWY_API vuint32mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32mf2(v1, 0, avl); +} + +template +HWY_API vuint32m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m1(v1, 0, avl); +} + +template +HWY_API vuint32m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m2(v1, 0, avl); +} + +template +HWY_API vuint32m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m4(v1, 0, avl); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = vand(v, 0xFF, avl); + const vuint16mf4_t v2 = vnclipu_wx_u16mf4(v1, 0, avl); + return vnclipu_wx_u8mf8(v2, 0, avl); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = vand(v, 0xFF, avl); + const vuint16mf2_t v2 = vnclipu_wx_u16mf2(v1, 0, avl); + return vnclipu_wx_u8mf4(v2, 0, avl); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = vand(v, 0xFF, avl); + const vuint16m1_t v2 = vnclipu_wx_u16m1(v1, 0, avl); + return vnclipu_wx_u8mf2(v2, 0, avl); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = vand(v, 0xFF, avl); + const vuint16m2_t v2 = vnclipu_wx_u16m2(v1, 0, avl); + return vnclipu_wx_u8m1(v2, 0, avl); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = vand(v, 0xFF, avl); + const vuint16m4_t v2 = vnclipu_wx_u16m4(v1, 0, avl); + return vnclipu_wx_u8m2(v2, 0, avl); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16mf4(v1, 0, avl); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16mf2(v1, 0, avl); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m1(v1, 0, avl); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m2(v1, 0, avl); +} + +template +HWY_API vuint16m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m4(v1, 0, avl); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf4_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf8(v1, 0, avl); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf2_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf4(v1, 0, avl); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m1_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf2(v1, 0, avl); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m2_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m1(v1, 0, avl); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m4_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m2(v1, 0, avl); +} + +template +HWY_API vuint8m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m8_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m4(v1, 0, avl); +} + +// ------------------------------ DemoteTo I + +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE_VIRT) + +template +HWY_API vint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} + +#undef HWY_RVV_DEMOTE + +// ------------------------------ DemoteTo F + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_F(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##SEWH##LMULH(v, Lanes(d)); \ + } + +#if HWY_HAVE_FLOAT16 +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, + _DEMOTE_VIRT) +#endif +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, + _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_F + +// TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { + return vfncvt_rtz_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { + return vfncvt_rtz_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { + return vfncvt_rtz_x_f_w_i32m4(v, Lanes(d)); +} + +template +HWY_API VFromD> DemoteTo( + Simd d, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du32; + return detail::DemoteToShr16(du16, BitCast(du32, v)); +} + +// ------------------------------ ConvertTo F + +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + return vfcvt_f_x_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(uint, SEW, LMUL) v) {\ + return vfcvt_f_xu_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(int, SEW, N, SHIFT) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_rtz_x_f_v_i##SEW##LMUL(v, Lanes(d)); \ + } \ +// API only requires f32 but we provide f64 for internal use. +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) +#undef HWY_RVV_CONVERT + +// Uses default rounding mode. Must be separate because there is no D arg. +#define HWY_RVV_NEAREST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) NearestInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_x_f_v_i##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_NEAREST, _, _, _ALL) +#undef HWY_RVV_NEAREST + +// ================================================== COMBINE + +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +size_t LanesPerBlock(Simd d) { + size_t lpb = 16 / sizeof(T); + if (IsFull(d)) return lpb; + // Also honor the user-specified (constexpr) N limit. + lpb = HWY_MIN(lpb, N); + // No fraction, we're done. + if (kPow2 >= 0) return lpb; + // Fractional LMUL: Lanes(d) may be smaller than lpb, so honor that. + return HWY_MIN(lpb, Lanes(d)); +} + +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return AndS(iota0, static_cast(~(LanesPerBlock(d) - 1))); +} + +template +HWY_INLINE MFromD FirstNPerBlock(D /* tag */) { + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; + const auto idx_mod = AndS(Iota0(du), static_cast(LanesPerBlock(du) - 1)); + return LtS(BitCast(di, idx_mod), static_cast>(kLanes)); +} + +// vector = f(vector, vector, size_t) +#define HWY_RVV_SLIDE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dst, HWY_RVV_V(BASE, SEW, LMUL) src, \ + size_t lanes) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(dst, src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideUp, slideup, _ALL) +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideDown, slidedown, _ALL) + +#undef HWY_RVV_SLIDE + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { + return detail::SlideUp(lo, hi, Lanes(d) / 2); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { + // Move upper half into lower + const auto lo_down = detail::SlideDown(lo, lo, Lanes(d) / 2); + return ConcatUpperLower(d, hi, lo_down); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { + // Move half of both inputs to the other half + const auto hi_up = detail::SlideUp(hi, hi, Lanes(d) / 2); + const auto lo_down = detail::SlideDown(lo, lo, Lanes(d) / 2); + return ConcatUpperLower(d, hi_up, lo_down); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { + return detail::SlideUp(detail::Ext(d2, lo), detail::Ext(d2, hi), + Lanes(d2) / 2); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API VFromD ZeroExtendVector(D2 d2, const V lo) { + return Combine(d2, Xor(lo, lo), lo); +} + +// ------------------------------ Lower/UpperHalf + +namespace detail { + +// RVV may only support LMUL >= SEW/64; returns whether that holds for D. Note +// that SEW = sizeof(T)*8 and LMUL = 1 << Pow2(). +template +constexpr bool IsSupportedLMUL(D d) { + return (size_t{1} << (Pow2(d) + 3)) >= sizeof(TFromD); +} + +} // namespace detail + +// If IsSupportedLMUL, just 'truncate' i.e. halve LMUL. +template * = nullptr> +HWY_API VFromD LowerHalf(const DH /* tag */, const VFromD> v) { + return detail::Trunc(v); +} + +// Otherwise, there is no corresponding intrinsic type (e.g. vuint64mf2_t), and +// the hardware may set "vill" if we attempt such an LMUL. However, the V +// extension on application processors requires Zvl128b, i.e. VLEN >= 128, so it +// still makes sense to have half of an SEW=64 vector. We instead just return +// the vector, and rely on the kPow2 in DH to halve the return value of Lanes(). +template * = nullptr> +HWY_API V LowerHalf(const DH /* tag */, const V v) { + return v; +} + +// Same, but without D arg +template +HWY_API VFromD>> LowerHalf(const V v) { + return LowerHalf(Half>(), v); +} + +template +HWY_API VFromD UpperHalf(const DH d2, const VFromD> v) { + return LowerHalf(d2, detail::SlideDown(v, v, Lanes(d2))); +} + +// ================================================== SWIZZLE + +namespace detail { +// Special instruction for 1 lane is presumably faster? +#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Up, slide1up_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Up, fslide1up_vf, _ALL) +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Down, slide1down_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Down, fslide1down_vf, _ALL) +#undef HWY_RVV_SLIDE1 +} // namespace detail + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f, _ALL) +#undef HWY_RVV_GET_LANE + +// ------------------------------ ExtractLane +template +HWY_API TFromV ExtractLane(const V v, size_t i) { + return GetLane(detail::SlideDown(v, v, i)); +} + +// ------------------------------ InsertLane + +template +HWY_API V InsertLane(const V v, size_t i, TFromV t) { + const DFromV d; + const RebindToUnsigned du; // Iota0 is unsigned only + using TU = TFromD; + const auto is_i = detail::EqS(detail::Iota0(du), static_cast(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +namespace detail { +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetOnlyFirst, sof) +} // namespace detail + +// For 8-bit lanes, Iota0 might overflow. +template +HWY_API V InsertLane(const V v, size_t i, TFromV t) { + const DFromV d; + const auto zero = Zero(d); + const auto one = Set(d, 1); + const auto ge_i = Eq(detail::SlideUp(zero, one, i), one); + const auto is_i = detail::SetOnlyFirst(ge_i); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ OddEven +template +HWY_API V OddEven(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + const auto is_even = detail::EqS(detail::AndS(detail::Iota0(du), 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ DupEven (OddEven) +template +HWY_API V DupEven(const V v) { + const V up = detail::Slide1Up(v); + return OddEven(up, v); +} + +// ------------------------------ DupOdd (OddEven) +template +HWY_API V DupOdd(const V v) { + const V down = detail::Slide1Down(v); + return OddEven(v, down); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV)); + const auto idx_block = ShiftRight(detail::Iota0(du)); + const auto is_even = detail::EqS(detail::AndS(idx_block, 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; + const size_t lpb = detail::LanesPerBlock(d); + const V down = detail::SlideDown(v, v, lpb); + const V up = detail::SlideUp(v, v, lpb); + return OddEvenBlocks(up, down); +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + static_assert(sizeof(TFromD) == sizeof(TFromV), "Index != lane"); + const RebindToUnsigned du; // instead of : avoids unused d. + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllTrue(du, detail::LtS(indices, Lanes(d)))); +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +// <32bit are not part of Highway API, but used in Broadcast. This limits VLMAX +// to 2048! We could instead use vrgatherei16. +#define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, idx, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) +#undef HWY_RVV_TABLE + +// ------------------------------ ConcatOdd (TableLookupLanes) +template +HWY_API V ConcatOdd(D d, const V hi, const V lo) { + const RebindToUnsigned du; // Iota0 is unsigned only + const auto iota = detail::Iota0(du); + const auto idx = detail::AddS(Add(iota, iota), 1); + const auto lo_odd = TableLookupLanes(lo, idx); + const auto hi_odd = TableLookupLanes(hi, idx); + return detail::SlideUp(lo_odd, hi_odd, Lanes(d) / 2); +} + +// ------------------------------ ConcatEven (TableLookupLanes) +template +HWY_API V ConcatEven(D d, const V hi, const V lo) { + const RebindToUnsigned du; // Iota0 is unsigned only + const auto iota = detail::Iota0(du); + const auto idx = Add(iota, iota); + const auto lo_even = TableLookupLanes(lo, idx); + const auto hi_even = TableLookupLanes(hi, idx); + return detail::SlideUp(lo_even, hi_even, Lanes(d) / 2); +} + +// ------------------------------ Reverse (TableLookupLanes) +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(du); + const auto idx = + detail::ReverseSubS(detail::Iota0(du), static_cast(N - 1)); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse2 (RotateRight, OddEven) + +// Shifting and adding requires fewer instructions than blending, but casting to +// u32 only works for LMUL in [1/2, 8]. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} +// For LMUL < 1/2, we can extend and then truncate. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Twice d2; + const Twice d4; + const Repartition du32; + const auto vx = detail::Ext(d4, detail::Ext(d2, v)); + const auto rx = BitCast(d4, RotateRight<16>(BitCast(du32, vx))); + return detail::Trunc(detail::Trunc(rx)); +} + +// Shifting and adding requires fewer instructions than blending, but casting to +// u64 does not work for LMUL < 1. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Repartition du64; + return BitCast(d, RotateRight<32>(BitCast(du64, v))); +} + +// For fractions, we can extend and then truncate. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const Twice d2; + const Twice d4; + const Repartition du64; + const auto vx = detail::Ext(d4, detail::Ext(d2, v)); + const auto rx = BitCast(d4, RotateRight<32>(BitCast(du64, vx))); + return detail::Trunc(detail::Trunc(rx)); +} + +template , HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API V Reverse2(D /* tag */, const V v) { + const V up = detail::Slide1Up(v); + const V down = detail::Slide1Down(v); + return OddEven(up, down); +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 3); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 7); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { + const Repartition du64; + const size_t N = Lanes(du64); + const auto rev = + detail::ReverseSubS(detail::Iota0(du64), static_cast(N - 1)); + // Swap lo/hi u64 within each block + const auto idx = detail::XorS(rev, 1); + return BitCast(d, TableLookupLanes(BitCast(du64, v), idx)); +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = 0 }; +}; + +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ + return v##OP##_vm_##CHAR##SEW##LMUL(mask, v, v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI163264(HWY_RVV_COMPRESS, Compress, compress, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_COMPRESS, Compress, compress, _ALL) +#undef HWY_RVV_COMPRESS + +// ------------------------------ CompressNot +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +template +HWY_API V CompressBlocksNot(V v, const M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + detail::StoreN(count, Compress(v, mask), d, unaligned); + return count; +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); + const auto hi_up = detail::SlideUp(hi8, hi8, 16 - kBytes); + const auto lo_down = detail::SlideDown(lo8, lo8, kBytes); + const auto is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +} + +// ------------------------------ CombineShiftRightLanes +template > +HWY_API V CombineShiftRightLanes(const D d, const V hi, V lo) { + constexpr size_t kLanesUp = 16 / sizeof(TFromV) - kLanes; + const auto hi_up = detail::SlideUp(hi, hi, kLanesUp); + const auto lo_down = detail::SlideDown(lo, lo, kLanes); + const auto is_lo = detail::FirstNPerBlock(d); + return IfThenElse(is_lo, lo_down, hi_up); +} + +// ------------------------------ Shuffle2301 (ShiftLeft) +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const Repartition du64; + const auto v64 = BitCast(du64, v); + return BitCast(d, Or(ShiftRight<32>(v64), ShiftLeft<32>(v64))); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<3>(d, v, v); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<2>(d, v, v); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ TableLookupBytes + +// Extends or truncates a vector to match the given d. +namespace detail { + +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + const Simd dh; + const Simd dhh; + return Ext(d, Ext(dh, Ext(dhh, v))); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + const Simd dh; + return Ext(d, Ext(dh, v)); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Ext(d, v); +} + +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD v) + -> VFromD { + return v; +} + +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Trunc(v); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Trunc(Trunc(v)); +} +template +HWY_INLINE auto ChangeLMUL(Simd d, VFromD> v) + -> VFromD { + return Trunc(Trunc(Trunc(v))); +} + +} // namespace detail + +template +HWY_API VI TableLookupBytes(const VT vt, const VI vi) { + const DFromV dt; // T=table, I=index. + const DFromV di; + const Repartition dt8; + const Repartition di8; + // Required for producing half-vectors with table lookups from a full vector. + // If we instead run at the LMUL of the index vector, lookups into the table + // would be truncated. Thus we run at the larger of the two LMULs and truncate + // the result vector to the original index LMUL. + constexpr int kPow2T = Pow2(dt8); + constexpr int kPow2I = Pow2(di8); + const Simd dm8; // m=max + const auto vmt = detail::ChangeLMUL(dm8, BitCast(dt8, vt)); + const auto vmi = detail::ChangeLMUL(dm8, BitCast(di8, vi)); + auto offsets = detail::OffsetsOf128BitBlocks(dm8, detail::Iota0(dm8)); + // If the table is shorter, wrap around offsets so they do not reference + // undefined lanes in the newly extended vmt. + if (kPow2T < kPow2I) { + offsets = detail::AndS(offsets, static_cast(Lanes(dt8) - 1)); + } + const auto out = TableLookupLanes(vmt, Add(vmi, offsets)); + return BitCast(di, detail::ChangeLMUL(di8, out)); +} + +template +HWY_API VI TableLookupBytesOr0(const VT vt, const VI idx) { + const DFromV di; + const Repartition di8; + const auto idx8 = BitCast(di8, idx); + const auto lookup = TableLookupBytes(vt, idx8); + return BitCast(di, IfThenZeroElse(detail::LtS(idx8, 0), lookup)); +} + +// ------------------------------ Broadcast +template +HWY_API V Broadcast(const V v) { + const DFromV d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(d)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return TableLookupLanes(v, idx); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(const D d, const V v) { + const RebindToSigned di; + using TI = TFromD; + const auto shifted = detail::SlideUp(v, v, kLanes); + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(di)), + static_cast(detail::LanesPerBlock(di) - 1)); + const auto clear = detail::LtS(idx_mod, static_cast(kLanes)); + return IfThenZeroElse(clear, shifted); +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template >> +HWY_API V ShiftRightLanes(const Simd d, V v) { + const RebindToSigned di; + using TI = TFromD; + // For partial vectors, clear upper lanes so we shift in zeros. + if (N <= 16 / sizeof(T)) { + v = IfThenElseZero(FirstN(d, N), v); + } + + const auto shifted = detail::SlideDown(v, v, kLanes); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + const size_t lpb = detail::LanesPerBlock(di); + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(di)), static_cast(lpb - 1)); + const auto keep = detail::LtS(idx_mod, static_cast(lpb - kLanes)); + return IfThenElseZero(keep, shifted); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const RebindToUnsigned du; + using TU = TFromD; + const auto i = detail::Iota0(du); + const auto idx_mod = ShiftRight<1>( + detail::AndS(i, static_cast(detail::LanesPerBlock(du) - 1))); + const auto idx = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper + +template +HWY_API V InterleaveUpper(const D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const RebindToUnsigned du; + using TU = TFromD; + const size_t lpb = detail::LanesPerBlock(du); + const auto i = detail::Iota0(du); + const auto idx_mod = ShiftRight<1>(detail::AndS(i, static_cast(lpb - 1))); + const auto idx_lower = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto idx = detail::AddS(idx_lower, static_cast(lpb / 2)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} + +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} + +// ------------------------------ ZipUpper +template +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== REDUCE + +// vector = f(vector, zero_m1) +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(D d, HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, m1) v0) { \ + return Set(d, GetLane(v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v0, v, v0, Lanes(d)))); \ + } + +// ------------------------------ SumOfLanes + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL) +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D d, const VFromD v) { + const auto v0 = Zero(ScalableTag>()); // always m1 + return detail::RedSum(d, v, v0); +} + +// ------------------------------ MinOfLanes +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL) +} // namespace detail + +template +HWY_API VFromD MinOfLanes(D d, const VFromD v) { + using T = TFromD; + const ScalableTag d1; // always m1 + const auto neutral = Set(d1, HighestValue()); + return detail::RedMin(d, v, neutral); +} + +// ------------------------------ MaxOfLanes +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL) +} // namespace detail + +template +HWY_API VFromD MaxOfLanes(D d, const VFromD v) { + using T = TFromD; + const ScalableTag d1; // always m1 + const auto neutral = Set(d1, LowestValue()); + return detail::RedMax(d, v, neutral); +} + +#undef HWY_RVV_REDUCE + +// ================================================== Ops with dependencies + +// ------------------------------ PopulationCount (ShiftRight) + +// Handles LMUL >= 2 or capped vectors, which generic_ops-inl cannot. +template , HWY_IF_LANE_SIZE_D(D, 1), + hwy::EnableIf* = nullptr> +HWY_API V PopulationCount(V v) { + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, detail::AndS(ShiftRight<1>(v), 0x55)); + v = Add(detail::AndS(ShiftRight<2>(v), 0x33), detail::AndS(v, 0x33)); + return detail::AndS(Add(v, ShiftRight<4>(v)), 0x0F); +} + +// ------------------------------ LoadDup128 + +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const VFromD loaded = Load(d, p); + // idx must be unsigned for TableLookupLanes. + using TU = MakeUnsigned>; + const TU mask = static_cast(detail::LanesPerBlock(d) - 1); + // Broadcast the first block. + const VFromD> idx = detail::AndS(detail::Iota0(d), mask); + return TableLookupLanes(loaded, idx); +} + +// ------------------------------ LoadMaskBits + +// Support all combinations of T and SHIFT(LMUL) without explicit overloads for +// each. First overload for MLEN=1..64. +namespace detail { + +// Maps D to MLEN (wrapped in SizeTag), such that #mask_bits = VLEN/MLEN. MLEN +// increases with lane size and decreases for increasing LMUL. Cap at 64, the +// largest supported by HWY_RVV_FOREACH_B (and intrinsics), for virtual LMUL +// e.g. vuint16mf8_t: (8*2 << 3) == 128. +template +using MaskTag = hwy::SizeTag), -Pow2(D())))>; + +#define HWY_RVV_LOAD_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_INLINE HWY_RVV_M(MLEN) \ + NAME(hwy::SizeTag /* tag */, const uint8_t* bits, size_t N) { \ + return OP##_v_b##MLEN(bits, N); \ + } +HWY_RVV_FOREACH_B(HWY_RVV_LOAD_MASK_BITS, LoadMaskBits, vlm) +#undef HWY_RVV_LOAD_MASK_BITS +} // namespace detail + +template > +HWY_API auto LoadMaskBits(D d, const uint8_t* bits) + -> decltype(detail::LoadMaskBits(MT(), bits, Lanes(d))) { + return detail::LoadMaskBits(MT(), bits, Lanes(d)); +} + +// ------------------------------ StoreMaskBits +#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(D d, HWY_RVV_M(MLEN) m, uint8_t* bits) { \ + const size_t N = Lanes(d); \ + OP##_v_b##MLEN(bits, m, N); \ + /* Non-full byte, need to clear the undefined upper bits. */ \ + /* Use MaxLanes and sizeof(T) to move some checks to compile-time. */ \ + constexpr bool kLessThan8 = \ + detail::ScaleByPower(16 / sizeof(TFromD), Pow2(d)) < 8; \ + if (MaxLanes(d) < 8 || (kLessThan8 && N < 8)) { \ + const int mask = (1 << N) - 1; \ + bits[0] = static_cast(bits[0] & mask); \ + } \ + return (N + 7) / 8; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, StoreMaskBits, vsm) +#undef HWY_RVV_STORE_MASK_BITS + +// ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits) + +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ FirstN (Iota0, Lt, RebindMask, SlideUp) + +// Disallow for 8-bit because Iota is likely to overflow. +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const RebindToSigned di; + using TI = TFromD; + return RebindMask( + d, detail::LtS(BitCast(di, detail::Iota0(d)), static_cast(n))); +} + +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const auto zero = Zero(d); + const auto one = Set(d, 1); + return Eq(detail::SlideUp(one, zero, n), one); +} + +// ------------------------------ Neg (Sub) + +template +HWY_API V Neg(const V v) { + return detail::ReverseSubS(v, 0); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + +// ------------------------------ Abs (Max, Neg) + +template +HWY_API V Abs(const V v) { + return Max(v, Neg(v)); +} + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx, _ALL) + +#undef HWY_RVV_RETV_ARGV2 + +// ------------------------------ AbsDiff (Abs, Sub) +template +HWY_API V AbsDiff(const V a, const V b) { + return Abs(Sub(a, b)); +} + +// ------------------------------ Round (NearestInt, ConvertTo, CopySign) + +// IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not have +// a dedicated instruction for that. Rounding to integer and converting back to +// float is correct except when the input magnitude is large, in which case the +// input was already an integer (because mantissa >> exponent is zero). + +namespace detail { +enum RoundingModes { kNear, kTrunc, kDown, kUp }; + +template +HWY_INLINE auto UseInt(const V v) -> decltype(MaskFromVec(v)) { + return detail::LtS(Abs(v), MantissaEnd>()); +} + +} // namespace detail + +template +HWY_API V Round(const V v) { + const DFromV df; + + const auto integer = NearestInt(v); // round using current mode + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Trunc (ConvertTo) +template +HWY_API V Trunc(const V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Ceil +template +HWY_API V Ceil(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kUp)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floor +template +HWY_API V Floor(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kDown)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floating-point classification (Ne) + +// vfclass does not help because it would require 3 instructions (to AND and +// then compare the bits), whereas these are just 1-3 integer instructions. + +template +HWY_API MFromD> IsNaN(const V v) { + return Ne(v, v); +} + +template > +HWY_API MFromD IsInf(const V v) { + const D d; + const RebindToSigned di; + using T = TFromD; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqS(Add(vi, vi), hwy::MaxExponentTimes2())); +} + +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + using T = TFromD; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtS(exp, hwy::MaxExponentField())); +} + +// ------------------------------ Iota (ConvertTo) + +template +HWY_API VFromD Iota(const D d, TFromD first) { + return detail::AddS(detail::Iota0(d), first); +} + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToUnsigned du; + return detail::AddS(BitCast(d, detail::Iota0(du)), first); +} + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToUnsigned du; + const RebindToSigned di; + return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), first); +} + +// ------------------------------ MulEven/Odd (Mul, OddEven) + +template , + class DW = RepartitionToWide> +HWY_API VFromD MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo)); +} + +// There is no 64x64 vwmul. +template +HWY_INLINE V MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return OddEven(detail::Slide1Up(hi), lo); +} + +template +HWY_INLINE V MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return OddEven(hi, detail::Slide1Down(lo)); +} + +// ------------------------------ ReorderDemote2To (OddEven, Combine) + +template +HWY_API VFromD> ReorderDemote2To( + Simd dbf16, + VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + const RebindToUnsigned> du32; + const VFromD b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template * = nullptr, + class D32 = RepartitionToWide>> +HWY_API VFromD> ReorderDemote2To( + Simd d16, VFromD a, VFromD b) { + const Twice d32t; + const VFromD ab = Combine(d32t, a, b); + return DemoteTo(d16, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template >>> +HWY_API VFromD> ReorderDemote2To(Simd d16, + V32 a, V32 b) { + const Half d16h; + const VFromD a16 = DemoteTo(d16h, a); + const VFromD b16 = DemoteTo(d16h, b); + return Combine(d16, a16, b16); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +namespace detail { + +// Non-overloaded wrapper function so we can define DF32 in template args. +template < + size_t N, int kPow2, class DF32 = Simd, + class VF32 = VFromD, + class DU16 = RepartitionToNarrow>>> +HWY_API VF32 ReorderWidenMulAccumulateBF16(Simd df32, + VFromD a, VFromD b, + const VF32 sum0, VF32& sum1) { + const DU16 du16; + const RebindToUnsigned du32; + using VU32 = VFromD; + const VFromD zero = Zero(du16); + const VU32 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const VU32 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const VU32 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const VU32 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +#define HWY_RVV_WIDEN_MACC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEWD, LMULD) sum, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return OP##CHAR##SEWD##LMULD(sum, a, b, Lanes(d)); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_WIDEN_MACC, WidenMulAcc, vwmacc_vv_, _EXT_VIRT) +#undef HWY_RVV_WIDEN_MACC + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template * = nullptr, + class D32 = Simd, class V32 = VFromD, + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(Simd d32, + VFromD a, VFromD b, + const V32 sum0, V32& sum1) { + const Twice d32t; + using V32T = VFromD; + V32T sum = Combine(d32t, sum0, sum1); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template , class V32 = VFromD, + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(Simd d32, + VFromD a, VFromD b, + const V32 sum0, V32& sum1) { + const Half d16h; + using V16H = VFromD; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +} // namespace detail + +template +HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, + const VW sum0, VW& sum1) { + return detail::ReorderWidenMulAccumulateBF16(d32, a, b, sum0, sum1); +} + +template +HWY_API VW ReorderWidenMulAccumulate(Simd d32, VN a, VN b, + const VW sum0, VW& sum1) { + return detail::ReorderWidenMulAccumulateI16(d32, a, b, sum0, sum1); +} + +// ------------------------------ Lt128 +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Shift leftward so L can influence H. + const VFromD ltLx = detail::Slide1Up(ltHL); + const VFromD vecHx = OrAnd(ltHL, eqHL, ltLx); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); +} + +// ------------------------------ Lt128Upper +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(ltHL, detail::Slide1Down(ltHL))); +} + +// ------------------------------ Eq128 +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD eqLH = Reverse2(d, eqHL); + return MaskFromVec(And(eqHL, eqLH)); +} + +// ------------------------------ Eq128Upper +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(eqHL, detail::Slide1Down(eqHL))); +} + +// ------------------------------ Ne128 +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + const VFromD neLH = Reverse2(d, neHL); + return MaskFromVec(Or(neHL, neLH)); +} + +// ------------------------------ Ne128Upper +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(neHL, detail::Slide1Down(neHL))); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE VFromD Min128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD minHL = Min(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, a, b); + // The upper lane is just minHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(minHL, IfThenElse(eqXH, minHL, lo)); +} + +template +HWY_INLINE VFromD Max128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD maxHL = Max(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, b, a); + // The upper lane is just maxHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(maxHL, IfThenElse(eqXH, maxHL, lo)); +} + +template +HWY_INLINE VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_RVV_AVL +#undef HWY_RVV_D +#undef HWY_RVV_FOREACH +#undef HWY_RVV_FOREACH_08_ALL +#undef HWY_RVV_FOREACH_08_ALL_VIRT +#undef HWY_RVV_FOREACH_08_DEMOTE +#undef HWY_RVV_FOREACH_08_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_08_EXT +#undef HWY_RVV_FOREACH_08_EXT_VIRT +#undef HWY_RVV_FOREACH_08_TRUNC +#undef HWY_RVV_FOREACH_08_VIRT +#undef HWY_RVV_FOREACH_16_ALL +#undef HWY_RVV_FOREACH_16_ALL_VIRT +#undef HWY_RVV_FOREACH_16_DEMOTE +#undef HWY_RVV_FOREACH_16_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_16_EXT +#undef HWY_RVV_FOREACH_16_EXT_VIRT +#undef HWY_RVV_FOREACH_16_TRUNC +#undef HWY_RVV_FOREACH_16_VIRT +#undef HWY_RVV_FOREACH_32_ALL +#undef HWY_RVV_FOREACH_32_ALL_VIRT +#undef HWY_RVV_FOREACH_32_DEMOTE +#undef HWY_RVV_FOREACH_32_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_32_EXT +#undef HWY_RVV_FOREACH_32_EXT_VIRT +#undef HWY_RVV_FOREACH_32_TRUNC +#undef HWY_RVV_FOREACH_32_VIRT +#undef HWY_RVV_FOREACH_64_ALL +#undef HWY_RVV_FOREACH_64_ALL_VIRT +#undef HWY_RVV_FOREACH_64_DEMOTE +#undef HWY_RVV_FOREACH_64_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_64_EXT +#undef HWY_RVV_FOREACH_64_EXT_VIRT +#undef HWY_RVV_FOREACH_64_TRUNC +#undef HWY_RVV_FOREACH_64_VIRT +#undef HWY_RVV_FOREACH_B +#undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F16 +#undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F3264 +#undef HWY_RVV_FOREACH_F64 +#undef HWY_RVV_FOREACH_I +#undef HWY_RVV_FOREACH_I08 +#undef HWY_RVV_FOREACH_I16 +#undef HWY_RVV_FOREACH_I163264 +#undef HWY_RVV_FOREACH_I32 +#undef HWY_RVV_FOREACH_I64 +#undef HWY_RVV_FOREACH_U +#undef HWY_RVV_FOREACH_U08 +#undef HWY_RVV_FOREACH_U16 +#undef HWY_RVV_FOREACH_U163264 +#undef HWY_RVV_FOREACH_U32 +#undef HWY_RVV_FOREACH_U64 +#undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI08 +#undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI163264 +#undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI3264 +#undef HWY_RVV_FOREACH_UI64 +#undef HWY_RVV_M +#undef HWY_RVV_RETM_ARGM +#undef HWY_RVV_RETV_ARGV +#undef HWY_RVV_RETV_ARGVS +#undef HWY_RVV_RETV_ARGVV +#undef HWY_RVV_T +#undef HWY_RVV_V +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/scalar-inl.h b/media/highway/src/hwy/ops/scalar-inl.h new file mode 100644 index 000000000..8b11828e6 --- /dev/null +++ b/media/highway/src/hwy/ops/scalar-inl.h @@ -0,0 +1,1571 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template +using Sisd = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec1 { + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template +class Mask1 { + using Raw = hwy::MakeUnsigned; + + public: + static HWY_INLINE Mask1 FromBool(bool b) { + Mask1 mask; + mask.bits = b ? static_cast(~Raw{0}) : 0; + return mask; + } + + Raw bits; +}; + +namespace detail { + +// Deduce Sisd from Vec1 +struct Deduce1 { + template + Sisd operator()(Vec1) const { + return Sisd(); + } +}; + +} // namespace detail + +template +using DFromV = decltype(detail::Deduce1()(V())); + +template +using TFromV = TFromD>; + +// ------------------------------ BitCast + +template +HWY_API Vec1 BitCast(Sisd /* tag */, Vec1 v) { + static_assert(sizeof(T) <= sizeof(FromT), "Promoting is undefined"); + T to; + CopyBytes(&v.raw, &to); // not same size - ok to shrink + return Vec1(to); +} + +// ------------------------------ Set + +template +HWY_API Vec1 Zero(Sisd /* tag */) { + return Vec1(T(0)); +} + +template +HWY_API Vec1 Set(Sisd /* tag */, const T2 t) { + return Vec1(static_cast(t)); +} + +template +HWY_API Vec1 Undefined(Sisd d) { + return Zero(d); +} + +template +HWY_API Vec1 Iota(const Sisd /* tag */, const T2 first) { + return Vec1(static_cast(first)); +} + +template +using VFromD = decltype(Zero(D())); + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec1 Not(const Vec1 v) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, v).raw))); +} + +// ------------------------------ And + +template +HWY_API Vec1 And(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator&(const Vec1 a, const Vec1 b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template +HWY_API Vec1 AndNot(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, a).raw & + BitCast(du, b).raw))); +} + +// ------------------------------ Or + +template +HWY_API Vec1 Or(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator|(const Vec1 a, const Vec1 b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template +HWY_API Vec1 Xor(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { + return Xor(a, b); +} + +// ------------------------------ Or3 + +template +HWY_API Vec1 Or3(Vec1 o1, Vec1 o2, Vec1 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ CopySign + +template +HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Sisd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Sisd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec1 BroadcastSignBit(const Vec1 v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + return v.raw < 0 ? Vec1(T(-1)) : Vec1(0); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template +HWY_API Vec1 PopulationCount(Vec1 v) { + return Vec1(static_cast(PopCount(v.raw))); +} + +// ------------------------------ Mask + +template +HWY_API Mask1 RebindMask(Sisd /*tag*/, Mask1 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1{m.bits}; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask1 MaskFromVec(const Vec1 v) { + Mask1 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +Vec1 VecFromMask(const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +Vec1 VecFromMask(Sisd /* tag */, const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +HWY_API Mask1 FirstN(Sisd /*tag*/, size_t n) { + return Mask1::FromBool(n != 0); +} + +// Returns mask ? yes : no. +template +HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, + const Vec1 no) { + return mask.bits ? yes : no; +} + +template +HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { + return mask.bits ? yes : Vec1(0); +} + +template +HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { + return mask.bits ? Vec1(0) : no; +} + +template +HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { + return v.raw < 0 ? yes : no; +} + +template +HWY_API Vec1 ZeroIfNegative(const Vec1 v) { + return v.raw < 0 ? Vec1(0) : v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask1 Not(const Mask1 m) { + return MaskFromVec(Not(VecFromMask(Sisd(), m))); +} + +template +HWY_API Mask1 And(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Or(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 ExclusiveNeither(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeft(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1( + static_cast(static_cast>(v.raw) << kBits)); +} + +template +HWY_API Vec1 ShiftRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(static_cast(v.raw >> kBits)); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = BitCast(du, v).raw >> kBits; + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast(sign << sign_shift); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { // T is unsigned + return Vec1(static_cast(v.raw >> kBits)); + } +#endif +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template +struct RotateRight { + template + HWY_INLINE Vec1 operator()(const Vec1 v) const { + return Or(ShiftRight(v), ShiftLeft(v)); + } +}; + +template <> +struct RotateRight<0> { + template + HWY_INLINE Vec1 operator()(const Vec1 v) const { + return v; + } +}; + +} // namespace detail + +template +HWY_API Vec1 RotateRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight()(v); +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { + return Vec1( + static_cast(static_cast>(v.raw) << bits)); +} + +template +HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(static_cast(v.raw >> bits)); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = BitCast(du, v).raw >> bits; + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast(sign << sign_shift); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { // T is unsigned + return Vec1(static_cast(v.raw >> bits)); + } +#endif +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template +HWY_API Vec1 operator<<(const Vec1 v, const Vec1 bits) { + return ShiftLeftSame(v, static_cast(bits.raw)); +} + +template +HWY_API Vec1 operator>>(const Vec1 v, const Vec1 bits) { + return ShiftRightSame(v, static_cast(bits.raw)); +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec1 operator+(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} + +template +HWY_API Vec1 operator-(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +HWY_API Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} +HWY_API Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} + +// ------------------------------ Absolute value + +template +HWY_API Vec1 Abs(const Vec1 a) { + const T i = a.raw; + return (i >= 0 || i == hwy::LimitsMin()) ? a : Vec1(-i); +} +HWY_API Vec1 Abs(const Vec1 a) { + return Vec1(std::abs(a.raw)); +} +HWY_API Vec1 Abs(const Vec1 a) { + return Vec1(std::abs(a.raw)); +} + +// ------------------------------ min/max + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + if (std::isnan(a.raw)) return b; + if (std::isnan(b.raw)) return a; + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + if (std::isnan(a.raw)) return b; + if (std::isnan(b.raw)) return a; + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Xor(v, SignBit(Sisd())); +} + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Zero(Sisd()) - v; +} + +// ------------------------------ mul/div + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(double(a.raw) * b.raw)); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(int64_t(a.raw) * b.raw)); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(uint64_t(a.raw) * b.raw)); +} + +template +HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { + return Vec1(a.raw / b.raw); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((a.raw * b.raw) >> 16)); +} +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the result + // is the same but this way it is also defined. + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> 16)); +} + +HWY_API Vec1 MulFixedPoint15(Vec1 a, Vec1 b) { + return Vec1(static_cast((2 * a.raw * b.raw + 32768) >> 16)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { + const int64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} +HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { + const uint64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} + +// Approximate reciprocal +HWY_API Vec1 ApproximateReciprocal(const Vec1 v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1(0.0f); + return Vec1(1.0f / v.raw); +} + +// Absolute value of difference. +HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { + return mul * x + add; +} + +template +HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return add - mul * x; +} + +template +HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { + return mul * x - sub; +} + +template +HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_API Vec1 ApproximateReciprocalSqrt(const Vec1 v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopySameSize(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &f); + // One Newton-Raphson iteration + return Vec1(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_API Vec1 Sqrt(const Vec1 v) { + return Vec1(std::sqrt(v.raw)); +} +HWY_API Vec1 Sqrt(const Vec1 v) { + return Vec1(std::sqrt(v.raw)); +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec1 Round(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN + return v; + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(0), v); + // Round to even + if ((rounded & 1) && std::abs(static_cast(rounded) - v.raw) == T(0.5)) { + return Vec1(static_cast(rounded - (v.raw < T(0) ? -1 : 1))); + } + return Vec1(static_cast(rounded)); +} + +// Round-to-nearest even. +HWY_API Vec1 NearestInt(const Vec1 v) { + using T = float; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool signbit = std::signbit(v.raw); + + if (!(abs < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= static_cast(LimitsMax()))) { + return Vec1(signbit ? LimitsMin() : LimitsMax()); + } + return Vec1(static_cast(v.raw)); + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return Vec1(0); + // Round to even + if ((rounded & 1) && std::abs(static_cast(rounded) - v.raw) == T(0.5)) { + return Vec1(rounded - (signbit ? -1 : 1)); + } + return Vec1(rounded); +} + +template +HWY_API Vec1 Trunc(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN + return v; + } + const TI truncated = static_cast(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1(0), v); + return Vec1(static_cast(truncated)); +} + +template +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +template +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} + +// Toward -infinity, aka floor +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} + +// ================================================== COMPARE + +template +HWY_API Mask1 operator==(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw == b.raw); +} + +template +HWY_API Mask1 operator!=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw != b.raw); +} + +template +HWY_API Mask1 TestBit(const Vec1 v, const Vec1 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask1 operator<(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw < b.raw); +} +template +HWY_API Mask1 operator>(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw > b.raw); +} + +template +HWY_API Mask1 operator<=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw <= b.raw); +} +template +HWY_API Mask1 operator>=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw >= b.raw); +} + +// ------------------------------ Floating-point classification (==) + +template +HWY_API Mask1 IsNaN(const Vec1 v) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned bits; + CopySameSize(&v, &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + return Mask1::FromBool(bits > ExponentMask()); +} + +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFF000000u)); +} +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFFE0000000000000ull)); +} + +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFF000000u); +} +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFFE0000000000000ull); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec1 Load(Sisd /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopySameSize(aligned, &t); + return Vec1(t); +} + +template +HWY_API Vec1 MaskedLoad(Mask1 m, Sisd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API Vec1 LoadU(Sisd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API Vec1 LoadDup128(Sisd d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec1 v, Sisd /* tag */, + T* HWY_RESTRICT aligned) { + CopySameSize(&v.raw, aligned); +} + +template +HWY_API void StoreU(const Vec1 v, Sisd d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +template +HWY_API void BlendedStore(const Vec1 v, Mask1 m, Sisd d, + T* HWY_RESTRICT p) { + if (!m.bits) return; + StoreU(v, d, p); +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +template +HWY_API void LoadInterleaved3(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1, Vec1& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +template +HWY_API void LoadInterleaved4(Sisd d, const T* HWY_RESTRICT unaligned, + Vec1& v0, Vec1& v1, Vec1& v2, + Vec1& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(const Vec1 v0, const Vec1 v1, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); +} + +template +HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, + const Vec1 v2, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +template +HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, + const Vec1 v2, const Vec1 v3, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template +HWY_API void Stream(const Vec1 v, Sisd d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template +HWY_API void ScatterOffset(Vec1 v, Sisd d, T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + uint8_t* const base8 = reinterpret_cast(base) + offset.raw; + return Store(v, d, reinterpret_cast(base8)); +} + +template +HWY_API void ScatterIndex(Vec1 v, Sisd d, T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +template +HWY_API Vec1 GatherOffset(Sisd d, const T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + return Load(d, reinterpret_cast(addr)); +} + +template +HWY_API Vec1 GatherIndex(Sisd d, const T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Load(d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template +HWY_API Vec1 PromoteTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1(static_cast(from.raw)); +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting float to narrower integer/float + if (std::isinf(from.raw) || + std::fabs(from.raw) > static_cast(HighestValue())) { + return Vec1(std::signbit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (std::isinf(from.raw) || + std::fabs(from.raw) > static_cast(HighestValue())) { + return Vec1(std::signbit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + static_assert(!IsFloat(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw = HWY_MIN(HWY_MAX(LimitsMin(), from.raw), LimitsMax()); + return Vec1(static_cast(from.raw)); +} + +HWY_API Vec1 PromoteTo(Sisd /* tag */, const Vec1 v) { + uint16_t bits16; + CopySameSize(&v.raw, &bits16); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return Vec1(sign ? -subnormal : subnormal); + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + float out; + CopySameSize(&bits32, &out); + return Vec1(out); +} + +HWY_API Vec1 PromoteTo(Sisd d, const Vec1 v) { + return Set(d, F32FromBF16(v.raw)); +} + +HWY_API Vec1 DemoteTo(Sisd /* tag */, + const Vec1 v) { + uint32_t bits32; + CopySameSize(&v.raw, &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + Vec1 out; + if (exp < -24) { + const uint16_t zero = 0; + CopySameSize(&zero, &out.raw); + return out; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + CopySameSize(&narrowed, &out.raw); + return out; +} + +HWY_API Vec1 DemoteTo(Sisd d, const Vec1 v) { + return Set(d, BF16FromF32(v.raw)); +} + +template +HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax in FromT, so use double. + const double f = static_cast(from.raw); + if (std::isinf(from.raw) || + std::fabs(f) > static_cast(LimitsMax())) { + return Vec1(std::signbit(from.raw) ? LimitsMin() + : LimitsMax()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // int## -> float##: no check needed + return Vec1(static_cast(from.raw)); +} + +HWY_API Vec1 U8FromU32(const Vec1 v) { + return DemoteTo(Sisd(), v); +} + +// ------------------------------ Truncations + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFFFFFFu)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +HWY_API Vec1 TruncateTo(Sisd /* tag */, + const Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +// ================================================== COMBINE +// UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. + +template +HWY_API Vec1 LowerHalf(Vec1 v) { + return v; +} + +template +HWY_API Vec1 LowerHalf(Sisd /* tag */, Vec1 v) { + return v; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec1 v) { + return v.raw; +} + +template +HWY_API T ExtractLane(const Vec1 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return v.raw; +} + +template +HWY_API Vec1 InsertLane(Vec1 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + v.raw = t; + return v; +} + +template +HWY_API Vec1 DupEven(Vec1 v) { + return v; +} +// DupOdd is unsupported. + +template +HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { + return even; +} + +template +HWY_API Vec1 OddEvenBlocks(Vec1 /* odd */, Vec1 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec1 SwapAdjacentBlocks(Vec1 v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices1 { + MakeSigned raw; +}; + +template +HWY_API Indices1 IndicesFromVec(Sisd, Vec1 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + HWY_DASSERT(vec.raw == 0); + return Indices1{vec.raw}; +} + +template +HWY_API Indices1 SetTableIndices(Sisd d, const TI* idx) { + return IndicesFromVec(d, LoadU(Sisd(), idx)); +} + +template +HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec1 ReverseBlocks(Sisd /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ Reverse + +template +HWY_API Vec1 Reverse(Sisd /* tag */, const Vec1 v) { + return v; +} + +// Must not be called: +template +HWY_API Vec1 Reverse2(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse4(Sisd /* tag */, const Vec1 v) { + return v; +} + +template +HWY_API Vec1 Reverse8(Sisd /* tag */, const Vec1 v) { + return v; +} + +// ================================================== BLOCKWISE +// Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec1 Broadcast(const Vec1 v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +template +HWY_API Vec1 TableLookupBytesOr0(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +// ------------------------------ ZipLower + +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((uint32_t(b.raw) << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint32_t(b.raw) << 16) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint64_t(b.raw) << 32) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((int32_t(b.raw) << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1((int32_t(b.raw) << 16) + a.raw); +} +HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1((int64_t(b.raw) << 32) + a.raw); +} + +template , class VW = Vec1> +HWY_API VW ZipLower(Sisd /* tag */, Vec1 a, Vec1 b) { + return VW(static_cast((TW{b.raw} << (sizeof(T) * 8)) + a.raw)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0; +} + +template +HWY_API bool AllTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits != 0; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask1 LoadMaskBits(Sisd /* tag */, + const uint8_t* HWY_RESTRICT bits) { + return Mask1::FromBool((bits[0] & 1) != 0); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(Sisd d, const Mask1 mask, uint8_t* bits) { + *bits = AllTrue(d, mask); + return 1; +} + +template +HWY_API size_t CountTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0 ? 0 : 1; +} + +template +HWY_API intptr_t FindFirstTrue(Sisd /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template +HWY_API size_t FindKnownFirstTrue(Sisd /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +// ------------------------------ Compress, CompressBits + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +template +HWY_API Vec1 CompressNot(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT unaligned) { + if (!mask.bits) return 0; + StoreU(v, d, unaligned); + return 1; +} + +// ------------------------------ CompressBits +template +HWY_API Vec1 CompressBits(Vec1 v, const uint8_t* HWY_RESTRICT /*bits*/) { + return v; +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(Vec1 v, const uint8_t* HWY_RESTRICT bits, + Sisd d, T* HWY_RESTRICT unaligned) { + const Mask1 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return MulAdd(Vec1(F32FromBF16(a.raw)), + Vec1(F32FromBF16(b.raw)), sum0); +} + +HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(a.raw * b.raw + sum0.raw); +} + +// ================================================== REDUCTIONS + +// Sum of all lanes, i.e. the only one. +template +HWY_API Vec1 SumOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} +template +HWY_API Vec1 MinOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} +template +HWY_API Vec1 MaxOfLanes(Sisd /* tag */, const Vec1 v) { + return v; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/set_macros-inl.h b/media/highway/src/hwy/ops/set_macros-inl.h new file mode 100644 index 000000000..c1189604b --- /dev/null +++ b/media/highway/src/hwy/ops/set_macros-inl.h @@ -0,0 +1,444 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Sets macros based on HWY_TARGET. + +// This include guard is toggled by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. +#if defined(HWY_SET_MACROS_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif + +#endif // HWY_SET_MACROS_PER_TARGET + +#include "hwy/detect_targets.h" + +#undef HWY_NAMESPACE +#undef HWY_ALIGN +#undef HWY_MAX_BYTES +#undef HWY_LANES + +#undef HWY_HAVE_SCALABLE +#undef HWY_HAVE_INTEGER64 +#undef HWY_HAVE_FLOAT16 +#undef HWY_HAVE_FLOAT64 +#undef HWY_MEM_OPS_MIGHT_FAULT +#undef HWY_NATIVE_FMA +#undef HWY_CAP_GE256 +#undef HWY_CAP_GE512 + +#undef HWY_TARGET_STR + +#if defined(HWY_DISABLE_PCLMUL_AES) +#define HWY_TARGET_STR_PCLMUL_AES "" +#else +#define HWY_TARGET_STR_PCLMUL_AES ",pclmul,aes" +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) +#define HWY_TARGET_STR_BMI2_FMA "" +#else +#define HWY_TARGET_STR_BMI2_FMA ",bmi,bmi2,fma" +#endif + +#if defined(HWY_DISABLE_F16C) +#define HWY_TARGET_STR_F16C "" +#else +#define HWY_TARGET_STR_F16C ",f16c" +#endif + +#define HWY_TARGET_STR_SSSE3 "sse2,ssse3" + +#define HWY_TARGET_STR_SSE4 \ + HWY_TARGET_STR_SSSE3 ",sse4.1,sse4.2" HWY_TARGET_STR_PCLMUL_AES +// Include previous targets, which are the half-vectors of the next target. +#define HWY_TARGET_STR_AVX2 \ + HWY_TARGET_STR_SSE4 ",avx,avx2" HWY_TARGET_STR_BMI2_FMA HWY_TARGET_STR_F16C +#define HWY_TARGET_STR_AVX3 \ + HWY_TARGET_STR_AVX2 ",avx512f,avx512vl,avx512dq,avx512bw" + +// Before include guard so we redefine HWY_TARGET_STR on each include, +// governed by the current HWY_TARGET. + +//----------------------------------------------------------------------------- +// SSSE3 +#if HWY_TARGET == HWY_SSSE3 + +#define HWY_NAMESPACE N_SSSE3 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSSE3 + +//----------------------------------------------------------------------------- +// SSE4 +#elif HWY_TARGET == HWY_SSE4 + +#define HWY_NAMESPACE N_SSE4 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE4 + +//----------------------------------------------------------------------------- +// AVX2 +#elif HWY_TARGET == HWY_AVX2 + +#define HWY_NAMESPACE N_AVX2 +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#ifdef HWY_DISABLE_BMI2_FMA +#define HWY_NATIVE_FMA 0 +#else +#define HWY_NATIVE_FMA 1 +#endif + +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_AVX2 + +//----------------------------------------------------------------------------- +// AVX3[_DL] +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL + +#define HWY_ALIGN alignas(64) +#define HWY_MAX_BYTES 64 +#define HWY_LANES(T) (64 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 1 + +#if HWY_TARGET == HWY_AVX3 + +#define HWY_NAMESPACE N_AVX3 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3 + +#elif HWY_TARGET == HWY_AVX3_DL + +#define HWY_NAMESPACE N_AVX3_DL +#define HWY_TARGET_STR \ + HWY_TARGET_STR_AVX3 \ + ",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avxvnni,avx512bitalg," \ + "avx512vpopcntdq" + +#else +#error "Logic error" +#endif // HWY_TARGET == HWY_AVX3_DL + +//----------------------------------------------------------------------------- +// PPC8 +#elif HWY_TARGET == HWY_PPC8 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_PPC8 + +#define HWY_TARGET_STR "altivec,vsx" + +//----------------------------------------------------------------------------- +// NEON +#elif HWY_TARGET == HWY_NEON + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 + +#if HWY_ARCH_ARM_A64 +#define HWY_HAVE_FLOAT64 1 +#else +#define HWY_HAVE_FLOAT64 0 +#endif + +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#define HWY_NATIVE_FMA 1 +#else +#define HWY_NATIVE_FMA 0 +#endif + +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_NEON + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_ARCH_ARM_V7 +#define HWY_TARGET_STR "+neon-vfpv4" +#else +#define HWY_TARGET_STR "+crypto" +#endif // HWY_ARCH_ARM_V7 +#else +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// SVE[2] +#elif HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE || \ + HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + +// SVE only requires lane alignment, not natural alignment of the entire vector. +#define HWY_ALIGN alignas(8) + +// Value ensures MaxLanes() is the tightest possible upper bound to reduce +// overallocation. +#define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_SVE2 +#define HWY_NAMESPACE N_SVE2 +#define HWY_MAX_BYTES 256 +#elif HWY_TARGET == HWY_SVE_256 +#define HWY_NAMESPACE N_SVE_256 +#define HWY_MAX_BYTES 32 +#elif HWY_TARGET == HWY_SVE2_128 +#define HWY_NAMESPACE N_SVE2_128 +#define HWY_MAX_BYTES 16 +#else +#define HWY_NAMESPACE N_SVE +#define HWY_MAX_BYTES 256 +#endif + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +#define HWY_TARGET_STR "+sve2-aes" +#else +#define HWY_TARGET_STR "+sve" +#endif +#else +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// WASM +#elif HWY_TARGET == HWY_WASM + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// WASM_EMU256 +#elif HWY_TARGET == HWY_WASM_EMU256 + +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM_EMU256 + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// RVV +#elif HWY_TARGET == HWY_RVV + +// RVV only requires lane alignment, not natural alignment of the entire vector, +// and the compiler already aligns builtin types, so nothing to do here. +#define HWY_ALIGN + +// The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8). +#define HWY_MAX_BYTES 65536 + +// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual +// LMUL. This is the tightest possible upper bound. +#define HWY_LANES(T) (8192 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if defined(__riscv_zvfh) +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#define HWY_NAMESPACE N_RVV + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. +// (rv64gcv is not a valid target) + +//----------------------------------------------------------------------------- +// EMU128 +#elif HWY_TARGET == HWY_EMU128 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_EMU128 + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// SCALAR +#elif HWY_TARGET == HWY_SCALAR + +#define HWY_ALIGN +#define HWY_MAX_BYTES 8 +#define HWY_LANES(T) 1 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_SCALAR + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +// Override this to 1 in asan/msan builds, which will still fault. +#if HWY_IS_ASAN || HWY_IS_MSAN +#undef HWY_MEM_OPS_MIGHT_FAULT +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#endif + +// Clang <9 requires this be invoked at file scope, before any namespace. +#undef HWY_BEFORE_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_BEFORE_NAMESPACE() \ + HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_BEFORE_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +// Clang <9 requires any namespaces be closed before this macro. +#undef HWY_AFTER_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_AFTER_NAMESPACE() \ + HWY_POP_ATTRIBUTES \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_AFTER_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +#undef HWY_ATTR +#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) +#define HWY_ATTR __attribute__((target(HWY_TARGET_STR))) +#else +#define HWY_ATTR +#endif diff --git a/media/highway/src/hwy/ops/shared-inl.h b/media/highway/src/hwy/ops/shared-inl.h new file mode 100644 index 000000000..29c430388 --- /dev/null +++ b/media/highway/src/hwy/ops/shared-inl.h @@ -0,0 +1,311 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target definitions shared by ops/*.h and user code. + +#include + +#include "hwy/base.h" + +// Separate header because foreach_target.h re-enables its include guard. +#include "hwy/ops/set_macros-inl.h" + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Highway operations are implemented as overloaded functions selected using an +// internal-only tag type D := Simd. T is the lane type. kPow2 is a +// shift count applied to scalable vectors. Instead of referring to Simd<> +// directly, users create D via aliases ScalableTag() (defaults to a +// full vector, or fractions/groups if the argument is negative/positive), +// CappedTag or FixedTag. The actual number of lanes is +// Lanes(D()), a power of two. For scalable vectors, N is either HWY_LANES or a +// cap. For constexpr-size vectors, N is the actual number of lanes. This +// ensures Half> is the same type as Full256, as required by x86. +template +struct Simd { + constexpr Simd() = default; + using T = Lane; + static_assert((N & (N - 1)) == 0 && N != 0, "N must be a power of two"); + + // Only for use by MaxLanes, required by MSVC. Cannot be enum because GCC + // warns when using enums and non-enums in the same expression. Cannot be + // static constexpr function (another MSVC limitation). + static constexpr size_t kPrivateN = N; + static constexpr int kPrivatePow2 = kPow2; + + template + static constexpr size_t NewN() { + // Round up to correctly handle scalars with N=1. + return (N * sizeof(T) + sizeof(NewT) - 1) / sizeof(NewT); + } + +#if HWY_HAVE_SCALABLE + template + static constexpr int Pow2Ratio() { + return (sizeof(NewT) > sizeof(T)) + ? static_cast(CeilLog2(sizeof(NewT) / sizeof(T))) + : -static_cast(CeilLog2(sizeof(T) / sizeof(NewT))); + } +#endif + + // Widening/narrowing ops change the number of lanes and/or their type. + // To initialize such vectors, we need the corresponding tag types: + +// PromoteTo/DemoteTo() with another lane type, but same number of lanes. +#if HWY_HAVE_SCALABLE + template + using Rebind = Simd()>; +#else + template + using Rebind = Simd; +#endif + + // Change lane type while keeping the same vector size, e.g. for MulEven. + template + using Repartition = Simd(), kPow2>; + +// Half the lanes while keeping the same lane type, e.g. for LowerHalf. +// Round up to correctly handle scalars with N=1. +#if HWY_HAVE_SCALABLE + // Reducing the cap (N) is required for SVE - if N is the limiter for f32xN, + // then we expect Half> to have N/2 lanes (rounded up). + using Half = Simd; +#else + using Half = Simd; +#endif + +// Twice the lanes while keeping the same lane type, e.g. for Combine. +#if HWY_HAVE_SCALABLE + using Twice = Simd; +#else + using Twice = Simd; +#endif +}; + +namespace detail { + +template +constexpr bool IsFull(Simd /* d */) { + return N == HWY_LANES(T) && kPow2 == 0; +} + +// Returns the number of lanes (possibly zero) after applying a shift: +// - 0: no change; +// - [1,3]: a group of 2,4,8 [fractional] vectors; +// - [-3,-1]: a fraction of a vector from 1/8 to 1/2. +constexpr size_t ScaleByPower(size_t N, int pow2) { +#if HWY_TARGET == HWY_RVV + return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); +#else + return pow2 >= 0 ? N : (N >> (-pow2)); +#endif +} + +// Struct wrappers enable validation of arguments via static_assert. +template +struct ScalableTagChecker { + static_assert(-3 <= kPow2 && kPow2 <= 3, "Fraction must be 1/8 to 8"); +#if HWY_TARGET == HWY_RVV + // Only RVV supports register groups. + using type = Simd; +#elif HWY_HAVE_SCALABLE + // For SVE[2], only allow full or fractions. + using type = Simd; +#elif HWY_TARGET == HWY_SCALAR + using type = Simd; +#else + // Only allow full or fractions. + using type = Simd; +#endif +}; + +template +struct CappedTagChecker { + static_assert(kLimit != 0, "Does not make sense to have zero lanes"); + // Safely handle non-power-of-two inputs by rounding down, which is allowed by + // CappedTag. Otherwise, Simd would static_assert. + static constexpr size_t kLimitPow2 = size_t{1} << hwy::FloorLog2(kLimit); + using type = Simd; +}; + +template +struct FixedTagChecker { + static_assert(kNumLanes != 0, "Does not make sense to have zero lanes"); + static_assert(kNumLanes <= HWY_LANES(T), "Too many lanes"); + using type = Simd; +}; + +} // namespace detail + +// Alias for a tag describing a full vector (kPow2 == 0: the most common usage, +// e.g. 1D loops where the application does not care about the vector size) or a +// fraction/multiple of one. Multiples are the same as full vectors for all +// targets except RVV. Fractions (kPow2 < 0) are useful as the argument/return +// value of type promotion and demotion. +template +using ScalableTag = typename detail::ScalableTagChecker::type; + +// Alias for a tag describing a vector with *up to* kLimit active lanes, even on +// targets with scalable vectors and HWY_SCALAR. The runtime lane count +// `Lanes(tag)` may be less than kLimit, and is 1 on HWY_SCALAR. This alias is +// typically used for 1D loops with a relatively low application-defined upper +// bound, e.g. for 8x8 DCTs. However, it is better if data structures are +// designed to be vector-length-agnostic (e.g. a hybrid SoA where there are +// chunks of `M >= MaxLanes(d)` DC components followed by M AC1, .., and M AC63; +// this would enable vector-length-agnostic loops using ScalableTag). +template +using CappedTag = typename detail::CappedTagChecker::type; + +// Alias for a tag describing a vector with *exactly* kNumLanes active lanes, +// even on targets with scalable vectors. Requires `kNumLanes` to be a power of +// two not exceeding `HWY_LANES(T)`. +// +// NOTE: if the application does not need to support HWY_SCALAR (+), use this +// instead of CappedTag to emphasize that there will be exactly kNumLanes lanes. +// This is useful for data structures that rely on exactly 128-bit SIMD, but +// these are discouraged because they cannot benefit from wider vectors. +// Instead, applications would ideally define a larger problem size and loop +// over it with the (unknown size) vectors from ScalableTag. +// +// + e.g. if the baseline is known to support SIMD, or the application requires +// ops such as TableLookupBytes not supported by HWY_SCALAR. +template +using FixedTag = typename detail::FixedTagChecker::type; + +template +using TFromD = typename D::T; + +// Tag for the same number of lanes as D, but with the LaneType T. +template +using Rebind = typename D::template Rebind; + +template +using RebindToSigned = Rebind>, D>; +template +using RebindToUnsigned = Rebind>, D>; +template +using RebindToFloat = Rebind>, D>; + +// Tag for the same total size as D, but with the LaneType T. +template +using Repartition = typename D::template Repartition; + +template +using RepartitionToWide = Repartition>, D>; +template +using RepartitionToNarrow = Repartition>, D>; + +// Tag for the same lane type as D, but half the lanes. +template +using Half = typename D::Half; + +// Tag for the same lane type as D, but twice the lanes. +template +using Twice = typename D::Twice; + +template +using Full32 = Simd; + +template +using Full64 = Simd; + +template +using Full128 = Simd; + +// Same as base.h macros but with a Simd argument instead of T. +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(TFromD) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(TFromD) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(TFromD) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(TFromD) +#define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_LANE_SIZE(TFromD, bytes) +#define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_LANE_SIZE(TFromD, bytes) + +// MSVC workaround: use PrivateN directly instead of MaxLanes. +#define HWY_IF_LT128_D(D) \ + hwy::EnableIf) < 16>* = nullptr +#define HWY_IF_GE128_D(D) \ + hwy::EnableIf) >= 16>* = nullptr + +// Same, but with a vector argument. ops/*-inl.h define their own TFromV. +#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV) +#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV) +#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV) +#define HWY_IF_LANE_SIZE_V(V, bytes) HWY_IF_LANE_SIZE(TFromV, bytes) +#define HWY_IF_NOT_LANE_SIZE_V(V, bytes) HWY_IF_NOT_LANE_SIZE(TFromV, bytes) + +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr int Pow2(D /* d */) { + return D::kPrivatePow2; +} + +// MSVC requires the explicit . +#define HWY_IF_POW2_GE(D, MIN) hwy::EnableIf(D()) >= (MIN)>* = nullptr + +#if HWY_HAVE_SCALABLE + +// Upper bound on the number of lanes. Intended for template arguments and +// reducing code size (e.g. for SSE4, we know at compile-time that vectors will +// not exceed 16 bytes). WARNING: this may be a loose bound, use Lanes() as the +// actual size for allocating storage. WARNING: MSVC might not be able to deduce +// arguments if this is used in EnableIf. See HWY_IF_LT128_D above. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return detail::ScaleByPower(HWY_MIN(D::kPrivateN, HWY_LANES(TFromD)), + D::kPrivatePow2); +} + +#else +// Workaround for MSVC 2017: T,N,kPow2 argument deduction fails, so returning N +// is not an option, nor does a member function work. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return D::kPrivateN; +} + +// (Potentially) non-constant actual size of the vector at runtime, subject to +// the limit imposed by the Simd. Useful for advancing loop counters. +// Targets with scalable vectors define this themselves. +template +HWY_INLINE HWY_MAYBE_UNUSED size_t Lanes(Simd) { + return N; +} + +#endif // !HWY_HAVE_SCALABLE + +// NOTE: GCC generates incorrect code for vector arguments to non-inlined +// functions in two situations: +// - on Windows and GCC 10.3, passing by value crashes due to unaligned loads: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412. +// - on ARM64 and GCC 9.3.0 or 11.2.1, passing by value causes many (but not +// all) tests to fail. +// +// We therefore pass by const& only on GCC and (Windows or ARM64). This alias +// must be used for all vector/mask parameters of functions marked HWY_NOINLINE, +// and possibly also other functions that are not inlined. +#if HWY_COMPILER_GCC_ACTUAL && (HWY_OS_WIN || HWY_ARCH_ARM_A64) +template +using VecArg = const V&; +#else +template +using VecArg = V; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/wasm_128-inl.h b/media/highway/src/hwy/ops/wasm_128-inl.h new file mode 100644 index 000000000..3831258fc --- /dev/null +++ b/media/highway/src/hwy/ops/wasm_128-inl.h @@ -0,0 +1,4589 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit WASM vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +#ifdef HWY_WASM_OLD_NAMES +#define wasm_i8x16_shuffle wasm_v8x16_shuffle +#define wasm_i16x8_shuffle wasm_v16x8_shuffle +#define wasm_i32x4_shuffle wasm_v32x4_shuffle +#define wasm_i64x2_shuffle wasm_v64x2_shuffle +#define wasm_u16x8_extend_low_u8x16 wasm_i16x8_widen_low_u8x16 +#define wasm_u32x4_extend_low_u16x8 wasm_i32x4_widen_low_u16x8 +#define wasm_i32x4_extend_low_i16x8 wasm_i32x4_widen_low_i16x8 +#define wasm_i16x8_extend_low_i8x16 wasm_i16x8_widen_low_i8x16 +#define wasm_u32x4_extend_high_u16x8 wasm_i32x4_widen_high_u16x8 +#define wasm_i32x4_extend_high_i16x8 wasm_i32x4_widen_high_i16x8 +#define wasm_i32x4_trunc_sat_f32x4 wasm_i32x4_trunc_saturate_f32x4 +#define wasm_u8x16_add_sat wasm_u8x16_add_saturate +#define wasm_u8x16_sub_sat wasm_u8x16_sub_saturate +#define wasm_u16x8_add_sat wasm_u16x8_add_saturate +#define wasm_u16x8_sub_sat wasm_u16x8_sub_saturate +#define wasm_i8x16_add_sat wasm_i8x16_add_saturate +#define wasm_i8x16_sub_sat wasm_i8x16_sub_saturate +#define wasm_i16x8_add_sat wasm_i16x8_add_saturate +#define wasm_i16x8_sub_sat wasm_i16x8_sub_saturate +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { + +template +struct Raw128 { + using type = __v128_u; +}; +template <> +struct Raw128 { + using type = __f32x4; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::type raw; +}; + +namespace detail { + +// Deduce Simd from Vec128 +struct DeduceD { + template + Simd operator()(Vec128) const { + return Simd(); + } +}; + +} // namespace detail + +template +using DFromV = decltype(detail::DeduceD()(V())); + +template +using TFromV = TFromD>; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __v128_u BitCastToInteger(__v128_u v) { return v; } +HWY_INLINE __v128_u BitCastToInteger(__f32x4 v) { + return static_cast<__v128_u>(v); +} +HWY_INLINE __v128_u BitCastToInteger(__f64x2 v) { + return static_cast<__v128_u>(v); +} + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __v128_u operator()(__v128_u v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f32x4 operator()(__v128_u v) { return static_cast<__f32x4>(v); } +}; + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{wasm_i32x4_splat(0)}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{wasm_f32x4_splat(0.0f)}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{wasm_i8x16_splat(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { + return Vec128{wasm_i16x8_splat(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { + return Vec128{wasm_i32x4_splat(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { + return Vec128{wasm_i64x2_splat(static_cast(t))}; +} + +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{wasm_i8x16_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{wasm_i16x8_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{wasm_i32x4_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { + return Vec128{wasm_i64x2_splat(t)}; +} + +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{wasm_f32x4_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_add_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add_sat(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_sub_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub_sat(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_avgr(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i8x16_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i16x8_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i32x4_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i64x2_abs(v.raw)}; +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f32x4_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u64x2_shr(v.raw, kBits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i64x2_shr(v.raw, kBits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +// After https://reviews.llvm.org/D108415 shift argument became unsigned. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unsigned +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u64x2_shr(v.raw, bits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shr(v.raw, bits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ignore Wsign-conversion +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t min[2] = {HWY_MIN(a0, b0), HWY_MIN(a1, b1)}; + return Vec128{wasm_v128_load(min)}; +} + +// Signed +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + alignas(16) int64_t min[4]; + min[0] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(min)}; +} + +// Float +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_min(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t max[2] = {HWY_MAX(a0, b0), HWY_MAX(a1, b1)}; + return Vec128{wasm_v128_load(max)}; +} + +// Signed +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + alignas(16) int64_t max[2]; + max[0] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(max)}; +} + +// Float +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_max(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto al = wasm_u32x4_extend_low_u16x8(a.raw); + const auto ah = wasm_u32x4_extend_high_u16x8(a.raw); + const auto bl = wasm_u32x4_extend_low_u16x8(b.raw); + const auto bh = wasm_u32x4_extend_high_u16x8(b.raw); + const auto l = wasm_i32x4_mul(al, bl); + const auto h = wasm_i32x4_mul(ah, bh); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto al = wasm_i32x4_extend_low_i16x8(a.raw); + const auto ah = wasm_i32x4_extend_high_i16x8(a.raw); + const auto bl = wasm_i32x4_extend_low_i16x8(b.raw); + const auto bh = wasm_i32x4_extend_high_i16x8(b.raw); + const auto l = wasm_i32x4_mul(al, bl); + const auto h = wasm_i32x4_mul(ah, bh); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + + const Vec128 lo = BitCast(du, Mul(a, b)); + const Vec128 hi = MulHigh(a, b); + // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must + // carry that into the result. Instead isolate the top two bits because only + // they can influence the result. + const Vec128 lo_top2 = ShiftRight<14>(lo); + // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. + const Vec128 rounding = ShiftRight<1>(Add(lo_top2, Set(du, 1))); + return Add(Add(hi, hi), BitCast(d, rounding)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i8x16_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i16x8_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i32x4_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_mul(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_div(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + return one / v; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfma? + return mul * x + add; +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + // TODO(eustas): replace, when implemented in WASM. + return add - mul * x; +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfms? + return mul * x - sub; +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + // TODO(eustas): replace, when implemented in WASM. + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f32x4_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{wasm_f32x4_nearest(v.raw)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{wasm_f32x4_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{wasm_f32x4_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{wasm_f32x4_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_eq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// Unsigned +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ne(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper halves are not equal, this is the answer. + const auto m_gt = a32 > b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt.raw, m_gt.raw, 0, 0, 2, 2); + const auto lo_gt = And(m_eq, MaskFromVec(VFromD{lo_in_hi})); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128{wasm_i32x4_shuffle(gt.raw, gt.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float <= >= +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_le(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ge(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec128 Not(Vec128 v) { + return Vec128{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(DFromV()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(DFromV()), sign)); +} + +// ------------------------------ BroadcastSignBit (compare) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(d, v < Zero(d)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(Simd /* tag */, Mask128 v) { + return Vec128{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + const DFromV d; + const auto zero = Zero(d); + return IfThenElse(Mask128{(v > zero).raw}, v, zero); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + alignas(16) T lanes[2]; + alignas(16) T bits_lanes[2]; + Store(v, d, lanes); + Store(bits, d, bits_lanes); + lanes[0] <<= bits_lanes[0]; + lanes[1] <<= bits_lanes[1]; + return Load(d, lanes); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{wasm_v128_load(aligned)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// Partial load. +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + Vec128 v; + CopyBytes(p, &v); + return v; +} + +// LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// Partial store. +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); +} + +HWY_API void Store(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT p) { + *p = wasm_f32x4_extract_lane(v.raw, 0); +} + +// StoreU == Store. +template +HWY_API void StoreU(Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i8x16_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i16x8_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i32x4_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i64x2_extract_lane(v.raw, kLane)); +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + return wasm_f32x4_extract_lane(v.raw, kLane); +} + +} // namespace detail + +// One overload per vector length just in case *_extract_lane raise compile +// errors if their argument is out of bounds (even if that would never be +// reached at runtime). +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ GetLane +template +HWY_API T GetLane(const Vec128 v) { + return detail::ExtractLane<0>(v); +} + +// ------------------------------ InsertLane + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i8x16_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i16x8_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i32x4_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i64x2_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{wasm_f32x4_replace_lane(v.raw, kLane, t)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + return Vec128{wasm_f64x2_replace_lane(v.raw, kLane, t)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls wasm_f64x2_replace_lane. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return LowerHalf(Simd(), v); +} + +// ------------------------------ ShiftLeftBytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return Vec128{wasm_i8x16_shuffle( + v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return Vec128{wasm_i8x16_shuffle( + v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, + 6)}; + + case 10: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, + 5)}; + + case 11: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, + 4)}; + + case 12: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, + 2, 3)}; + + case 13: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1, 2)}; + + case 14: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0, 1)}; + + case 15: + return Vec128{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 0)}; + } + return Vec128{zero}; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(Simd(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +namespace detail { + +// Helper function allows zeroing invalid lanes in caller. +template +HWY_API __i8x16 ShrBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + + switch (kBytes) { + case 0: + return v.raw; + + case 1: + return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16); + + case 2: + return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16); + + case 3: + return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16); + + case 4: + return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16); + + case 5: + return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16); + + case 6: + return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16); + + case 7: + return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16); + + case 8: + return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 9: + return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 10: + return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 11: + return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 12: + return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 13: + return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 14: + return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 15: + return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + case 16: + return zero; + } +} + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128 vfull{v.raw}; + v = Vec128{IfThenElseZero(FirstN(Full128(), N), vfull).raw}; + } + return Vec128{detail::ShrBytes(v)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128{upper.raw}; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full128 /* tag */, V hi, V lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16)}; + + case 2: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21)}; + + case 7: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22)}; + + case 8: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23)}; + + case 9: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24)}; + + case 10: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25)}; + + case 11: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26)}; + + case 12: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27)}; + + case 13: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28)}; + + case 14: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29)}; + + case 15: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30)}; + } + return hi; +} + +template > +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full128 d_full8; + using V8 = VFromD; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, kLane, kLane)}; +} + +// ------------------------------ TableLookupBytes + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { +// Not yet available in all engines, see +// https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md +// V8 implementation of this had a bug, fixed on 2021-04-03: +// https://chromium-review.googlesource.com/c/v8/v8/+/2822951 +#if 0 + return Vec128{wasm_i8x16_swizzle(bytes.raw, from.raw)}; +#else + alignas(16) uint8_t control[16]; + alignas(16) uint8_t input[16]; + alignas(16) uint8_t output[16]; + wasm_v128_store(control, from.raw); + wasm_v128_store(input, bytes.raw); + for (size_t i = 0; i < 16; ++i) { + output[i] = control[i] < 16 ? input[control[i]] : 0; + } + return Vec128{wasm_v128_load(output)}; +#endif +} + +template +HWY_API Vec128 TableLookupBytesOr0(const Vec128 bytes, + const Vec128 from) { + const Simd d; + // Mask size must match vector type, so cast everything to this type. + Repartition di8; + Repartition> d_bytes8; + const auto msb = BitCast(di8, from) < Zero(di8); + const auto lookup = + TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. +namespace detail { + +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 1, 0, 3 + 16, 2 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 1, 0, 3 + 8, 2 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 1, 0, 3 + 4, 2 + 4)}; +} + +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 0, 3, 2 + 16, 1 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 0, 3, 2 + 8, 1 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 3, 2 + 4, 1 + 4)}; +} + +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 2, 1, 0 + 16, 3 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 2, 1, 0 + 8, 3 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 1, 0 + 4, 3 + 4)}; +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + __v128_u raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#endif + + const Repartition d8; + using V8 = VFromD; + const Repartition d16; + + // Broadcast each lane index to all bytes of T and shift to bytes + static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); + if (sizeof(T) == 4) { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; + } else { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; + } +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + using TI = MakeSigned; + const DFromV d; + const Rebind di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return Vec128{Shuffle2301(Vec128{v.raw}).raw}; +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + return BitCast(d, Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 3, 2, + 1, 0, 7, 6, 5, 4)}); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, const Vec128) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { + return Reverse(d, v); +} + +template +HWY_API Vec128 Reverse8(Simd, const Vec128) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +// Additional overload for the optional tag. +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +} // namespace detail + +// Full +template > +HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, + Vec128 lo_half) { + const Half d2; + const RebindToUnsigned du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); +} + +// ------------------------------ ConcatLowerLower + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(Full128 /* tag */, const Vec128 hi, + const Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} +template +HWY_API Vec128 ConcatLowerLower(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatUpperUpper + +template +HWY_API Vec128 ConcatUpperUpper(Full128 /* tag */, const Vec128 hi, + const Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} +template +HWY_API Vec128 ConcatUpperUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +// ------------------------------ ConcatLowerUpper + +template +HWY_API Vec128 ConcatLowerUpper(Full128 d, const Vec128 hi, + const Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} +template +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +// ------------------------------ ConcatUpperLower +template +HWY_API Vec128 ConcatUpperLower(Simd d, const Vec128 hi, + const Vec128 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} + +// 8-bit x8 +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 17, 19, 21, + 23, 1, 3, 5, 7, 17, 19, 21, 23)}; +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper 3/4. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 17, 19, 1, 3, 17, + 19, 1, 3, 17, 19, 1, 3, 17, 19)}; +} + +// 16-bit full +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +// 16-bit x4 +template +HWY_API Vec128 ConcatOdd(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 9, 11, 1, 3, 9, 11)}; +} + +// 32-bit full +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30)}; +} + +// 8-bit x8 +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 16, 18, 20, + 22, 0, 2, 4, 6, 16, 18, 20, 22)}; +} + +// 8-bit x4 +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper 3/4. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 16, 18, 0, 2, 16, + 18, 0, 2, 16, 18, 0, 2, 16, 18)}; +} + +// 16-bit full +template +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14)}; +} + +// 16-bit x4 +template +HWY_API Vec128 ConcatEven(Simd /* tag */, Vec128 hi, + Vec128 lo) { + // Don't care about upper half. + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 8, 10, 0, 2, 8, 10)}; +} + +// 32-bit full +template +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 0, 0, 2, 2)}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<2> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<4> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<8> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; +} + +} // namespace detail + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +template +HWY_API Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u32x4_extend_low_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_extend_low_i8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_extend_low_i16x8(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i64x2_extend_low_i32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f64x2_convert_low_i32x4(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* di */, + const Vec128 v) { + return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128{DemoteTo(du16, bits16).raw}; +} + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// Specializations for partial vectors because i16x8_narrow_i32x4 sets lanes +// above 2*N. +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Full128 /*d16*/, + Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + const auto v4 = ConcatEven(d, v2, v2); + return LowerHalf(LowerHalf(LowerHalf(ConcatEven(d, v4, v4)))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + return LowerHalf(LowerHalf(ConcatEven(d, v2, v2))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + const auto v3 = ConcatEven(d, v2, v2); + return Vec128{v3.raw}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return Vec128{v2.raw}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return Vec128{v2.raw}; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f32x4_convert_i32x4(v.raw)}; +} +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f32x4_convert_u32x4(v.raw)}; +} +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return ConvertTo(Simd(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = And(BitCast(du16, v), Set(du16, 0xFF)); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128 vbits{wasm_i32x4_splat(static_cast(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Full +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return (hi + lo); +} + +// 64-bit +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return (static_cast(wasm_i64x2_extract_lane(mask.raw, 0)) * + kMagic) >> + 56; +} + +// 32-bit or less: need masking +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + uint64_t bytes = static_cast(wasm_i64x2_extract_lane(mask.raw, 0)); + // Clear potentially undefined bytes. + bytes &= (1ULL << (N * 8)) - 1; + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return (bytes * kMagic) >> 56; +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const __i16x8 zero = wasm_i16x8_splat(0); + const Mask128 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return BitsFromMask(hwy::SizeTag<1>(), mask8); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1] | lanes[2] | lanes[3]; +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const __i64x2 mask_i = static_cast<__i64x2>(mask.raw); + const __i64x2 slice = wasm_i64x2_make(1, 2); + const __i64x2 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1]; +} + +// Returns the lowest N bits for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) == 16) ? bits : bits & ((1ull << N) - 1); +} + +// Returns 0xFF for bytes with index >= N, otherwise 0. +template +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + alignas(16) int64_t lanes[2]; + wasm_v128_store(lanes, m.raw); + return static_cast(-(lanes[0] + lanes[1])); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Simd /* tag */, const Mask128 m) { + return detail::CountTrue(hwy::SizeTag(), m); +} + +// Partial vector +template +HWY_API size_t CountTrue(const Simd d, const Mask128 m) { + // Ensure all undefined bytes are 0. + const Mask128 mask{detail::BytesAbove()}; + return CountTrue(d, Mask128{AndNot(mask, m).raw}); +} + +// Full vector +template +HWY_API bool AllFalse(const Full128 d, const Mask128 m) { +#if 0 + // Casting followed by wasm_i8x16_any_true results in wasm error: + // i32.eqz[0] expected type i32, found i8x16.popcnt of type s128 + const auto v8 = BitCast(Full128(), VecFromMask(d, m)); + return !wasm_i8x16_any_true(v8.raw); +#else + (void)d; + return (wasm_i64x2_extract_lane(m.raw, 0) | + wasm_i64x2_extract_lane(m.raw, 1)) == 0; +#endif +} + +// Full vector +namespace detail { +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + return wasm_i8x16_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + return wasm_i16x8_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + return wasm_i32x4_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + return wasm_i64x2_all_true(m.raw); +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 m) { + return detail::AllTrue(hwy::SizeTag(), m); +} + +// Partial vectors + +template +HWY_API bool AllFalse(Simd /* tag */, const Mask128 m) { + // Ensure all undefined bytes are 0. + const Mask128 mask{detail::BytesAbove()}; + return AllFalse(Full128(), Mask128{AndNot(mask, m).raw}); +} + +template +HWY_API bool AllTrue(const Simd /* d */, const Mask128 m) { + // Ensure all undefined bytes are FF. + const Mask128 mask{detail::BytesAbove()}; + return AllTrue(Full128(), Mask128{Or(mask, m).raw}); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return bits ? static_cast(Num0BitsBelowLS1Bit_Nonzero64(bits)) : -1; +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromNotBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNot(v, detail::BitsFromMask(mask)); +} +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + using TU = TFromD; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Vec128 compressed = detail::Compress(BitCast(du, v), mask_bits); + const Mask128 store_mask = RebindMask(d, FirstN(du, count)); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ MulEven/Odd (Load) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = + Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = + Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + const Repartition du16; + const RebindToUnsigned du32; + const Vec128 zero = Zero(du16); + const Vec128 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec128 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec128 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec128 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is +// safe. +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd /*d32*/, Vec128 a, + Vec128 b, const Vec128 sum0, + Vec128& /*sum1*/) { + return sum0 + Vec128{wasm_i32x4_dot_i16x8(a.raw, b.raw)}; +} + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Vec128{Shuffle2301(Vec128{v10.raw}).raw}; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Vec128{Shuffle2301(Vec128{v10.raw}).raw}); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Vec128{Shuffle2301(Vec128{v10.raw}).raw}); +} + +// N=4 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask128 Lt128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128 eqHL = Eq(a, b); + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const Vec128 ltLx = DupEven(ltHL); + const Vec128 outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE Mask128 Lt128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE Mask128 Eq128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE Mask128 Eq128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE Mask128 Ne128(Simd d, Vec128 a, + Vec128 b) { + static_assert(!IsSigned() && sizeof(T) == 8, "T must be u64"); + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE Mask128 Ne128Upper(Simd d, Vec128 a, + Vec128 b) { + const Vec128 neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/wasm_256-inl.h b/media/highway/src/hwy/ops/wasm_256-inl.h new file mode 100644 index 000000000..42f4fb2f4 --- /dev/null +++ b/media/highway/src/hwy/ops/wasm_256-inl.h @@ -0,0 +1,3060 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit WASM vectors and operations. Experimental. +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" +#include "hwy/ops/wasm_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +using Full256 = Simd; + +template +using Full128 = Simd; + +// TODO(richardwinterton): add this to DeduceD in wasm_128 similar to x86_128. +template +class Vec256 { + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Vec128 v0; + Vec128 v1; +}; + +template +struct Mask256 { + Mask128 m0; + Mask128 m1; +}; + +// ------------------------------ BitCast + +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = BitCast(dh, v.v0); + ret.v1 = BitCast(dh, v.v1); + return ret; + + // TODO(richardwinterton): implement other ops like this +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{wasm_i32x4_splat(0)}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{wasm_f32x4_splat(0.0f)}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { + return Vec256{wasm_i8x16_splat(static_cast(t))}; +} +HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { + return Vec256{wasm_i16x8_splat(static_cast(t))}; +} +HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { + return Vec256{wasm_i32x4_splat(static_cast(t))}; +} +HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { + return Vec256{wasm_i64x2_splat(static_cast(t))}; +} + +HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { + return Vec256{wasm_i8x16_splat(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { + return Vec256{wasm_i16x8_splat(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { + return Vec256{wasm_i32x4_splat(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { + return Vec256{wasm_i64x2_splat(t)}; +} + +HWY_API Vec256 Set(Full256 /* tag */, const float t) { + return Vec256{wasm_f32x4_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec256 Undefined(Full256 d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec256 Iota(const Full256 d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_add(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_add(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_add(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_add(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_add(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_add(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_sub(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{wasm_i16x8_sub(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_sub(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_sub(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_sub(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_sub(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_sub(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + HWY_ABORT("not implemented"); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u8x16_add_sat(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_add_sat(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_add_sat(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_add_sat(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u8x16_sub_sat(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_sub_sat(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i8x16_sub_sat(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_sub_sat(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u8x16_avgr(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i8x16_abs(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i16x8_abs(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i32x4_abs(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_i62x2_abs(v.raw)}; +} + +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{wasm_f32x4_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_u16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_u32x4_shr(v.raw, kBits)}; +} + +// Signed +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_i16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{wasm_i32x4_shr(v.raw, kBits)}; +} + +// 8-bit +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec256 RotateRight(const Vec256 v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +// Unsigned +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i16x8_shl(v.raw, bits)}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_u16x8_shr(v.raw, bits)}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i32x4_shl(v.raw, bits)}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_u32x4_shr(v.raw, bits)}; +} + +// Signed +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{wasm_i16x8_shl(v.raw, bits)}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i16x8_shr(v.raw, bits)}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{wasm_i32x4_shl(v.raw, bits)}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{wasm_i32x4_shr(v.raw, bits)}; +} + +// 8-bit +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, (0xFF << bits) & 0xFF); +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_u8x16_min(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_min(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u32x4_min(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + alignas(32) float min[4]; + min[0] = + HWY_MIN(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); + min[1] = + HWY_MIN(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); + return Vec256{wasm_v128_load(min)}; +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i8x16_min(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i16x8_min(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i32x4_min(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + alignas(32) float min[4]; + min[0] = + HWY_MIN(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); + min[1] = + HWY_MIN(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); + return Vec256{wasm_v128_load(min)}; +} + +// Float +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_min(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_u8x16_max(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u16x8_max(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_u32x4_max(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + alignas(32) float max[4]; + max[0] = + HWY_MAX(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); + max[1] = + HWY_MAX(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); + return Vec256{wasm_v128_load(max)}; +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i8x16_max(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i16x8_max(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i32x4_max(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + alignas(32) float max[4]; + max[0] = + HWY_MAX(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); + max[1] = + HWY_MAX(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); + return Vec256{wasm_v128_load(max)}; +} + +// Float +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_max(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_mul(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_mul(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256 MulHigh(const Vec256 a, + const Vec256 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto al = wasm_u32x4_extend_low_u16x8(a.raw); + const auto ah = wasm_u32x4_extend_high_u16x8(a.raw); + const auto bl = wasm_u32x4_extend_low_u16x8(b.raw); + const auto bh = wasm_u32x4_extend_high_u16x8(b.raw); + const auto l = wasm_i32x4_mul(al, bl); + const auto h = wasm_i32x4_mul(ah, bh); + // TODO(eustas): shift-right + narrow? + return Vec256{wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +HWY_API Vec256 MulHigh(const Vec256 a, + const Vec256 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto al = wasm_i32x4_extend_low_i16x8(a.raw); + const auto ah = wasm_i32x4_extend_high_i16x8(a.raw); + const auto bl = wasm_i32x4_extend_low_i16x8(b.raw); + const auto bh = wasm_i32x4_extend_high_i16x8(b.raw); + const auto l = wasm_i32x4_mul(al, bl); + const auto h = wasm_i32x4_mul(ah, bh); + // TODO(eustas): shift-right + narrow? + return Vec256{wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +HWY_API Vec256 MulFixedPoint15(Vec256, Vec256) { + HWY_ASSERT(0); // Not implemented +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +HWY_API Vec256 MulEven(const Vec256 a, + const Vec256 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec256{wasm_i64x2_mul(ae, be)}; +} +HWY_API Vec256 MulEven(const Vec256 a, + const Vec256 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec256{wasm_i64x2_mul(ae, be)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec256 Neg(const Vec256 v) { + return Xor(v, SignBit(Full256())); +} + +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i8x16_neg(v.raw)}; +} +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i16x8_neg(v.raw)}; +} +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i32x4_neg(v.raw)}; +} +HWY_API Vec256 Neg(const Vec256 v) { + return Vec256{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{wasm_f32x4_mul(a.raw, b.raw)}; +} + +HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { + return Vec256{wasm_f32x4_div(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + const Vec256 one = Vec256{wasm_f32x4_splat(1.0f)}; + return one / v; +} + +// Absolute value of difference. +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfma? + return mul * x + add; +} + +// Returns add - mul * x +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { + // TODO(eustas): replace, when implemented in WASM. + return add - mul * x; +} + +// Returns mul * x - sub +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfms? + return mul * x - sub; +} + +// Returns -mul * x - sub +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { + // TODO(eustas): replace, when implemented in WASM. + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{wasm_f32x4_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec256 one = Vec256{wasm_f32x4_splat(1.0f)}; + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{wasm_f32x4_nearest(v.raw)}; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{wasm_f32x4_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{wasm_f32x4_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{wasm_f32x4_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask256 IsNaN(const Vec256 v) { + return v != v; +} + +template +HWY_API Mask256 IsInf(const Vec256 v) { + const Full256 d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + const Full256 d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256{m.raw}; +} + +template +HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_eq(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i16x8_eq(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_eq(a.raw, b.raw)}; +} + +// Signed +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_eq(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{wasm_i16x8_eq(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_eq(a.raw, b.raw)}; +} + +// Float +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_eq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// Unsigned +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_ne(a.raw, b.raw)}; +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i16x8_ne(a.raw, b.raw)}; +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_ne(a.raw, b.raw)}; +} + +// Signed +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_ne(a.raw, b.raw)}; +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{wasm_i16x8_ne(a.raw, b.raw)}; +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_ne(a.raw, b.raw)}; +} + +// Float +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_ne(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i8x16_gt(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i16x8_gt(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_i32x4_gt(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + const Rebind < int32_t, DFromV d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper half is less than or greater, this is the answer. + const auto m_gt = a32 < b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt, m_gt, 2, 2, 0, 0); + const auto lo_gt = And(m_eq, lo_in_hi); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask256{wasm_i32x4_shuffle(gt, gt, 3, 3, 1, 1)}; +} + +template +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + const Full256 du; + const RebindToSigned di; + const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); + return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); +} + +HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { + return Mask256{wasm_f32x4_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float <= >= +HWY_API Mask256 operator<=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_le(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{wasm_f32x4_ge(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask256 FirstN(const Full256 d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec256 Not(Vec256 v) { + return Vec256{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Or3 + +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Full256()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Full256()), sign)); +} + +// ------------------------------ BroadcastSignBit (compare) + +template +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight(v); +} +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return VecFromMask(Full256(), v < Zero(Full256())); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(Full256 /* tag */, Mask256 v) { + return Vec256{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return Vec256{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template + HWY_API Vec256 < + T IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + HWY_ASSERT(0); // Not implemented +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + const Full256 d; + const auto zero = Zero(d); + return IfThenElse(Mask256{(v > zero).raw}, v, zero); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template +HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { + const Full256 d; + Mask256 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { + const Full256 d; + Mask256 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template +HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { + const Full256 d; + Mask256 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { + const Full256 d; + Mask256 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec256{wasm_v128_load(aligned)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// LoadU == Load. +template +HWY_API Vec256 LoadU(Full256 d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec256 LoadDup128(Full256 d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// StoreU == Store. +template +HWY_API void StoreU(Vec256 v, Full256 d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(Vec256 v, Full256 /* tag */, + T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[32 / sizeof(T)]; + Store(offset, Full256(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[32 / sizeof(T)]; + Store(index, Full256(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec256 GatherOffset(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) Offset offset_lanes[32 / sizeof(T)]; + Store(offset, Full256(), offset_lanes); + + alignas(32) T lanes[32 / sizeof(T)]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec256 GatherIndex(const Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) Index index_lanes[32 / sizeof(T)]; + Store(index, Full256(), index_lanes); + + alignas(32) T lanes[32 / sizeof(T)]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ InsertLane +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ GetLane +// Gets the single value stored in a vector/part. +HWY_API uint8_t GetLane(const Vec256 v) { + return wasm_i8x16_extract_lane(v.raw, 0); +} +HWY_API int8_t GetLane(const Vec256 v) { + return wasm_i8x16_extract_lane(v.raw, 0); +} +HWY_API uint16_t GetLane(const Vec256 v) { + return wasm_i16x8_extract_lane(v.raw, 0); +} +HWY_API int16_t GetLane(const Vec256 v) { + return wasm_i16x8_extract_lane(v.raw, 0); +} +HWY_API uint32_t GetLane(const Vec256 v) { + return wasm_i32x4_extract_lane(v.raw, 0); +} +HWY_API int32_t GetLane(const Vec256 v) { + return wasm_i32x4_extract_lane(v.raw, 0); +} +HWY_API uint64_t GetLane(const Vec256 v) { + return wasm_i64x2_extract_lane(v.raw, 0); +} +HWY_API int64_t GetLane(const Vec256 v) { + return wasm_i64x2_extract_lane(v.raw, 0); +} + +HWY_API float GetLane(const Vec256 v) { + return wasm_f32x4_extract_lane(v.raw, 0); +} + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return LowerHalf(Full128(), v); +} + +// ------------------------------ ShiftLeftBytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 0, 1, 2, 3, 4, 5, 6)}; + + case 10: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5)}; + + case 11: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4)}; + + case 12: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3)}; + + case 13: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, 2)}; + + case 14: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1)}; + + case 15: + return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0)}; + } + return Vec256{zero}; +} + +template +HWY_API Vec256 ShiftLeftBytes(Vec256 v) { + return ShiftLeftBytes(Full256(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(Full256(), v); +} + +// ------------------------------ ShiftRightBytes +namespace detail { + +// Helper function allows zeroing invalid lanes in caller. +template +HWY_API __i8x16 ShrBytes(const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + + switch (kBytes) { + case 0: + return v.raw; + + case 1: + return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16); + + case 2: + return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16); + + case 3: + return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16); + + case 4: + return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16); + + case 5: + return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16); + + case 6: + return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16); + + case 7: + return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16); + + case 8: + return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 9: + return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 10: + return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 11: + return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 12: + return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 13: + return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 14: + return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 15: + return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + case 16: + return zero; + } +} + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, Vec256 v) { + return Vec256{detail::ShrBytes(v)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec128 UpperHalf(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} +HWY_API Vec128 UpperHalf(Full128 /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full256 /* tag */, V hi, V lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16)}; + + case 2: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21)}; + + case 7: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22)}; + + case 8: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23)}; + + case 9: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24)}; + + case 10: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25)}; + + case 11: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26)}; + + case 12: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27)}; + + case 13: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28)}; + + case 14: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29)}; + + case 15: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30)}; + } + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec256{wasm_i16x8_shuffle( + v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec256{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +// Signed +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec256{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec256{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +// Float +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec256{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +// ------------------------------ TableLookupBytes + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, + const Vec256 from) { +// Not yet available in all engines, see +// https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md +// V8 implementation of this had a bug, fixed on 2021-04-03: +// https://chromium-review.googlesource.com/c/v8/v8/+/2822951 +#if 0 + return Vec256{wasm_i8x16_swizzle(bytes.raw, from.raw)}; +#else + alignas(32) uint8_t control[16]; + alignas(32) uint8_t input[16]; + alignas(32) uint8_t output[16]; + wasm_v128_store(control, from.raw); + wasm_v128_store(input, bytes.raw); + for (size_t i = 0; i < 16; ++i) { + output[i] = control[i] < 16 ? input[control[i]] : 0; + } + return Vec256{wasm_v128_load(output)}; +#endif +} + +template +HWY_API Vec256 TableLookupBytesOr0(const Vec256 bytes, + const Vec256 from) { + const Full256 d; + // Mask size must match vector type, so cast everything to this type. + Repartition di8; + Repartition> d_bytes8; + const auto msb = BitCast(di8, from) < Zero(di8); + const auto lookup = + TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices256 { + __v128_u raw; +}; + +template +HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + return Indices256{}; +} + +template +HWY_API Indices256 SetTableIndices(Full256 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + using TI = MakeSigned; + const Full256 d; + const Full256 di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec256{idx.raw})); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +template +HWY_API Vec256 Reverse(Full256 /* tag */, const Vec256 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec256 Reverse(Full256 /* tag */, const Vec256 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ InterleaveLower + +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 1, 17, 2, 18, + 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, + Vec256 b) { + return Vec256{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, + Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, + Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, + 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +// Additional overload for the optional tag. +template > +HWY_API V InterleaveLower(Full256 /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, 26, + 11, 27, 12, 28, 13, 29, 14, 30, 15, + 31)}; +} +HWY_API Vec256 InterleaveUpper(Vec256 a, + Vec256 b) { + return Vec256{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +HWY_API Vec256 InterleaveUpper(Vec256 a, + Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +HWY_API Vec256 InterleaveUpper(Vec256 a, + Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, 26, + 11, 27, 12, 28, 13, 29, 14, 30, 15, + 31)}; +} +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +} // namespace detail + +template > +HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(Vec256 a, Vec256 b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template +HWY_API Vec256 Combine(Full256 d, Vec128 hi_half, Vec128 lo_half) { + const Half d2; + const RebindToUnsigned du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API Vec256 ZeroExtendVector(Full256 d, Vec128 lo) { + return IfThenElseZero(FirstN(d, 16 / sizeof(T)), Vec256{lo.raw}); +} + +// ------------------------------ ConcatLowerLower + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec256 ConcatLowerLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} + +// ------------------------------ ConcatUpperUpper + +template +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} + +// ------------------------------ ConcatLowerUpper + +template +HWY_API Vec256 ConcatLowerUpper(Full256 d, const Vec256 hi, + const Vec256 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// ------------------------------ ConcatUpperLower +template +HWY_API Vec256 ConcatUpperLower(Full256 d, const Vec256 hi, + const Vec256 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd + +// 32-bit +template +HWY_API Vec256 ConcatOdd(Full256 /* tag */, Vec256 hi, Vec256 lo) { + return Vec256{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; +} + +// 64-bit full - no partial because we need at least two inputs to have +// even/odd. +template +HWY_API Vec256 ConcatOdd(Full256 /* tag */, Vec256 hi, Vec256 lo) { + return InterleaveUpper(Full256(), lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 32-bit full +template +HWY_API Vec256 ConcatEven(Full256 /* tag */, Vec256 hi, Vec256 lo) { + return Vec256{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; +} + +// 64-bit full - no partial because we need at least two inputs to have +// even/odd. +template +HWY_API Vec256 ConcatEven(Full256 /* tag */, Vec256 hi, Vec256 lo) { + return InterleaveLower(Full256(), lo, hi); +} + +// ------------------------------ DupEven +template +HWY_API Vec256 DupEven(Vec256 v) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ DupOdd +template +HWY_API Vec256 DupOdd(Vec256 v) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, + const Vec256 b) { + const Full256 d; + const Repartition d8; + alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; +} + +} // namespace detail + +template +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec256 OddEvenBlocks(Vec256 /* odd */, Vec256 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return v; +} + +// ------------------------------ ReverseBlocks + +template +HWY_API Vec256 ReverseBlocks(Full256 /* tag */, const Vec256 v) { + return v; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_u32x4_extend_low_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_i16x8_extend_low_i8x16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{ + wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_i32x4_extend_low_i16x8(v.raw)}; +} + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{wasm_f64x2_convert_low_i32x4(v.raw)}; +} + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + const Full256 di32; + const Full256 du32; + const Full256 df32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec256{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +HWY_API Vec256 PromoteTo(Full256 df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* di */, + const Vec256 v) { + return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const Full256 di; + const Full256 du; + const Full256 du16; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128{DemoteTo(du16, bits16).raw}; +} + +HWY_API Vec128 DemoteTo(Full128 dbf16, + const Vec256 v) { + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec128 ReorderDemote2To(Full128 dbf16, + Vec256 a, Vec256 b) { + const RebindToUnsigned du16; + const Repartition du32; + const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec512 ReorderDemote2To(Full512 /*d16*/, + Vec512 a, Vec512 b) { + return Vec512{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec256 U8FromU32(const Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec256{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Truncations + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec256 v) { + return Vec256{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, + 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, + 16, 24)}; +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec256 v) { + return Vec256{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, + 16, 17, 24, 25, 0, 1, 8, 9, 16, + 17, 24, 25)}; +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec256 v) { + return Vec256{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, + 8, 9, 10, 11, 16, 17, 18, 19, + 24, 25, 26, 27)}; +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec256 v) { + return Vec256{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, + 16, 20, 24, 28, 0, 4, 8, 12, 16, + 20, 24, 28)}; +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec256 v) { + return Vec256{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, + 8, 9, 12, 13, 16, 17, 20, 21, + 24, 25, 28, 29)}; +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec256 v) { + return Vec256{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, + 8, 10, 12, 14, 16, 18, 20, 22, + 24, 26, 28, 30)}; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{wasm_f32x4_convert_i32x4(v.raw)}; +} +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{wasm_f32x4_convert_u32x4(v.raw)}; +} +// Truncates (rounds toward zero). +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +} + +HWY_API Vec256 NearestInt(const Vec256 v) { + return ConvertTo(Full256(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec256 vbits{wasm_i32x4_splat(static_cast(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(32) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Full +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(32) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return (hi + lo); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask256 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const __i16x8 zero = wasm_i16x8_splat(0); + const Mask256 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return BitsFromMask(hwy::SizeTag<1>(), mask8); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask256 mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(32) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1] | lanes[2] | lanes[3]; +} + +// Returns 0xFF for bytes with index >= N, otherwise 0. +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + return BitsFromMask(hwy::SizeTag(), mask); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(32) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, + uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Full256 /* tag */, const Mask128 m) { + return detail::CountTrue(hwy::SizeTag(), m); +} + +template +HWY_API bool AllFalse(const Full256 d, const Mask128 m) { +#if 0 + // Casting followed by wasm_i8x16_any_true results in wasm error: + // i32.eqz[0] expected type i32, found i8x16.popcnt of type s128 + const auto v8 = BitCast(Full256(), VecFromMask(d, m)); + return !wasm_i8x16_any_true(v8.raw); +#else + (void)d; + return (wasm_i64x2_extract_lane(m.raw, 0) | + wasm_i64x2_extract_lane(m.raw, 1)) == 0; +#endif +} + +// Full vector +namespace detail { +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + return wasm_i8x16_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + return wasm_i16x8_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + return wasm_i32x4_all_true(m.raw); +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Full256 /* tag */, const Mask128 m) { + return detail::AllTrue(hwy::SizeTag(), m); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return bits ? Num0BitsBelowLS1Bit_Nonzero64(bits) : -1; +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec256 Idx16x8FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Full256 d; + const Rebind d8; + const Full256 du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(32) constexpr uint8_t table[256 * 8] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, + 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, + 0, 6, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 2, + 6, 0, 0, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 0, 0, 4, 6, 0, + 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, + 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, + 2, 8, 0, 0, 0, 0, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 4, 8, + 0, 0, 0, 0, 0, 0, 0, 4, 8, 0, 0, 0, 0, 0, 2, 4, 8, 0, + 0, 0, 0, 0, 0, 2, 4, 8, 0, 0, 0, 0, 6, 8, 0, 0, 0, 0, + 0, 0, 0, 6, 8, 0, 0, 0, 0, 0, 2, 6, 8, 0, 0, 0, 0, 0, + 0, 2, 6, 8, 0, 0, 0, 0, 4, 6, 8, 0, 0, 0, 0, 0, 0, 4, + 6, 8, 0, 0, 0, 0, 2, 4, 6, 8, 0, 0, 0, 0, 0, 2, 4, 6, + 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, + 0, 0, 2, 10, 0, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 0, + 4, 10, 0, 0, 0, 0, 0, 0, 0, 4, 10, 0, 0, 0, 0, 0, 2, 4, + 10, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0, 0, 0, 0, 6, 10, 0, 0, + 0, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 2, 6, 10, 0, 0, 0, + 0, 0, 0, 2, 6, 10, 0, 0, 0, 0, 4, 6, 10, 0, 0, 0, 0, 0, + 0, 4, 6, 10, 0, 0, 0, 0, 2, 4, 6, 10, 0, 0, 0, 0, 0, 2, + 4, 6, 10, 0, 0, 0, 8, 10, 0, 0, 0, 0, 0, 0, 0, 8, 10, 0, + 0, 0, 0, 0, 2, 8, 10, 0, 0, 0, 0, 0, 0, 2, 8, 10, 0, 0, + 0, 0, 4, 8, 10, 0, 0, 0, 0, 0, 0, 4, 8, 10, 0, 0, 0, 0, + 2, 4, 8, 10, 0, 0, 0, 0, 0, 2, 4, 8, 10, 0, 0, 0, 6, 8, + 10, 0, 0, 0, 0, 0, 0, 6, 8, 10, 0, 0, 0, 0, 2, 6, 8, 10, + 0, 0, 0, 0, 0, 2, 6, 8, 10, 0, 0, 0, 4, 6, 8, 10, 0, 0, + 0, 0, 0, 4, 6, 8, 10, 0, 0, 0, 2, 4, 6, 8, 10, 0, 0, 0, + 0, 2, 4, 6, 8, 10, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 12, + 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, + 0, 0, 0, 0, 4, 12, 0, 0, 0, 0, 0, 0, 0, 4, 12, 0, 0, 0, + 0, 0, 2, 4, 12, 0, 0, 0, 0, 0, 0, 2, 4, 12, 0, 0, 0, 0, + 6, 12, 0, 0, 0, 0, 0, 0, 0, 6, 12, 0, 0, 0, 0, 0, 2, 6, + 12, 0, 0, 0, 0, 0, 0, 2, 6, 12, 0, 0, 0, 0, 4, 6, 12, 0, + 0, 0, 0, 0, 0, 4, 6, 12, 0, 0, 0, 0, 2, 4, 6, 12, 0, 0, + 0, 0, 0, 2, 4, 6, 12, 0, 0, 0, 8, 12, 0, 0, 0, 0, 0, 0, + 0, 8, 12, 0, 0, 0, 0, 0, 2, 8, 12, 0, 0, 0, 0, 0, 0, 2, + 8, 12, 0, 0, 0, 0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 4, 8, 12, + 0, 0, 0, 0, 2, 4, 8, 12, 0, 0, 0, 0, 0, 2, 4, 8, 12, 0, + 0, 0, 6, 8, 12, 0, 0, 0, 0, 0, 0, 6, 8, 12, 0, 0, 0, 0, + 2, 6, 8, 12, 0, 0, 0, 0, 0, 2, 6, 8, 12, 0, 0, 0, 4, 6, + 8, 12, 0, 0, 0, 0, 0, 4, 6, 8, 12, 0, 0, 0, 2, 4, 6, 8, + 12, 0, 0, 0, 0, 2, 4, 6, 8, 12, 0, 0, 10, 12, 0, 0, 0, 0, + 0, 0, 0, 10, 12, 0, 0, 0, 0, 0, 2, 10, 12, 0, 0, 0, 0, 0, + 0, 2, 10, 12, 0, 0, 0, 0, 4, 10, 12, 0, 0, 0, 0, 0, 0, 4, + 10, 12, 0, 0, 0, 0, 2, 4, 10, 12, 0, 0, 0, 0, 0, 2, 4, 10, + 12, 0, 0, 0, 6, 10, 12, 0, 0, 0, 0, 0, 0, 6, 10, 12, 0, 0, + 0, 0, 2, 6, 10, 12, 0, 0, 0, 0, 0, 2, 6, 10, 12, 0, 0, 0, + 4, 6, 10, 12, 0, 0, 0, 0, 0, 4, 6, 10, 12, 0, 0, 0, 2, 4, + 6, 10, 12, 0, 0, 0, 0, 2, 4, 6, 10, 12, 0, 0, 8, 10, 12, 0, + 0, 0, 0, 0, 0, 8, 10, 12, 0, 0, 0, 0, 2, 8, 10, 12, 0, 0, + 0, 0, 0, 2, 8, 10, 12, 0, 0, 0, 4, 8, 10, 12, 0, 0, 0, 0, + 0, 4, 8, 10, 12, 0, 0, 0, 2, 4, 8, 10, 12, 0, 0, 0, 0, 2, + 4, 8, 10, 12, 0, 0, 6, 8, 10, 12, 0, 0, 0, 0, 0, 6, 8, 10, + 12, 0, 0, 0, 2, 6, 8, 10, 12, 0, 0, 0, 0, 2, 6, 8, 10, 12, + 0, 0, 4, 6, 8, 10, 12, 0, 0, 0, 0, 4, 6, 8, 10, 12, 0, 0, + 2, 4, 6, 8, 10, 12, 0, 0, 0, 2, 4, 6, 8, 10, 12, 0, 14, 0, + 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 2, 14, 0, 0, + 0, 0, 0, 0, 0, 2, 14, 0, 0, 0, 0, 0, 4, 14, 0, 0, 0, 0, + 0, 0, 0, 4, 14, 0, 0, 0, 0, 0, 2, 4, 14, 0, 0, 0, 0, 0, + 0, 2, 4, 14, 0, 0, 0, 0, 6, 14, 0, 0, 0, 0, 0, 0, 0, 6, + 14, 0, 0, 0, 0, 0, 2, 6, 14, 0, 0, 0, 0, 0, 0, 2, 6, 14, + 0, 0, 0, 0, 4, 6, 14, 0, 0, 0, 0, 0, 0, 4, 6, 14, 0, 0, + 0, 0, 2, 4, 6, 14, 0, 0, 0, 0, 0, 2, 4, 6, 14, 0, 0, 0, + 8, 14, 0, 0, 0, 0, 0, 0, 0, 8, 14, 0, 0, 0, 0, 0, 2, 8, + 14, 0, 0, 0, 0, 0, 0, 2, 8, 14, 0, 0, 0, 0, 4, 8, 14, 0, + 0, 0, 0, 0, 0, 4, 8, 14, 0, 0, 0, 0, 2, 4, 8, 14, 0, 0, + 0, 0, 0, 2, 4, 8, 14, 0, 0, 0, 6, 8, 14, 0, 0, 0, 0, 0, + 0, 6, 8, 14, 0, 0, 0, 0, 2, 6, 8, 14, 0, 0, 0, 0, 0, 2, + 6, 8, 14, 0, 0, 0, 4, 6, 8, 14, 0, 0, 0, 0, 0, 4, 6, 8, + 14, 0, 0, 0, 2, 4, 6, 8, 14, 0, 0, 0, 0, 2, 4, 6, 8, 14, + 0, 0, 10, 14, 0, 0, 0, 0, 0, 0, 0, 10, 14, 0, 0, 0, 0, 0, + 2, 10, 14, 0, 0, 0, 0, 0, 0, 2, 10, 14, 0, 0, 0, 0, 4, 10, + 14, 0, 0, 0, 0, 0, 0, 4, 10, 14, 0, 0, 0, 0, 2, 4, 10, 14, + 0, 0, 0, 0, 0, 2, 4, 10, 14, 0, 0, 0, 6, 10, 14, 0, 0, 0, + 0, 0, 0, 6, 10, 14, 0, 0, 0, 0, 2, 6, 10, 14, 0, 0, 0, 0, + 0, 2, 6, 10, 14, 0, 0, 0, 4, 6, 10, 14, 0, 0, 0, 0, 0, 4, + 6, 10, 14, 0, 0, 0, 2, 4, 6, 10, 14, 0, 0, 0, 0, 2, 4, 6, + 10, 14, 0, 0, 8, 10, 14, 0, 0, 0, 0, 0, 0, 8, 10, 14, 0, 0, + 0, 0, 2, 8, 10, 14, 0, 0, 0, 0, 0, 2, 8, 10, 14, 0, 0, 0, + 4, 8, 10, 14, 0, 0, 0, 0, 0, 4, 8, 10, 14, 0, 0, 0, 2, 4, + 8, 10, 14, 0, 0, 0, 0, 2, 4, 8, 10, 14, 0, 0, 6, 8, 10, 14, + 0, 0, 0, 0, 0, 6, 8, 10, 14, 0, 0, 0, 2, 6, 8, 10, 14, 0, + 0, 0, 0, 2, 6, 8, 10, 14, 0, 0, 4, 6, 8, 10, 14, 0, 0, 0, + 0, 4, 6, 8, 10, 14, 0, 0, 2, 4, 6, 8, 10, 14, 0, 0, 0, 2, + 4, 6, 8, 10, 14, 0, 12, 14, 0, 0, 0, 0, 0, 0, 0, 12, 14, 0, + 0, 0, 0, 0, 2, 12, 14, 0, 0, 0, 0, 0, 0, 2, 12, 14, 0, 0, + 0, 0, 4, 12, 14, 0, 0, 0, 0, 0, 0, 4, 12, 14, 0, 0, 0, 0, + 2, 4, 12, 14, 0, 0, 0, 0, 0, 2, 4, 12, 14, 0, 0, 0, 6, 12, + 14, 0, 0, 0, 0, 0, 0, 6, 12, 14, 0, 0, 0, 0, 2, 6, 12, 14, + 0, 0, 0, 0, 0, 2, 6, 12, 14, 0, 0, 0, 4, 6, 12, 14, 0, 0, + 0, 0, 0, 4, 6, 12, 14, 0, 0, 0, 2, 4, 6, 12, 14, 0, 0, 0, + 0, 2, 4, 6, 12, 14, 0, 0, 8, 12, 14, 0, 0, 0, 0, 0, 0, 8, + 12, 14, 0, 0, 0, 0, 2, 8, 12, 14, 0, 0, 0, 0, 0, 2, 8, 12, + 14, 0, 0, 0, 4, 8, 12, 14, 0, 0, 0, 0, 0, 4, 8, 12, 14, 0, + 0, 0, 2, 4, 8, 12, 14, 0, 0, 0, 0, 2, 4, 8, 12, 14, 0, 0, + 6, 8, 12, 14, 0, 0, 0, 0, 0, 6, 8, 12, 14, 0, 0, 0, 2, 6, + 8, 12, 14, 0, 0, 0, 0, 2, 6, 8, 12, 14, 0, 0, 4, 6, 8, 12, + 14, 0, 0, 0, 0, 4, 6, 8, 12, 14, 0, 0, 2, 4, 6, 8, 12, 14, + 0, 0, 0, 2, 4, 6, 8, 12, 14, 0, 10, 12, 14, 0, 0, 0, 0, 0, + 0, 10, 12, 14, 0, 0, 0, 0, 2, 10, 12, 14, 0, 0, 0, 0, 0, 2, + 10, 12, 14, 0, 0, 0, 4, 10, 12, 14, 0, 0, 0, 0, 0, 4, 10, 12, + 14, 0, 0, 0, 2, 4, 10, 12, 14, 0, 0, 0, 0, 2, 4, 10, 12, 14, + 0, 0, 6, 10, 12, 14, 0, 0, 0, 0, 0, 6, 10, 12, 14, 0, 0, 0, + 2, 6, 10, 12, 14, 0, 0, 0, 0, 2, 6, 10, 12, 14, 0, 0, 4, 6, + 10, 12, 14, 0, 0, 0, 0, 4, 6, 10, 12, 14, 0, 0, 2, 4, 6, 10, + 12, 14, 0, 0, 0, 2, 4, 6, 10, 12, 14, 0, 8, 10, 12, 14, 0, 0, + 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 2, 8, 10, 12, 14, 0, 0, 0, + 0, 2, 8, 10, 12, 14, 0, 0, 4, 8, 10, 12, 14, 0, 0, 0, 0, 4, + 8, 10, 12, 14, 0, 0, 2, 4, 8, 10, 12, 14, 0, 0, 0, 2, 4, 8, + 10, 12, 14, 0, 6, 8, 10, 12, 14, 0, 0, 0, 0, 6, 8, 10, 12, 14, + 0, 0, 2, 6, 8, 10, 12, 14, 0, 0, 0, 2, 6, 8, 10, 12, 14, 0, + 4, 6, 8, 10, 12, 14, 0, 0, 0, 4, 6, 8, 10, 12, 14, 0, 2, 4, + 6, 8, 10, 12, 14, 0, 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec256 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec256 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec256 Idx32x4FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(32) constexpr uint8_t packed_array[16 * 16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, // + 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, // + 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Full256 d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template +HWY_INLINE Vec256 Idx64x2FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(32) constexpr uint8_t packed_array[4 * 16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Full256 d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +#endif + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_INLINE Vec256 Compress(hwy::SizeTag<2> /*tag*/, Vec256 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx16x8FromBits(mask_bits); + using D = Full256; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec256 Compress(hwy::SizeTag<4> /*tag*/, Vec256 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx32x4FromBits(mask_bits); + using D = Full256; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template +HWY_INLINE Vec256 Compress(hwy::SizeTag<8> /*tag*/, + Vec256 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx64x2FromBits(mask_bits); + using D = Full256; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +#endif + +} // namespace detail + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec256 Compress(Vec256 v, const Mask256 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return detail::Compress(hwy::SizeTag(), v, mask_bits); +} + +// ------------------------------ CompressNot +template +HWY_API Vec256 Compress(Vec256 v, const Mask256 mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + HWY_ASSERT(0); // Not implemented +} + +// ------------------------------ CompressBits + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(hwy::SizeTag(), v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const auto c = detail::Compress(hwy::SizeTag(), v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + using TU = TFromD; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Mask256 store_mask = FirstN(du, count); + const Vec256 compressed = + detail::Compress(hwy::SizeTag(), BitCast(du, v), mask_bits); + const Vec256 prev = BitCast(du, LoadU(d, unaligned)); + StoreU(BitCast(d, IfThenElse(store_mask, compressed, prev)), d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + Full256 d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + const auto c = detail::Compress(hwy::SizeTag(), v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ MulEven/Odd (Load) + +HWY_INLINE Vec256 MulEven(const Vec256 a, + const Vec256 b) { + alignas(32) uint64_t mul[2]; + mul[0] = + Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full256(), mul); +} + +HWY_INLINE Vec256 MulOdd(const Vec256 a, + const Vec256 b) { + alignas(32) uint64_t mul[2]; + mul[0] = + Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full256(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 df32, + Vec256 a, + Vec256 b, + const Vec256 sum0, + Vec256& sum1) { + const Repartition du16; + const RebindToUnsigned du32; + const Vec256 zero = Zero(du16); + const Vec256 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec256 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec256 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec256 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 /*d32*/, + Vec256 a, + Vec256 b, + const Vec256 sum0, + Vec256& /*sum1*/) { + return sum0 + Vec256{wasm_i32x4_dot_i16x8(a.raw, b.raw)}; +} + +// ------------------------------ Reductions + +namespace detail { + +// u32/i32/f32: + +template +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const Vec256 v1032 = Shuffle1032(v3210); + const Vec256 v31_20_31_20 = v3210 + v1032; + const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const Vec256 v1032 = Shuffle1032(v3210); + const Vec256 v31_20_31_20 = Min(v3210, v1032); + const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const Vec256 v1032 = Shuffle1032(v3210); + const Vec256 v31_20_31_20 = Max(v3210, v1032); + const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +template +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const Vec256 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const Vec256 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const Vec256 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +// u16/i16 +template +HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec256 /*v*/) { + HWY_ASSERT(0); // Not implemented +} +template +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec256 /*v*/) { + HWY_ASSERT(0); // Not implemented +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template +HWY_API Vec256 SumOfLanes(Full256 /* tag */, const Vec256 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec256 MinOfLanes(Full256 /* tag */, const Vec256 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec256 MaxOfLanes(Full256 /* tag */, const Vec256 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE Mask256 Lt128(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Mask256 Lt128Upper(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Mask256 Eq128(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Mask256 Eq128Upper(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Mask256 Ne128(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Mask256 Ne128Upper(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Vec256 Min128(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Vec256 Max128(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Vec256 Min128Upper(Full256 d, Vec256 a, Vec256 b) {} + +template +HWY_INLINE Vec256 Max128Upper(Full256 d, Vec256 a, Vec256 b) {} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/media/highway/src/hwy/ops/x86_128-inl.h b/media/highway/src/hwy/ops/x86_128-inl.h new file mode 100644 index 000000000..68b156e5a --- /dev/null +++ b/media/highway/src/hwy/ops/x86_128-inl.h @@ -0,0 +1,7485 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_GCC_ACTUAL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's emmintrin.h - see +// https://github.com/google/highway/issues/710 and pull/902) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include +#include +#if HWY_TARGET == HWY_SSSE3 +#include // SSSE3 +#else +#include // SSE4 +#include // CLMUL +#endif +#include +#include +#include // memcpy + +#include "hwy/ops/shared-inl.h" + +#if HWY_IS_MSAN +#include +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET <= HWY_AVX2 +template +using Full256 = Simd; +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +using Full512 = Simd; +#endif + +namespace detail { + +template +struct Raw128 { + using type = __m128i; +}; +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +#if HWY_TARGET <= HWY_AVX3 + +// Forward-declare for use by DeduceD, see below. +template +class Vec512; + +namespace detail { + +// Template arg: sizeof(lane type) +template +struct RawMask128 {}; +template <> +struct RawMask128<1> { + using type = __mmask16; +}; +template <> +struct RawMask128<2> { + using type = __mmask8; +}; +template <> +struct RawMask128<4> { + using type = __mmask8; +}; +template <> +struct RawMask128<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +struct Mask128 { + using Raw = typename detail::RawMask128::type; + + static Mask128 FromBits(uint64_t mask_bits) { + return Mask128{static_cast(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 or below + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_TARGET <= HWY_AVX2 +// Forward-declare for use by DeduceD, see below. +template +class Vec256; +#endif + +namespace detail { + +// Deduce Simd from Vec* (pointers because Vec256/512 may be +// incomplete types at this point; this is simpler than avoiding multiple +// definitions of DFromV via #if) +struct DeduceD { + template + Simd operator()(const Vec128*) const { + return Simd(); + } +#if HWY_TARGET <= HWY_AVX2 + template + Full256 operator()(const hwy::HWY_NAMESPACE::Vec256*) const { + return Full256(); + } +#endif +#if HWY_TARGET <= HWY_AVX3 + template + Full512 operator()(const hwy::HWY_NAMESPACE::Vec512*) const { + return Full512(); + } +#endif +}; + +// Workaround for MSVC v19.14: alias with a dependent type fails to specialize. +template +struct ExpandDFromV { + using type = decltype(DeduceD()(static_cast(nullptr))); +}; + +} // namespace detail + +template +using DFromV = typename detail::ExpandDFromV::type; + +template +using TFromV = TFromD>; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_si128()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_ps()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_pd()}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint16_t t) { + return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint32_t t) { + return Vec128{_mm_set1_epi32(static_cast(t))}; +} +template +HWY_API Vec128 Set(Simd /* tag */, + const uint64_t t) { + return Vec128{ + _mm_set1_epi64x(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{_mm_set1_epi32(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { + return Vec128{ + _mm_set1_epi64x(static_cast(t))}; // NOLINT +} +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{_mm_set1_ps(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const double t) { + return Vec128{_mm_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec128{_mm_undefined_si128()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_ps()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFF); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFFFF); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw)); +} +template +HWY_API float GetLane(const Vec128 v) { + return _mm_cvtss_f32(v.raw); +} +template +HWY_API uint64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) uint64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return static_cast(_mm_cvtsi128_si64(v.raw)); +#endif +} +template +HWY_API int64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) int64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return _mm_cvtsi128_si64(v.raw); +#endif +} +template +HWY_API double GetLane(const Vec128 v) { + return _mm_cvtsd_f64(v.raw); +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_si128(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{_mm_andnot_si128(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not + +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_TARGET <= HWY_AVX3 + const __m128i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Neg(hwy::FloatTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::NonFloatTag /*tag*/, const Vec128 v) { + return Zero(DFromV()) - v; +} + +} // namespace detail + +template +HWY_INLINE Vec128 Neg(const Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (reaches breakpoint) + const auto zero = Zero(DFromV()); + return Vec128{_mm_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec128{_mm_abs_epi8(v.raw)}; +#endif +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi16(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(DFromV(), mask); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(DFromV(), mask); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m128i out = _mm_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, VFromD{out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(DFromV()), sign)); +#endif +} + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// ------------------------------ Mask logical + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS + +namespace detail { + +template +HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Not(const Mask128 m) { + // Flip only the valid bits. + // TODO(janwas): use _knot intrinsics if N >= 8. + return Xor(m, Mask128::FromBits((1ull << N) - 1)); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 or below + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Simd /* tag */, + const Mask128 v) { + return Vec128{v.raw}; +} + +#if HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const auto vmask = VecFromMask(DFromV(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); +} + +#else // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(Simd(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + return Vec128{_mm_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80. +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {1, 0, 7, 6}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0302, 0x0100, 0x0f0e, 0x0d0c}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0, 3, 6, 5}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0100, 0x0706, 0x0d0c, 0x0b0a}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {2, 1, 4, 7}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const Twice> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0504, 0x0302, 0x0908, 0x0f0e}; + return Vec128{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template +HWY_API Vec128 Shuffle3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi8(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi16(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi32(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi64(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(Simd /* tag */, + const Mask128 v) { + return VecFromMask(v); +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, + Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(Simd(), VecFromMask(d, m))); +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const Simd d; + const RepartitionToNarrow d32; + const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask128 FirstN(const Simd d, size_t num) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of num! + const uint64_t bits = (num > 255) ? all : _bzhi_u64(all, num); + return Mask128::FromBits(bits); +#else + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); +#endif +} + +template +using MFromD = decltype(FirstN(D(), 0)); + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{_mm_load_si128(reinterpret_cast(aligned))}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; +} + +template +HWY_API Vec128 LoadU(Full128 /* tag */, const T* HWY_RESTRICT p) { + return Vec128{_mm_loadu_si128(reinterpret_cast(p))}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; +} + +template +HWY_API Vec64 Load(Full64 /* tag */, const T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_loadl_epi64(reinterpret_cast(p))}; +#endif +} + +HWY_API Vec128 Load(Full64 /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec128{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec128{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif +} + +HWY_API Vec64 Load(Full64 /* tag */, + const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_load_sd(p)}; +#endif +} + +HWY_API Vec128 Load(Full32 /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec128{v}; +#else + return Vec128{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + constexpr size_t kSize = sizeof(T) * N; +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes(p, &v); // not same size + return Vec128{v}; +#else + int32_t bits = 0; + CopyBytes(p, &bits); // not same size + return Vec128{_mm_cvtsi32_si128(bits)}; +#endif +} + +// For < 128 bit, LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +HWY_API Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, + Simd /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, + Simd /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_maskz_loadu_pd(m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return Vec128{_mm_maskload_epi32(p_p, m.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return Vec128{_mm_maskload_epi64(p_p, m.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const float* HWY_RESTRICT p) { + const Vec128 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec128{_mm_maskload_ps(p, mi.raw)}; +} + +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const double* HWY_RESTRICT p) { + const Vec128 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec128{_mm_maskload_pd(p, mi.raw)}; +} + +// There is no maskload_epi8/16, so blend instead. +template * = nullptr> +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#else // <= SSE4 + +// Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). +template +HWY_API Vec128 MaskedLoad(Mask128 m, Simd d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#endif + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template +HWY_API void Store(Vec64 v, Full64 /* tag */, T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec128 v, Full64 /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec64 v, Full64 /* tag */, + double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); // not same size +} +HWY_API void Store(const Vec128 v, Full32 /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ BlendedStore + +namespace detail { + +// There is no maskload_epi8/16 with which we could safely implement +// BlendedStore. Manual blending is also unsafe because loading a full vector +// that crosses the array end causes asan faults. Resort to scalar code; the +// caller should instead use memcpy, assuming m is FirstN(d, n). +template +HWY_API void ScalarMaskedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + const RebindToSigned di; // for testing mask if T=bfloat16_t. + using TI = TFromD; + alignas(16) TI buf[N]; + alignas(16) TI mask[N]; + Store(BitCast(di, v), di, buf); + Store(BitCast(di, VecFromMask(d, m)), di, mask); + for (size_t i = 0; i < N; ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi8(p, m.raw, v.raw); +} +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd, float* HWY_RESTRICT p) { + _mm_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd, double* HWY_RESTRICT p) { + _mm_mask_storeu_pd(p, m.raw, v.raw); +} + +#elif HWY_TARGET == HWY_AVX2 + +template * = nullptr> +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + detail::ScalarMaskedStore(v, m, d, p); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd d, float* HWY_RESTRICT p) { + using T = float; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + const Vec128, N> mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm_maskstore_ps(p, mi.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, + Simd d, double* HWY_RESTRICT p) { + using T = double; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128 df; + const Mask128 mf{m.raw}; + m = Mask128{And(mf, FirstN(df, N)).raw}; + } + + const Vec128, N> mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm_maskstore_pd(p, mi.raw, v.raw); +} + +#else // <= SSE4 + +template +HWY_API void BlendedStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT p) { + // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). + detail::ScalarMaskedStore(v, m, d, p); +} + +#endif // SSE4 + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epu32(a.raw, b.raw)}; +} + +#if HWY_TARGET == HWY_SSSE3 + +template // N=1 or 2 +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Set(Simd(), + static_cast(GetLane(a)) * GetLane(b)); +} +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) int32_t a_lanes[4]; + alignas(16) int32_t b_lanes[4]; + const Full128 di32; + Store(a, di32, a_lanes); + Store(b, di32, b_lanes); + alignas(16) int64_t mul[2]; + mul[0] = static_cast(a_lanes[0]) * b_lanes[0]; + mul[1] = static_cast(a_lanes[2]) * b_lanes[2]; + return Load(Full128(), mul); +} + +#else // HWY_TARGET == HWY_SSSE3 + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epi32(a.raw, b.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // Not as inefficient as it looks: _mm_mullo_epi32 has 10 cycle latency. + // 64-bit right shift would also work but also needs port 5, so no benefit. + // Notation: x=don't care, z=0. + const __m128i a_x3x1 = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x2x0 = MulEven(a, b); + const __m128i b_x3x1 = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x3x1 = + MulEven(Vec128{a_x3x1}, Vec128{b_x3x1}); + // We could _mm_slli_epi64 by 32 to get 3z1z and OR with z2z0, but generating + // the latter requires one more instruction or a constant. + const __m128i mul_20 = + _mm_shuffle_epi32(mullo_x2x0.raw, _MM_SHUFFLE(2, 0, 2, 0)); + const __m128i mul_31 = + _mm_shuffle_epi32(mullo_x3x1.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(mul_20, mul_31)}; +#else + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + // Same as unsigned; avoid duplicating the SSSE3 code. + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +// ------------------------------ RotateRight (ShiftRight, Or) + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<15>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<31>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Vec128{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 + return VecFromMask(v < Zero(d)); +#else + // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift + // avoids generating a zero. + const RepartitionToNarrow d32; + const auto sign = ShiftRight<31>(BitCast(d32, v)); + return Vec128{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_abs_epi64(v.raw)}; +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srai_epi64(v.raw, kBits)}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ ZeroIfNegative (BroadcastSignBit) +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + static_assert(IsFloat(), "Only works for float"); + const DFromV d; +#if HWY_TARGET == HWY_SSSE3 + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#else + const auto mask = MaskFromVec(v); // MSB is sufficient for BLENDVPS +#endif + return IfThenElse(mask, Zero(d), v); +} + +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(const Vec128 v, + const Vec128 yes, + const Vec128 no) { + // int8: IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + return BitCast(d, IfThenElse(MaskFromVec(BitCast(df, v)), BitCast(df, yes), + BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { + return Vec64{_mm_mul_sd(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator/(const Vec64 a, const Vec64 b) { + return Vec64{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ss(v.raw)}; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ss(v.raw)}; +} +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec64 Sqrt(const Vec64 v) { + return Vec64{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ss(v.raw)}; +} + +// ------------------------------ Min (Gt, IfThenElse) + +namespace detail { + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MinU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epu64(a.raw, b.raw)}; +#else + return detail::MinU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +namespace detail { +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MaxU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epu64(a.raw, b.raw)}; +#else + return detail::MaxU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_pd(a.raw, b.raw)}; +} + +// ================================================== MEMORY (2) + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(runtime/int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_epi32(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_epi32(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_epi64(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_epi64(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, 8); + } +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_ps(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_ps(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_API void ScatterOffset(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_pd(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_pd(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, 8); + } +} +#else // HWY_TARGET <= HWY_AVX3 + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, + T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +#else + +namespace detail { + +template +HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<4> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<4> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_INLINE Vec128 GatherOffset(hwy::SizeTag<8> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec128 GatherIndex(hwy::SizeTag<8> /* tag */, + Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec128 GatherOffset(Simd d, const T* HWY_RESTRICT base, + const Vec128 offset) { + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec128 GatherIndex(Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_ps(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_ps(base, index.raw, 4)}; +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_pd(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_pd(base, index.raw, 8)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +// Returns upper/lower half of a vector. +template +HWY_API Vec128 LowerHalf(Simd /* tag */, + Vec128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return LowerHalf(Simd(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec128 ShiftLeftBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec128{_mm_slli_si128(v.raw, kBytes)}; +} + +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec128 ShiftLeftLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec128 ShiftRightBytes(Simd /* tag */, Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128 vfull{v.raw}; + v = Vec128{IfThenElseZero(FirstN(Full128(), N), vfull).raw}; + } + return Vec128{_mm_srli_si128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec128 ShiftRightLanes(Simd d, const Vec128 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec64 UpperHalf(Half> /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_epi64(v.raw, v.raw)}; +} +HWY_API Vec128 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec128{_mm_movehl_ps(v.raw, v.raw)}; +} +HWY_API Vec64 UpperHalf(Full64 /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// Partial +template +HWY_API Vec128 UpperHalf(Half> /* tag */, + Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes(du, vu)); + return Vec128{upper.raw}; +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const int pair = _mm_extract_epi16(v.raw, kLane / 2); + constexpr int kShift = kLane & 1 ? 8 : 0; + return static_cast((pair >> kShift) & 0xFF); +#else + return static_cast(_mm_extract_epi8(v.raw, kLane) & 0xFF); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast(_mm_extract_epi16(v.raw, kLane) & 0xFFFF); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + return static_cast(_mm_extract_epi32(v.raw, kLane)); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + return static_cast(_mm_extract_epi64(v.raw, kLane)); +#endif +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) float lanes[4]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#else + // Bug in the intrinsic, returns int but should be float. + const int32_t bits = _mm_extract_ps(v.raw, kLane); + float ret; + CopySameSize(&bits, &ret); + return ret; +#endif +} + +// There is no extract_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane == 0, "Lane index out of bounds"); + return GetLane(v); +} + +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane < 2, "Lane index out of bounds"); + const Half> dh; + return kLane == 0 ? GetLane(v) : GetLane(UpperHalf(dh, v)); +} + +} // namespace detail + +// Requires one overload per vector length because ExtractLane<3> may be a +// compile error if it calls _mm_extract_epi64. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128{_mm_insert_epi8(v.raw, t, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{_mm_insert_epi16(v.raw, t, kLane)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + const DFromV d; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128{_mm_insert_epi32(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128{_mm_insert_epi64(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + alignas(16) float lanes[4]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128{_mm_insert_ps(v.raw, _mm_set_ss(t), kLane << 4)}; +#endif +} + +// There is no insert_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane == 0, "Lane index out of bounds"); + return Set(DFromV(), t); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV d; + const Vec128 vt = Set(d, t); + if (kLane == 0) { + return Vec128{_mm_shuffle_pd(vt.raw, v.raw, 2)}; + } + return Vec128{_mm_shuffle_pd(v.raw, vt.raw, 0)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls _mm_insert_epi64. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full128 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec128{_mm_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +template > +HWY_API V CombineShiftRightBytes(Simd d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full128 d_full8; + using V8 = VFromD; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, N)))); +#endif + +#if HWY_TARGET <= HWY_AVX2 + (void)d; + return Indices128{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif +} + +template +HWY_API Indices128 IndicesFromVec(Simd d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(N))))); +#else + (void)d; +#endif + + // No change - even without AVX3, we can shuffle+blend. + return Indices128{vec.raw}; +} + +template +HWY_API Indices128 SetTableIndices(Simd d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + const DFromV d; + const RebindToFloat df; + const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; + return BitCast(d, perm); +#else + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; +#else + const DFromV df; + const RebindToSigned di; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +#endif +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const Full128 d; + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const Full128 df; + return BitCast( + d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128 di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + vidx += vidx; // bit1 is the decider (unusual) + return Vec128{_mm_permutevar_pd(v.raw, vidx.raw)}; +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128 d; + const Full128 di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API Vec128 ReverseBlocks(Full128 /* tag */, const Vec128 v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API Vec128 Reverse(Simd /* tag */, const Vec128 v) { + return v; +} + +// Two lanes: shuffle +template +HWY_API Vec128 Reverse(Full64 /* tag */, const Vec128 v) { + return Vec128{Shuffle2301(Vec128{v.raw}).raw}; +} + +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template +HWY_API Vec128 Reverse(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API Vec128 Reverse(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + if (N == 1) return v; + if (N == 2) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); + } + const RebindToSigned di; + alignas(16) constexpr int16_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + const Vec128 idx = Load(di, kReverse + (N == 8 ? 0 : 4)); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +#endif +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec128 Reverse2(Simd d, const Vec128 v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec128 Reverse2(Simd /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec128 Reverse4(Simd d, const Vec128 v) { + const RebindToSigned di; + // 4x 16-bit: a single shufflelo suffices. + if (N == 4) { + return BitCast(d, Vec128{_mm_shufflelo_epi16( + BitCast(di, v).raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET <= HWY_AVX3 + alignas(16) constexpr int16_t kReverse4[8] = {3, 2, 1, 0, 7, 6, 5, 4}; + const Vec128 idx = Load(di, kReverse4); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +// 4x 32-bit: use Shuffle0123 +template +HWY_API Vec128 Reverse4(Full128 /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec128 Reverse4(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec128 Reverse8(Simd d, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec128 idx = Load(di, kReverse8); + return BitCast(d, Vec128{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec128 Reverse8(Simd /* tag */, Vec128 /* v */) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Additional overload for the optional tag (also for 256/512). +template +HWY_API V InterleaveLower(DFromV /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +// Full +template > +HWY_API V InterleaveUpper(Full128 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template > +HWY_API V InterleaveUpper(Simd d, V a, V b) { + const Half d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template +HWY_API Vec128 Combine(Simd d, Vec128 hi_half, + Vec128 lo_half) { + const Half d2; + const RebindToUnsigned du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 ZeroExtendVector(hwy::NonFloatTag /*tag*/, + Full128 /* d */, Vec64 lo) { + return Vec128{_mm_move_epi64(lo.raw)}; +} + +template +HWY_INLINE Vec128 ZeroExtendVector(hwy::FloatTag /*tag*/, Full128 d, + Vec64 lo) { + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(Half(), lo))); +} + +} // namespace detail + +template +HWY_API Vec128 ZeroExtendVector(Full128 d, Vec64 lo) { + return detail::ZeroExtendVector(hwy::IsFloatTag(), d, lo); +} + +template +HWY_API Vec128 ZeroExtendVector(Simd d, Vec128 lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128{lo.raw}); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec128 ConcatUpperUpper(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API Vec128 ConcatLowerUpper(Full128 d, const Vec128 hi, + const Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec128 ConcatUpperLower(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(dd, lo).raw, BitCast(dd, hi).raw, + _MM_SHUFFLE2(1, 0))}); +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _pd can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128 ConcatUpperLower(Full128 d, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET == HWY_SSSE3 + (void)d; + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 2, 1, 0))}; +#else + // _mm_shuffle_ps has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + const RepartitionToWide dd; + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128 ConcatUpperLower(Full128 /* tag */, + Vec128 hi, Vec128 lo) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_shuffle_pd(lo.raw, hi.raw, _MM_SHUFFLE2(1, 0))}; +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return Vec128{_mm_blend_pd(hi.raw, lo.raw, 1)}; +#endif +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API Vec128 ConcatLowerLower(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatUpperUpper(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatLowerUpper(Simd d, const Vec128 hi, + const Vec128 lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API Vec128 ConcatUpperLower(Simd d, Vec128 hi, + Vec128 lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); + return Vec128{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API Vec64 ConcatOdd(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[8] = {1, 3, 5, 7}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactOddU8)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template +HWY_API Vec32 ConcatOdd(Simd d, Vec32 hi, Vec32 lo) { + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[4] = {1, 3}; + const Vec32 shuf = BitCast(d, Load(Full32(), kCompactOddU8)); + const Vec32 L = TableLookupBytes(lo, shuf); + const Vec32 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const Repartition dw; + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); + return Vec128{_mm_packs_epi32(uL.raw, uH.raw)}; +} + +// 16-bit x4 +template +HWY_API Vec64 ConcatOdd(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU16[8] = {2, 3, 6, 7}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactOddU16)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template +HWY_API Vec128 ConcatOdd(Full128 d, Vec128 hi, Vec128 lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(3, 1, 3, 1))}); +} +template +HWY_API Vec128 ConcatOdd(Full128 /* tag */, Vec128 hi, + Vec128 lo) { + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; +} + +// Any type x2 +template +HWY_API Vec128 ConcatOdd(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec128 mask = Set(dw, 0x00FF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return Vec128{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API Vec64 ConcatEven(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[8] = {0, 2, 4, 6}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactEvenU8)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template +HWY_API Vec32 ConcatEven(Simd d, Vec32 hi, Vec32 lo) { + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[4] = {0, 2}; + const Vec32 shuf = BitCast(d, Load(Full32(), kCompactEvenU8)); + const Vec32 L = TableLookupBytes(lo, shuf); + const Vec32 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { +#if HWY_TARGET <= HWY_SSE4 + // Isolate lower 16 bits per u32 so we can pack. + const Repartition dw; + const Vec128 mask = Set(dw, 0x0000FFFF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return Vec128{_mm_packus_epi32(uL.raw, uH.raw)}; +#else + // packs_epi32 saturates 0x8000 to 0x7FFF. Instead ConcatEven within the two + // inputs, then concatenate them. + alignas(16) const T kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + const Vec128 shuf = BitCast(d, Load(d, kCompactEvenU16)); + const Vec128 L = TableLookupBytes(lo, shuf); + const Vec128 H = TableLookupBytes(hi, shuf); + return ConcatLowerLower(d, H, L); +#endif +} + +// 16-bit x4 +template +HWY_API Vec64 ConcatEven(Simd d, Vec64 hi, Vec64 lo) { + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU16[8] = {0, 1, 4, 5}; + const Vec64 shuf = BitCast(d, Load(Full64(), kCompactEvenU16)); + const Vec64 L = TableLookupBytes(lo, shuf); + const Vec64 H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template +HWY_API Vec128 ConcatEven(Full128 d, Vec128 hi, Vec128 lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +HWY_API Vec128 ConcatEven(Full128 /* tag */, Vec128 hi, + Vec128 lo) { + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; +} + +// Any T x2 +template +HWY_API Vec128 ConcatEven(Simd d, Vec128 hi, + Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const DFromV d; + const Repartition d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +#else + return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x55)}; +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i odd = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128i even = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(even, odd)}; +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _ps can do 3/cycle. + const DFromV d; + const RebindToFloat df; + return BitCast(d, Vec128{_mm_blend_ps(BitCast(df, a).raw, + BitCast(df, b).raw, 5)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const RebindToFloat dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd( + BitCast(dd, b).raw, BitCast(dd, a).raw, _MM_SHUFFLE2(1, 0))}); +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, a).raw, + BitCast(dd, b).raw, 1)}); +#endif +} + +template +HWY_API Vec128 OddEven(Vec128 a, Vec128 b) { +#if HWY_TARGET == HWY_SSSE3 + // SHUFPS must fill the lower half of the output from one input, so we + // need another shuffle. Unpack avoids another immediate byte. + const __m128 odd = _mm_shuffle_ps(a.raw, a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128 even = _mm_shuffle_ps(b.raw, b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_ps(even, odd)}; +#else + return Vec128{_mm_blend_ps(a.raw, b.raw, 5)}; +#endif +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +namespace detail { +#if HWY_TARGET > HWY_AVX3 // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // See comment below. + const Vec128 bits0{_mm_cvtps_epi32(BitCast(df, f0).raw)}; + const Vec128 bits1{_mm_cvtps_epi32(BitCast(df, f1).raw)}; + return Vec128, N>{_mm_packus_epi32(bits0.raw, bits1.raw)}; +} + +// Same, for 32-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. cvt instead of cvtt should be equivalent, but avoids test + // failure under GCC 10.2.1. + return Vec128, N>{_mm_cvtps_epi32(_mm_castsi128_ps(f.raw))}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { + return Vec128{_mm_sll_epi16(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return v * Pow2(bits); +#else + return Vec128{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sll_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128 out0{_mm_sll_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_sll_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128(), out1, out0); +#else + return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 Shl(hwy::UnsignedTag /*tag*/, Vec64 v, + Vec64 bits) { + return Vec64{_mm_sll_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for SSSE3/SSE4 or 16-bit. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; +#else + const Simd d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { + return Vec128{_mm_srl_epi16(in.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const Simd d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128 mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = BitCast(d32, MulEven(in31, mul31)); // 3 ? 1 ? + const Vec128 out = OddEven(out31, BitCast(d32, out20)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { + return Vec128{_mm_srl_epi32(in.raw, bits.raw)}; +} + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128 out0{_mm_srl_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_srl_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128(), out1, out0); +#else + return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_srl_epi64(v.raw, bits.raw)}; +} + +#if HWY_TARGET > HWY_AVX3 // AVX2 or older +namespace detail { + +// Also used in x86_256-inl.h. +template +HWY_INLINE V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sra_epi16(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128{_mm_sra_epi32(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +HWY_INLINE Vec128 MulOdd(const Vec128 a, + const Vec128 b) { + alignas(16) uint64_t mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API Vec128 ReorderWidenMulAccumulate(Simd df32, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { + // TODO(janwas): _mm_dpbf16_ps when available + const Repartition du16; + const RebindToUnsigned du32; + const Vec128 zero = Zero(du16); + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo. + const Vec128 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec128 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec128 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec128 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template +HWY_API Vec128 ReorderWidenMulAccumulate( + Simd /*d32*/, Vec128 a, + Vec128 b, const Vec128 sum0, + Vec128& /*sum1*/) { + return sum0 + Vec128{_mm_madd_epi16(a.raw, b.raw)}; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return Vec128{_mm_unpacklo_epi8(v.raw, zero)}; +#else + return Vec128{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return Vec128{_mm_cvtepu16_epi32(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; +#else + return Vec128{_mm_cvtepu32_epi64(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return Vec128{_mm_unpacklo_epi16(u16, zero)}; +#else + return Vec128{_mm_cvtepu8_epi32(v.raw)}; +#endif +} + +// Unsigned to signed: same plus cast. +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} +template +HWY_API Vec128 PromoteTo(Simd di, + const Vec128 v) { + return BitCast(di, PromoteTo(Simd(), v)); +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<8>(Vec128{_mm_unpacklo_epi8(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi8_epi16(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<16>(Vec128{_mm_unpacklo_epi16(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi16_epi32(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<32>(Vec128{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return Vec128{_mm_cvtepi32_epi64(v.raw)}; +#endif +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(Vec128{x4}); +#else + return Vec128{_mm_cvtepi8_epi32(v.raw)}; +#endif +} + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 Vec128 PromoteTo(Simd df32, + const Vec128 v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec128{_mm_cvtph_ps(v.raw)}; +#endif +} + +template +HWY_API Vec128 PromoteTo(Simd df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtps_pd(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSSE3 + const Simd di32; + const Simd du16; + const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); + const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); + const auto clamped = Or(zero_if_neg, too_big); + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16, kLower2Bytes); + return Vec128{TableLookupBytes(BitCast(du16, clamped), lo2).raw}; +#else + return Vec128{_mm_packus_epi32(v.raw, v.raw)}; +#endif +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi32(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packus_epi16(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi16(v.raw, v.raw)}; +} + +// Work around MSVC warning for _mm_cvtps_ph (8 is actually a valid immediate). +// clang-cl requires a non-empty string, so we 'ignore' the irrelevant -Wmain. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") + +template +HWY_API Vec128 DemoteTo(Simd df16, + const Vec128 v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API Vec128 DemoteTo(Simd dbf16, + const Vec128 v) { + // TODO(janwas): _mm_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template +HWY_API Vec128 ReorderDemote2To( + Simd dbf16, Vec128 a, Vec128 b) { + // TODO(janwas): _mm_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec128 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Simd dn, + Vec128 a, + Vec128 b) { + const Half dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128 an{DemoteTo(dnh, a).raw}; + const Vec128 bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128 ReorderDemote2To(Full128 /*d16*/, + Vec128 a, Vec128 b) { + return Vec128{_mm_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// For well-defined float->int demotion in all x86_*-inl.h. + +template +HWY_INLINE auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) + -> decltype(Zero(d)) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +// For ConvertTo float->int of same size, clamping before conversion would +// change the result because the max integer value is not exactly representable. +// Instead detect the overflow result after conversion and fix it. +template > +HWY_INLINE auto FixConversionOverflow(DI di, VFromD original, + decltype(Zero(di).raw) converted_raw) + -> VFromD { + // Combinations of original and output sign: + // --: normal <0 or -huge_val to 80..00: OK + // -+: -0 to 0 : OK + // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF + // ++: normal >0 : OK + const auto converted = VFromD{converted_raw}; + const auto sign_wrong = AndNot(BitCast(di, original), converted); +#if HWY_COMPILER_GCC_ACTUAL + // Critical GCC 11 compiler bug (possibly also GCC 10): omits the Xor; also + // Add() if using that instead. Work around with one more instruction. + const RebindToUnsigned du; + const VFromD mask = BroadcastSignBit(sign_wrong); + const VFromD max = BitCast(di, ShiftRight<1>(BitCast(du, mask))); + return IfVecThenElse(mask, max, converted); +#else + return Xor(converted, BroadcastSignBit(sign_wrong)); +#endif +} + +} // namespace detail + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto clamped = detail::ClampF64ToI32Max(Simd(), v); + return Vec128{_mm_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const Simd d32; + const Simd d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +} + +// ------------------------------ Truncations + +template * = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + static_assert(!IsSigned() && !IsSigned(), "Unsigned only"); + const Repartition> d; + const auto v1 = BitCast(d, v); + return Vec128{v1.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d8; + alignas(16) static constexpr uint8_t kMap[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + return LowerHalf(LowerHalf(LowerHalf(TableLookupBytes(v, Load(d8, kMap))))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Full128 d16; + alignas(16) static constexpr uint16_t kMap[8] = { + 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d16, kMap)))); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x88)}; +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + alignas(16) static constexpr uint8_t kMap[16] = { + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu, + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d, kMap)))); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template = 2>* = nullptr> +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_ps(v.raw)}; +} + +template +HWY_API Vec128 ConvertTo(HWY_MAYBE_UNUSED Simd df, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +template +HWY_API Vec128 ConvertTo(Simd dd, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec128{_mm_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +template +HWY_API Vec128 ConvertTo(HWY_MAYBE_UNUSED Simd dd, + const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double128_fast = [&dd](VU w) HWY_ATTR { + w = Or(w, VU{detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double128_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double128_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(const Simd di, + const Vec128 v) { + return detail::FixConversionOverflow(di, v, _mm_cvttps_epi32(v.raw)); +} + +// Full (partial handled below) +HWY_API Vec128 ConvertTo(Full128 di, const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 + return detail::FixConversionOverflow(di, v, _mm_cvttpd_epi64(v.raw)); +#elif HWY_ARCH_X86_64 + const __m128i i0 = _mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw)); + const Half> dd2; + const __m128i i1 = _mm_cvtsi64_si128(_mm_cvttsd_si64(UpperHalf(dd2, v).raw)); + return detail::FixConversionOverflow(di, v, _mm_unpacklo_epi64(i0, i1)); +#else + using VI = VFromD; + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} +HWY_API Vec64 ConvertTo(Full64 di, const Vec64 v) { + // Only need to specialize for non-AVX3, 64-bit (single scalar op) +#if HWY_TARGET > HWY_AVX3 && HWY_ARCH_X86_64 + const Vec64 i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; + return detail::FixConversionOverflow(di, v, i0.raw); +#else + (void)di; + const auto full = ConvertTo(Full128(), Vec128{v.raw}); + return Vec64{full.raw}; +#endif +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const Simd di; + return detail::FixConversionOverflow(di, v, _mm_cvtps_epi32(v.raw)); +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +#if HWY_TARGET == HWY_SSSE3 + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + // Rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const Simd df; + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#else + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +#endif // !HWY_SSSE3 + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask128{_mm_cmpunord_ps(v.raw, v.raw)}; +#endif +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask128{_mm_cmpunord_pd(v.raw, v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ps_mask(v.raw, 0x18)}; +} +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ps_mask(v.raw, 0x99)}); +} +template +HWY_API Mask128 IsFinite(const Vec128 v) { + return Not(Mask128{_mm_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template +HWY_API Mask128 IsInf(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const Simd d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD exp = + BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenc_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x00)}; +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x11)}; +} + +#endif // !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// ================================================== MISC + +template +struct CompressIsPartition { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 supports native compress, but a table-based approach allows + // 'partitioning' (also moving mask=false lanes to the top), which helps + // vqsort. This is only feasible for eight or less lanes, i.e. sizeof(T) == 8 + // on AVX3. For simplicity, we only use tables for 64-bit lanes (not AVX3 + // u32x8 etc.). + enum { value = (sizeof(T) == 8) }; +#else + enum { value = 1 }; +#endif +}; + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadMaskBits + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd /* tag */, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask128::FromBits(mask_bits); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = (1 << N) - 1; + bits[0] = static_cast(bits[0] & mask_bits); + } + + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +template +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return PopCount(mask_bits); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint32_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + return mask_bits == 0; +} + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { + const uint64_t mask_bits = static_cast(mask.raw) & ((1u << N) - 1); + // Cannot use _kortestc because we may have less than 8 mask bits. + return mask_bits == (1u << N) - 1; +} + +// ------------------------------ Compress + +#if HWY_TARGET != HWY_AVX3_DL +namespace detail { + +// Returns permutevar_epi16 indices for 16-bit Compress. Also used by x86_256. +HWY_INLINE Vec128 IndicesForCompress16(uint64_t mask_bits) { + Full128 du16; + // Table of u16 indices packed into bytes to reduce L1 usage. Will be unpacked + // to u16. Ideally we would broadcast 8*3 (half of the 8 bytes currently used) + // bits into each lane and then varshift, but that does not fit in 16 bits. + Rebind du8; + alignas(16) constexpr uint8_t tbl[2048] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 2, + 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, + 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 2, 3, 0, 0, + 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, + 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 1, 4, 0, 0, 0, 0, + 0, 0, 0, 1, 4, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, + 0, 1, 2, 4, 0, 0, 0, 0, 0, 0, 1, 2, 4, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, + 0, 3, 4, 0, 0, 0, 0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 1, 3, 4, 0, 0, 0, 0, 2, + 3, 4, 0, 0, 0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 1, + 2, 3, 4, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 1, 5, 0, + 0, 0, 0, 0, 0, 0, 1, 5, 0, 0, 0, 0, 0, 2, 5, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, + 0, 0, 0, 0, 1, 2, 5, 0, 0, 0, 0, 0, 0, 1, 2, 5, 0, 0, 0, 0, 3, 5, 0, 0, 0, + 0, 0, 0, 0, 3, 5, 0, 0, 0, 0, 0, 1, 3, 5, 0, 0, 0, 0, 0, 0, 1, 3, 5, 0, 0, + 0, 0, 2, 3, 5, 0, 0, 0, 0, 0, 0, 2, 3, 5, 0, 0, 0, 0, 1, 2, 3, 5, 0, 0, 0, + 0, 0, 1, 2, 3, 5, 0, 0, 0, 4, 5, 0, 0, 0, 0, 0, 0, 0, 4, 5, 0, 0, 0, 0, 0, + 1, 4, 5, 0, 0, 0, 0, 0, 0, 1, 4, 5, 0, 0, 0, 0, 2, 4, 5, 0, 0, 0, 0, 0, 0, + 2, 4, 5, 0, 0, 0, 0, 1, 2, 4, 5, 0, 0, 0, 0, 0, 1, 2, 4, 5, 0, 0, 0, 3, 4, + 5, 0, 0, 0, 0, 0, 0, 3, 4, 5, 0, 0, 0, 0, 1, 3, 4, 5, 0, 0, 0, 0, 0, 1, 3, + 4, 5, 0, 0, 0, 2, 3, 4, 5, 0, 0, 0, 0, 0, 2, 3, 4, 5, 0, 0, 0, 1, 2, 3, 4, + 5, 0, 0, 0, 0, 1, 2, 3, 4, 5, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, + 0, 0, 0, 1, 6, 0, 0, 0, 0, 0, 0, 0, 1, 6, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, + 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 1, 2, 6, 0, 0, 0, 0, 0, 0, 1, 2, 6, 0, 0, 0, + 0, 3, 6, 0, 0, 0, 0, 0, 0, 0, 3, 6, 0, 0, 0, 0, 0, 1, 3, 6, 0, 0, 0, 0, 0, + 0, 1, 3, 6, 0, 0, 0, 0, 2, 3, 6, 0, 0, 0, 0, 0, 0, 2, 3, 6, 0, 0, 0, 0, 1, + 2, 3, 6, 0, 0, 0, 0, 0, 1, 2, 3, 6, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 0, 0, 4, + 6, 0, 0, 0, 0, 0, 1, 4, 6, 0, 0, 0, 0, 0, 0, 1, 4, 6, 0, 0, 0, 0, 2, 4, 6, + 0, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 1, 2, 4, 6, 0, 0, 0, 0, 0, 1, 2, 4, + 6, 0, 0, 0, 3, 4, 6, 0, 0, 0, 0, 0, 0, 3, 4, 6, 0, 0, 0, 0, 1, 3, 4, 6, 0, + 0, 0, 0, 0, 1, 3, 4, 6, 0, 0, 0, 2, 3, 4, 6, 0, 0, 0, 0, 0, 2, 3, 4, 6, 0, + 0, 0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 1, 2, 3, 4, 6, 0, 0, 5, 6, 0, 0, 0, 0, 0, + 0, 0, 5, 6, 0, 0, 0, 0, 0, 1, 5, 6, 0, 0, 0, 0, 0, 0, 1, 5, 6, 0, 0, 0, 0, + 2, 5, 6, 0, 0, 0, 0, 0, 0, 2, 5, 6, 0, 0, 0, 0, 1, 2, 5, 6, 0, 0, 0, 0, 0, + 1, 2, 5, 6, 0, 0, 0, 3, 5, 6, 0, 0, 0, 0, 0, 0, 3, 5, 6, 0, 0, 0, 0, 1, 3, + 5, 6, 0, 0, 0, 0, 0, 1, 3, 5, 6, 0, 0, 0, 2, 3, 5, 6, 0, 0, 0, 0, 0, 2, 3, + 5, 6, 0, 0, 0, 1, 2, 3, 5, 6, 0, 0, 0, 0, 1, 2, 3, 5, 6, 0, 0, 4, 5, 6, 0, + 0, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 1, 4, 5, 6, 0, 0, 0, 0, 0, 1, 4, 5, 6, + 0, 0, 0, 2, 4, 5, 6, 0, 0, 0, 0, 0, 2, 4, 5, 6, 0, 0, 0, 1, 2, 4, 5, 6, 0, + 0, 0, 0, 1, 2, 4, 5, 6, 0, 0, 3, 4, 5, 6, 0, 0, 0, 0, 0, 3, 4, 5, 6, 0, 0, + 0, 1, 3, 4, 5, 6, 0, 0, 0, 0, 1, 3, 4, 5, 6, 0, 0, 2, 3, 4, 5, 6, 0, 0, 0, + 0, 2, 3, 4, 5, 6, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 7, + 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 1, 7, 0, 0, 0, 0, 0, 0, 0, 1, + 7, 0, 0, 0, 0, 0, 2, 7, 0, 0, 0, 0, 0, 0, 0, 2, 7, 0, 0, 0, 0, 0, 1, 2, 7, + 0, 0, 0, 0, 0, 0, 1, 2, 7, 0, 0, 0, 0, 3, 7, 0, 0, 0, 0, 0, 0, 0, 3, 7, 0, + 0, 0, 0, 0, 1, 3, 7, 0, 0, 0, 0, 0, 0, 1, 3, 7, 0, 0, 0, 0, 2, 3, 7, 0, 0, + 0, 0, 0, 0, 2, 3, 7, 0, 0, 0, 0, 1, 2, 3, 7, 0, 0, 0, 0, 0, 1, 2, 3, 7, 0, + 0, 0, 4, 7, 0, 0, 0, 0, 0, 0, 0, 4, 7, 0, 0, 0, 0, 0, 1, 4, 7, 0, 0, 0, 0, + 0, 0, 1, 4, 7, 0, 0, 0, 0, 2, 4, 7, 0, 0, 0, 0, 0, 0, 2, 4, 7, 0, 0, 0, 0, + 1, 2, 4, 7, 0, 0, 0, 0, 0, 1, 2, 4, 7, 0, 0, 0, 3, 4, 7, 0, 0, 0, 0, 0, 0, + 3, 4, 7, 0, 0, 0, 0, 1, 3, 4, 7, 0, 0, 0, 0, 0, 1, 3, 4, 7, 0, 0, 0, 2, 3, + 4, 7, 0, 0, 0, 0, 0, 2, 3, 4, 7, 0, 0, 0, 1, 2, 3, 4, 7, 0, 0, 0, 0, 1, 2, + 3, 4, 7, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 1, 5, 7, 0, + 0, 0, 0, 0, 0, 1, 5, 7, 0, 0, 0, 0, 2, 5, 7, 0, 0, 0, 0, 0, 0, 2, 5, 7, 0, + 0, 0, 0, 1, 2, 5, 7, 0, 0, 0, 0, 0, 1, 2, 5, 7, 0, 0, 0, 3, 5, 7, 0, 0, 0, + 0, 0, 0, 3, 5, 7, 0, 0, 0, 0, 1, 3, 5, 7, 0, 0, 0, 0, 0, 1, 3, 5, 7, 0, 0, + 0, 2, 3, 5, 7, 0, 0, 0, 0, 0, 2, 3, 5, 7, 0, 0, 0, 1, 2, 3, 5, 7, 0, 0, 0, + 0, 1, 2, 3, 5, 7, 0, 0, 4, 5, 7, 0, 0, 0, 0, 0, 0, 4, 5, 7, 0, 0, 0, 0, 1, + 4, 5, 7, 0, 0, 0, 0, 0, 1, 4, 5, 7, 0, 0, 0, 2, 4, 5, 7, 0, 0, 0, 0, 0, 2, + 4, 5, 7, 0, 0, 0, 1, 2, 4, 5, 7, 0, 0, 0, 0, 1, 2, 4, 5, 7, 0, 0, 3, 4, 5, + 7, 0, 0, 0, 0, 0, 3, 4, 5, 7, 0, 0, 0, 1, 3, 4, 5, 7, 0, 0, 0, 0, 1, 3, 4, + 5, 7, 0, 0, 2, 3, 4, 5, 7, 0, 0, 0, 0, 2, 3, 4, 5, 7, 0, 0, 1, 2, 3, 4, 5, + 7, 0, 0, 0, 1, 2, 3, 4, 5, 7, 0, 6, 7, 0, 0, 0, 0, 0, 0, 0, 6, 7, 0, 0, 0, + 0, 0, 1, 6, 7, 0, 0, 0, 0, 0, 0, 1, 6, 7, 0, 0, 0, 0, 2, 6, 7, 0, 0, 0, 0, + 0, 0, 2, 6, 7, 0, 0, 0, 0, 1, 2, 6, 7, 0, 0, 0, 0, 0, 1, 2, 6, 7, 0, 0, 0, + 3, 6, 7, 0, 0, 0, 0, 0, 0, 3, 6, 7, 0, 0, 0, 0, 1, 3, 6, 7, 0, 0, 0, 0, 0, + 1, 3, 6, 7, 0, 0, 0, 2, 3, 6, 7, 0, 0, 0, 0, 0, 2, 3, 6, 7, 0, 0, 0, 1, 2, + 3, 6, 7, 0, 0, 0, 0, 1, 2, 3, 6, 7, 0, 0, 4, 6, 7, 0, 0, 0, 0, 0, 0, 4, 6, + 7, 0, 0, 0, 0, 1, 4, 6, 7, 0, 0, 0, 0, 0, 1, 4, 6, 7, 0, 0, 0, 2, 4, 6, 7, + 0, 0, 0, 0, 0, 2, 4, 6, 7, 0, 0, 0, 1, 2, 4, 6, 7, 0, 0, 0, 0, 1, 2, 4, 6, + 7, 0, 0, 3, 4, 6, 7, 0, 0, 0, 0, 0, 3, 4, 6, 7, 0, 0, 0, 1, 3, 4, 6, 7, 0, + 0, 0, 0, 1, 3, 4, 6, 7, 0, 0, 2, 3, 4, 6, 7, 0, 0, 0, 0, 2, 3, 4, 6, 7, 0, + 0, 1, 2, 3, 4, 6, 7, 0, 0, 0, 1, 2, 3, 4, 6, 7, 0, 5, 6, 7, 0, 0, 0, 0, 0, + 0, 5, 6, 7, 0, 0, 0, 0, 1, 5, 6, 7, 0, 0, 0, 0, 0, 1, 5, 6, 7, 0, 0, 0, 2, + 5, 6, 7, 0, 0, 0, 0, 0, 2, 5, 6, 7, 0, 0, 0, 1, 2, 5, 6, 7, 0, 0, 0, 0, 1, + 2, 5, 6, 7, 0, 0, 3, 5, 6, 7, 0, 0, 0, 0, 0, 3, 5, 6, 7, 0, 0, 0, 1, 3, 5, + 6, 7, 0, 0, 0, 0, 1, 3, 5, 6, 7, 0, 0, 2, 3, 5, 6, 7, 0, 0, 0, 0, 2, 3, 5, + 6, 7, 0, 0, 1, 2, 3, 5, 6, 7, 0, 0, 0, 1, 2, 3, 5, 6, 7, 0, 4, 5, 6, 7, 0, + 0, 0, 0, 0, 4, 5, 6, 7, 0, 0, 0, 1, 4, 5, 6, 7, 0, 0, 0, 0, 1, 4, 5, 6, 7, + 0, 0, 2, 4, 5, 6, 7, 0, 0, 0, 0, 2, 4, 5, 6, 7, 0, 0, 1, 2, 4, 5, 6, 7, 0, + 0, 0, 1, 2, 4, 5, 6, 7, 0, 3, 4, 5, 6, 7, 0, 0, 0, 0, 3, 4, 5, 6, 7, 0, 0, + 1, 3, 4, 5, 6, 7, 0, 0, 0, 1, 3, 4, 5, 6, 7, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, + 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, 6, 7}; + return PromoteTo(du16, Load(du8, tbl + mask_bits * 8)); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_AVX3_DL + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const Simd d; + const Rebind du; + const auto vu = BitCast(du, v); // (required for float16_t inputs) + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + const Vec128 cu{_mm_maskz_compress_epi16(mask.raw, vu.raw)}; +#else + const auto idx = detail::IndicesForCompress16(uint64_t{mask.raw}); + const Vec128 cu{_mm_permutexvar_epi16(idx.raw, vu.raw)}; +#endif // HWY_TARGET != HWY_AVX3_DL + return BitCast(d, cu); +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return Vec128{_mm_maskz_compress_epi32(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return Vec128{_mm_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + HWY_DASSERT(mask.raw < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Full128 d; + const Repartition d8; + const auto index = Load(d8, u8_indices + 16 * mask.raw); + return BitCast(d, TableLookupBytes(BitCast(d8, v), index)); +} + +// ------------------------------ CompressNot (Compress) + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits (LoadMaskBits) + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Simd(), bits)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd d, T* HWY_RESTRICT unaligned) { + const Rebind du; + const auto vu = BitCast(du, v); // (required for float16_t inputs) + + const uint64_t mask_bits{mask.raw}; + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + _mm_mask_compressstoreu_epi16(unaligned, mask.raw, vu.raw); +#else + const auto idx = detail::IndicesForCompress16(mask_bits); + const Vec128 cu{_mm_permutexvar_epi16(idx.raw, vu.raw)}; + StoreU(BitCast(d, cu), d, unaligned); +#endif // HWY_TARGET == HWY_AVX3_DL + + const size_t count = PopCount(mask_bits & ((1ull << N) - 1)); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + float* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(float)); +#endif + return count; +} + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 mask, + Simd /* tag */, + double* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(double)); +#endif + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + // AVX-512 already does the blending at no extra cost (latency 11, + // rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { + // We're relying on the mask to blend. Clear the undefined upper bits. + if (N != 16 / sizeof(T)) { + m = And(m, FirstN(d, N)); + } + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + const Vec128 compressed = Compress(v, m); +#if HWY_MEM_OPS_MIGHT_FAULT + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(16) T buf[N]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + // Workaround: as of 2022-02-23 MSAN does not mark the output as + // initialized. +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; + } +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 or below + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128 vbits{_mm_cvtsi32_si128(static_cast(mask_bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask128 LoadMaskBits(Simd d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask128 LoadMaskBits(Simd d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast(static_cast(mask_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + const Simd d; + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_ps(sign_bits.raw)); +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_pd(sign_bits.raw)); +} + +// Returns the lowest N of the _mm_movemask* bits. +template +constexpr uint64_t OnlyActive(uint64_t mask_bits) { + return ((N * sizeof(T)) == 16) ? mask_bits : mask_bits & ((1ull << N) - 1); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Simd /* tag */, + const Mask128 mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(const Simd /* tag */, const Mask128 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Simd /* tag */, const Mask128 mask) { + constexpr uint64_t kAllBits = + detail::OnlyActive((1ull << (16 / sizeof(T))) - 1); + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Simd /* tag */, + const Mask128 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Simd /* tag */, + const Mask128 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromBits(Simd d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IndicesFromNotBits(Simd d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const Simd d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const Simd d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return detail::CompressBits(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNotBits(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(Vec128 v, Mask128 m, Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + + return count; +} + +template +HWY_API size_t CompressBlendedStore(Vec128 v, Mask128 m, + Simd d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressBitsStore(Vec128 v, + const uint8_t* HWY_RESTRICT bits, + Simd d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag /* tag */, + const Vec128 v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return v10 + Shuffle2301(v10); +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Min(v10, Shuffle2301(v10)); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v10) { + return Max(v10, Shuffle2301(v10)); +} + +// N=4 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template +HWY_INLINE Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Simd d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template +HWY_API Vec128 SumOfLanes(Simd /* tag */, const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(Simd /* tag */, const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Lt128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Also used by x86_256/x86_512. +template > +HWY_INLINE V Eq128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template > +HWY_INLINE V Ne128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned>() && sizeof(TFromD) == 8, + "D must be u64"); + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template > +HWY_INLINE V Lt128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template > +HWY_INLINE V Eq128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template > +HWY_INLINE V Ne128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template > +HWY_API MFromD Lt128(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template > +HWY_API MFromD Eq128(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template > +HWY_API MFromD Ne128(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template > +HWY_API MFromD Lt128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Eq128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Ne128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/media/highway/src/hwy/ops/x86_256-inl.h b/media/highway/src/hwy/ops/x86_256-inl.h new file mode 100644 index 000000000..12a83cbfc --- /dev/null +++ b/media/highway/src/hwy/ops/x86_256-inl.h @@ -0,0 +1,5619 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when +// compiling for that target. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +// Must come before HWY_COMPILER_CLANGCL +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include +// avxintrin defines __m256i and must come before avx2intrin. +#include +#include // _pext_u64 +#include +#include +#include +#endif // HWY_COMPILER_CLANGCL + +#include +#include +#include // memcpy + +#if HWY_IS_MSAN +#include +#endif + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw256 { + using type = __m256i; +}; +template <> +struct Raw256 { + using type = __m256; +}; +template <> +struct Raw256 { + using type = __m256d; +}; + +} // namespace detail + +template +class Vec256 { + using Raw = typename detail::Raw256::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +#if HWY_TARGET <= HWY_AVX3 + +namespace detail { + +// Template arg: sizeof(lane type) +template +struct RawMask256 {}; +template <> +struct RawMask256<1> { + using type = __mmask32; +}; +template <> +struct RawMask256<2> { + using type = __mmask16; +}; +template <> +struct RawMask256<4> { + using type = __mmask8; +}; +template <> +struct RawMask256<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +struct Mask256 { + using Raw = typename detail::RawMask256::type; + + static Mask256 FromBits(uint64_t mask_bits) { + return Mask256{static_cast(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 + +// FF..FF or 0. +template +struct Mask256 { + typename detail::Raw256::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } +HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } +HWY_INLINE __m256i BitCastToInteger(__m256d v) { + return _mm256_castpd_si256(v); +} + +template +HWY_INLINE Vec256 BitCastToByte(Vec256 v) { + return Vec256{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } +}; + +template +HWY_INLINE Vec256 BitCastFromByte(Full256 /* tag */, Vec256 v) { + return Vec256{BitCastFromInteger256()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_si256()}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_ps()}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { + return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { + return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { + return Vec256{_mm256_set1_epi32(static_cast(t))}; +} +HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { + return Vec256{ + _mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { + return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { + return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { + return Vec256{_mm256_set1_epi32(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { + return Vec256{ + _mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const float t) { + return Vec256{_mm256_set1_ps(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const double t) { + return Vec256{_mm256_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec256 Undefined(Full256 /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec256{_mm256_undefined_si256()}; +} +HWY_API Vec256 Undefined(Full256 /* tag */) { + return Vec256{_mm256_undefined_ps()}; +} +HWY_API Vec256 Undefined(Full256 /* tag */) { + return Vec256{_mm256_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 And(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_and_ps(a.raw, b.raw)}; +} +HWY_API Vec256 And(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_si256(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(const Vec256 not_mask, + const Vec256 mask) { + return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(const Vec256 not_mask, + const Vec256 mask) { + return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_or_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not + +template +HWY_API Vec256 Not(const Vec256 v) { + using TU = MakeUnsigned; +#if HWY_TARGET <= HWY_AVX3 + const __m256i vu = BitCast(Full256(), v).raw; + return BitCast(Full256(), + Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(Full256(), Vec256{_mm256_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Or3 + +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { +#if HWY_TARGET <= HWY_AVX3 + const Full256 d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 PopulationCount(Vec256 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ CopySign + +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Full256 d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m256i out = _mm256_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(Full256()), sign)); +#endif +} + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return Vec256{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec256{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec256 ZeroIfNegative(const Vec256 v) { + static_assert(IsSigned(), "Only for float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask256 And(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Or(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Not(const Mask256 m) { + // Flip only the valid bits. + constexpr size_t N = 32 / sizeof(T); + return Xor(m, Mask256::FromBits((1ull << N) - 1)); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { + return Vec256{v.raw}; +} + +// ------------------------------ IfThenElse + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(const Mask256 mask, + const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(const Mask256 mask, + const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + static_assert(IsSigned(), "Only for float"); + const auto zero = Zero(Full256()); + // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes + return IfThenElse(MaskFromVec(v), zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<1> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<2> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<4> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<8> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; +} +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi8(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi16(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi32(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi64(v.raw)}; +} + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; +} + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; +} + +template +HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { + return VecFromMask(v); +} + +#else // AVX2 + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 d_to, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(d_to, VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Not(a == b); +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; +} +HWY_API Mask256 operator!=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 +// to perform an unsigned comparison instead of the intended signed. Workaround +// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy +#if HWY_COMPILER_GCC != 0 && HWY_COMPILER_GCC < 930 +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 +#else +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 +#endif + +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > + reinterpret_cast(b.raw))}; +#else + return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; +#endif +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, Vec256 b) { + const Full256 du; + const RebindToSigned di; + const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); + return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); +} + +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; +} + +} // namespace detail + +template +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return b > a; +} + +template +HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { + return b >= a; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API Mask256 FirstN(const Full256 d, size_t n) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + constexpr size_t N = 32 / sizeof(T); +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of n! + return Mask256::FromBits((n > 255) ? all : _bzhi_u64(all, n)); +#else + const uint32_t all = static_cast((1ull << N) - 1); + // BZHI only looks at the lower 8 bits of n! + return Mask256::FromBits( + (n > 255) ? all : _bzhi_u32(all, static_cast(n))); +#endif // HWY_ARCH_X86_64 +#else + const RebindToSigned di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(n))); +#endif +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_add_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(const Vec256 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (wrong result) + const auto zero = Zero(Full256()); + return Vec256{_mm256_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec256{_mm256_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi16(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. + +HWY_API Vec256 Abs(const Vec256 v) { + const Vec256 mask{_mm256_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(Full256(), mask); +} +HWY_API Vec256 Abs(const Vec256 v) { + const Vec256 mask{_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(Full256(), mask); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec256 MulFixedPoint15(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ RotateRight + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return VecFromMask(v < Zero(Full256())); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { +#if HWY_TARGET == HWY_AVX2 + return VecFromMask(v < Zero(Full256())); +#else + return Vec256{_mm256_srai_epi64(v.raw, 63)}; +#endif +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srai_epi64(v.raw, kBits)}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +HWY_API Vec256 Abs(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_abs_epi64(v.raw)}; +#else + const auto zero = Zero(Full256()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, + Vec256 no) { + // int8: AVX2 IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Full256 d; + const RebindToSigned di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const Full256 d; + const RebindToFloat df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + const MFromD msb = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Neg (Xor, Sub) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec256 Neg(hwy::FloatTag /*tag*/, const Vec256 v) { + return Xor(v, SignBit(Full256())); +} + +// Not floating-point +template +HWY_INLINE Vec256 Neg(hwy::NonFloatTag /*tag*/, const Vec256 v) { + return Zero(Full256()) - v; +} + +} // namespace detail + +template +HWY_API Vec256 Neg(const Vec256 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_div_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + return Vec256{_mm256_rcp_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 NegMulAdd(const Vec256 mul, + const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 NegMulSub(const Vec256 mul, + const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{_mm256_sqrt_ps(v.raw)}; +} +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{_mm256_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { + return Vec256{_mm256_rsqrt_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ------------------------------ Floating-point classification + +HWY_API Mask256 IsNaN(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask256{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} +HWY_API Mask256 IsNaN(const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask256{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +HWY_API Mask256 IsInf(const Vec256 v) { + return Mask256{_mm256_fpclass_ps_mask(v.raw, 0x18)}; +} +HWY_API Mask256 IsInf(const Vec256 v) { + return Mask256{_mm256_fpclass_pd_mask(v.raw, 0x18)}; +} + +HWY_API Mask256 IsFinite(const Vec256 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256{_mm256_fpclass_ps_mask(v.raw, 0x99)}); +} +HWY_API Mask256 IsFinite(const Vec256 v) { + return Not(Mask256{_mm256_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template +HWY_API Mask256 IsInf(const Vec256 v) { + static_assert(IsFloat(), "Only for float"); + const Full256 d; + const RebindToSigned di; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + static_assert(IsFloat(), "Only for float"); + const Full256 d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD exp = + BitCast(di, ShiftRight() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec256{ + _mm256_load_si256(reinterpret_cast(aligned))}; +} +HWY_API Vec256 Load(Full256 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ps(aligned)}; +} +HWY_API Vec256 Load(Full256 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_pd(aligned)}; +} + +template +HWY_API Vec256 LoadU(Full256 /* tag */, const T* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_si256(reinterpret_cast(p))}; +} +HWY_API Vec256 LoadU(Full256 /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ps(p)}; +} +HWY_API Vec256 LoadU(Full256 /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_epi64(m.raw, p)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_ps(m.raw, p)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_pd(m.raw, p)}; +} + +#else // AVX2 + +// There is no maskload_epi8/16, so blend instead. +template * = nullptr> +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return Vec256{_mm256_maskload_epi32(pi, m.raw)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, + const T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return Vec256{_mm256_maskload_epi64(pi, m.raw)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_ps(p, mi.raw)}; +} + +HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, + const double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_pd(p, mi.raw)}; +} + +#endif + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API Vec256 LoadDup128(Full256 /* tag */, const T* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note + // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the + // upper half undefined) is fine because we're overwriting that anyway. + // This workaround seems in turn to generate incorrect code in MSVC 2022 + // (19.31), so use broadcastsi128 there. + const __m128i v128 = LoadU(Full128(), p).raw; + return Vec256{ + _mm256_inserti128_si256(_mm256_castsi128_si256(v128), v128, 1)}; +#else + return Vec256{_mm256_broadcastsi128_si256(LoadU(Full128(), p).raw)}; +#endif +} +HWY_API Vec256 LoadDup128(Full256 /* tag */, + const float* const HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const __m128 v128 = LoadU(Full128(), p).raw; + return Vec256{ + _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; +#else + return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; +#endif +} +HWY_API Vec256 LoadDup128(Full256 /* tag */, + const double* const HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const __m128d v128 = LoadU(Full128(), p).raw; + return Vec256{ + _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; +#else + return Vec256{ + _mm256_broadcast_pd(reinterpret_cast(p))}; +#endif +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { + _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Store(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT p) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); +} +HWY_API void StoreU(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT p) { + _mm256_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT p) { + _mm256_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi64(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, + Full256 /* tag */, float* HWY_RESTRICT p) { + _mm256_mask_storeu_ps(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, + Full256 /* tag */, double* HWY_RESTRICT p) { + _mm256_mask_storeu_pd(p, m.raw, v.raw); +} + +#else // AVX2 + +// Intel SDM says "No AC# reported for any mask bit combinations". However, AMD +// allows AC# if "Alignment checking enabled and: 256-bit memory operand not +// 32-byte aligned". Fortunately AC# is not enabled by default and requires both +// OS support (CR0) and the application to set rflags.AC. We assume these remain +// disabled because x86/x64 code and compiler output often contain misaligned +// scalar accesses, which would also fault. +// +// Caveat: these are slow on AMD Jaguar/Bulldozer. + +template * = nullptr> +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT p) { + // There is no maskload_epi8/16. Blending is also unsafe because loading a + // full vector that crosses the array end causes asan faults. Resort to scalar + // code; the caller should instead use memcpy, assuming m is FirstN(d, n). + const RebindToUnsigned du; + using TU = TFromD; + alignas(32) TU buf[32 / sizeof(T)]; + alignas(32) TU mask[32 / sizeof(T)]; + Store(BitCast(du, v), du, buf); + Store(BitCast(du, VecFromMask(d, m)), du, mask); + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 /* tag */, + T* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi64(pi, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, Full256 d, + float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_ps(p, mi.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256 v, Mask256 m, + Full256 d, double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_pd(p, mi.raw, v.raw); +} + +#endif + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(Vec256 v, Full256 /* tag */, + T* HWY_RESTRICT aligned) { + _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i32scatter_ps(base, index.raw, v.raw, 4); +} + +HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i64scatter_pd(base, index.raw, v.raw, 8); +} + +#else + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Full256(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<4> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<4> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<8> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<8> /* tag */, + Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec256 GatherOffset(Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec256 GatherIndex(Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +HWY_API Vec256 GatherOffset(Full256 /* tag */, + const float* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i32gather_ps(base, offset.raw, 1)}; +} +HWY_API Vec256 GatherIndex(Full256 /* tag */, + const float* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i32gather_ps(base, index.raw, 4)}; +} + +HWY_API Vec256 GatherOffset(Full256 /* tag */, + const double* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i64gather_pd(base, offset.raw, 1)}; +} +HWY_API Vec256 GatherIndex(Full256 /* tag */, + const double* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i64gather_pd(base, index.raw, 8)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_castsi256_si128(v.raw)}; +} +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_castps256_ps128(v.raw)}; +} +HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_castpd256_pd128(v.raw)}; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return LowerHalf(Full128(), v); +} + +// ------------------------------ UpperHalf + +template +HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_extracti128_si256(v.raw, 1)}; +} +HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_extractf128_ps(v.raw, 1)}; +} +HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { + return Vec128{_mm256_extractf128_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + const Full256 d; + HWY_DASSERT(i < Lanes(d)); + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + const Full256 d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (VEX encoded +// 128-bit instructions zero the upper lanes to avoid large penalties), a +// compiler could decide to optimize out code that relies on this. +// +// The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the +// zeroing, but it is not available on MSVC until 15.7 nor GCC until 10.1. For +// older GCC, we can still obtain the desired code thanks to pattern +// recognition; note that the expensive insert instruction is not actually +// generated, see https://gcc.godbolt.org/z/1MKGaP. + +#if !defined(HWY_HAVE_ZEXT) +#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC >= 1915) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) +#define HWY_HAVE_ZEXT 1 +#else +#define HWY_HAVE_ZEXT 0 +#endif +#endif // defined(HWY_HAVE_ZEXT) + +template +HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { +#if HWY_HAVE_ZEXT +return Vec256{_mm256_zextsi128_si256(lo.raw)}; +#else + return Vec256{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; +#endif +} +HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, + Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextps128_ps256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; +#endif +} +HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, + Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextpd128_pd256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ Combine + +template +HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; +} +HWY_API Vec256 Combine(Full256 d, Vec128 hi, + Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; +} +HWY_API Vec256 Combine(Full256 d, Vec128 hi, + Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bslli_epi128. + return Vec256{_mm256_slli_si256(v.raw, kBytes)}; +} + +template +HWY_API Vec256 ShiftLeftBytes(const Vec256 v) { + return ShiftLeftBytes(Full256(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(Full256(), v); +} + +// ------------------------------ ShiftRightBytes + +template +HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bsrli_epi128. + return Vec256{_mm256_srli_si256(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ CombineShiftRightBytes + +// Extracts 128 bits from by skipping the least-significant kBytes. +template > +HWY_API V CombineShiftRightBytes(Full256 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec256{_mm256_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec256 Broadcast(Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +namespace detail { + +template +HWY_API Vec256 Shuffle2301(const Vec256 a, const Vec256 b) { + const Full256 d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 Shuffle1230(const Vec256 a, const Vec256 b) { + const Full256 d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 Shuffle3012(const Vec256 a, const Vec256 b) { + const Full256 d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + // Shorter encoding than _mm256_permute_ps. + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + // Shorter encoding than _mm256_permute_pd. + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; +} + +// Rotate right 32 bits +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices256 { + __m256i raw; +}; + +// Native 8x32 instruction: indices remain unchanged +template +HWY_API Indices256 IndicesFromVec(Full256 /* tag */, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(32 / sizeof(T)))))); +#endif + return Indices256{vec.raw}; +} + +// 64-bit lanes: convert indices to 8x32 unless AVX3 is available +template +HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 idx64) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + const Rebind di; + (void)di; // potentially unused +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && + AllTrue(di, Lt(idx64, Set(di, static_cast(32 / sizeof(T)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Indices256{idx64.raw}; +#else + const Repartition df; // 32-bit! + // Replicate 64-bit index into upper 32 bits + const Vec256 dup = + BitCast(di, Vec256{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); + // For each idx64 i, idx32 are 2*i and 2*i+1. + const Vec256 idx32 = dup + dup + Set(di, TI(1) << 32); + return Indices256{idx32.raw}; +#endif +} + +template +HWY_API Indices256 SetTableIndices(const Full256 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_epi64(idx.raw, v.raw)}; +#else + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +#endif +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_pd(idx.raw, v.raw)}; +#else + const Full256 df; + const Full256 du; + return BitCast(df, Vec256{_mm256_permutevar8x32_epi32( + BitCast(du, v).raw, idx.raw)}); +#endif +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute2x128_si256(v.raw, v.raw, 0x01)}; +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute2f128_ps(v.raw, v.raw, 0x01)}; +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute2f128_pd(v.raw, v.raw, 0x01)}; +} + +// ------------------------------ Reverse (RotateRight) + +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + alignas(32) constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { + alignas(32) constexpr int64_t kReverse[4] = {3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec256 idx = Load(di, kReverse); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide> du32; + const Vec256 rev32 = Reverse(du32, BitCast(du32, v)); + return BitCast(d, RotateRight<16>(rev32)); +#endif +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { + const Full256 du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 (SwapAdjacentBlocks) + +template +HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse4[16] = {3, 2, 1, 0, 7, 6, 5, 4, + 11, 10, 9, 8, 15, 14, 13, 12}; + const Vec256 idx = Load(di, kReverse4); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { + // Could also use _mm256_permute4x64_epi64. + return SwapAdjacentBlocks(Shuffle01(v)); +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec256 idx = Load(di, kReverse8); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template +HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { + return Reverse(d, v); +} + +template +HWY_API Vec256 Reverse8(Full256 /* tag */, const Vec256 /* v */) { + HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template > +HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template > +HWY_API Vec256 ZipLower(Vec256 a, Vec256 b) { + return BitCast(Full256(), InterleaveLower(a, b)); +} +template > +HWY_API Vec256 ZipLower(Full256 dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveLower(a, b)); +} + +template > +HWY_API Vec256 ZipUpper(Full256 dw, Vec256 a, Vec256 b) { + return BitCast(dw, InterleaveUpper(Full256(), a, b)); +} + +// ------------------------------ Blocks (LowerHalf, ZeroExtendVector) + +// _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. +// _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no +// extra cost) for LowerLower and UpperLower. + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half d2; + return Vec256{_mm256_inserti128_si256(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, + const Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +HWY_API Vec256 ConcatLowerLower(Full256 d, + const Vec256 hi, + const Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; +} +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; +} +HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; +} +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; +} +HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x31)}; +} +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; +} +HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, + const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; +} + +// ------------------------------ ConcatOdd + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(32) constexpr uint8_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 8-bit shift so we can pack. + const Vec256 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<8>(BitCast(dw, lo)); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint16_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 16-bit shift so we can pack. + const Vec256 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<16>(BitCast(dw, lo)); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v3131{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return Vec256{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, + Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + const Vec256 v3131{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v31{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; + return Vec256{ + _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, + Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + (void)d; + const Vec256 v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; + return Vec256{ + _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ ConcatEven + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec256 mask = Set(dw, 0x00FF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint16_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 16 bits per u32 so we can pack. + const Vec256 mask = Set(dw, 0x0000FFFF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return Vec256{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v2020{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return Vec256{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, + Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + const Vec256 v2020{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + +#endif +} + +template +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v20{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; + return Vec256{ + _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, + Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + (void)d; + const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; + return Vec256{ + _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec256 DupEven(const Vec256 v) { + return InterleaveLower(Full256(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec256 DupOdd(const Vec256 v) { + return InterleaveUpper(Full256(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, + const Vec256 b) { + const Full256 d; + const Full256 d8; + alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi16(a.raw, b.raw, 0x55)}; +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; +} +template +HWY_INLINE Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; +} + +} // namespace detail + +template +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; +} + +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; +} + +// ------------------------------ OddEvenBlocks + +template +Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_epi32(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; +} + +// ------------------------------ ReverseBlocks (ConcatLowerUpper) + +template +HWY_API Vec256 ReverseBlocks(Full256 d, Vec256 v) { + return ConcatLowerUpper(d, v, v); +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, + const Vec256 from) { + return Vec256{_mm256_shuffle_epi8(bytes.raw, from.raw)}; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(const Vec256 bytes, + const Vec128 from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(Full128(), tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(const Vec128 bytes, + const Vec256 from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by x86_128. + +// ------------------------------ Shl (Mul, ZipLower) + +namespace detail { + +#if HWY_TARGET > HWY_AVX3 // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec256> Pow2(const Vec256 v) { + static_assert(sizeof(T) == 2, "Only for 16-bit"); + const Full256 d; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // Do not use ConvertTo because it checks for overflow, which is redundant + // because we only care about v in [0, 16). + const Vec256 bits0{_mm256_cvttps_epi32(BitCast(df, f0).raw)}; + const Vec256 bits1{_mm256_cvttps_epi32(BitCast(df, f1).raw)}; + return Vec256>{_mm256_packus_epi32(bits0.raw, bits1.raw)}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; +} + +template +HWY_INLINE Vec256 Shl(hwy::SignedTag /*tag*/, Vec256 v, Vec256 bits) { + // Signed left shifts are the same as unsigned. + const Full256 di; + const Full256> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec256 operator<<(Vec256 v, Vec256 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; +#else + Full256 d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + auto out = MulHigh(v, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), v, out); +#endif +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256(), v, bits); +#endif +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256(), v, bits); +#endif +} + +HWY_INLINE Vec256 MulEven(const Vec256 a, + const Vec256 b) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveLower(mulL, mulH); +} + +HWY_INLINE Vec256 MulOdd(const Vec256 a, + const Vec256 b) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need bits [95:64] (= upper half of input) + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Same as above, but we're using the odd results (upper 64 bits per block). + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveUpper(du64, mulL, mulH); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 df32, + Vec256 a, + Vec256 b, + const Vec256 sum0, + Vec256& sum1) { + // TODO(janwas): _mm256_dpbf16_ps when available + const Repartition du16; + const RebindToUnsigned du32; + const Vec256 zero = Zero(du16); + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo. + const Vec256 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec256 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec256 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec256 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +HWY_API Vec256 ReorderWidenMulAccumulate(Full256 /*d32*/, + Vec256 a, + Vec256 b, + const Vec256 sum0, + Vec256& /*sum1*/) { + return sum0 + Vec256{_mm256_madd_epi16(a.raw, b.raw)}; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtps_pd(v.raw)}; +} + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtepi32_pd(v.raw)}; +} + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi32_epi64(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); + // Concatenating lower halves of both 128-bit blocks afterward is more + // efficient than an extra input with low block = high block of v. + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i u16_concat = _mm256_permute4x64_epi64(u16_blocks, 0x88); + const __m128i u16 = _mm256_castsi256_si128(u16_concat); + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full64 /* tag */, + const Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; +} + + // Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". + // 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") + +HWY_API Vec128 DemoteTo(Full128 df16, + const Vec256 v) { +#ifdef HWY_DISABLE_F16C + const RebindToUnsigned du16; + const Rebind du; + const RebindToSigned di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +HWY_API Vec128 DemoteTo(Full128 dbf16, + const Vec256 v) { + // TODO(janwas): _mm256_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec256 ReorderDemote2To(Full256 dbf16, + Vec256 a, Vec256 b) { + // TODO(janwas): _mm256_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec256 ReorderDemote2To(Full256 /*d16*/, + Vec256 a, Vec256 b) { + return Vec256{_mm256_packs_epi32(a.raw, b.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{_mm256_cvtpd_ps(v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const auto clamped = detail::ClampF64ToI32Max(Full256(), v); + return Vec128{_mm256_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec256 v) { + const Full256 d32; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Full128(), quad); + const auto pair = LowerHalf(lo | hi); + return BitCast(Full64(), pair); +} + +// ------------------------------ Truncations + +namespace detail { + +// LO and HI each hold four indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatHalves(Vec256 v) { + const Full256 d32; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint32_t kMap[8] = { + LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); +#else + alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, + ~0u, ~0u, LO, HI}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); + // Possible alternative: + // const auto lo = LowerHalf(quad); + // const auto hi = UpperHalf(Full128(), quad); + // const auto result = lo | hi; +#endif + + return Vec128{_mm256_castsi256_si128(result)}; +} + +// LO and HI each hold two indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatQuarters(Vec256 v) { + const Full256 d16; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint16_t kMap[16] = { + LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d16, kMap).raw); + return LowerHalf(Vec128{_mm256_castsi256_si128(result)}); +#else + constexpr uint16_t ff = static_cast(~0u); + alignas(32) static constexpr uint16_t kMap[16] = { + LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; + const auto quad = TableLookupBytes(v, Load(d16, kMap)); + const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); + const auto half = _mm256_castsi256_si128(mixed); + return LowerHalf(Vec128{_mm_packus_epi32(half, half)}); +#endif +} + +} // namespace detail + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const Full256 d32; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); + return LowerHalf(LowerHalf(LowerHalf(Vec256{result}))); +#else + alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, + 0x0800FFFFu, ~0u, ~0u, ~0u}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Full128(), quad); + const auto result = lo | hi; + return LowerHalf(LowerHalf(Vec128{result.raw})); +#endif +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); + return Vec128{result.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const Full256 d32; + alignas(32) constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto v32 = + TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); + return LowerHalf(Vec256{v32.raw}); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); + return Vec128{full.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); + return Vec128{full.raw}; +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); + return Vec128{full.raw}; +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{_mm256_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec256 ConvertTo(Full256 dd, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec256{_mm256_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +HWY_API Vec256 ConvertTo(HWY_MAYBE_UNUSED Full256 df, + const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +HWY_API Vec256 ConvertTo(HWY_MAYBE_UNUSED Full256 dd, + const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFFULL); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double256_fast = [&dd](Vec256 w) HWY_ATTR { + w = Or(w, Vec256{ + detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double256_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double256_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { + return detail::FixConversionOverflow(d, v, _mm256_cvttps_epi32(v.raw)); +} + +HWY_API Vec256 ConvertTo(Full256 di, const Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::FixConversionOverflow(di, v, _mm256_cvttpd_epi64(v.raw)); +#else + using VI = decltype(Zero(di)); + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} + +HWY_API Vec256 NearestInt(const Vec256 v) { + const Full256 di; + return detail::FixConversionOverflow(di, v, _mm256_cvtps_epi32(v.raw)); +} + + +HWY_API Vec256 PromoteTo(Full256 df32, + const Vec128 v) { +#ifdef HWY_DISABLE_F16C + const RebindToSigned di32; + const RebindToUnsigned du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec256{_mm256_cvtph_ps(v.raw)}; +#endif +} + +HWY_API Vec256 PromoteTo(Full256 df32, + const Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec256 AESRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESLastRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulLower(LowerHalf(a), LowerHalf(b))); +#endif +} + +HWY_API Vec256 CLMulUpper(Vec256 a, Vec256 b) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulUpper(LowerHalf(a), LowerHalf(b))); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +HWY_API Vec256 Iota(const Full256 d, const T2 first) { + HWY_ALIGN T lanes[32 / sizeof(T)]; + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadMaskBits + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask256 LoadMaskBits(const Full256 /* tag */, + const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask256::FromBits(mask_bits); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, + uint8_t* bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = static_cast((1ull << N) - 1); + bits[0] = static_cast(bits[0] & mask_bits); + } + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 d, const Mask256 mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + return (uint64_t{mask.raw} & 0xF) == 0; +} + +} // namespace detail + +template +HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { + return detail::AllFalse(hwy::SizeTag(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + // Cannot use _kortestc because we have less than 8 mask bits. + return mask.raw == 0xFu; +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { + return detail::AllTrue(hwy::SizeTag(), mask); +} + +// ------------------------------ Compress + +// 16-bit is defined in x86_512 so we can use 512-bit vectors. + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintCompress64x4NibbleTables + 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, + 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, + 0x00001032, 0x00001320, 0x00000321, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const Full256 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot (Compress) + +template +HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintCompressNot64x4NibbleTables + 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, + 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, + 0x00003210, 0x00003201, 0x00003210, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const Full256 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressBits (LoadMaskBits) +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Full256(), bits)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 d, + T* HWY_RESTRICT unaligned) { + const Rebind du; + const auto vu = BitCast(du, v); // (required for float16_t inputs) + + const uint64_t mask_bits{mask.raw}; + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, vu.raw); +#else + // Split into halves to keep the table size manageable. + const Half duh; + const auto vL = LowerHalf(duh, vu); + const auto vH = UpperHalf(duh, vu); + + const uint64_t mask_bitsL = mask_bits & 0xFF; + const uint64_t mask_bitsH = mask_bits >> 8; + + const auto idxL = detail::IndicesForCompress16(mask_bitsL); + const auto idxH = detail::IndicesForCompress16(mask_bitsH); + + // Compress and 128-bit halves. + const Vec128 cL{_mm_permutexvar_epi16(idxL.raw, vL.raw)}; + const Vec128 cH{_mm_permutexvar_epi16(idxH.raw, vH.raw)}; + const Half dh; + StoreU(BitCast(dh, cL), dh, unaligned); + StoreU(BitCast(dh, cH), dh, unaligned + PopCount(mask_bitsL)); +#endif // HWY_TARGET == HWY_AVX3_DL + + return PopCount(mask_bits); +} + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, + T* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, + T* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, + Full256 /* tag */, + float* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(float)); +#endif + return count; +} + +HWY_API size_t CompressStore(Vec256 v, Mask256 mask, + Full256 /* tag */, + double* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(double)); +#endif + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) + +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + // Native (32 or 64-bit) AVX-512 instruction already does the blending at no + // extra cost (latency 11, rthroughput 2 - same as compress plus store). + return CompressStore(v, m, d, unaligned); +} + +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL + return CompressStore(v, m, d, unaligned); // also native +#else + const size_t count = CountTrue(d, m); + BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +#endif +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + Full256 d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_LE128 there. +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + const Repartition du32; + const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + const Repartition du64; + alignas(32) constexpr uint64_t kRep8[4] = { + 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, + 0x0303030303030303ull}; + const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); + + alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint16_t kBit[16] = { + 1, 2, 4, 8, 16, 32, 64, 128, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(32) constexpr uint64_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask256 LoadMaskBits(Full256 d, + const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits256(d, mask_bits); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + const Full256 d; + const Full256 d8; + const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; + // Prevent sign-extension of 32-bit masks because the intrinsic returns int. + return static_cast(_mm256_movemask_epi8(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { +#if HWY_ARCH_X86_64 + const Full256 d; + const Full256 d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + const uint64_t sign_bits8 = BitsFromMask(mask8); + // Skip the bits from the lower byte of each u16 (better not to use the + // same packs_epi16 as SSE4, because that requires an extra swizzle here). + return _pext_u64(sign_bits8, 0xAAAAAAAAull); +#else + // Slow workaround for 32-bit builds, which lack _pext_u64. + // Remove useless lower half of each u16 while preserving the sign bit. + // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. + const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); + // Move odd qwords (value zero) to top so they don't affect the mask value. + const auto compressed = + _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0)); + return static_cast(_mm256_movemask_epi8(compressed)); +#endif // HWY_ARCH_X86_64 +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + const Full256 d; + const Full256 df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_ps(sign_bits)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { + const Full256 d; + const Full256 df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_pd(sign_bits)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, + uint8_t* bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask +// lane is 0 or ~0. +template +HWY_API bool AllFalse(const Full256 d, const Mask256 mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return detail::BitsFromMask(mask8) == 0; +} + +template +HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Full256 d, const Mask256 mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return detail::BitsFromMask(mask8) == (1ull << 32) - 1; +} +template +HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { + constexpr uint64_t kAllBits = (1ull << (32 / sizeof(T))) - 1; + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Full256 d, const Mask256 mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return PopCount(detail::BitsFromMask(mask8)) >> 1; +} +template +HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, + const Mask256 mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +template +HWY_INLINE Vec256 IndicesFromBits(Full256 d, uint64_t mask_bits) { + const RebindToUnsigned d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintCompress32x8Tables + 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, + 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, + 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, + 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, + 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, + 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, + 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, + 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, + 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, + 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, + 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, + 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, + 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, + 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, + 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, + 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, + 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, + 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, + 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, + 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, + 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, + 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, + 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, + 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, + 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, + 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, + 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, + 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, + 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, + 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, + 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, + 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, + 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, + 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, + 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, + 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, + 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, + 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, + 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, + 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, + 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, + 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, + 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; + + // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromBits(Full256 d, uint64_t mask_bits) { + const Repartition d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t u32_indices[128] = { + // PrintCompress64x4PairTables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, + 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, + 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, + 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, + 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, + 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, + 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, + 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits(Full256 d, + uint64_t mask_bits) { + const RebindToUnsigned d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintCompressNot32x8Tables + 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, + 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, + 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, + 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, + 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, + 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, + 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, + 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, + 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, + 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, + 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, + 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, + 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, + 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, + 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, + 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, + 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, + 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, + 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, + 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, + 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, + 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, + 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, + 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, + 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, + 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, + 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, + 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, + 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, + 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, + 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, + 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, + 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, + 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, + 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, + 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, + 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, + 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, + 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, + 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, + 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, + 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, + 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; + + // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits(Full256 d, + uint64_t mask_bits) { + const Repartition d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t u32_indices[128] = { + // PrintCompressNot64x4PairTables + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, + 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, + 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, + 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, + 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const Full256 d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromBits(d, mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Half duh; + const auto half0 = LowerHalf(duh, vu16); + const auto half1 = UpperHalf(duh, vu16); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = detail::CompressBits(half0, mask_bits0); + const auto compressed1 = detail::CompressBits(half1, mask_bits1); + + alignas(32) uint16_t all_true[16] = {}; + // Store mask=true lanes, left to right. + const size_t num_true0 = PopCount(mask_bits0); + Store(compressed0, duh, all_true); + StoreU(compressed1, duh, all_true + num_true0); + + if (hwy::HWY_NAMESPACE::CompressIsPartition::value) { + // Store mask=false lanes, right to left. The second vector fills the upper + // half with right-aligned false lanes. The first vector is shifted + // rightwards to overwrite the true lanes of the second. + alignas(32) uint16_t all_false[16] = {}; + const size_t num_true1 = PopCount(mask_bits1); + Store(compressed1, duh, all_false + 8); + StoreU(compressed0, duh, all_false + num_true1); + + const auto mask = FirstN(du, num_true0 + num_true1); + return BitCast(d, + IfThenElse(mask, Load(du, all_true), Load(du, all_false))); + } else { + // Only care about the mask=true lanes. + return BitCast(d, Load(du, all_true)); + } +} + +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + const Full256 d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromNotBits(d, mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + // Compress ensures only the lower 16 bits are set, so flip those. + return Compress(v, mask_bits ^ 0xFFFF); +} + +} // namespace detail + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 m) { + return detail::Compress(v, detail::BitsFromMask(m)); +} + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 m) { + return detail::CompressNot(v, detail::BitsFromMask(m)); +} + +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + return CompressNot(v, mask); +} + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + + const Repartition du32; + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). Nibble MSB encodes FirstN. + const Vec256 idx_and_mask = detail::IndicesFromBits(d, mask_bits); + // Shift nibble MSB into MSB + const Mask256 mask32 = MaskFromVec(ShiftLeft<28>(idx_and_mask)); + // First cast to unsigned (RebindMask cannot change lane size) + const Mask256> mask_u{mask32.raw}; + const Mask256 mask = RebindMask(d, mask_u); + const Vec256 compressed = + BitCast(d, TableLookupLanes(BitCast(du32, v), + Indices256{idx_and_mask.raw})); + + BlendedStore(compressed, mask, d, unaligned); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Vec256 compressed = detail::Compress(v, mask_bits); + +#if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(32) T buf[16]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +template +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + Full256 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + // Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(Full256 d, + const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C, + Vec256& D) { + constexpr size_t N = 32 / sizeof(T); + const Vec256 v10 = LoadU(d, unaligned + 0 * N); + const Vec256 v32 = LoadU(d, unaligned + 1 * N); + const Vec256 v54 = LoadU(d, unaligned + 2 * N); + const Vec256 v76 = LoadU(d, unaligned + 3 * N); + + A = ConcatLowerLower(d, v54, v10); + B = ConcatUpperUpper(d, v54, v10); + C = ConcatLowerLower(d, v76, v32); + D = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(const Vec256 i, const Vec256 j, + const Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(const Vec256 i, const Vec256 j, + const Vec256 k, Full256 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(const Vec256 i, const Vec256 j, + const Vec256 k, const Vec256 l, + Full256 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ Reductions + +namespace detail { + +// Returns sum{lane[i]} in each lane. "v3210" is a replicated 128-bit block. +// Same logic as x86/128.h, but with Vec256 arguments. +template +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = v3210 + v1032; + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Min(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Max(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +template +HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return Max(v10, v01); +} + +HWY_API Vec256 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +HWY_API Vec256 SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec256 v) { + const Full256 d; + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for {uif}{32,64},{ui}16. Returns the broadcasted result. +template +HWY_API Vec256 SumOfLanes(Full256 d, const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::SumOfLanes(hwy::SizeTag(), vLH + vHL); +} +template +HWY_API Vec256 MinOfLanes(Full256 d, const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::MinOfLanes(hwy::SizeTag(), Min(vLH, vHL)); +} +template +HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::MaxOfLanes(hwy::SizeTag(), Max(vLH, vHL)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/media/highway/src/hwy/ops/x86_512-inl.h b/media/highway/src/hwy/ops/x86_512-inl.h new file mode 100644 index 000000000..09b14a937 --- /dev/null +++ b/media/highway/src/hwy/ops/x86_512-inl.h @@ -0,0 +1,4412 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 512-bit AVX512 vectors and operations. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +// clang-format off +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +#include +#include + +#if HWY_IS_MSAN +#include +#endif + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_256-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { + +template +struct Raw512 { + using type = __m512i; +}; +template <> +struct Raw512 { + using type = __m512; +}; +template <> +struct Raw512 { + using type = __m512d; +}; + +// Template arg: sizeof(lane type) +template +struct RawMask512 {}; +template <> +struct RawMask512<1> { + using type = __mmask64; +}; +template <> +struct RawMask512<2> { + using type = __mmask32; +}; +template <> +struct RawMask512<4> { + using type = __mmask16; +}; +template <> +struct RawMask512<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +class Vec512 { + using Raw = typename detail::Raw512::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec512& operator*=(const Vec512 other) { + return *this = (*this * other); + } + HWY_INLINE Vec512& operator/=(const Vec512 other) { + return *this = (*this / other); + } + HWY_INLINE Vec512& operator+=(const Vec512 other) { + return *this = (*this + other); + } + HWY_INLINE Vec512& operator-=(const Vec512 other) { + return *this = (*this - other); + } + HWY_INLINE Vec512& operator&=(const Vec512 other) { + return *this = (*this & other); + } + HWY_INLINE Vec512& operator|=(const Vec512 other) { + return *this = (*this | other); + } + HWY_INLINE Vec512& operator^=(const Vec512 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Mask register: one bit per lane. +template +struct Mask512 { + typename detail::RawMask512::type raw; +}; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } +HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } +HWY_INLINE __m512i BitCastToInteger(__m512d v) { + return _mm512_castpd_si512(v); +} + +template +HWY_INLINE Vec512 BitCastToByte(Vec512 v) { + return Vec512{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger512 { + HWY_INLINE __m512i operator()(__m512i v) { return v; } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } +}; + +template +HWY_INLINE Vec512 BitCastFromByte(Full512 /* tag */, Vec512 v) { + return Vec512{BitCastFromInteger512()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 BitCast(Full512 d, Vec512 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_si512()}; +} +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_ps()}; +} +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec512 Set(Full512 /* tag */, const uint8_t t) { + return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint16_t t) { + return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint32_t t) { + return Vec512{_mm512_set1_epi32(static_cast(t))}; +} +HWY_API Vec512 Set(Full512 /* tag */, const uint64_t t) { + return Vec512{ + _mm512_set1_epi64(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int8_t t) { + return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int16_t t) { + return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int32_t t) { + return Vec512{_mm512_set1_epi32(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const int64_t t) { + return Vec512{ + _mm512_set1_epi64(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const float t) { + return Vec512{_mm512_set1_ps(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const double t) { + return Vec512{_mm512_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec512 Undefined(Full512 /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec512{_mm512_undefined_epi32()}; +} +HWY_API Vec512 Undefined(Full512 /* tag */) { + return Vec512{_mm512_undefined_ps()}; +} +HWY_API Vec512 Undefined(Full512 /* tag */) { + return Vec512{_mm512_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec512 Not(const Vec512 v) { + using TU = MakeUnsigned; + const __m512i vu = BitCast(Full512(), v).raw; + return BitCast(Full512(), + Vec512{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); +} + +// ------------------------------ And + +template +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_ps(a.raw, b.raw)}; +} +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { + return Vec512{_mm512_andnot_si512(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Or3 + +template +HWY_API Vec512 Or3(Vec512 o1, Vec512 o2, Vec512 o3) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { + const Full512 d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { + return And(a, b); +} + +template +HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { + return Or(a, b); +} + +template +HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<1> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<2> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<4> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<8> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 PopulationCount(Vec512 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ CopySign + +template +HWY_API Vec512 CopySign(const Vec512 magn, const Vec512 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Full512 d; + const auto msb = SignBit(d); + + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m512i out = _mm512_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +} + +template +HWY_API Vec512 CopySignToAbs(const Vec512 abs, const Vec512 sign) { + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== MASK + +// ------------------------------ FirstN + +// Possibilities for constructing a bitmask of N ones: +// - kshift* only consider the lowest byte of the shift count, so they would +// not correctly handle large n. +// - Scalar shifts >= 64 are UB. +// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, +// we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. + +#if HWY_ARCH_X86_32 +namespace detail { + +// 32 bit mask is sufficient for lane size >= 2. +template +HWY_INLINE Mask512 FirstN(size_t n) { + Mask512 m; + const uint32_t all = ~uint32_t{0}; + // BZHI only looks at the lower 8 bits of n! + m.raw = static_cast((n > 255) ? all : _bzhi_u32(all, n)); + return m; +} + +template +HWY_INLINE Mask512 FirstN(size_t n) { + const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; + return Mask512{static_cast<__mmask64>(bits)}; +} + +} // namespace detail +#endif // HWY_ARCH_X86_32 + +template +HWY_API Mask512 FirstN(const Full512 /*tag*/, size_t n) { +#if HWY_ARCH_X86_64 + Mask512 m; + const uint64_t all = ~uint64_t{0}; + // BZHI only looks at the lower 8 bits of n! + m.raw = static_cast((n > 255) ? all : _bzhi_u64(all, n)); + return m; +#else + return detail::FirstN(n); +#endif // HWY_ARCH_X86_64 +} + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +HWY_API Vec512 IfThenElse(const Mask512 mask, + const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElse(const Mask512 mask, + const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec512 IfThenElseZero(const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElseZero(const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, + const Vec512 no) { + return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, + const Vec512 no) { + return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { + static_assert(IsSigned(), "Only works for signed/float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec512 ZeroIfNegative(const Vec512 v) { + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_add_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec512 SumsOf8(const Vec512 v) { + return Vec512{_mm512_sad_epu8(v.raw, _mm512_setzero_si512())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec512 AverageRound(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 AverageRound(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec512 Abs(const Vec512 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (untested due to internal compiler error) + const auto zero = Zero(Full512()); + return Vec512{_mm512_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec512{_mm512_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi16(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi32(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi64(v.raw)}; +} + +// These aren't native instructions, they also involve AND with constant. +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_ps(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_pd(v.raw)}; +} +// ------------------------------ ShiftLeft + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const Full512 d8; + // Use raw instead of BitCast to support N=1. + const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const Full512 di; + const Full512 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + return Vec512{_mm512_ror_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + return Vec512{_mm512_ror_epi64(v.raw, kBits)}; +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const Full512 di; + const Full512 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Shl + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { + const Full512 di; + const Full512> du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; +} + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// Signed +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512 MulFixedPoint15(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ Neg (Sub) + +template +HWY_API Vec512 Neg(const Vec512 v) { + return Xor(v, SignBit(Full512())); +} + +template +HWY_API Vec512 Neg(const Vec512 v) { + return Zero(Full512()) - v; +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_div_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator/(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp14_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec512 AbsDiff(const Vec512 a, const Vec512 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 NegMulAdd(const Vec512 mul, + const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 NegMulSub(const Vec512 mul, + const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ps(v.raw)}; +} +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec512 ApproximateReciprocalSqrt(const Vec512 v) { + return Vec512{_mm512_rsqrt14_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Toward nearest integer, tie to even +HWY_API Vec512 Round(const Vec512 v) { + return Vec512{_mm512_roundscale_ps( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Round(const Vec512 v) { + return Vec512{_mm512_roundscale_pd( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec512 Trunc(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Trunc(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec512 Ceil(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Ceil(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec512 Floor(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Floor(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask512 RebindMask(Full512 /*tag*/, Mask512 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask512{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<1> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<2> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<4> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<8> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask512 operator<(Vec512 a, Vec512 b) { + return b > a; +} + +template +HWY_API Mask512 operator<=(Vec512 a, Vec512 b) { + return b >= a; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; +} +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi8(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi8(v.raw)}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi16(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi16(v.raw)}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; +} + +template +HWY_API Vec512 VecFromMask(Full512 /* tag */, const Mask512 v) { + return VecFromMask(v); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask512 Not(hwy::SizeTag<1> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask64(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<2> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask32(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<4> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask16(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<8> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask8(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFF)}; +#endif +} + +template +HWY_INLINE Mask512 And(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask64(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Or(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask64(a.raw, b.raw)}; +#else + return Mask512{~(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask512 a, const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask512 Not(const Mask512 m) { + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask512 And(const Mask512 a, Mask512 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 AndNot(const Mask512 a, Mask512 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Or(const Mask512 a, Mask512 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Xor(const Mask512 a, Mask512 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 ExclusiveNeither(const Mask512 a, Mask512 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return VecFromMask(v < Zero(Full512())); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, 63)}; +} + +// ------------------------------ Floating-point classification (Not) + +HWY_API Mask512 IsNaN(const Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask(v.raw, 0x81)}; +} +HWY_API Mask512 IsNaN(const Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask(v.raw, 0x81)}; +} + +HWY_API Mask512 IsInf(const Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask(v.raw, 0x18)}; +} +HWY_API Mask512 IsInf(const Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512 IsFinite(const Vec512 v) { + return Not(Mask512{_mm512_fpclass_ps_mask(v.raw, 0x99)}); +} +HWY_API Mask512 IsFinite(const Vec512 v) { + return Not(Mask512{_mm512_fpclass_pd_mask(v.raw, 0x99)}); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec512 Load(Full512 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_si512(aligned)}; +} +HWY_API Vec512 Load(Full512 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ps(aligned)}; +} +HWY_API Vec512 Load(Full512 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_pd(aligned)}; +} + +template +HWY_API Vec512 LoadU(Full512 /* tag */, const T* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_si512(p)}; +} +HWY_API Vec512 LoadU(Full512 /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ps(p)}; +} +HWY_API Vec512 LoadU(Full512 /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi16(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const T* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_epi64(m.raw, p)}; +} + +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_ps(m.raw, p)}; +} + +HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_pd(m.raw, p)}; +} + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const T* const HWY_RESTRICT p) { + const auto x4 = LoadU(Full128(), p); + return Vec512{_mm512_broadcast_i32x4(x4.raw)}; +} +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const float* const HWY_RESTRICT p) { + const __m128 x4 = _mm_loadu_ps(p); + return Vec512{_mm512_broadcast_f32x4(x4)}; +} + +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const double* const HWY_RESTRICT p) { + const __m128d x2 = _mm_loadu_pd(p); + return Vec512{_mm512_broadcast_f64x2(x2)}; +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Store(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT aligned) { + _mm512_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); +} +HWY_API void StoreU(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT p) { + _mm512_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec512 v, Full512, + double* HWY_RESTRICT p) { + _mm512_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi16(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi64(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec512 v, Mask512 m, + Full512 /* tag */, float* HWY_RESTRICT p) { + _mm512_mask_storeu_ps(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec512 v, Mask512 m, + Full512 /* tag */, double* HWY_RESTRICT p) { + _mm512_mask_storeu_pd(p, m.raw, v.raw); +} + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec512 v, Full512, + double* HWY_RESTRICT aligned) { + _mm512_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +namespace detail { + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec512 v, Full512 d, T* HWY_RESTRICT base, + const Vec512 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec512 v, Full512 d, T* HWY_RESTRICT base, + const Vec512 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i32scatter_ps(base, index.raw, v.raw, 4); +} + +HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<4> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i32gather_epi32(offset.raw, base, 1)}; +} +template +HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<4> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i32gather_epi32(index.raw, base, 4)}; +} + +template +HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<8> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i64gather_epi64(offset.raw, base, 1)}; +} +template +HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<8> /* tag */, + Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i64gather_epi64(index.raw, base, 8)}; +} + +} // namespace detail + +template +HWY_API Vec512 GatherOffset(Full512 d, const T* HWY_RESTRICT base, + const Vec512 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec512 GatherIndex(Full512 d, const T* HWY_RESTRICT base, + const Vec512 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +HWY_API Vec512 GatherOffset(Full512 /* tag */, + const float* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i32gather_ps(offset.raw, base, 1)}; +} +HWY_API Vec512 GatherIndex(Full512 /* tag */, + const float* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i32gather_ps(index.raw, base, 4)}; +} + +HWY_API Vec512 GatherOffset(Full512 /* tag */, + const double* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i64gather_pd(offset.raw, base, 1)}; +} +HWY_API Vec512 GatherIndex(Full512 /* tag */, + const double* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i64gather_pd(index.raw, base, 8)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_castsi512_si256(v.raw)}; +} +HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_castps512_ps256(v.raw)}; +} +HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_castpd512_pd256(v.raw)}; +} + +template +HWY_API Vec256 LowerHalf(Vec512 v) { + return LowerHalf(Full256(), v); +} + +// ------------------------------ UpperHalf + +template +HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_extracti32x8_epi32(v.raw, 1)}; +} +HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_extractf32x8_ps(v.raw, 1)}; +} +HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { + return Vec256{_mm512_extractf64x4_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec512 v, size_t i) { + const Full512 d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec512 InsertLane(const Vec512 v, size_t i, T t) { + const Full512 d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec512 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { +#if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. + return Vec512{_mm512_zextsi256_si512(lo.raw)}; +#else + return Vec512{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; +#endif +} +HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, + Vec256 lo) { +#if HWY_HAVE_ZEXT + return Vec512{_mm512_zextps256_ps512(lo.raw)}; +#else + return Vec512{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; +#endif +} +HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, + Vec256 lo) { +#if HWY_HAVE_ZEXT + return Vec512{_mm512_zextpd256_pd512(lo.raw)}; +#else + return Vec512{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ Combine + +template +HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; +} +HWY_API Vec512 Combine(Full512 d, Vec256 hi, + Vec256 lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; +} +HWY_API Vec512 Combine(Full512 d, Vec256 hi, + Vec256 lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API Vec512 ShiftLeftBytes(Full512 /* tag */, const Vec512 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512{_mm512_bslli_epi128(v.raw, kBytes)}; +} + +template +HWY_API Vec512 ShiftLeftBytes(const Vec512 v) { + return ShiftLeftBytes(Full512(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API Vec512 ShiftLeftLanes(Full512 d, const Vec512 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec512 ShiftLeftLanes(const Vec512 v) { + return ShiftLeftLanes(Full512(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API Vec512 ShiftRightBytes(Full512 /* tag */, const Vec512 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512{_mm512_bsrli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API Vec512 ShiftRightLanes(Full512 d, const Vec512 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API V CombineShiftRightBytes(Full512 d, V hi, V lo) { + const Repartition d8; + return BitCast(d, Vec512{_mm512_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Signed +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Float +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +namespace detail { + +template +HWY_API Vec512 Shuffle2301(const Vec512 a, const Vec512 b) { + const Full512 d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_CDAB)}); +} +template +HWY_API Vec512 Shuffle1230(const Vec512 a, const Vec512 b) { + const Full512 d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_BCDA)}); +} +template +HWY_API Vec512 Shuffle3012(const Vec512 a, const Vec512 b) { + const Full512 d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_DABC)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + // Shorter encoding than _mm512_permute_ps. + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + // Shorter encoding than _mm512_permute_pd. + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; +} + +// Rotate right 32 bits +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; +} +// Rotate left 32 bits +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; +} + +// Reverse +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices512 { + __m512i raw; +}; + +template +HWY_API Indices512 IndicesFromVec(Full512 /* tag */, Vec512 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full512 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(64 / sizeof(T)))))); +#endif + return Indices512{vec.raw}; +} + +template +HWY_API Indices512 SetTableIndices(const Full512 d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi64(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, + Indices512 idx) { + return Vec512{_mm512_permutexvar_pd(idx.raw, v.raw)}; +} + +// ------------------------------ Reverse + +template +HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512 idx = Load(di, kReverse); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { + alignas(64) constexpr int32_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { + alignas(64) constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +// ------------------------------ Reverse2 + +template +HWY_API Vec512 Reverse2(Full512 d, const Vec512 v) { + const Full512 du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template +HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { + return Shuffle2301(v); +} + +template +HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API Vec512 Reverse4(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse4[32] = { + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, + 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; + const Vec512 idx = Load(di, kReverse4); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { + return Shuffle0123(v); +} + +template +HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { + return Vec512{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +HWY_API Vec512 Reverse4(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int16_t kReverse8[32] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + const RebindToSigned di; + alignas(64) constexpr int32_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { + return Reverse(d, v); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template > +HWY_API V InterleaveUpper(Full512 /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template > +HWY_API Vec512 ZipLower(Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveLower(a, b)); +} +template > +HWY_API Vec512 ZipLower(Full512 /* d */, Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveLower(a, b)); +} + +template > +HWY_API Vec512 ZipUpper(Full512 d, Vec512 a, Vec512 b) { + return BitCast(Full512(), InterleaveUpper(d, a, b)); +} + +// ------------------------------ Concat* halves + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, + const Vec512 lo) { + // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks + // are efficiently loaded from 32-bit regs. + const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); + return Vec512{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; +} +HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); + return Vec512{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; +} +HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, + const Vec512 hi, + const Vec512 lo) { + const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); + return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; +} + +// ------------------------------ ConcatOdd + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[64] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, + 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, + 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, + 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, + 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; + return BitCast(d, + Vec512{_mm512_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec512 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec512 uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint16_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +} + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, hi.raw)}; +} + +template +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +} + +// ------------------------------ ConcatEven + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, + 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, + 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, + 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, + 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; + return BitCast(d, + Vec512{_mm512_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec512 mask = Set(dw, 0x00FF); + const Vec512 uH = And(BitCast(dw, hi), mask); + const Vec512 uL = And(BitCast(dw, lo), mask); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint16_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +} + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, hi.raw)}; +} + +template +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +} + +HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, + Vec512 lo) { + const RebindToUnsigned du; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; +} +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; +} + +template +HWY_API Vec512 DupEven(const Vec512 v) { + return InterleaveLower(Full512(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; +} +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; +} + +template +HWY_API Vec512 DupOdd(const Vec512 v) { + return InterleaveUpper(Full512(), v, v); +} + +// ------------------------------ OddEven + +template +HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { + constexpr size_t s = sizeof(T); + constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; + return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); +} + +// ------------------------------ OddEvenBlocks + +template +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{_mm512_mask_blend_epi64(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// ------------------------------ ReverseBlocks + +template +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 ReverseBlocks(Full512 /* tag */, + Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { + return Vec512{_mm512_shuffle_epi8(bytes.raw, indices.raw)}; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec512 bytes, Vec128 from) { + const Full512 d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 from_full{from.raw}; + const auto from_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); + const auto tbl_full = TableLookupBytes(bytes, from_512); + // Shrink to 256, then 128, then partial. + return Vec128{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; +} +template +HWY_API Vec256 TableLookupBytes(Vec512 bytes, Vec256 from) { + const auto from_512 = ZeroExtendVector(Full512(), from); + return LowerHalf(Full256(), TableLookupBytes(bytes, from_512)); +} + +// Partial table vector +template +HWY_API Vec512 TableLookupBytes(Vec128 bytes, Vec512 from) { + const Full512 d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 bytes_full{bytes.raw}; + const auto bytes_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); + return TableLookupBytes(bytes_512, from); +} +template +HWY_API Vec512 TableLookupBytes(Vec256 bytes, Vec512 from) { + const auto bytes_512 = ZeroExtendVector(Full512(), bytes); + return TableLookupBytes(bytes_512, from); +} + +// Partial both are handled by x86_128/256. + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi32_epi64(v.raw)}; +} + +// Float +HWY_API Vec512 PromoteTo(Full512 /* tag */, + const Vec256 v) { + return Vec512{_mm512_cvtph_ps(v.raw)}; +} + +HWY_API Vec512 PromoteTo(Full512 df32, + const Vec256 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { + return Vec512{_mm512_cvtps_pd(v.raw)}; +} + +HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { + return Vec512{_mm512_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec512 v) { + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const Vec512 i16{ + _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; + const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512(), kLanes); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec512 v) { + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, + 0, 4, 8, 12, 0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512(), kLanes); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Vec256{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Vec256 DemoteTo(Full256 dbf16, + const Vec512 v) { + // TODO(janwas): _mm512_cvtneps_pbh once we have avx512bf16. + const Rebind di32; + const Rebind du32; // for logical shift right + const Rebind du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec512 ReorderDemote2To(Full512 dbf16, + Vec512 a, Vec512 b) { + // TODO(janwas): _mm512_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned du16; + const Repartition du32; + const Vec512 b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec512 ReorderDemote2To(Full512 /*d16*/, + Vec512 a, Vec512 b) { + return Vec512{_mm512_packs_epi32(a.raw, b.raw)}; +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + return Vec256{_mm512_cvtpd_ps(v.raw)}; +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const auto clamped = detail::ClampF64ToI32Max(Full512(), v); + return Vec256{_mm512_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec512 v) { + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +// ------------------------------ Truncations + +HWY_API Vec128 TruncateTo(Simd d, + const Vec512 v) { +#if HWY_TARGET == HWY_AVX3_DL + (void)d; + const Full512 d8; + alignas(16) static constexpr uint8_t k8From64[16] = { + 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56}; + const Vec512 bytes{ + _mm512_permutexvar_epi8(LoadDup128(d8, k8From64).raw, v.raw)}; + return LowerHalf(LowerHalf(LowerHalf(bytes))); +#else + const Full512 d32; + alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return TruncateTo(d, LowerHalf(even)); +#endif +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec512 v) { + const Full512 d16; + alignas(16) static constexpr uint16_t k16From64[8] = { + 0, 4, 8, 12, 16, 20, 24, 28}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec512 v) { + const Full512 d32; + alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 TruncateTo(Simd /* tag */, + const Vec512 v) { +#if HWY_TARGET == HWY_AVX3_DL + const Full512 d8; + alignas(16) static constexpr uint8_t k8From32[16] = { + 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d8, k8From32).raw, v.raw)}; +#else + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(LowerHalf(bytes)); +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec512 v) { + const Full512 d16; + alignas(64) static constexpr uint16_t k16From32[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; + return LowerHalf(bytes); +} + +HWY_API Vec256 TruncateTo(Simd /* tag */, + const Vec512 v) { +#if HWY_TARGET == HWY_AVX3_DL + const Full512 d8; + alignas(64) static constexpr uint8_t k8From16[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + const Vec512 bytes{ + _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; +#else + const Full512 d32; + alignas(16) static constexpr uint32_t k16From32[4] = { + 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k16From32)); + alignas(64) static constexpr uint32_t kIndex32[16] = { + 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(bytes); +} + +// ------------------------------ Convert integer <=> floating point + +HWY_API Vec512 ConvertTo(Full512 /* tag */, + const Vec512 v) { + return Vec512{_mm512_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag */, + const Vec512 v) { + return Vec512{_mm512_cvtepi64_pd(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag*/, + const Vec512 v) { + return Vec512{_mm512_cvtepu32_ps(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag*/, + const Vec512 v) { + return Vec512{_mm512_cvtepu64_pd(v.raw)}; +} + +// Truncates (rounds toward zero). +HWY_API Vec512 ConvertTo(Full512 d, const Vec512 v) { + return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); +} +HWY_API Vec512 ConvertTo(Full512 di, const Vec512 v) { + return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); +} + +HWY_API Vec512 NearestInt(const Vec512 v) { + const Full512 di; + return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec512 AESRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 CLMulLower(Vec512 va, Vec512 vb) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const Full512 d; + const Full128 d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const Full512 d; + const Full128 d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec512 Iota(const Full512 d, const T2 first) { + HWY_ALIGN T lanes[64 / sizeof(T)]; + for (size_t i = 0; i < 64 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} + +} // namespace detail + +template +HWY_API bool AllFalse(const Full512 /* tag */, const Mask512 mask) { + return detail::AllFalse(hwy::SizeTag(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFull; +#endif +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Full512 /* tag */, const Mask512 mask) { + return detail::AllTrue(hwy::SizeTag(), mask); +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API Mask512 LoadMaskBits(const Full512 /* tag */, + const uint8_t* HWY_RESTRICT bits) { + Mask512 mask; + CopyBytes<8 / sizeof(T)>(bits, &mask.raw); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return mask; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(const Full512 /* tag */, const Mask512 mask, + uint8_t* bits) { + const size_t kNumBytes = 8 / sizeof(T); + CopyBytes(&mask.raw, bits); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Full512 /* tag */, const Mask512 mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full512 /* tag */, + const Mask512 mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API size_t FindKnownFirstTrue(const Full512 /* tag */, + const Mask512 mask) { + return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(const Full512 d, const Mask512 mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +// ------------------------------ Compress + +template +HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { + return Vec512{_mm512_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) constexpr uint64_t packed_array[256] = { + // From PrintCompress32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, + 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, + 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, + 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, + 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, + 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, + 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, + 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, + 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, + 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, + 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, + 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, + 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, + 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, + 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, + 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, + 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, + 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, + 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, + 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, + 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, + 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, + 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, + 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, + 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, + 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, + 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, + 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, + 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, + 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, + 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, + 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, + 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, + 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, + 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, + 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, + 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, + 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, + 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, + 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, + 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, + 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, + 0x10765432, 0x17654320, 0x07654321, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const Full512 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// 16-bit may use the 32-bit Compress and must be defined after it. +// +// Ignore IDE redefinition error - this is not actually defined in x86_256 if +// we are including x86_512-inl.h. +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + const Full256 d; + const Rebind du; + const auto vu = BitCast(du, v); // (required for float16_t inputs) + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + const Vec256 cu{_mm256_maskz_compress_epi16(mask.raw, vu.raw)}; +#else + // Promote to i32 (512-bit vector!) so we can use the native Compress. + const auto vw = PromoteTo(Rebind(), vu); + const Mask512 mask32{static_cast<__mmask16>(mask.raw)}; + const auto cu = DemoteTo(du, Compress(vw, mask32)); +#endif // HWY_TARGET == HWY_AVX3_DL + + return BitCast(d, cu); +} + +// Expands to 32-bit, compresses, concatenate demoted halves. +template +HWY_API Vec512 Compress(Vec512 v, const Mask512 mask) { + const Full512 d; + const Rebind du; + const auto vu = BitCast(du, v); // (required for float16_t inputs) + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + const Vec512 cu{_mm512_maskz_compress_epi16(mask.raw, vu.raw)}; +#else + const Repartition dw; + const Half duh; + const auto promoted0 = PromoteTo(dw, LowerHalf(duh, vu)); + const auto promoted1 = PromoteTo(dw, UpperHalf(duh, vu)); + + const uint32_t mask_bits{mask.raw}; + const Mask512 mask0{static_cast<__mmask16>(mask_bits & 0xFFFF)}; + const Mask512 mask1{static_cast<__mmask16>(mask_bits >> 16)}; + const auto compressed0 = Compress(promoted0, mask0); + const auto compressed1 = Compress(promoted1, mask1); + + const auto demoted0 = ZeroExtendVector(du, DemoteTo(duh, compressed0)); + const auto demoted1 = ZeroExtendVector(du, DemoteTo(duh, compressed1)); + + // Concatenate into single vector by shifting upper with writemask. + const size_t num0 = CountTrue(dw, mask0); + const __mmask32 m_upper = ~((1u << num0) - 1); + alignas(64) uint16_t iota[64] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + const Vec512 idx = LoadU(du, iota + 32 - num0); + const Vec512 cu{_mm512_mask_permutexvar_epi16( + demoted0.raw, m_upper, idx.raw, demoted1.raw)}; +#endif // HWY_TARGET == HWY_AVX3_DL + + return BitCast(d, cu); +} + +// ------------------------------ CompressNot + +template +HWY_API Vec512 CompressNot(Vec512 v, const Mask512 mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API Vec512 CompressNot(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) constexpr uint64_t packed_array[256] = { + // From PrintCompressNot32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, + 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, + 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, + 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, + 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, + 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, + 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, + 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, + 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, + 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, + 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, + 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, + 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, + 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, + 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, + 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, + 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, + 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, + 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, + 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, + 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, + 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, + 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, + 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, + 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, + 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, + 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, + 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, + 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, + 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, + 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, + 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, + 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, + 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, + 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, + 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, + 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, + 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, + 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, + 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, + 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, + 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, + 0x76543210, 0x76543201, 0x76543210, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const Full512 d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +HWY_API Vec512 CompressBlocksNot(Vec512 v, + Mask512 mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressBits +template +HWY_API Vec512 CompressBits(Vec512 v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Full512(), bits)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 d, + T* HWY_RESTRICT unaligned) { + const Rebind du; + const auto vu = BitCast(du, v); // (required for float16_t inputs) + + const uint64_t mask_bits{mask.raw}; + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, vu.raw); +#else + const Repartition dw; + const Half duh; + const auto promoted0 = PromoteTo(dw, LowerHalf(duh, vu)); + const auto promoted1 = PromoteTo(dw, UpperHalf(duh, vu)); + + const uint64_t maskL = mask_bits & 0xFFFF; + const uint64_t maskH = mask_bits >> 16; + const Mask512 mask0{static_cast<__mmask16>(maskL)}; + const Mask512 mask1{static_cast<__mmask16>(maskH)}; + const auto compressed0 = Compress(promoted0, mask0); + const auto compressed1 = Compress(promoted1, mask1); + + const Half dh; + const auto demoted0 = BitCast(dh, DemoteTo(duh, compressed0)); + const auto demoted1 = BitCast(dh, DemoteTo(duh, compressed1)); + + // Store 256-bit halves + StoreU(demoted0, dh, unaligned); + StoreU(demoted1, dh, unaligned + PopCount(maskL)); +#endif + + return PopCount(mask_bits); +} + +template +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, + T* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); +// Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +template +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, + T* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); +// Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; +} + +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, + Full512 /* tag */, + float* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); +// Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(float)); +#endif + return count; +} + +HWY_API size_t CompressStore(Vec512 v, Mask512 mask, + Full512 /* tag */, + double* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); +// Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(double)); +#endif + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(Vec512 v, Mask512 m, Full512 d, + T* HWY_RESTRICT unaligned) { + // AVX-512 already does the blending at no extra cost (latency 11, + // rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); +// Workaround for MSAN not marking output as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#endif + return count; + } +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(Vec512 v, const uint8_t* HWY_RESTRICT bits, + Full512 d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ LoadInterleaved4 + +// Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. +namespace detail { + +// Type-safe wrapper. +template <_MM_PERM_ENUM kPerm, typename T> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_i64x2(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// Output: +// 9 6 3 0 (LSB of A) +// a 7 4 1 +// b 8 5 2 +template +HWY_API void LoadTransposedBlocks3(Full512 d, + const T* HWY_RESTRICT unaligned, + Vec512& A, Vec512& B, Vec512& C) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 v3210 = LoadU(d, unaligned + 0 * N); + const Vec512 v7654 = LoadU(d, unaligned + 1 * N); + const Vec512 vba98 = LoadU(d, unaligned + 2 * N); + + const Vec512 v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); + const Vec512 va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); + + A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); + B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); + C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// f e d c +// Output: +// c 8 4 0 (LSB of A) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +template +HWY_API void LoadTransposedBlocks4(Full512 d, + const T* HWY_RESTRICT unaligned, + Vec512& A, Vec512& B, Vec512& C, + Vec512& D) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 v3210 = LoadU(d, unaligned + 0 * N); + const Vec512 v7654 = LoadU(d, unaligned + 1 * N); + const Vec512 vba98 = LoadU(d, unaligned + 2 * N); + const Vec512 vfedc = LoadU(d, unaligned + 3 * N); + + const Vec512 v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); + const Vec512 vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); + const Vec512 v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); + const Vec512 vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); + A = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); + B = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); + C = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); + D = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2 + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 6 4 2 0 (LSB of i) +// 7 5 3 1 +// Output: +// 3 2 1 0 +// 7 6 5 4 +template +HWY_API void StoreTransposedBlocks2(const Vec512 i, const Vec512 j, + const Full512 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const auto j1_i1_j0_i0 = + detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); + const auto j3_i3_j2_i2 = + detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); + StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); + StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 9 6 3 0 (LSB of i) +// a 7 4 1 +// b 8 5 2 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +template +HWY_API void StoreTransposedBlocks3(const Vec512 i, const Vec512 j, + const Vec512 k, Full512 d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); + const Vec512 i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); + const Vec512 j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); + + const Vec512 out0 = // i1 k0 j0 i0 + detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); + const Vec512 out1 = // j2 i2 k1 j1 + detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); + const Vec512 out2 = // k3 j3 i3 k2 + detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); + + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// c 8 4 0 (LSB of i) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +// f e d c +template +HWY_API void StoreTransposedBlocks4(const Vec512 i, const Vec512 j, + const Vec512 k, const Vec512 l, + Full512 d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const Vec512 j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const Vec512 l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); + const Vec512 j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const Vec512 l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); + const Vec512 out0 = + detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); + const Vec512 out1 = + detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); + const Vec512 out2 = + detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); + const Vec512 out3 = + detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ MulEven/Odd (Shuffle2301, InterleaveLower) + +HWY_INLINE Vec512 MulEven(const Vec512 a, + const Vec512 b) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveLower(mulL, mulH); +} + +HWY_INLINE Vec512 MulOdd(const Vec512 a, + const Vec512 b) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need bits [95:64] (= upper half of input) + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Same as above, but we're using the odd results (upper 64 bits per block). + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveUpper(du64, mulL, mulH); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec512 ReorderWidenMulAccumulate(Full512 df32, + Vec512 a, + Vec512 b, + const Vec512 sum0, + Vec512& sum1) { + // TODO(janwas): _mm512_dpbf16_ps when available + const Repartition du16; + const RebindToUnsigned du32; + const Vec512 zero = Zero(du16); + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo. + const Vec512 a0 = ZipLower(du32, zero, BitCast(du16, a)); + const Vec512 a1 = ZipUpper(du32, zero, BitCast(du16, a)); + const Vec512 b0 = ZipLower(du32, zero, BitCast(du16, b)); + const Vec512 b1 = ZipUpper(du32, zero, BitCast(du16, b)); + sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); + return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); +} + +HWY_API Vec512 ReorderWidenMulAccumulate(Full512 /*d32*/, + Vec512 a, + Vec512 b, + const Vec512 sum0, + Vec512& /*sum1*/) { + return sum0 + Vec512{_mm512_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Reductions + +// Returns the sum in each lane. +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_epi32(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_epi64(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, static_cast(_mm512_reduce_add_epi32(v.raw))); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, static_cast(_mm512_reduce_add_epi64(v.raw))); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_ps(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_add_pd(v.raw)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(d32, even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(d32, even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +// Returns the minimum in each lane. +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epi32(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epi64(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epu32(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_epu64(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_ps(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_min_pd(v.raw)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(d32, Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(d32, Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// Returns the maximum in each lane. +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epi32(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epi64(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epu32(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_epu64(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_ps(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + return Set(d, _mm512_reduce_max_pd(v.raw)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(d32, Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { + const RepartitionToWide d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(d32, Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/media/highway/src/hwy/per_target.cc b/media/highway/src/hwy/per_target.cc new file mode 100644 index 000000000..4cbf15232 --- /dev/null +++ b/media/highway/src/hwy/per_target.cc @@ -0,0 +1,50 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/per_target.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/per_target.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +// On SVE, Lanes rounds down to a power of two, but we want to know the actual +// size here. Otherwise, hypothetical SVE with 48 bytes would round down to 32 +// and we'd enable HWY_SVE_256, and then fail reverse_test because Reverse on +// HWY_SVE_256 requires the actual vector to be a power of two. +#if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE_256 +size_t GetVectorBytes() { return detail::AllHardwareLanes(hwy::SizeTag<1>()); } +#else +size_t GetVectorBytes() { return Lanes(ScalableTag()); } +#endif +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(GetVectorBytes); // Local function. +} // namespace + +size_t VectorBytes() { return HWY_DYNAMIC_DISPATCH(GetVectorBytes)(); } + +} // namespace hwy +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/per_target.h b/media/highway/src/hwy/per_target.h new file mode 100644 index 000000000..da85de322 --- /dev/null +++ b/media/highway/src/hwy/per_target.h @@ -0,0 +1,37 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PER_TARGET_H_ +#define HIGHWAY_HWY_PER_TARGET_H_ + +#include + +// Per-target functions. + +namespace hwy { + +// Returns size in bytes of a vector, i.e. `Lanes(ScalableTag())`. +// +// Do not cache the result, which may change after calling DisableTargets, or +// if software requests a different vector size (e.g. when entering/exiting SME +// streaming mode). Instead call this right before the code that depends on the +// result, without any DisableTargets or SME transition in-between. Note that +// this involves an indirect call, so prefer not to call this frequently nor +// unnecessarily. +size_t VectorBytes(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_PER_TARGET_H_ diff --git a/media/highway/src/hwy/print-inl.h b/media/highway/src/hwy/print-inl.h new file mode 100644 index 000000000..d256657eb --- /dev/null +++ b/media/highway/src/hwy/print-inl.h @@ -0,0 +1,55 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Print() function + +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" +#include "hwy/print.h" + +// Per-target include guard +#if defined(HIGHWAY_HWY_PRINT_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_PRINT_INL_H_ +#undef HIGHWAY_HWY_PRINT_INL_H_ +#else +#define HIGHWAY_HWY_PRINT_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Prints lanes around `lane`, in memory order. +template > +void Print(const D d, const char* caption, VecArg v, size_t lane_u = 0, + size_t max_lanes = 7) { + const size_t N = Lanes(d); + using T = TFromD; + auto lanes = AllocateAligned(N); + Store(v, d, lanes.get()); + + const auto info = hwy::detail::MakeTypeInfo(); + hwy::detail::PrintArray(info, caption, lanes.get(), N, lane_u, max_lanes); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/media/highway/src/hwy/print.cc b/media/highway/src/hwy/print.cc new file mode 100644 index 000000000..0b52cde1b --- /dev/null +++ b/media/highway/src/hwy/print.cc @@ -0,0 +1,110 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/print.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include +#include + +#include "hwy/base.h" + +namespace hwy { +namespace detail { + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100) { + const char prefix = info.is_float ? 'f' : (info.is_signed ? 'i' : 'u'); + // Omit the xN suffix for scalars. + if (N == 1) { + // NOLINTNEXTLINE + snprintf(string100, 64, "%c%d", prefix, + static_cast(info.sizeof_t * 8)); + } else { + // NOLINTNEXTLINE + snprintf(string100, 64, "%c%dx%d", prefix, + static_cast(info.sizeof_t * 8), static_cast(N)); + } +} + +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100) { + if (info.sizeof_t == 1) { + uint8_t byte; + CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. + snprintf(string100, 100, "0x%02X", byte); // NOLINT + } else if (info.sizeof_t == 2) { + uint16_t bits; + CopyBytes<2>(ptr, &bits); + snprintf(string100, 100, "0x%04X", bits); // NOLINT + } else if (info.sizeof_t == 4) { + if (info.is_float) { + float value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%g", static_cast(value)); // NOLINT + } else if (info.is_signed) { + int32_t value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%d", value); // NOLINT + } else { + uint32_t value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%u", value); // NOLINT + } + } else { + HWY_ASSERT(info.sizeof_t == 8); + if (info.is_float) { + double value; + CopyBytes<8>(ptr, &value); + snprintf(string100, 100, "%g", value); // NOLINT + } else if (info.is_signed) { + int64_t value; + CopyBytes<8>(ptr, &value); + snprintf(string100, 100, "%" PRIi64 "", value); // NOLINT + } else { + uint64_t value; + CopyBytes<8>(ptr, &value); + snprintf(string100, 100, "%" PRIu64 "", value); // NOLINT + } + } +} + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, size_t lane_u, + size_t max_lanes) { + const uint8_t* array_bytes = reinterpret_cast(array_void); + + char type_name[100]; + TypeName(info, N, type_name); + + const intptr_t lane = intptr_t(lane_u); + const size_t begin = static_cast(HWY_MAX(0, lane - 2)); + const size_t end = HWY_MIN(begin + max_lanes, N); + fprintf(stderr, "%s %s [%" PRIu64 "+ ->]:\n ", type_name, caption, + static_cast(begin)); + for (size_t i = begin; i < end; ++i) { + const void* ptr = array_bytes + i * info.sizeof_t; + char str[100]; + ToString(info, ptr, str); + fprintf(stderr, "%s,", str); + } + if (begin >= end) fprintf(stderr, "(out of bounds)"); + fprintf(stderr, "\n"); +} + +} // namespace detail +} // namespace hwy diff --git a/media/highway/src/hwy/print.h b/media/highway/src/hwy/print.h new file mode 100644 index 000000000..13792866a --- /dev/null +++ b/media/highway/src/hwy/print.h @@ -0,0 +1,73 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_PRINT_H_ +#define HWY_PRINT_H_ + +// Helpers for printing vector lanes. + +#include +#include + +#include "hwy/base.h" +#include "hwy/highway_export.h" + +namespace hwy { + +namespace detail { + +// For implementing value comparisons etc. as type-erased functions to reduce +// template bloat. +struct TypeInfo { + size_t sizeof_t; + bool is_float; + bool is_signed; +}; + +template +HWY_INLINE TypeInfo MakeTypeInfo() { + TypeInfo info; + info.sizeof_t = sizeof(T); + info.is_float = IsFloat(); + info.is_signed = IsSigned(); + return info; +} + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100); +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100); + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, + size_t lane_u = 0, size_t max_lanes = 7); + +} // namespace detail + +template +HWY_NOINLINE void PrintValue(T value) { + char str[100]; + detail::ToString(hwy::detail::MakeTypeInfo(), &value, str); + fprintf(stderr, "%s,", str); +} + +template +HWY_NOINLINE void PrintArray(const T* value, size_t count) { + detail::PrintArray(hwy::detail::MakeTypeInfo(), "", value, count, 0, + count); +} + +} // namespace hwy + +#endif // HWY_PRINT_H_ diff --git a/media/highway/src/hwy/targets.cc b/media/highway/src/hwy/targets.cc new file mode 100644 index 000000000..2fde4db9a --- /dev/null +++ b/media/highway/src/hwy/targets.cc @@ -0,0 +1,434 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include // PRIx64 +#include +#include +#include +#include + +#include + +#include "hwy/per_target.h" // VectorBytes + +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif + +#include // abort / exit + +#if HWY_ARCH_X86 +#include +#if HWY_COMPILER_MSVC +#include +#else // !HWY_COMPILER_MSVC +#include +#endif // HWY_COMPILER_MSVC + +#elif HWY_ARCH_ARM && HWY_OS_LINUX +#include +#include +#endif // HWY_ARCH_* + +namespace hwy { +namespace { + +#if HWY_ARCH_X86 + +HWY_INLINE bool IsBitSet(const uint32_t reg, const int index) { + return (reg & (1U << index)) != 0; +} + +// Calls CPUID instruction with eax=level and ecx=count and returns the result +// in abcd array where abcd = {eax, ebx, ecx, edx} (hence the name abcd). +HWY_INLINE void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else // HWY_COMPILER_MSVC + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif // HWY_COMPILER_MSVC +} + +// Returns the lower 32 bits of extended control register 0. +// Requires CPU support for "OSXSAVE" (see below). +uint32_t ReadXCR0() { +#if HWY_COMPILER_MSVC + return static_cast(_xgetbv(0)); +#else // HWY_COMPILER_MSVC + uint32_t xcr0, xcr0_high; + const uint32_t index = 0; + asm volatile(".byte 0x0F, 0x01, 0xD0" + : "=a"(xcr0), "=d"(xcr0_high) + : "c"(index)); + return xcr0; +#endif // HWY_COMPILER_MSVC +} + +#endif // HWY_ARCH_X86 + +// When running tests, this value can be set to the mocked supported targets +// mask. Only written to from a single thread before the test starts. +int64_t supported_targets_for_test_ = 0; + +// Mask of targets disabled at runtime with DisableTargets. +int64_t supported_mask_ = LimitsMax(); + +#if HWY_ARCH_X86 +// Arbritrary bit indices indicating which instruction set extensions are +// supported. Use enum to ensure values are distinct. +enum class FeatureIndex : uint32_t { + kSSE = 0, + kSSE2, + kSSE3, + kSSSE3, + + kSSE41, + kSSE42, + kCLMUL, + kAES, + + kAVX, + kAVX2, + kF16C, + kFMA, + kLZCNT, + kBMI, + kBMI2, + + kAVX512F, + kAVX512VL, + kAVX512DQ, + kAVX512BW, + + kVNNI, + kVPCLMULQDQ, + kVBMI, + kVBMI2, + kVAES, + kPOPCNTDQ, + kBITALG, + + kSentinel +}; +static_assert(static_cast(FeatureIndex::kSentinel) < 64, + "Too many bits for u64"); + +HWY_INLINE constexpr uint64_t Bit(FeatureIndex index) { + return 1ull << static_cast(index); +} + +constexpr uint64_t kGroupSSSE3 = + Bit(FeatureIndex::kSSE) | Bit(FeatureIndex::kSSE2) | + Bit(FeatureIndex::kSSE3) | Bit(FeatureIndex::kSSSE3); + +constexpr uint64_t kGroupSSE4 = + Bit(FeatureIndex::kSSE41) | Bit(FeatureIndex::kSSE42) | + Bit(FeatureIndex::kCLMUL) | Bit(FeatureIndex::kAES) | kGroupSSSE3; + +// We normally assume BMI/BMI2/FMA are available if AVX2 is. This allows us to +// use BZHI and (compiler-generated) MULX. However, VirtualBox lacks them +// [https://www.virtualbox.org/ticket/15471]. Thus we provide the option of +// avoiding using and requiring these so AVX2 can still be used. +#ifdef HWY_DISABLE_BMI2_FMA +constexpr uint64_t kGroupBMI2_FMA = 0; +#else +constexpr uint64_t kGroupBMI2_FMA = Bit(FeatureIndex::kBMI) | + Bit(FeatureIndex::kBMI2) | + Bit(FeatureIndex::kFMA); +#endif + +#ifdef HWY_DISABLE_F16C +constexpr uint64_t kGroupF16C = 0; +#else +constexpr uint64_t kGroupF16C = Bit(FeatureIndex::kF16C); +#endif + +constexpr uint64_t kGroupAVX2 = + Bit(FeatureIndex::kAVX) | Bit(FeatureIndex::kAVX2) | + Bit(FeatureIndex::kLZCNT) | kGroupBMI2_FMA | kGroupF16C | kGroupSSE4; + +constexpr uint64_t kGroupAVX3 = + Bit(FeatureIndex::kAVX512F) | Bit(FeatureIndex::kAVX512VL) | + Bit(FeatureIndex::kAVX512DQ) | Bit(FeatureIndex::kAVX512BW) | kGroupAVX2; + +constexpr uint64_t kGroupAVX3_DL = + Bit(FeatureIndex::kVNNI) | Bit(FeatureIndex::kVPCLMULQDQ) | + Bit(FeatureIndex::kVBMI) | Bit(FeatureIndex::kVBMI2) | + Bit(FeatureIndex::kVAES) | Bit(FeatureIndex::kPOPCNTDQ) | + Bit(FeatureIndex::kBITALG) | kGroupAVX3; + +#endif // HWY_ARCH_X86 + +// Returns targets supported by the CPU, independently of DisableTargets. +// Factored out of SupportedTargets to make its structure more obvious. Note +// that x86 CPUID may take several hundred cycles. +int64_t DetectTargets() { + // Apps will use only one of these (the default is EMU128), but compile flags + // for this TU may differ from that of the app, so allow both. + int64_t bits = HWY_SCALAR | HWY_EMU128; + +#if HWY_ARCH_X86 + bool has_osxsave = false; + { // ensures we do not accidentally use flags outside this block + uint64_t flags = 0; + uint32_t abcd[4]; + + Cpuid(0, 0, abcd); + const uint32_t max_level = abcd[0]; + + // Standard feature flags + Cpuid(1, 0, abcd); + flags |= IsBitSet(abcd[3], 25) ? Bit(FeatureIndex::kSSE) : 0; + flags |= IsBitSet(abcd[3], 26) ? Bit(FeatureIndex::kSSE2) : 0; + flags |= IsBitSet(abcd[2], 0) ? Bit(FeatureIndex::kSSE3) : 0; + flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kCLMUL) : 0; + flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kSSSE3) : 0; + flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kFMA) : 0; + flags |= IsBitSet(abcd[2], 19) ? Bit(FeatureIndex::kSSE41) : 0; + flags |= IsBitSet(abcd[2], 20) ? Bit(FeatureIndex::kSSE42) : 0; + flags |= IsBitSet(abcd[2], 25) ? Bit(FeatureIndex::kAES) : 0; + flags |= IsBitSet(abcd[2], 28) ? Bit(FeatureIndex::kAVX) : 0; + flags |= IsBitSet(abcd[2], 29) ? Bit(FeatureIndex::kF16C) : 0; + has_osxsave = IsBitSet(abcd[2], 27); + + // Extended feature flags + Cpuid(0x80000001U, 0, abcd); + flags |= IsBitSet(abcd[2], 5) ? Bit(FeatureIndex::kLZCNT) : 0; + + // Extended features + if (max_level >= 7) { + Cpuid(7, 0, abcd); + flags |= IsBitSet(abcd[1], 3) ? Bit(FeatureIndex::kBMI) : 0; + flags |= IsBitSet(abcd[1], 5) ? Bit(FeatureIndex::kAVX2) : 0; + flags |= IsBitSet(abcd[1], 8) ? Bit(FeatureIndex::kBMI2) : 0; + + flags |= IsBitSet(abcd[1], 16) ? Bit(FeatureIndex::kAVX512F) : 0; + flags |= IsBitSet(abcd[1], 17) ? Bit(FeatureIndex::kAVX512DQ) : 0; + flags |= IsBitSet(abcd[1], 30) ? Bit(FeatureIndex::kAVX512BW) : 0; + flags |= IsBitSet(abcd[1], 31) ? Bit(FeatureIndex::kAVX512VL) : 0; + + flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kVBMI) : 0; + flags |= IsBitSet(abcd[2], 6) ? Bit(FeatureIndex::kVBMI2) : 0; + flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kVAES) : 0; + flags |= IsBitSet(abcd[2], 10) ? Bit(FeatureIndex::kVPCLMULQDQ) : 0; + flags |= IsBitSet(abcd[2], 11) ? Bit(FeatureIndex::kVNNI) : 0; + flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kBITALG) : 0; + flags |= IsBitSet(abcd[2], 14) ? Bit(FeatureIndex::kPOPCNTDQ) : 0; + } + + // Set target bit(s) if all their group's flags are all set. + if ((flags & kGroupAVX3_DL) == kGroupAVX3_DL) { + bits |= HWY_AVX3_DL; + } + if ((flags & kGroupAVX3) == kGroupAVX3) { + bits |= HWY_AVX3; + } + if ((flags & kGroupAVX2) == kGroupAVX2) { + bits |= HWY_AVX2; + } + if ((flags & kGroupSSE4) == kGroupSSE4) { + bits |= HWY_SSE4; + } + if ((flags & kGroupSSSE3) == kGroupSSSE3) { + bits |= HWY_SSSE3; + } + } + + // Clear bits if the OS does not support XSAVE - otherwise, registers + // are not preserved across context switches. + if (has_osxsave) { + const uint32_t xcr0 = ReadXCR0(); + const int64_t min_avx3 = HWY_AVX3 | HWY_AVX3_DL; + const int64_t min_avx2 = HWY_AVX2 | min_avx3; + // XMM + if (!IsBitSet(xcr0, 1)) { + bits &= ~(HWY_SSSE3 | HWY_SSE4 | min_avx2); + } + // YMM + if (!IsBitSet(xcr0, 2)) { + bits &= ~min_avx2; + } + // opmask, ZMM lo/hi + if (!IsBitSet(xcr0, 5) || !IsBitSet(xcr0, 6) || !IsBitSet(xcr0, 7)) { + bits &= ~min_avx3; + } + } + + if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { + fprintf(stderr, + "WARNING: CPU supports %" PRIx64 " but software requires %" PRIx64 + "\n", + bits, static_cast(HWY_ENABLED_BASELINE)); + } + +#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH + using CapBits = unsigned long; // NOLINT + const CapBits hw = getauxval(AT_HWCAP); + (void)hw; + +#if HWY_ARCH_ARM_A64 + +#if defined(HWCAP_AES) + // aarch64 always has NEON and VFPv4, but not necessarily AES, which we + // require and thus must still check for. + if (hw & HWCAP_AES) { + bits |= HWY_NEON; + } +#endif // HWCAP_AES + +#if defined(HWCAP_SVE) + if (hw & HWCAP_SVE) { + bits |= HWY_SVE; + } +#endif + +#if defined(HWCAP2_SVE2) && defined(HWCAP2_SVEAES) + const CapBits hw2 = getauxval(AT_HWCAP2); + if ((hw2 & HWCAP2_SVE2) && (hw2 & HWCAP2_SVEAES)) { + bits |= HWY_SVE2; + } +#endif + +#else // HWY_ARCH_ARM_A64 + +// Some old auxv.h / hwcap.h do not define these. If not, treat as unsupported. +// Note that AES has a different HWCAP bit compared to aarch64. +#if defined(HWCAP_NEON) && defined(HWCAP_VFPv4) + if ((hw & HWCAP_NEON) && (hw & HWCAP_VFPv4)) { + bits |= HWY_NEON; + } +#endif + +#endif // HWY_ARCH_ARM_A64 + if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { + fprintf(stderr, + "WARNING: CPU supports %" PRIx64 " but software requires %" PRIx64 + "\n", + bits, static_cast(HWY_ENABLED_BASELINE)); + } +#else // HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH + // TODO(janwas): detect for other platforms and check for baseline + // This file is typically compiled without HWY_IS_TEST, but targets_test has + // it set, and will expect all of its HWY_TARGETS (= all attainable) to be + // supported. + bits |= HWY_ENABLED_BASELINE; +#endif // HWY_ARCH_X86 + + return bits; +} + +} // namespace + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[2000]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); + +// If compiled with any sanitizer, they can also print a stack trace. +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + __sanitizer_print_stack_trace(); +#endif // HWY_IS_* + fflush(stderr); + +// Now terminate the program: +#if HWY_ARCH_RVV + exit(1); // trap/abort just freeze Spike. +#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC + // Facilitates breaking into a debugger, but don't use this in non-debug + // builds because it looks like "illegal instruction", which is misleading. + __builtin_trap(); +#else + abort(); // Compile error without this due to HWY_NORETURN. +#endif +} + +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets) { + supported_mask_ = static_cast(~disabled_targets); + // This will take effect on the next call to SupportedTargets, which is + // called right before GetChosenTarget::Update. However, calling Update here + // would make it appear that HWY_DYNAMIC_DISPATCH was called, which we want + // to check in tests. We instead de-initialize such that the next + // HWY_DYNAMIC_DISPATCH calls GetChosenTarget::Update via FunctionCache. + GetChosenTarget().DeInit(); +} + +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets) { + supported_targets_for_test_ = targets; + GetChosenTarget().DeInit(); // see comment above +} + +HWY_DLLEXPORT int64_t SupportedTargets() { + int64_t targets = supported_targets_for_test_; + if (HWY_LIKELY(targets == 0)) { + // Mock not active. Re-detect instead of caching just in case we're on a + // heterogeneous ISA (also requires some app support to pin threads). This + // is only reached on the first HWY_DYNAMIC_DISPATCH or after each call to + // DisableTargets or SetSupportedTargetsForTest. + targets = DetectTargets(); + + // VectorBytes invokes HWY_DYNAMIC_DISPATCH. To prevent infinite recursion, + // first set up ChosenTarget. No need to Update() again afterwards with the + // final targets - that will be done by a caller of this function. + GetChosenTarget().Update(targets); + + // Now that we can call VectorBytes, check for targets with specific sizes. + if (HWY_ARCH_ARM_A64) { + const size_t vec_bytes = VectorBytes(); // uncached, see declaration + if ((targets & HWY_SVE) && vec_bytes == 32) { + targets = static_cast(targets | HWY_SVE_256); + } else { + targets = static_cast(targets & ~HWY_SVE_256); + } + if ((targets & HWY_SVE2) && vec_bytes == 16) { + targets = static_cast(targets | HWY_SVE2_128); + } else { + targets = static_cast(targets & ~HWY_SVE2_128); + } + } // HWY_ARCH_ARM_A64 + } + + targets &= supported_mask_; + return targets == 0 ? HWY_STATIC_TARGET : targets; +} + +HWY_DLLEXPORT ChosenTarget& GetChosenTarget() { + static ChosenTarget chosen_target; + return chosen_target; +} + +} // namespace hwy diff --git a/media/highway/src/hwy/targets.h b/media/highway/src/hwy/targets.h new file mode 100644 index 000000000..2d9afbff4 --- /dev/null +++ b/media/highway/src/hwy/targets.h @@ -0,0 +1,318 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TARGETS_H_ +#define HIGHWAY_HWY_TARGETS_H_ + +#include + +// For SIMD module implementations and their callers. Defines which targets to +// generate and call. + +#include "hwy/base.h" +#include "hwy/detect_targets.h" +#include "hwy/highway_export.h" + +#if !HWY_ARCH_RVV +#include +#endif + +namespace hwy { + +// Returns bitfield of enabled targets that are supported on this CPU; there is +// always at least one such target, hence the return value is never 0. The +// targets returned may change after calling DisableTargets. This function is +// always defined, but the HWY_SUPPORTED_TARGETS wrapper may allow eliding +// calls to it if there is only a single target enabled. +HWY_DLLEXPORT int64_t SupportedTargets(); + +// Evaluates to a function call, or literal if there is a single target. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) == 0 +#define HWY_SUPPORTED_TARGETS HWY_TARGETS +#else +#define HWY_SUPPORTED_TARGETS hwy::SupportedTargets() +#endif + +// Subsequent SupportedTargets will not return targets whose bit(s) are set in +// `disabled_targets`. Exception: if SupportedTargets would return 0, it will +// instead return HWY_STATIC_TARGET (there must always be one target to call). +// +// This function is useful for disabling targets known to be buggy, or if the +// best available target is undesirable (perhaps due to throttling or memory +// bandwidth limitations). Use SetSupportedTargetsForTest instead of this +// function for iteratively enabling specific targets for testing. +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets); + +// Subsequent SupportedTargets will return the given set of targets, except +// those disabled via DisableTargets. Call with a mask of 0 to disable the mock +// and return to the normal SupportedTargets behavior. Used to run tests for +// all targets. +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets); + +// Return the list of targets in HWY_TARGETS supported by the CPU as a list of +// individual HWY_* target macros such as HWY_SCALAR or HWY_NEON. This list +// is affected by the current SetSupportedTargetsForTest() mock if any. +HWY_INLINE std::vector SupportedAndGeneratedTargets() { + std::vector ret; + for (int64_t targets = SupportedTargets() & HWY_TARGETS; targets != 0; + targets = targets & (targets - 1)) { + int64_t current_target = targets & ~(targets - 1); + ret.push_back(current_target); + } + return ret; +} + +static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { + switch (target) { +#if HWY_ARCH_X86 + case HWY_SSSE3: + return "SSSE3"; + case HWY_SSE4: + return "SSE4"; + case HWY_AVX2: + return "AVX2"; + case HWY_AVX3: + return "AVX3"; + case HWY_AVX3_DL: + return "AVX3_DL"; +#endif + +#if HWY_ARCH_ARM + case HWY_SVE2_128: + return "SVE2_128"; + case HWY_SVE_256: + return "SVE_256"; + case HWY_SVE2: + return "SVE2"; + case HWY_SVE: + return "SVE"; + case HWY_NEON: + return "NEON"; +#endif + +#if HWY_ARCH_PPC + case HWY_PPC8: + return "PPC8"; +#endif + +#if HWY_ARCH_WASM + case HWY_WASM: + return "WASM"; + case HWY_WASM_EMU256: + return "WASM_EMU256"; +#endif + +#if HWY_ARCH_RVV + case HWY_RVV: + return "RVV"; +#endif + + case HWY_EMU128: + return "EMU128"; + case HWY_SCALAR: + return "SCALAR"; + + default: + return "Unknown"; // must satisfy gtest IsValidParamName() + } +} + +// The maximum number of dynamic targets on any architecture is defined by +// HWY_MAX_DYNAMIC_TARGETS and depends on the arch. + +// For the ChosenTarget mask and index we use a different bit arrangement than +// in the HWY_TARGETS mask. Only the targets involved in the current +// architecture are used in this mask, and therefore only the least significant +// (HWY_MAX_DYNAMIC_TARGETS + 2) bits of the int64_t mask are used. The least +// significant bit is set when the mask is not initialized, the next +// HWY_MAX_DYNAMIC_TARGETS more significant bits are a range of bits from the +// HWY_TARGETS or SupportedTargets() mask for the given architecture shifted to +// that position and the next more significant bit is used for HWY_SCALAR (if +// HWY_COMPILE_ONLY_SCALAR is defined) or HWY_EMU128. Because of this we need to +// define equivalent values for HWY_TARGETS in this representation. +// This mask representation allows to use ctz() on this mask and obtain a small +// number that's used as an index of the table for dynamic dispatch. In this +// way the first entry is used when the mask is uninitialized, the following +// HWY_MAX_DYNAMIC_TARGETS are for dynamic dispatch and the last one is for +// scalar. + +// The HWY_SCALAR/HWY_EMU128 bit in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_SCALAR (1LL << (HWY_MAX_DYNAMIC_TARGETS + 1)) + +// Converts from a HWY_TARGETS mask to a ChosenTarget mask format for the +// current architecture. +#define HWY_CHOSEN_TARGET_SHIFT(X) \ + ((((X) >> (HWY_HIGHEST_TARGET_BIT + 1 - HWY_MAX_DYNAMIC_TARGETS)) & \ + ((1LL << HWY_MAX_DYNAMIC_TARGETS) - 1)) \ + << 1) + +// The HWY_TARGETS mask in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_TARGETS \ + (HWY_CHOSEN_TARGET_SHIFT(HWY_TARGETS) | HWY_CHOSEN_TARGET_MASK_SCALAR | 1LL) + +#if HWY_ARCH_X86 +// Maximum number of dynamic targets, changing this value is an ABI incompatible +// change +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_X86 +// These must match the order in which the HWY_TARGETS are defined +// starting by the least significant (HWY_HIGHEST_TARGET_BIT + 1 - +// HWY_MAX_DYNAMIC_TARGETS) bit. This list must contain exactly +// HWY_MAX_DYNAMIC_TARGETS elements and does not include SCALAR. The first entry +// corresponds to the best target. Don't include a "," at the end of the list. +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_AVX3_DL(func_name), /* AVX3_DL */ \ + HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \ + HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \ + nullptr, /* AVX */ \ + HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \ + HWY_CHOOSE_SSSE3(func_name), /* SSSE3 */ \ + nullptr , /* reserved - SSE3? */ \ + nullptr /* reserved - SSE2? */ + +#elif HWY_ARCH_ARM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_ARM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2_128(func_name), /* SVE2 128-bit */ \ + HWY_CHOOSE_SVE_256(func_name), /* SVE 256-bit */ \ + HWY_CHOOSE_SVE2(func_name), /* SVE2 */ \ + HWY_CHOOSE_SVE(func_name), /* SVE */ \ + HWY_CHOOSE_NEON(func_name), /* NEON */ \ + nullptr /* reserved - Helium? */ + +#elif HWY_ARCH_RVV +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_RVV +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_RVV(func_name), /* RVV */ \ + nullptr /* reserved */ + +#elif HWY_ARCH_PPC +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_PPC +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ + nullptr, /* reserved (VSX or AltiVec) */ \ + nullptr /* reserved (VSX or AltiVec) */ + +#elif HWY_ARCH_WASM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_WASM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_WASM_EMU256(func_name), /* WASM_EMU256 */ \ + HWY_CHOOSE_WASM(func_name), /* WASM */ \ + nullptr /* reserved */ + +#else +// Unknown architecture, will use HWY_SCALAR without dynamic dispatch, though +// still creating single-entry tables in HWY_EXPORT to ensure portability. +#define HWY_MAX_DYNAMIC_TARGETS 1 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_SCALAR +#endif + +// Bitfield of supported and enabled targets. The format differs from that of +// HWY_TARGETS; the lowest bit governs the first function pointer (which is +// special in that it calls FunctionCache, then Update, then dispatches to the +// actual implementation) in the tables created by HWY_EXPORT. Monostate (see +// GetChosenTarget), thread-safe except on RVV. +struct ChosenTarget { + public: + // Reset bits according to `targets` (typically the return value of + // SupportedTargets()). Postcondition: IsInitialized() == true. + void Update(int64_t targets) { + // These are `targets` shifted downwards, see above. Also include SCALAR + // (corresponds to the last entry in the function table) as fallback. + StoreMask(HWY_CHOSEN_TARGET_SHIFT(targets) | HWY_CHOSEN_TARGET_MASK_SCALAR); + } + + // Reset to the uninitialized state, so that FunctionCache will call Update + // during the next HWY_DYNAMIC_DISPATCH, and IsInitialized returns false. + void DeInit() { StoreMask(1); } + + // Whether Update was called. This indicates whether any HWY_DYNAMIC_DISPATCH + // function was called, which we check in tests. + bool IsInitialized() const { return LoadMask() != 1; } + + // Return the index in the dynamic dispatch table to be used by the current + // CPU. Note that this method must be in the header file so it uses the value + // of HWY_CHOSEN_TARGET_MASK_TARGETS defined in the translation unit that + // calls it, which may be different from others. This means we only enable + // those targets that were actually compiled in this module. + size_t HWY_INLINE GetIndex() const { + return hwy::Num0BitsBelowLS1Bit_Nonzero64( + static_cast(LoadMask() & HWY_CHOSEN_TARGET_MASK_TARGETS)); + } + + private: + // TODO(janwas): remove #if once is available +#if HWY_ARCH_RVV + int64_t LoadMask() const { return mask_; } + void StoreMask(int64_t mask) { mask_ = mask; } + + int64_t mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#else + int64_t LoadMask() const { return mask_.load(); } + void StoreMask(int64_t mask) { mask_.store(mask); } + + std::atomic mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#endif // HWY_ARCH_RVV +}; + +// For internal use (e.g. by FunctionCache and DisableTargets). +HWY_DLLEXPORT ChosenTarget& GetChosenTarget(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_TARGETS_H_ diff --git a/media/highway/src/hwy/targets_test.cc b/media/highway/src/hwy/targets_test.cc new file mode 100644 index 000000000..e58a6fa46 --- /dev/null +++ b/media/highway/src/hwy/targets_test.cc @@ -0,0 +1,135 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#include "hwy/tests/test_util-inl.h" + +namespace fake { + +#define DECLARE_FUNCTION(TGT) \ + namespace N_##TGT { \ + /* Function argument is just to ensure/demonstrate they are possible. */ \ + int64_t FakeFunction(int) { return HWY_##TGT; } \ + } + +DECLARE_FUNCTION(AVX3_DL) +DECLARE_FUNCTION(AVX3) +DECLARE_FUNCTION(AVX2) +DECLARE_FUNCTION(SSE4) +DECLARE_FUNCTION(SSSE3) +DECLARE_FUNCTION(NEON) +DECLARE_FUNCTION(SVE) +DECLARE_FUNCTION(SVE2) +DECLARE_FUNCTION(SVE_256) +DECLARE_FUNCTION(SVE2_128) +DECLARE_FUNCTION(PPC8) +DECLARE_FUNCTION(WASM) +DECLARE_FUNCTION(RVV) +DECLARE_FUNCTION(SCALAR) +DECLARE_FUNCTION(EMU128) + +HWY_EXPORT(FakeFunction); + +void CallFunctionForTarget(int64_t target, int line) { + if ((HWY_TARGETS & target) == 0) return; + hwy::SetSupportedTargetsForTest(target); + + // Call Update() first to make &HWY_DYNAMIC_DISPATCH() return + // the pointer to the already cached function. + hwy::GetChosenTarget().Update(hwy::SupportedTargets()); + + EXPECT_EQ(target, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; + + // Calling DeInit() will test that the initializer function + // also calls the right function. + hwy::GetChosenTarget().DeInit(); + +#if HWY_DISPATCH_WORKAROUND + EXPECT_EQ(HWY_STATIC_TARGET, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; +#else + EXPECT_EQ(target, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; +#endif + + // Second call uses the cached value from the previous call. + EXPECT_EQ(target, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; +} + +void CheckFakeFunction() { + // When adding a target, also add to DECLARE_FUNCTION above. + CallFunctionForTarget(HWY_AVX3_DL, __LINE__); + CallFunctionForTarget(HWY_AVX3, __LINE__); + CallFunctionForTarget(HWY_AVX2, __LINE__); + CallFunctionForTarget(HWY_SSE4, __LINE__); + CallFunctionForTarget(HWY_SSSE3, __LINE__); + CallFunctionForTarget(HWY_NEON, __LINE__); + CallFunctionForTarget(HWY_SVE, __LINE__); + CallFunctionForTarget(HWY_SVE2, __LINE__); + CallFunctionForTarget(HWY_SVE_256, __LINE__); + CallFunctionForTarget(HWY_SVE2_128, __LINE__); + CallFunctionForTarget(HWY_PPC8, __LINE__); + CallFunctionForTarget(HWY_WASM, __LINE__); + CallFunctionForTarget(HWY_RVV, __LINE__); + // The tables only have space for either HWY_SCALAR or HWY_EMU128; the former + // is opt-in only. +#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128 + CallFunctionForTarget(HWY_SCALAR, __LINE__); +#else + CallFunctionForTarget(HWY_EMU128, __LINE__); +#endif +} + +} // namespace fake + +namespace hwy { + +class HwyTargetsTest : public testing::Test { + protected: + void TearDown() override { + SetSupportedTargetsForTest(0); + DisableTargets(0); // Reset the mask. + } +}; + +// Test that the order in the HWY_EXPORT static array matches the expected +// value of the target bits. This is only checked for the targets that are +// enabled in the current compilation. +TEST_F(HwyTargetsTest, ChosenTargetOrderTest) { fake::CheckFakeFunction(); } + +TEST_F(HwyTargetsTest, DisabledTargetsTest) { + DisableTargets(~0LL); + // Check that disabling everything at least leaves the static target. + HWY_ASSERT(HWY_STATIC_TARGET == SupportedTargets()); + + DisableTargets(0); // Reset the mask. + const int64_t current_targets = SupportedTargets(); + const int64_t enabled_baseline = static_cast(HWY_ENABLED_BASELINE); + // Exclude these two because they are always returned by SupportedTargets. + const int64_t fallback = HWY_SCALAR | HWY_EMU128; + if ((current_targets & ~enabled_baseline & ~fallback) == 0) { + // We can't test anything else if the only compiled target is the baseline. + return; + } + + // Get the lowest bit in the mask (the best target) and disable that one. + const int64_t best_target = current_targets & (~current_targets + 1); + DisableTargets(best_target); + + // Check that the other targets are still enabled. + HWY_ASSERT((best_target ^ current_targets) == SupportedTargets()); + DisableTargets(0); // Reset the mask. +} + +} // namespace hwy diff --git a/media/highway/src/hwy/tests/arithmetic_test.cc b/media/highway/src/hwy/tests/arithmetic_test.cc new file mode 100644 index 000000000..1fbbd29ad --- /dev/null +++ b/media/highway/src/hwy/tests/arithmetic_test.cc @@ -0,0 +1,445 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/arithmetic_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestPlusMinus { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, T(2)); + const auto v3 = Iota(d, T(3)); + const auto v4 = Iota(d, T(4)); + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast((2 + i) + (3 + i)); + } + HWY_ASSERT_VEC_EQ(d, lanes.get(), Add(v2, v3)); + HWY_ASSERT_VEC_EQ(d, Set(d, 2), Sub(v4, v2)); + + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast((2 + i) + (4 + i)); + } + auto sum = v2; + sum = Add(sum, v4); // sum == 6,8.. + HWY_ASSERT_VEC_EQ(d, Load(d, lanes.get()), sum); + + sum = Sub(sum, v4); + HWY_ASSERT_VEC_EQ(d, v2, sum); + } +}; + +HWY_NOINLINE void TestAllPlusMinus() { + ForAllTypes(ForPartialVectors()); +} + +struct TestUnsignedSaturatingArithmetic { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 1); + const auto vm = Set(d, LimitsMax()); + + HWY_ASSERT_VEC_EQ(d, Add(v0, v0), SaturatedAdd(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Add(v0, vi), SaturatedAdd(v0, vi)); + HWY_ASSERT_VEC_EQ(d, Add(v0, vm), SaturatedAdd(v0, vm)); + HWY_ASSERT_VEC_EQ(d, vm, SaturatedAdd(vi, vm)); + HWY_ASSERT_VEC_EQ(d, vm, SaturatedAdd(vm, vm)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(vi, vi)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(vi, vm)); + HWY_ASSERT_VEC_EQ(d, Sub(vm, vi), SaturatedSub(vm, vi)); + } +}; + +struct TestSignedSaturatingArithmetic { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vpm = Set(d, LimitsMax()); + // Ensure all lanes are positive, even if Iota wraps around + const auto vi = Or(And(Iota(d, 0), vpm), Set(d, 1)); + const auto vn = Sub(v0, vi); + const auto vnm = Set(d, LimitsMin()); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), Gt(vi, v0)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), Lt(vn, v0)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedAdd(v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, SaturatedAdd(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(v0, vpm)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(vi, vpm)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(vpm, vpm)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Sub(v0, vi), SaturatedSub(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vn, SaturatedSub(vn, v0)); + HWY_ASSERT_VEC_EQ(d, vnm, SaturatedSub(vnm, vi)); + HWY_ASSERT_VEC_EQ(d, vnm, SaturatedSub(vnm, vpm)); + } +}; + +HWY_NOINLINE void TestAllSaturatingArithmetic() { + const ForPartialVectors test_unsigned; + test_unsigned(uint8_t()); + test_unsigned(uint16_t()); + + const ForPartialVectors test_signed; + test_signed(int8_t()); + test_signed(int16_t()); +} + +struct TestAverage { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto v2 = Set(d, T(2)); + + HWY_ASSERT_VEC_EQ(d, v0, AverageRound(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v0, v1)); + HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v1, v1)); + HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v2, v2)); + } +}; + +HWY_NOINLINE void TestAllAverage() { + const ForPartialVectors test; + test(uint8_t()); + test(uint16_t()); +} + +struct TestAbs { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp1 = Set(d, T(1)); + const auto vn1 = Set(d, T(-1)); + const auto vpm = Set(d, LimitsMax()); + const auto vnm = Set(d, LimitsMin()); + + HWY_ASSERT_VEC_EQ(d, v0, Abs(v0)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vp1)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vn1)); + HWY_ASSERT_VEC_EQ(d, vpm, Abs(vpm)); + HWY_ASSERT_VEC_EQ(d, vnm, Abs(vnm)); + } +}; + +struct TestFloatAbs { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp1 = Set(d, T(1)); + const auto vn1 = Set(d, T(-1)); + const auto vp2 = Set(d, T(0.01)); + const auto vn2 = Set(d, T(-0.01)); + + HWY_ASSERT_VEC_EQ(d, v0, Abs(v0)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vp1)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vn1)); + HWY_ASSERT_VEC_EQ(d, vp2, Abs(vp2)); + HWY_ASSERT_VEC_EQ(d, vp2, Abs(vn2)); + } +}; + +HWY_NOINLINE void TestAllAbs() { + ForSignedTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +struct TestNeg { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vn = Set(d, T(-3)); + const auto vp = Set(d, T(3)); + HWY_ASSERT_VEC_EQ(d, v0, Neg(v0)); + HWY_ASSERT_VEC_EQ(d, vp, Neg(vn)); + HWY_ASSERT_VEC_EQ(d, vn, Neg(vp)); + } +}; + +HWY_NOINLINE void TestAllNeg() { + ForSignedTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +struct TestUnsignedMinMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + // Leave headroom such that v1 < v2 even after wraparound. + const auto mod = And(Iota(d, 0), Set(d, LimitsMax() >> 1)); + const auto v1 = Add(mod, Set(d, 1)); + const auto v2 = Add(mod, Set(d, 2)); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v0, Min(v1, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v0)); + + const auto vmin = Set(d, LimitsMin()); + const auto vmax = Set(d, LimitsMax()); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +struct TestSignedMinMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Leave headroom such that v1 < v2 even after wraparound. + const auto mod = And(Iota(d, 0), Set(d, LimitsMax() >> 1)); + const auto v1 = Add(mod, Set(d, 1)); + const auto v2 = Add(mod, Set(d, 2)); + const auto v_neg = Sub(Zero(d), v1); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v_neg, Min(v1, v_neg)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v_neg)); + + const auto v0 = Zero(d); + const auto vmin = Set(d, LimitsMin()); + const auto vmax = Set(d, LimitsMax()); + HWY_ASSERT_VEC_EQ(d, vmin, Min(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Max(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, v0, Max(vmin, v0)); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +struct TestFloatMinMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + const auto v_neg = Iota(d, -T(Lanes(d))); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v_neg, Min(v1, v_neg)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v_neg)); + + const auto v0 = Zero(d); + const auto vmin = Set(d, T(-1E30)); + const auto vmax = Set(d, T(1E30)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Max(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, v0, Max(vmin, v0)); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +HWY_NOINLINE void TestAllMinMax() { + ForUnsignedTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +template +static HWY_NOINLINE Vec Make128(D d, uint64_t hi, uint64_t lo) { + alignas(16) uint64_t in[2]; + in[0] = lo; + in[1] = hi; + return LoadDup128(d, in); +} + +struct TestMinMax128 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const size_t N = Lanes(d); + auto a_lanes = AllocateAligned(N); + auto b_lanes = AllocateAligned(N); + auto min_lanes = AllocateAligned(N); + auto max_lanes = AllocateAligned(N); + RandomState rng; + + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + // Same arg + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Min128(d, v11, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Max128(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v11, v11)); + + // First arg less + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v10, v11)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v10, v11)); + + // Second arg less + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v11, v10)); + + // Also check 128-bit blocks are independent + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + a_lanes[i] = Random64(&rng); + b_lanes[i] = Random64(&rng); + } + const V a = Load(d, a_lanes.get()); + const V b = Load(d, b_lanes.get()); + for (size_t i = 0; i < N; i += 2) { + const bool lt = a_lanes[i + 1] == b_lanes[i + 1] + ? (a_lanes[i] < b_lanes[i]) + : (a_lanes[i + 1] < b_lanes[i + 1]); + min_lanes[i + 0] = lt ? a_lanes[i + 0] : b_lanes[i + 0]; + min_lanes[i + 1] = lt ? a_lanes[i + 1] : b_lanes[i + 1]; + max_lanes[i + 0] = lt ? b_lanes[i + 0] : a_lanes[i + 0]; + max_lanes[i + 1] = lt ? b_lanes[i + 1] : a_lanes[i + 1]; + } + HWY_ASSERT_VEC_EQ(d, min_lanes.get(), Min128(d, a, b)); + HWY_ASSERT_VEC_EQ(d, max_lanes.get(), Max128(d, a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMinMax128() { + ForGEVectors<128, TestMinMax128>()(uint64_t()); +} + +struct TestMinMax128Upper { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const size_t N = Lanes(d); + auto a_lanes = AllocateAligned(N); + auto b_lanes = AllocateAligned(N); + auto min_lanes = AllocateAligned(N); + auto max_lanes = AllocateAligned(N); + RandomState rng; + + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + // Same arg + HWY_ASSERT_VEC_EQ(d, v00, Min128Upper(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128Upper(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Min128Upper(d, v11, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Max128Upper(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Max128Upper(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128Upper(d, v11, v11)); + + // Equivalent but not equal (chooses second arg) + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Min128Upper(d, v10, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Min128Upper(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Min128Upper(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v00, Max128Upper(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v01, Max128Upper(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Max128Upper(d, v10, v11)); + + // First arg less + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v01, v10)); + + // Second arg less + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v10, v01)); + + // Also check 128-bit blocks are independent + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + a_lanes[i] = Random64(&rng); + b_lanes[i] = Random64(&rng); + } + const V a = Load(d, a_lanes.get()); + const V b = Load(d, b_lanes.get()); + for (size_t i = 0; i < N; i += 2) { + const bool lt = a_lanes[i + 1] < b_lanes[i + 1]; + min_lanes[i + 0] = lt ? a_lanes[i + 0] : b_lanes[i + 0]; + min_lanes[i + 1] = lt ? a_lanes[i + 1] : b_lanes[i + 1]; + max_lanes[i + 0] = lt ? b_lanes[i + 0] : a_lanes[i + 0]; + max_lanes[i + 1] = lt ? b_lanes[i + 1] : a_lanes[i + 1]; + } + HWY_ASSERT_VEC_EQ(d, min_lanes.get(), Min128Upper(d, a, b)); + HWY_ASSERT_VEC_EQ(d, max_lanes.get(), Max128Upper(d, a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMinMax128Upper() { + ForGEVectors<128, TestMinMax128Upper>()(uint64_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyArithmeticTest); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllPlusMinus); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSaturatingArithmetic); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAverage); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAbs); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllNeg); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax128); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax128Upper); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/blockwise_shift_test.cc b/media/highway/src/hwy/tests/blockwise_shift_test.cc new file mode 100644 index 000000000..d14fb86e3 --- /dev/null +++ b/media/highway/src/hwy/tests/blockwise_shift_test.cc @@ -0,0 +1,268 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memcpy + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/blockwise_shift_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestShiftBytes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Bytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const Repartition du8; + const size_t N8 = Lanes(du8); + + // Zero remains zero + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(v0)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(d, v0)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightBytes<1>(d, v0)); + + // Zero after shifting out the high/low byte + auto bytes = AllocateAligned(N8); + std::fill(bytes.get(), bytes.get() + N8, 0); + bytes[N8 - 1] = 0x7F; + const auto vhi = BitCast(d, Load(du8, bytes.get())); + bytes[N8 - 1] = 0; + bytes[0] = 0x7F; + const auto vlo = BitCast(d, Load(du8, bytes.get())); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(vhi)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(d, vhi)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightBytes<1>(d, vlo)); + + // Check expected result with Iota + const size_t N = Lanes(d); + auto in = AllocateAligned(N); + const uint8_t* in_bytes = reinterpret_cast(in.get()); + const auto v = BitCast(d, Iota(du8, 1)); + Store(v, d, in.get()); + + auto expected = AllocateAligned(N); + uint8_t* expected_bytes = reinterpret_cast(expected.get()); + + const size_t block_size = HWY_MIN(N8, 16); + for (size_t block = 0; block < N8; block += block_size) { + expected_bytes[block] = 0; + memcpy(expected_bytes + block + 1, in_bytes + block, block_size - 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftBytes<1>(v)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftBytes<1>(d, v)); + + for (size_t block = 0; block < N8; block += block_size) { + memcpy(expected_bytes + block, in_bytes + block + 1, block_size - 1); + expected_bytes[block + block_size - 1] = 0; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightBytes<1>(d, v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllShiftBytes() { + ForIntegerTypes(ForPartialVectors()); +} + +struct TestShiftLeftLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Lanes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const auto v = Iota(d, T(1)); + const size_t N = Lanes(d); + if (N == 1) return; + auto expected = AllocateAligned(N); + + HWY_ASSERT_VEC_EQ(d, v, ShiftLeftLanes<0>(v)); + HWY_ASSERT_VEC_EQ(d, v, ShiftLeftLanes<0>(d, v)); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + + for (size_t i = 0; i < N; ++i) { + expected[i] = (i % kLanesPerBlock) == 0 ? T(0) : T(i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftLanes<1>(v)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftLanes<1>(d, v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +struct TestShiftRightLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Lanes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const auto v = Iota(d, T(1)); + const size_t N = Lanes(d); + if (N == 1) return; + auto expected = AllocateAligned(N); + + HWY_ASSERT_VEC_EQ(d, v, ShiftRightLanes<0>(d, v)); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + + for (size_t i = 0; i < N; ++i) { + const size_t mod = i % kLanesPerBlock; + expected[i] = mod == (kLanesPerBlock - 1) || i >= N - 1 ? T(0) : T(2 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightLanes<1>(d, v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllShiftLeftLanes() { + ForAllTypes(ForPartialVectors()); +} + +HWY_NOINLINE void TestAllShiftRightLanes() { + ForAllTypes(ForPartialVectors()); +} + +// Scalar does not define CombineShiftRightBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template +struct TestCombineShiftRightBytes { + template + HWY_NOINLINE void operator()(T, D d) { + constexpr size_t kBlockSize = 16; + static_assert(kBytes < kBlockSize, "Shift count is per block"); + const Repartition d8; + const size_t N8 = Lanes(d8); + if (N8 < 16) return; + auto hi_bytes = AllocateAligned(N8); + auto lo_bytes = AllocateAligned(N8); + auto expected_bytes = AllocateAligned(N8); + uint8_t combined[2 * kBlockSize]; + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(100); ++rep) { + for (size_t i = 0; i < N8; ++i) { + hi_bytes[i] = static_cast(Random64(&rng) & 0xFF); + lo_bytes[i] = static_cast(Random64(&rng) & 0xFF); + } + for (size_t i = 0; i < N8; i += kBlockSize) { + // Arguments are not the same size. + CopyBytes(&lo_bytes[i], combined); + CopyBytes(&hi_bytes[i], combined + kBlockSize); + CopyBytes(combined + kBytes, &expected_bytes[i]); + } + + const auto hi = BitCast(d, Load(d8, hi_bytes.get())); + const auto lo = BitCast(d, Load(d8, lo_bytes.get())); + const auto expected = BitCast(d, Load(d8, expected_bytes.get())); + HWY_ASSERT_VEC_EQ(d, expected, CombineShiftRightBytes(d, hi, lo)); + } + } +}; + +template +struct TestCombineShiftRightLanes { + template + HWY_NOINLINE void operator()(T, D d) { + const Repartition d8; + const size_t N8 = Lanes(d8); + if (N8 < 16) return; + + auto hi_bytes = AllocateAligned(N8); + auto lo_bytes = AllocateAligned(N8); + auto expected_bytes = AllocateAligned(N8); + constexpr size_t kBlockSize = 16; + uint8_t combined[2 * kBlockSize]; + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(100); ++rep) { + for (size_t i = 0; i < N8; ++i) { + hi_bytes[i] = static_cast(Random64(&rng) & 0xFF); + lo_bytes[i] = static_cast(Random64(&rng) & 0xFF); + } + for (size_t i = 0; i < N8; i += kBlockSize) { + // Arguments are not the same size. + CopyBytes(&lo_bytes[i], combined); + CopyBytes(&hi_bytes[i], combined + kBlockSize); + CopyBytes(combined + kLanes * sizeof(T), + &expected_bytes[i]); + } + + const auto hi = BitCast(d, Load(d8, hi_bytes.get())); + const auto lo = BitCast(d, Load(d8, lo_bytes.get())); + const auto expected = BitCast(d, Load(d8, expected_bytes.get())); + HWY_ASSERT_VEC_EQ(d, expected, CombineShiftRightLanes(d, hi, lo)); + } + } +}; + +#endif // #if HWY_TARGET != HWY_SCALAR + +struct TestCombineShiftRight { + template + HWY_NOINLINE void operator()(T t, D d) { +// Scalar does not define CombineShiftRightBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + constexpr int kMaxBytes = + HWY_MIN(16, static_cast(MaxLanes(d) * sizeof(T))); + constexpr int kMaxLanes = kMaxBytes / static_cast(sizeof(T)); + TestCombineShiftRightBytes()(t, d); + TestCombineShiftRightBytes()(t, d); + TestCombineShiftRightBytes<1>()(t, d); + + TestCombineShiftRightLanes()(t, d); + TestCombineShiftRightLanes()(t, d); + TestCombineShiftRightLanes<1>()(t, d); +#else + (void)t; + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllCombineShiftRight() { + // Need at least 2 lanes. + ForAllTypes(ForShrinkableVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyBlockwiseShiftTest); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllShiftBytes); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllShiftLeftLanes); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllShiftRightLanes); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllCombineShiftRight); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/blockwise_test.cc b/media/highway/src/hwy/tests/blockwise_test.cc new file mode 100644 index 000000000..41097eeca --- /dev/null +++ b/media/highway/src/hwy/tests/blockwise_test.cc @@ -0,0 +1,452 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/blockwise_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct TestBroadcastR { + HWY_NOINLINE void operator()() const { + using T = typename D::T; + const D d; + const size_t N = Lanes(d); + if (kLane >= N) return; + auto in_lanes = AllocateAligned(N); + std::fill(in_lanes.get(), in_lanes.get() + N, T(0)); + const size_t blockN = HWY_MIN(N * sizeof(T), 16) / sizeof(T); + // Need to set within each 128-bit block + for (size_t block = 0; block < N; block += blockN) { + in_lanes[block + kLane] = static_cast(block + 1); + } + const auto in = Load(d, in_lanes.get()); + auto expected = AllocateAligned(N); + for (size_t block = 0; block < N; block += blockN) { + for (size_t i = 0; i < blockN; ++i) { + expected[block + i] = T(block + 1); + } + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Broadcast(in)); + + TestBroadcastR()(); + } +}; + +template +struct TestBroadcastR { + void operator()() const {} +}; + +struct TestBroadcast { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + TestBroadcastR()(); + } +}; + +HWY_NOINLINE void TestAllBroadcast() { + const ForPartialVectors test; + // No u/i8. + test(uint16_t()); + test(int16_t()); + ForUIF3264(test); +} + +template +struct ChooseTableSize { + template + using type = DIdx; +}; +template <> +struct ChooseTableSize { + template + using type = ScalableTag; +}; + +template +struct TestTableLookupBytes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + RandomState rng; + + const typename ChooseTableSize::template type d_tbl; + const Repartition d_tbl8; + const size_t NT8 = Lanes(d_tbl8); + + const Repartition d8; + const size_t N8 = Lanes(d8); + + // Random input bytes + auto in_bytes = AllocateAligned(NT8); + for (size_t i = 0; i < NT8; ++i) { + in_bytes[i] = Random32(&rng) & 0xFF; + } + const auto in = BitCast(d_tbl, Load(d_tbl8, in_bytes.get())); + + // Enough test data; for larger vectors, upper lanes will be zero. + const uint8_t index_bytes_source[64] = { + // Same index as source, multiple outputs from same input, + // unused input (9), ascending/descending and nonconsecutive neighbors. + 0, 2, 1, 2, 15, 12, 13, 14, 6, 7, 8, 5, 4, 3, 10, 11, + 11, 10, 3, 4, 5, 8, 7, 6, 14, 13, 12, 15, 2, 1, 2, 0, + 4, 3, 2, 2, 5, 6, 7, 7, 15, 15, 15, 15, 15, 15, 0, 1}; + auto index_bytes = AllocateAligned(N8); + const size_t max_index = HWY_MIN(NT8, 16) - 1; + for (size_t i = 0; i < N8; ++i) { + index_bytes[i] = (i < 64) ? index_bytes_source[i] : 0; + // Avoid asan error for partial vectors. + index_bytes[i] = static_cast(HWY_MIN(index_bytes[i], max_index)); + } + const auto indices = Load(d, reinterpret_cast(index_bytes.get())); + + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + uint8_t* expected_bytes = reinterpret_cast(expected.get()); + + for (size_t block = 0; block < N8; block += 16) { + for (size_t i = 0; i < 16 && (block + i) < N8; ++i) { + const uint8_t index = index_bytes[block + i]; + HWY_ASSERT(index <= max_index); + // Note that block + index may exceed NT8 on RVV, which is fine because + // the operation uses the larger of the table and index vector size. + HWY_ASSERT(block + index < HWY_MAX(N8, NT8)); + // For large vectors, the lane index may wrap around due to block, + // also wrap around after 8-bit overflow. + expected_bytes[block + i] = + in_bytes[(block + index) % HWY_MIN(NT8, 256)]; + } + } + HWY_ASSERT_VEC_EQ(d, expected.get(), TableLookupBytes(in, indices)); + + // Individually test zeroing each byte position. + for (size_t i = 0; i < N8; ++i) { + const uint8_t prev_expected = expected_bytes[i]; + const uint8_t prev_index = index_bytes[i]; + expected_bytes[i] = 0; + + const int idx = 0x80 + (static_cast(Random32(&rng) & 7) << 4); + HWY_ASSERT(0x80 <= idx && idx < 256); + index_bytes[i] = static_cast(idx); + + const auto indices = + Load(d, reinterpret_cast(index_bytes.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), TableLookupBytesOr0(in, indices)); + expected_bytes[i] = prev_expected; + index_bytes[i] = prev_index; + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllTableLookupBytesSame() { + // Partial index, same-sized table. + ForIntegerTypes(ForPartialVectors>()); +} + +HWY_NOINLINE void TestAllTableLookupBytesMixed() { + // Partial index, full-size table. + ForIntegerTypes(ForPartialVectors>()); +} + +struct TestInterleaveLower { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast(2 * i + 0); + odd_lanes[i] = static_cast(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = HWY_MIN(16 / sizeof(T), N); + for (size_t i = 0; i < Lanes(d); ++i) { + const size_t block = i / blockN; + const size_t index = (i % blockN) + block * 2 * blockN; + expected[i] = static_cast(index & LimitsMax()); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveLower(even, odd)); + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveLower(d, even, odd)); + } +}; + +struct TestInterleaveUpper { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + if (N == 1) return; + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast(2 * i + 0); + odd_lanes[i] = static_cast(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = HWY_MIN(16 / sizeof(T), N); + for (size_t i = 0; i < Lanes(d); ++i) { + const size_t block = i / blockN; + expected[i] = T((i % blockN) + block * 2 * blockN + blockN); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveUpper(d, even, odd)); + } +}; + +HWY_NOINLINE void TestAllInterleave() { + // Not DemoteVectors because this cannot be supported by HWY_SCALAR. + ForAllTypes(ForShrinkableVectors()); + ForAllTypes(ForShrinkableVectors()); +} + +struct TestZipLower { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using WideT = MakeWide; + static_assert(sizeof(T) * 2 == sizeof(WideT), "Must be double-width"); + static_assert(IsSigned() == IsSigned(), "Must have same sign"); + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + // At least 2 lanes for HWY_SCALAR + auto zip_lanes = AllocateAligned(HWY_MAX(N, 2)); + const T kMaxT = LimitsMax(); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast((2 * i + 0) & kMaxT); + odd_lanes[i] = static_cast((2 * i + 1) & kMaxT); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const Repartition dw; +#if HWY_TARGET == HWY_SCALAR + // Safely handle big-endian + const auto expected = Set(dw, static_cast(1ULL << (sizeof(T) * 8))); +#else + const size_t blockN = HWY_MIN(size_t(16) / sizeof(T), N); + for (size_t i = 0; i < N; i += 2) { + const size_t base = (i / blockN) * blockN; + const size_t mod = i % blockN; + zip_lanes[i + 0] = even_lanes[mod / 2 + base]; + zip_lanes[i + 1] = odd_lanes[mod / 2 + base]; + } + const auto expected = + Load(dw, reinterpret_cast(zip_lanes.get())); +#endif // HWY_TARGET == HWY_SCALAR + HWY_ASSERT_VEC_EQ(dw, expected, ZipLower(even, odd)); + HWY_ASSERT_VEC_EQ(dw, expected, ZipLower(dw, even, odd)); + } +}; + +HWY_NOINLINE void TestAllZipLower() { + const ForDemoteVectors lower_unsigned; + lower_unsigned(uint8_t()); + lower_unsigned(uint16_t()); +#if HWY_HAVE_INTEGER64 + lower_unsigned(uint32_t()); // generates u64 +#endif + + const ForDemoteVectors lower_signed; + lower_signed(int8_t()); + lower_signed(int16_t()); +#if HWY_HAVE_INTEGER64 + lower_signed(int32_t()); // generates i64 +#endif + + // No float - concatenating f32 does not result in a f64 +} + +// Remove this test (so it does not show as having run) if the only target is +// HWY_SCALAR, which does not support this op. +#if HWY_TARGETS != HWY_SCALAR + +struct TestZipUpper { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET == HWY_SCALAR + (void)d; +#else + using WideT = MakeWide; + static_assert(sizeof(T) * 2 == sizeof(WideT), "Must be double-width"); + static_assert(IsSigned() == IsSigned(), "Must have same sign"); + const size_t N = Lanes(d); + if (N < 16 / sizeof(T)) return; + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + auto zip_lanes = AllocateAligned(N); + const T kMaxT = LimitsMax(); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast((2 * i + 0) & kMaxT); + odd_lanes[i] = static_cast((2 * i + 1) & kMaxT); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = HWY_MIN(size_t(16) / sizeof(T), N); + + for (size_t i = 0; i < N; i += 2) { + const size_t base = (i / blockN) * blockN + blockN / 2; + const size_t mod = i % blockN; + zip_lanes[i + 0] = even_lanes[mod / 2 + base]; + zip_lanes[i + 1] = odd_lanes[mod / 2 + base]; + } + const Repartition dw; + const auto expected = + Load(dw, reinterpret_cast(zip_lanes.get())); + HWY_ASSERT_VEC_EQ(dw, expected, ZipUpper(dw, even, odd)); +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllZipUpper() { + const ForShrinkableVectors upper_unsigned; + upper_unsigned(uint8_t()); + upper_unsigned(uint16_t()); +#if HWY_HAVE_INTEGER64 + upper_unsigned(uint32_t()); // generates u64 +#endif + + const ForShrinkableVectors upper_signed; + upper_signed(int8_t()); + upper_signed(int16_t()); +#if HWY_HAVE_INTEGER64 + upper_signed(int32_t()); // generates i64 +#endif + + // No float - concatenating f32 does not result in a f64 +} + +#endif // HWY_TARGETS != HWY_SCALAR + +class TestSpecialShuffle32 { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, 0); + VerifyLanes32(d, Shuffle2301(v), 2, 3, 0, 1, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle1032(v), 1, 0, 3, 2, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle0321(v), 0, 3, 2, 1, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle2103(v), 2, 1, 0, 3, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle0123(v), 0, 1, 2, 3, __FILE__, __LINE__); + } + + private: + // HWY_INLINE works around a Clang SVE compiler bug where all but the first + // 128 bits (the NEON register) of actual are zero. + template + HWY_INLINE void VerifyLanes32(D d, VecArg actual, const size_t i3, + const size_t i2, const size_t i1, + const size_t i0, const char* filename, + const int line) { + using T = TFromD; + constexpr size_t kBlockN = 16 / sizeof(T); + const size_t N = Lanes(d); + if (N < 4) return; + auto expected = AllocateAligned(N); + for (size_t block = 0; block < N; block += kBlockN) { + expected[block + 3] = static_cast(block + i3); + expected[block + 2] = static_cast(block + i2); + expected[block + 1] = static_cast(block + i1); + expected[block + 0] = static_cast(block + i0); + } + AssertVecEqual(d, expected.get(), actual, filename, line); + } +}; + +class TestSpecialShuffle64 { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, 0); + VerifyLanes64(d, Shuffle01(v), 0, 1, __FILE__, __LINE__); + } + + private: + // HWY_INLINE works around a Clang SVE compiler bug where all but the first + // 128 bits (the NEON register) of actual are zero. + template + HWY_INLINE void VerifyLanes64(D d, VecArg actual, const size_t i1, + const size_t i0, const char* filename, + const int line) { + using T = TFromD; + constexpr size_t kBlockN = 16 / sizeof(T); + const size_t N = Lanes(d); + if (N < 2) return; + auto expected = AllocateAligned(N); + for (size_t block = 0; block < N; block += kBlockN) { + expected[block + 1] = static_cast(block + i1); + expected[block + 0] = static_cast(block + i0); + } + AssertVecEqual(d, expected.get(), actual, filename, line); + } +}; + +HWY_NOINLINE void TestAllSpecialShuffles() { + const ForGEVectors<128, TestSpecialShuffle32> test32; + test32(uint32_t()); + test32(int32_t()); + test32(float()); + +#if HWY_HAVE_INTEGER64 + const ForGEVectors<128, TestSpecialShuffle64> test64; + test64(uint64_t()); + test64(int64_t()); +#endif + +#if HWY_HAVE_FLOAT64 + const ForGEVectors<128, TestSpecialShuffle64> test_d; + test_d(double()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyBlockwiseTest); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllBroadcast); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllTableLookupBytesSame); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllTableLookupBytesMixed); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllInterleave); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllZipLower); +#if HWY_TARGETS != HWY_SCALAR +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllZipUpper); +#endif +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllSpecialShuffles); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/combine_test.cc b/media/highway/src/hwy/tests/combine_test.cc new file mode 100644 index 000000000..b99f07a7d --- /dev/null +++ b/media/highway/src/hwy/tests/combine_test.cc @@ -0,0 +1,273 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memcpy + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/combine_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLowerHalf { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Half d2; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + auto lanes2 = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + std::fill(lanes2.get(), lanes2.get() + N, T(0)); + const auto v = Iota(d, 1); + Store(LowerHalf(d2, v), d2, lanes.get()); + Store(LowerHalf(v), d2, lanes2.get()); // optionally without D + size_t i = 0; + for (; i < Lanes(d2); ++i) { + HWY_ASSERT_EQ(T(1 + i), lanes[i]); + HWY_ASSERT_EQ(T(1 + i), lanes2[i]); + } + // Other half remains unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + HWY_ASSERT_EQ(T(0), lanes2[i]); + } + } +}; + +struct TestLowerQuarter { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Half d2; + const Half d4; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + auto lanes2 = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + std::fill(lanes2.get(), lanes2.get() + N, T(0)); + const auto v = Iota(d, 1); + const auto lo = LowerHalf(d4, LowerHalf(d2, v)); + const auto lo2 = LowerHalf(LowerHalf(v)); // optionally without D + Store(lo, d4, lanes.get()); + Store(lo2, d4, lanes2.get()); + size_t i = 0; + for (; i < Lanes(d4); ++i) { + HWY_ASSERT_EQ(T(i + 1), lanes[i]); + HWY_ASSERT_EQ(T(i + 1), lanes2[i]); + } + // Upper 3/4 remain unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + HWY_ASSERT_EQ(T(0), lanes2[i]); + } + } +}; + +HWY_NOINLINE void TestAllLowerHalf() { + ForAllTypes(ForHalfVectors()); + + // The minimum vector size is 128 bits, so there's no guarantee we can have + // quarters of 64-bit lanes, hence test 'all' other types. + ForHalfVectors test_quarter; + ForUI8(test_quarter); + ForUI16(test_quarter); // exclude float16_t - cannot compare + ForUIF32(test_quarter); +} + +struct TestUpperHalf { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define UpperHalf. +#if HWY_TARGET != HWY_SCALAR + const Half d2; + const size_t N2 = Lanes(d2); + HWY_ASSERT(N2 * 2 == Lanes(d)); + auto expected = AllocateAligned(N2); + size_t i = 0; + for (; i < N2; ++i) { + expected[i] = static_cast(N2 + 1 + i); + } + HWY_ASSERT_VEC_EQ(d2, expected.get(), UpperHalf(d2, Iota(d, 1))); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllUpperHalf() { + ForAllTypes(ForHalfVectors()); +} + +struct TestZeroExtendVector { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Twice d2; + + const auto v = Iota(d, 1); + const size_t N = Lanes(d); + const size_t N2 = Lanes(d2); + // If equal, then N was already MaxLanes(d) and it's not clear what + // Combine or ZeroExtendVector should return. + if (N2 == N) return; + HWY_ASSERT(N2 == 2 * N); + auto lanes = AllocateAligned(N2); + Store(v, d, &lanes[0]); + Store(v, d, &lanes[N]); + + const auto ext = ZeroExtendVector(d2, v); + Store(ext, d2, lanes.get()); + + // Lower half is unchanged + HWY_ASSERT_VEC_EQ(d, v, Load(d, &lanes[0])); + // Upper half is zero + HWY_ASSERT_VEC_EQ(d, Zero(d), Load(d, &lanes[N])); + } +}; + +HWY_NOINLINE void TestAllZeroExtendVector() { + ForAllTypes(ForExtendableVectors()); +} + +struct TestCombine { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Twice d2; + const size_t N2 = Lanes(d2); + auto lanes = AllocateAligned(N2); + + const auto lo = Iota(d, 1); + const auto hi = Iota(d, static_cast(N2 / 2 + 1)); + const auto combined = Combine(d2, hi, lo); + Store(combined, d2, lanes.get()); + + const auto expected = Iota(d2, 1); + HWY_ASSERT_VEC_EQ(d2, expected, combined); + } +}; + +HWY_NOINLINE void TestAllCombine() { + ForAllTypes(ForExtendableVectors()); +} + +struct TestConcat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + if (N == 1) return; + const size_t half_bytes = N * sizeof(T) / 2; + + auto hi = AllocateAligned(N); + auto lo = AllocateAligned(N); + auto expected = AllocateAligned(N); + RandomState rng; + for (size_t rep = 0; rep < 10; ++rep) { + for (size_t i = 0; i < N; ++i) { + hi[i] = static_cast(Random64(&rng) & 0xFF); + lo[i] = static_cast(Random64(&rng) & 0xFF); + } + + { + memcpy(&expected[N / 2], &hi[N / 2], half_bytes); + memcpy(&expected[0], &lo[0], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatUpperLower(d, vhi, vlo)); + } + + { + memcpy(&expected[N / 2], &hi[N / 2], half_bytes); + memcpy(&expected[0], &lo[N / 2], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatUpperUpper(d, vhi, vlo)); + } + + { + memcpy(&expected[N / 2], &hi[0], half_bytes); + memcpy(&expected[0], &lo[N / 2], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatLowerUpper(d, vhi, vlo)); + } + + { + memcpy(&expected[N / 2], &hi[0], half_bytes); + memcpy(&expected[0], &lo[0], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatLowerLower(d, vhi, vlo)); + } + } + } +}; + +HWY_NOINLINE void TestAllConcat() { + ForAllTypes(ForShrinkableVectors()); +} + +struct TestConcatOddEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + const auto hi = Iota(d, static_cast(N)); + const auto lo = Iota(d, 0); + const auto even = Add(Iota(d, 0), Iota(d, 0)); + const auto odd = Add(even, Set(d, 1)); + HWY_ASSERT_VEC_EQ(d, odd, ConcatOdd(d, hi, lo)); + HWY_ASSERT_VEC_EQ(d, even, ConcatEven(d, hi, lo)); + + // This test catches inadvertent saturation. + const auto min = Set(d, LowestValue()); + const auto max = Set(d, HighestValue()); + HWY_ASSERT_VEC_EQ(d, max, ConcatOdd(d, max, max)); + HWY_ASSERT_VEC_EQ(d, max, ConcatEven(d, max, max)); + HWY_ASSERT_VEC_EQ(d, min, ConcatOdd(d, min, min)); + HWY_ASSERT_VEC_EQ(d, min, ConcatEven(d, min, min)); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllConcatOddEven() { + ForAllTypes(ForShrinkableVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCombineTest); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllLowerHalf); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllUpperHalf); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllZeroExtendVector); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllCombine); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllConcat); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllConcatOddEven); +} // namespace hwy + +#endif // HWY_ONCE diff --git a/media/highway/src/hwy/tests/compare_test.cc b/media/highway/src/hwy/tests/compare_test.cc new file mode 100644 index 000000000..a96e29fc6 --- /dev/null +++ b/media/highway/src/hwy/tests/compare_test.cc @@ -0,0 +1,509 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memset + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/compare_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// All types. +struct TestEquality { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, 2); + const auto v2b = Iota(d, 2); + const auto v3 = Iota(d, 3); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq(v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq(v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq(v2, v2b)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ne(v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne(v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne(v2, v2b)); + } +}; + +HWY_NOINLINE void TestAllEquality() { + ForAllTypes(ForPartialVectors()); +} + +// a > b should be true, verify that for Gt/Lt and with swapped args. +template +void EnsureGreater(D d, TFromD a, TFromD b, const char* file, int line) { + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + const auto va = Set(d, a); + const auto vb = Set(d, b); + AssertMaskEqual(d, mask_true, Gt(va, vb), file, line); + AssertMaskEqual(d, mask_false, Lt(va, vb), file, line); + + // Swapped order + AssertMaskEqual(d, mask_false, Gt(vb, va), file, line); + AssertMaskEqual(d, mask_true, Lt(vb, va), file, line); + + // Also ensure irreflexive + AssertMaskEqual(d, mask_false, Gt(va, va), file, line); + AssertMaskEqual(d, mask_false, Gt(vb, vb), file, line); + AssertMaskEqual(d, mask_false, Lt(va, va), file, line); + AssertMaskEqual(d, mask_false, Lt(vb, vb), file, line); +} + +#define HWY_ENSURE_GREATER(d, a, b) EnsureGreater(d, a, b, __FILE__, __LINE__) + +struct TestStrictUnsigned { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T max = LimitsMax(); + const auto v0 = Zero(d); + const auto v2 = And(Iota(d, T(2)), Set(d, 255)); // 0..255 + + const auto mask_false = MaskFalse(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 128, 127); + HWY_ENSURE_GREATER(d, max, max / 2); + HWY_ENSURE_GREATER(d, max, 1); + HWY_ENSURE_GREATER(d, max, 0); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + } +}; + +HWY_NOINLINE void TestAllStrictUnsigned() { + ForUnsignedTypes(ForPartialVectors()); +} + +struct TestStrictInt { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T min = LimitsMin(); + const T max = LimitsMax(); + const auto v0 = Zero(d); + const auto v2 = And(Iota(d, T(2)), Set(d, 127)); // 0..127 + const auto vn = Sub(Neg(v2), Set(d, 1)); // -1..-128 + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 0, -1); + HWY_ENSURE_GREATER(d, -1, -2); + HWY_ENSURE_GREATER(d, max, max / 2); + HWY_ENSURE_GREATER(d, max, 1); + HWY_ENSURE_GREATER(d, max, 0); + HWY_ENSURE_GREATER(d, max, -1); + HWY_ENSURE_GREATER(d, max, min); + HWY_ENSURE_GREATER(d, 0, min); + HWY_ENSURE_GREATER(d, min / 2, min); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_true, Gt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt(vn, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(vn, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, vn)); + } +}; + +// S-SSE3 bug (#795): same upper, differing MSB in lower +struct TestStrictInt64 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto m0 = MaskFalse(d); + const auto m1 = MaskTrue(d); + HWY_ASSERT_MASK_EQ(d, m0, Lt(Set(d, 0x380000000LL), Set(d, 0x300000001LL))); + HWY_ASSERT_MASK_EQ(d, m1, Lt(Set(d, 0xF00000000LL), Set(d, 0xF80000000LL))); + HWY_ASSERT_MASK_EQ(d, m1, Lt(Set(d, 0xF00000000LL), Set(d, 0xF80000001LL))); + } +}; + +HWY_NOINLINE void TestAllStrictInt() { + ForSignedTypes(ForPartialVectors()); + ForPartialVectors()(int64_t()); +} + +struct TestStrictFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T huge_neg = T(-1E35); + const T huge_pos = T(1E36); + const auto v0 = Zero(d); + const auto v2 = Iota(d, T(2)); + const auto vn = Neg(v2); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 0, -1); + HWY_ENSURE_GREATER(d, -1, -2); + HWY_ENSURE_GREATER(d, huge_pos, 1); + HWY_ENSURE_GREATER(d, huge_pos, 0); + HWY_ENSURE_GREATER(d, huge_pos, -1); + HWY_ENSURE_GREATER(d, huge_pos, huge_neg); + HWY_ENSURE_GREATER(d, 0, huge_neg); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_true, Gt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt(vn, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(vn, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, vn)); + } +}; + +HWY_NOINLINE void TestAllStrictFloat() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestWeakFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, T(2)); + const auto vn = Iota(d, -T(Lanes(d))); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ge(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Le(vn, vn)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ge(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Le(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Le(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ge(vn, v2)); + } +}; + +HWY_NOINLINE void TestAllWeakFloat() { + ForFloatTypes(ForPartialVectors()); +} + +template +static HWY_NOINLINE Vec Make128(D d, uint64_t hi, uint64_t lo) { + alignas(16) uint64_t in[2]; + in[0] = lo; + in[1] = hi; + return LoadDup128(d, in); +} + +struct TestLt128 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax(), LimitsMax()); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllLt128() { ForGEVectors<128, TestLt128>()(uint64_t()); } + +struct TestLt128Upper { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax(), LimitsMax()); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllLt128Upper() { + ForGEVectors<128, TestLt128Upper>()(uint64_t()); +} + +struct TestEq128 { // Also Ne128 + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, v10, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v11, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, Add(iota, v10), iota)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax(), LimitsMax()); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, vm, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v11, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllEq128() { ForGEVectors<128, TestEq128>()(uint64_t()); } + +struct TestEq128Upper { // Also Ne128Upper + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v10, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v00, v01)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v01, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v01, v00)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v11, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, Add(iota, v10), iota)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax(), LimitsMax()); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, vm, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v11, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllEq128Upper() { + ForGEVectors<128, TestEq128Upper>()(uint64_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCompareTest); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEquality); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictUnsigned); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictInt); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictFloat); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllWeakFloat); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/compress_test.cc b/media/highway/src/hwy/tests/compress_test.cc new file mode 100644 index 000000000..e2d0ef0ba --- /dev/null +++ b/media/highway/src/hwy/tests/compress_test.cc @@ -0,0 +1,757 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memset + +#include // IWYU pragma: keep + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/compress_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Regenerate tables used in the implementation, instead of testing. +#define HWY_PRINT_TABLES 0 + +#if !HWY_PRINT_TABLES || HWY_IDE + +template , typename TI = TFromD> +void CheckStored(D d, DI di, size_t expected_pos, size_t actual_pos, + size_t num_to_check, const AlignedFreeUniquePtr& in, + const AlignedFreeUniquePtr& mask_lanes, + const AlignedFreeUniquePtr& expected, const T* actual_u, + int line) { + if (expected_pos != actual_pos) { + hwy::Abort(__FILE__, line, "Size mismatch for %s: expected %d, actual %d\n", + TypeName(T(), Lanes(d)).c_str(), static_cast(expected_pos), + static_cast(actual_pos)); + } + // Modified from AssertVecEqual - we may not be checking all lanes. + for (size_t i = 0; i < num_to_check; ++i) { + if (!IsEqual(expected[i], actual_u[i])) { + const size_t N = Lanes(d); + fprintf(stderr, "Mismatch at i=%d of %d, line %d:\n\n", + static_cast(i), static_cast(num_to_check), line); + Print(di, "mask", Load(di, mask_lanes.get()), 0, N); + Print(d, "in", Load(d, in.get()), 0, N); + Print(d, "expect", Load(d, expected.get()), 0, N); + Print(d, "actual", Load(d, actual_u), 0, N); + HWY_ASSERT(false); + } + } +} + +struct TestCompress { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(d); + + const T zero{0}; + + for (int frac : {0, 2, 3}) { + // For CompressStore + const size_t misalign = static_cast(frac) * N / 4; + + auto in_lanes = AllocateAligned(N); + auto mask_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + auto actual_a = AllocateAligned(misalign + N); + T* actual_u = actual_a.get() + misalign; + + const size_t bits_size = RoundUpTo((N + 7) / 8, 8); + auto bits = AllocateAligned(bits_size); + memset(bits.get(), 0, bits_size); // for MSAN + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + size_t expected_pos = 0; + for (size_t i = 0; i < N; ++i) { + const uint64_t bits = Random32(&rng); + in_lanes[i] = T(); // cannot initialize float16_t directly. + CopyBytes(&bits, &in_lanes[i]); // not same size + mask_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + if (mask_lanes[i] > 0) { + expected[expected_pos++] = in_lanes[i]; + } + } + size_t num_to_check; + if (CompressIsPartition::value) { + // For non-native Compress, also check that mask=false lanes were + // moved to the back of the vector (highest indices). + size_t extra = expected_pos; + for (size_t i = 0; i < N; ++i) { + if (mask_lanes[i] == 0) { + expected[extra++] = in_lanes[i]; + } + } + HWY_ASSERT(extra == N); + num_to_check = N; + } else { + // For native Compress, only the mask=true lanes are defined. + num_to_check = expected_pos; + } + + const auto in = Load(d, in_lanes.get()); + const auto mask = + RebindMask(d, Gt(Load(di, mask_lanes.get()), Zero(di))); + StoreMaskBits(d, mask, bits.get()); + + // Compress + memset(actual_u, 0, N * sizeof(T)); + StoreU(Compress(in, mask), d, actual_u); + CheckStored(d, di, expected_pos, expected_pos, num_to_check, in_lanes, + mask_lanes, expected, actual_u, __LINE__); + + // CompressNot + memset(actual_u, 0, N * sizeof(T)); + StoreU(CompressNot(in, Not(mask)), d, actual_u); + CheckStored(d, di, expected_pos, expected_pos, num_to_check, in_lanes, + mask_lanes, expected, actual_u, __LINE__); + + // CompressStore + memset(actual_u, 0, N * sizeof(T)); + const size_t size1 = CompressStore(in, mask, d, actual_u); + // expected_pos instead of num_to_check because this op is not + // affected by CompressIsPartition. + CheckStored(d, di, expected_pos, size1, expected_pos, in_lanes, + mask_lanes, expected, actual_u, __LINE__); + + // CompressBlendedStore + memset(actual_u, 0, N * sizeof(T)); + const size_t size2 = CompressBlendedStore(in, mask, d, actual_u); + // expected_pos instead of num_to_check because this op only writes + // the mask=true lanes. + CheckStored(d, di, expected_pos, size2, expected_pos, in_lanes, + mask_lanes, expected, actual_u, __LINE__); + // Subsequent lanes are untouched. + for (size_t i = size2; i < N; ++i) { + HWY_ASSERT_EQ(zero, actual_u[i]); + } + + // CompressBits + memset(actual_u, 0, N * sizeof(T)); + StoreU(CompressBits(in, bits.get()), d, actual_u); + CheckStored(d, di, expected_pos, expected_pos, num_to_check, in_lanes, + mask_lanes, expected, actual_u, __LINE__); + + // CompressBitsStore + memset(actual_u, 0, N * sizeof(T)); + const size_t size3 = CompressBitsStore(in, bits.get(), d, actual_u); + // expected_pos instead of num_to_check because this op is not + // affected by CompressIsPartition. + CheckStored(d, di, expected_pos, size3, expected_pos, in_lanes, + mask_lanes, expected, actual_u, __LINE__); + } // rep + } // frac + } // operator() +}; + +HWY_NOINLINE void TestAllCompress() { + ForUIF163264(ForPartialVectors()); +} + +struct TestCompressBlocks { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET == HWY_SCALAR + (void)d; +#else + static_assert(sizeof(T) == 8 && !IsSigned(), "Should be u64"); + RandomState rng; + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(d); + + auto in_lanes = AllocateAligned(N); + auto mask_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + auto actual = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + size_t expected_pos = 0; + for (size_t i = 0; i < N; i += 2) { + const uint64_t bits = Random32(&rng); + in_lanes[i + 1] = in_lanes[i] = T(); // cannot set float16_t directly. + CopyBytes(&bits, &in_lanes[i]); // not same size + CopyBytes(&bits, &in_lanes[i + 1]); // not same size + mask_lanes[i + 1] = mask_lanes[i] = TI{(Random32(&rng) & 8) ? 1 : 0}; + if (mask_lanes[i] > 0) { + expected[expected_pos++] = in_lanes[i]; + expected[expected_pos++] = in_lanes[i + 1]; + } + } + size_t num_to_check; + if (CompressIsPartition::value) { + // For non-native Compress, also check that mask=false lanes were + // moved to the back of the vector (highest indices). + size_t extra = expected_pos; + for (size_t i = 0; i < N; ++i) { + if (mask_lanes[i] == 0) { + expected[extra++] = in_lanes[i]; + } + } + HWY_ASSERT(extra == N); + num_to_check = N; + } else { + // For native Compress, only the mask=true lanes are defined. + num_to_check = expected_pos; + } + + const auto in = Load(d, in_lanes.get()); + const auto mask = RebindMask(d, Gt(Load(di, mask_lanes.get()), Zero(di))); + + // CompressBlocksNot + memset(actual.get(), 0, N * sizeof(T)); + StoreU(CompressBlocksNot(in, Not(mask)), d, actual.get()); + CheckStored(d, di, expected_pos, expected_pos, num_to_check, in_lanes, + mask_lanes, expected, actual.get(), __LINE__); + } // rep +#endif // HWY_TARGET == HWY_SCALAR + } // operator() +}; + +HWY_NOINLINE void TestAllCompressBlocks() { + ForGE128Vectors()(uint64_t()); +} + +#endif // !HWY_PRINT_TABLES + +#if HWY_PRINT_TABLES || HWY_IDE +namespace detail { // for code folding + +void PrintCompress16x8Tables() { + printf("======================================= 16x8\n"); + constexpr size_t N = 8; // 128-bit SIMD + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Doubled (for converting lane to byte indices) + for (size_t i = 0; i < N; ++i) { + printf("%d,", 2 * indices[i]); + } + printf(code & 1 ? "//\n" : "/**/"); + } + printf("\n"); +} + +void PrintCompressNot16x8Tables() { + printf("======================================= Not 16x8\n"); + constexpr size_t N = 8; // 128-bit SIMD + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Doubled (for converting lane to byte indices) + for (size_t i = 0; i < N; ++i) { + printf("%d,", 2 * indices[i]); + } + printf(not_code & 1 ? "//\n" : "/**/"); + } + printf("\n"); +} + +// Compressed to nibbles, unpacked via variable right shift. Also includes +// FirstN bits in the nibble MSB. +void PrintCompress32x8Tables() { + printf("======================================= 32/64x8\n"); + constexpr size_t N = 8; // AVX2 or 64-bit AVX3 + for (uint64_t code = 0; code < (1ull << N); ++code) { + const size_t count = PopCount(code); + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + if (i < count) { + indices[i] |= N; + HWY_ASSERT(indices[i] < 0x10); + } + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast(packed)); + } + printf("\n"); +} + +void PrintCompressNot32x8Tables() { + printf("======================================= Not 32/64x8\n"); + constexpr size_t N = 8; // AVX2 or 64-bit AVX3 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + const size_t count = PopCount(code); + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + if (i < count) { + indices[i] |= N; + HWY_ASSERT(indices[i] < 0x10); + } + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast(packed)); + } + printf("\n"); +} + +// Compressed to nibbles (for AVX3 64x4) +void PrintCompress64x4NibbleTables() { + printf("======================================= 64x4Nibble\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast(packed)); + } + printf("\n"); +} + +void PrintCompressNot64x4NibbleTables() { + printf("======================================= Not 64x4Nibble\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast(packed)); + } + printf("\n"); +} + +void PrintCompress64x4Tables() { + printf("======================================= 64x4 uncompressed\n"); + constexpr size_t N = 4; // SVE_256 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + printf("%d,", static_cast(indices[i])); + } + } + printf("\n"); +} + +void PrintCompressNot64x4Tables() { + printf("======================================= Not 64x4 uncompressed\n"); + constexpr size_t N = 4; // SVE_256 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + printf("%d,", static_cast(indices[i])); + } + } + printf("\n"); +} + +// Same as above, but prints pairs of u32 indices (for AVX2). Also includes +// FirstN bits in the nibble MSB. +void PrintCompress64x4PairTables() { + printf("======================================= 64x4 u32 index\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t code = 0; code < (1ull << N); ++code) { + const size_t count = PopCount(code); + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + const int first_n_bit = i < count ? 8 : 0; + const int low = static_cast(2 * indices[i]) + first_n_bit; + HWY_ASSERT(low < 0x10); + printf("%d, %d, ", low, low + 1); + } + } + printf("\n"); +} + +void PrintCompressNot64x4PairTables() { + printf("======================================= Not 64x4 u32 index\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + const size_t count = PopCount(code); + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + const int first_n_bit = i < count ? 8 : 0; + const int low = static_cast(2 * indices[i]) + first_n_bit; + HWY_ASSERT(low < 0x10); + printf("%d, %d, ", low, low + 1); + } + } + printf("\n"); +} + +// 4-tuple of byte indices +void PrintCompress32x4Tables() { + printf("======================================= 32x4\n"); + using T = uint32_t; + constexpr size_t N = 4; // SSE4 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +void PrintCompressNot32x4Tables() { + printf("======================================= Not 32x4\n"); + using T = uint32_t; + constexpr size_t N = 4; // SSE4 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +// 8-tuple of byte indices +void PrintCompress64x2Tables() { + printf("======================================= 64x2\n"); + using T = uint64_t; + constexpr size_t N = 2; // SSE4 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +void PrintCompressNot64x2Tables() { + printf("======================================= Not 64x2\n"); + using T = uint64_t; + constexpr size_t N = 2; // SSE4 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +} // namespace detail + +HWY_NOINLINE void PrintTables() { + // Only print once. +#if HWY_TARGET == HWY_STATIC_TARGET + detail::PrintCompress32x8Tables(); + detail::PrintCompressNot32x8Tables(); + detail::PrintCompress64x4NibbleTables(); + detail::PrintCompressNot64x4NibbleTables(); + detail::PrintCompress64x4Tables(); + detail::PrintCompressNot64x4Tables(); + detail::PrintCompress32x4Tables(); + detail::PrintCompressNot32x4Tables(); + detail::PrintCompress64x2Tables(); + detail::PrintCompressNot64x2Tables(); + detail::PrintCompress64x4PairTables(); + detail::PrintCompressNot64x4PairTables(); + detail::PrintCompress16x8Tables(); + detail::PrintCompressNot16x8Tables(); +#endif +} + +#endif // HWY_PRINT_TABLES + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCompressTest); +#if HWY_PRINT_TABLES +// Only print instead of running tests; this will be visible in the log. +HWY_EXPORT_AND_TEST_P(HwyCompressTest, PrintTables); +#else +HWY_EXPORT_AND_TEST_P(HwyCompressTest, TestAllCompress); +HWY_EXPORT_AND_TEST_P(HwyCompressTest, TestAllCompressBlocks); +#endif +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/convert_test.cc b/media/highway/src/hwy/tests/convert_test.cc new file mode 100644 index 000000000..a7aea5fe9 --- /dev/null +++ b/media/highway/src/hwy/tests/convert_test.cc @@ -0,0 +1,643 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include // std::isfinite + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/convert_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Cast and ensure bytes are the same. Called directly from TestAllBitCast or +// via TestBitCastFrom. +template +struct TestBitCast { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Repartition dto; + const size_t N = Lanes(d); + const size_t Nto = Lanes(dto); + if (N == 0 || Nto == 0) return; + HWY_ASSERT_EQ(N * sizeof(T), Nto * sizeof(ToT)); + const auto vf = Iota(d, 1); + const auto vt = BitCast(dto, vf); + // Must return the same bits + auto from_lanes = AllocateAligned(Lanes(d)); + auto to_lanes = AllocateAligned(Lanes(dto)); + Store(vf, d, from_lanes.get()); + Store(vt, dto, to_lanes.get()); + HWY_ASSERT( + BytesEqual(from_lanes.get(), to_lanes.get(), Lanes(d) * sizeof(T))); + } +}; + +// From D to all types. +struct TestBitCastFrom { + template + HWY_NOINLINE void operator()(T t, D d) { + TestBitCast()(t, d); + TestBitCast()(t, d); + TestBitCast()(t, d); +#if HWY_HAVE_INTEGER64 + TestBitCast()(t, d); +#endif + TestBitCast()(t, d); + TestBitCast()(t, d); + TestBitCast()(t, d); +#if HWY_HAVE_INTEGER64 + TestBitCast()(t, d); +#endif + TestBitCast()(t, d); +#if HWY_HAVE_FLOAT64 + TestBitCast()(t, d); +#endif + } +}; + +HWY_NOINLINE void TestAllBitCast() { + // For HWY_SCALAR and partial vectors, we can only cast to same-sized types: + // the former can't partition its single lane, and the latter can be smaller + // than a destination type. + const ForPartialVectors> to_u8; + to_u8(uint8_t()); + to_u8(int8_t()); + + const ForPartialVectors> to_i8; + to_i8(uint8_t()); + to_i8(int8_t()); + + const ForPartialVectors> to_u16; + to_u16(uint16_t()); + to_u16(int16_t()); + + const ForPartialVectors> to_i16; + to_i16(uint16_t()); + to_i16(int16_t()); + + const ForPartialVectors> to_u32; + to_u32(uint32_t()); + to_u32(int32_t()); + to_u32(float()); + + const ForPartialVectors> to_i32; + to_i32(uint32_t()); + to_i32(int32_t()); + to_i32(float()); + +#if HWY_HAVE_INTEGER64 + const ForPartialVectors> to_u64; + to_u64(uint64_t()); + to_u64(int64_t()); +#if HWY_HAVE_FLOAT64 + to_u64(double()); +#endif + + const ForPartialVectors> to_i64; + to_i64(uint64_t()); + to_i64(int64_t()); +#if HWY_HAVE_FLOAT64 + to_i64(double()); +#endif +#endif // HWY_HAVE_INTEGER64 + + const ForPartialVectors> to_float; + to_float(uint32_t()); + to_float(int32_t()); + to_float(float()); + +#if HWY_HAVE_FLOAT64 + const ForPartialVectors> to_double; + to_double(double()); +#if HWY_HAVE_INTEGER64 + to_double(uint64_t()); + to_double(int64_t()); +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_HAVE_FLOAT64 + +#if HWY_TARGET != HWY_SCALAR + // For non-scalar vectors, we can cast all types to all. + ForAllTypes(ForGEVectors<64, TestBitCastFrom>()); +#endif +} + +template +struct TestPromoteTo { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(sizeof(T) < sizeof(ToT), "Input type must be narrower"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + const uint64_t bits = rng(); + CopyBytes(&bits, &from[i]); // not same size + expected[i] = from[i]; + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + PromoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllPromoteTo() { + const ForPromoteVectors, 1> to_u16div2; + to_u16div2(uint8_t()); + + const ForPromoteVectors, 2> to_u32div4; + to_u32div4(uint8_t()); + + const ForPromoteVectors, 1> to_u32div2; + to_u32div2(uint16_t()); + + const ForPromoteVectors, 1> to_i16div2; + to_i16div2(uint8_t()); + to_i16div2(int8_t()); + + const ForPromoteVectors, 1> to_i32div2; + to_i32div2(uint16_t()); + to_i32div2(int16_t()); + + const ForPromoteVectors, 2> to_i32div4; + to_i32div4(uint8_t()); + to_i32div4(int8_t()); + + // Must test f16/bf16 separately because we can only load/store/convert them. + +#if HWY_HAVE_INTEGER64 + const ForPromoteVectors, 1> to_u64div2; + to_u64div2(uint32_t()); + + const ForPromoteVectors, 1> to_i64div2; + to_i64div2(int32_t()); +#endif + +#if HWY_HAVE_FLOAT64 + const ForPromoteVectors, 1> to_f64div2; + to_f64div2(int32_t()); + to_f64div2(float()); +#endif +} + +template +bool IsFinite(T t) { + return std::isfinite(t); +} +// Wrapper avoids calling std::isfinite for integer types (ambiguous). +template +bool IsFinite(T /*unused*/) { + return true; +} + +template +AlignedFreeUniquePtr F16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // +/- 1 + 1.0f, -1.0f, + // +/- 0 + 0.0f, -0.0f, + // near 0 + 0.25f, -0.25f, + // +/- integer + 4.0f, -32.0f, + // positive near limit + 65472.0f, 65504.0f, + // negative near limit + -65472.0f, -65504.0f, + // positive +/- delta + 2.00390625f, 3.99609375f, + // negative +/- delta + -2.00390625f, -3.99609375f, + // No infinity/NaN - implementation-defined due to ARM. + }; + constexpr size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + HWY_ASSERT(N != 0); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + size_t i = 0; + for (; i < kNumTestCases; ++i) { + in[i] = test_cases[i]; + } + for (; i < padded; ++i) { + in[i] = 0.0f; + } + return in; +} + +struct TestF16 { + template + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if HWY_HAVE_FLOAT16 + size_t padded; + const size_t N = Lanes(d32); // same count for f16 + HWY_ASSERT(N != 0); + auto in = F16TestCases(d32, padded); + using TF16 = float16_t; + const Rebind d16; + auto temp16 = AllocateAligned(N); + + for (size_t i = 0; i < padded; i += N) { + const auto loaded = Load(d32, &in[i]); + Store(DemoteTo(d16, loaded), d16, temp16.get()); + HWY_ASSERT_VEC_EQ(d32, loaded, PromoteTo(d32, Load(d16, temp16.get()))); + } +#else + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllF16() { ForDemoteVectors()(float()); } + +template +AlignedFreeUniquePtr BF16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // +/- 1 + 1.0f, -1.0f, + // +/- 0 + 0.0f, -0.0f, + // near 0 + 0.25f, -0.25f, + // +/- integer + 4.0f, -32.0f, + // positive near limit + 3.389531389251535E38f, 1.99384199368e+38f, + // negative near limit + -3.389531389251535E38f, -1.99384199368e+38f, + // positive +/- delta + 2.015625f, 3.984375f, + // negative +/- delta + -2.015625f, -3.984375f, + }; + constexpr size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + HWY_ASSERT(N != 0); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + size_t i = 0; + for (; i < kNumTestCases; ++i) { + in[i] = test_cases[i]; + } + for (; i < padded; ++i) { + in[i] = 0.0f; + } + return in; +} + +struct TestBF16 { + template + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if !defined(HWY_EMULATE_SVE) + size_t padded; + auto in = BF16TestCases(d32, padded); + using TBF16 = bfloat16_t; +#if HWY_TARGET == HWY_SCALAR + const Rebind dbf16; // avoid 4/2 = 2 lanes +#else + const Repartition dbf16; +#endif + const Half dbf16_half; + const size_t N = Lanes(d32); + HWY_ASSERT(Lanes(dbf16_half) <= N); + auto temp16 = AllocateAligned(N); + + for (size_t i = 0; i < padded; i += N) { + const auto loaded = Load(d32, &in[i]); + const auto v16 = DemoteTo(dbf16_half, loaded); + Store(v16, dbf16_half, temp16.get()); + const auto v16_loaded = Load(dbf16_half, temp16.get()); + HWY_ASSERT_VEC_EQ(d32, loaded, PromoteTo(d32, v16_loaded)); + } +#else + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllBF16() { ForShrinkableVectors()(float()); } + +struct TestConvertU8 { + template + HWY_NOINLINE void operator()(T /*unused*/, const D du32) { + const Rebind du8; + const auto wrap = Set(du32, 0xFF); + HWY_ASSERT_VEC_EQ(du8, Iota(du8, 0), U8FromU32(And(Iota(du32, 0), wrap))); + HWY_ASSERT_VEC_EQ(du8, Iota(du8, 0x7F), + U8FromU32(And(Iota(du32, 0x7F), wrap))); + } +}; + +HWY_NOINLINE void TestAllConvertU8() { + ForDemoteVectors()(uint32_t()); +} + +template +constexpr bool IsSupportedTruncation() { + return (sizeof(To) < sizeof(From)) && + (Pow2(Rebind()) + 3 >= static_cast(CeilLog2(sizeof(To)))); +} + +struct TestTruncateTo { + template ()>* = nullptr> + HWY_NOINLINE void testTo(From, To, const D) { + // do nothing + } + + template ()>* = nullptr> + HWY_NOINLINE void testTo(From, To, const D d) { + constexpr uint32_t base = 0xFA578D00; + const Rebind dTo; + const auto src = Iota(d, static_cast(base)); + const auto expected = Iota(dTo, static_cast(base)); + const VFromD actual = TruncateTo(dTo, src); + HWY_ASSERT_VEC_EQ(dTo, expected, actual); + } + + template + HWY_NOINLINE void operator()(T from, const D d) { + testTo(from, uint8_t(), d); + testTo(from, uint16_t(), d); + testTo(from, uint32_t(), d); + } +}; + +HWY_NOINLINE void TestAllTruncate() { + ForUnsignedTypes(ForPartialVectors()); +} + +// Separate function to attempt to work around a compiler bug on ARM: when this +// is merged with TestIntFromFloat, outputs match a previous Iota(-(N+1)) input. +struct TestIntFromFloatHuge { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + // The ARMv7 manual says that float->int saturates, i.e. chooses the + // nearest representable value. This works correctly on armhf with GCC, but + // not with clang. For reasons unknown, MSVC also runs into an out-of-memory + // error here. +#if HWY_COMPILER_CLANG || HWY_COMPILER_MSVC + (void)df; +#else + using TI = MakeSigned; + const Rebind di; + + // Workaround for incorrect 32-bit GCC codegen for SSSE3 - Print-ing + // the expected lvalue also seems to prevent the issue. + const size_t N = Lanes(df); + auto expected = AllocateAligned(N); + + // Huge positive + Store(Set(di, LimitsMax()), di, expected.get()); + HWY_ASSERT_VEC_EQ(di, expected.get(), ConvertTo(di, Set(df, TF(1E20)))); + + // Huge negative + Store(Set(di, LimitsMin()), di, expected.get()); + HWY_ASSERT_VEC_EQ(di, expected.get(), ConvertTo(di, Set(df, TF(-1E20)))); +#endif + } +}; + +class TestIntFromFloat { + template + static HWY_NOINLINE void TestPowers(TF /*unused*/, const DF df) { + using TI = MakeSigned; + const Rebind di; + constexpr size_t kBits = sizeof(TF) * 8; + + // Powers of two, plus offsets to set some mantissa bits. + const int64_t ofs_table[3] = {0LL, 3LL << (kBits / 2), 1LL << (kBits - 15)}; + for (int sign = 0; sign < 2; ++sign) { + for (size_t shift = 0; shift < kBits - 1; ++shift) { + for (int64_t ofs : ofs_table) { + const int64_t mag = (int64_t{1} << shift) + ofs; + const int64_t val = sign ? mag : -mag; + HWY_ASSERT_VEC_EQ(di, Set(di, static_cast(val)), + ConvertTo(di, Set(df, static_cast(val)))); + } + } + } + } + + template + static HWY_NOINLINE void TestRandom(TF /*unused*/, const DF df) { + using TI = MakeSigned; + const Rebind di; + const size_t N = Lanes(df); + + // TF does not have enough precision to represent TI. + const double min = static_cast(LimitsMin()); + const double max = static_cast(LimitsMax()); + + // Also check random values. + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + CopyBytes(&bits, &from[i]); // not same size + } while (!std::isfinite(from[i])); + if (from[i] >= max) { + expected[i] = LimitsMax(); + } else if (from[i] <= min) { + expected[i] = LimitsMin(); + } else { + expected[i] = static_cast(from[i]); + } + } + + HWY_ASSERT_VEC_EQ(di, expected.get(), + ConvertTo(di, Load(df, from.get()))); + } + } + + public: + template + HWY_NOINLINE void operator()(TF tf, const DF df) { + using TI = MakeSigned; + const Rebind di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), ConvertTo(di, Iota(df, TF(4.0)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), ConvertTo(di, Iota(df, -TF(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), ConvertTo(di, Iota(df, TF(2.001)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), ConvertTo(di, Iota(df, TF(3.9999)))); + + const TF eps = static_cast(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + ConvertTo(di, Iota(df, -TF(N + 1) + eps))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + ConvertTo(di, Iota(df, -TF(N + 1) - eps))); + + TestPowers(tf, df); + TestRandom(tf, df); + } +}; + +HWY_NOINLINE void TestAllIntFromFloat() { + ForFloatTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +struct TestFloatFromInt { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = MakeSigned; + const RebindToSigned di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), ConvertTo(df, Iota(di, TI(4)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), ConvertTo(df, Iota(di, -TI(N)))); + + // Max positive + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax())), + ConvertTo(df, Set(di, LimitsMax()))); + + // Min negative + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin())), + ConvertTo(df, Set(di, LimitsMin()))); + } +}; + +HWY_NOINLINE void TestAllFloatFromInt() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestFloatFromUint { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TU = MakeUnsigned; + const RebindToUnsigned du; + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), ConvertTo(df, Iota(du, TU(4)))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(65535.0)), + ConvertTo(df, Iota(du, 65535))); // 2^16-1 + if (sizeof(TF) > 4) { + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4294967295.0)), + ConvertTo(df, Iota(du, 4294967295ULL))); // 2^32-1 + } + + // Max positive + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax())), + ConvertTo(df, Set(du, LimitsMax()))); + + // Zero + HWY_ASSERT_VEC_EQ(df, Zero(df), ConvertTo(df, Zero(du))); + } +}; + +HWY_NOINLINE void TestAllFloatFromUint() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestI32F64 { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = int32_t; + const Rebind di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), PromoteTo(df, Iota(di, -TI(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(2.0)), PromoteTo(df, Iota(di, TI(2)))); + + // Below positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); + + // Above negative + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-4.0)), PromoteTo(df, Iota(di, TI(-4)))); + + // Below negative + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-2.0)), PromoteTo(df, Iota(di, TI(-2)))); + + // Max positive int + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax())), + PromoteTo(df, Set(di, LimitsMax()))); + + // Min negative int + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin())), + PromoteTo(df, Set(di, LimitsMin()))); + } +}; + +HWY_NOINLINE void TestAllI32F64() { +#if HWY_HAVE_FLOAT64 + ForDemoteVectors()(double()); +#endif +} + + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyConvertTest); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBitCast); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllPromoteTo); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllF16); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBF16); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllConvertU8); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllTruncate); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllIntFromFloat); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllFloatFromInt); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllFloatFromUint); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllI32F64); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/crypto_test.cc b/media/highway/src/hwy/tests/crypto_test.cc new file mode 100644 index 000000000..b7dfb198a --- /dev/null +++ b/media/highway/src/hwy/tests/crypto_test.cc @@ -0,0 +1,553 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memcpy + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/crypto_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#define HWY_PRINT_CLMUL_GOLDEN 0 + +#if HWY_TARGET != HWY_SCALAR + +class TestAES { + template + HWY_NOINLINE void TestSBox(T /*unused*/, D d) { + // The generic implementation of the S-box is difficult to verify by + // inspection, so we add a white-box test that verifies it using enumeration + // (outputs for 0..255 vs. https://en.wikipedia.org/wiki/Rijndael_S-box). + const uint8_t sbox[256] = { + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, + 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, + 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, + 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, + 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, + 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, + 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, + 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, + 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, + 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, + 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, + 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, + 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, + 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, + 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, + 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, + 0xb0, 0x54, 0xbb, 0x16}; + + // Ensure it's safe to load an entire vector by padding. + const size_t N = Lanes(d); + const size_t padded = RoundUpTo(256, N); + auto expected = AllocateAligned(padded); + // Must wrap around to match the input (Iota). + for (size_t pos = 0; pos < padded;) { + const size_t remaining = HWY_MIN(padded - pos, size_t(256)); + memcpy(expected.get() + pos, sbox, remaining); + pos += remaining; + } + + for (size_t i = 0; i < 256; i += N) { + const auto in = Iota(d, static_cast(i)); + HWY_ASSERT_VEC_EQ(d, expected.get() + i, detail::SubBytes(in)); + } + } + + public: + template + HWY_NOINLINE void operator()(T t, D d) { + // Test vector (after first KeyAddition) from + // https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/AES_Core128.pdf + alignas(16) constexpr uint8_t test_lanes[16] = { + 0x40, 0xBF, 0xAB, 0xF4, 0x06, 0xEE, 0x4D, 0x30, + 0x42, 0xCA, 0x6B, 0x99, 0x7A, 0x5C, 0x58, 0x16}; + const auto test = LoadDup128(d, test_lanes); + + // = ShiftRow result + alignas(16) constexpr uint8_t expected_sr_lanes[16] = { + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE}; + const auto expected_sr = LoadDup128(d, expected_sr_lanes); + + // = MixColumn result + alignas(16) constexpr uint8_t expected_mc_lanes[16] = { + 0x52, 0x9F, 0x16, 0xC2, 0x97, 0x86, 0x15, 0xCA, + 0xE0, 0x1A, 0xAE, 0x54, 0xBA, 0x1A, 0x26, 0x59}; + const auto expected_mc = LoadDup128(d, expected_mc_lanes); + + // = KeyAddition result + alignas(16) constexpr uint8_t expected_lanes[16] = { + 0xF2, 0x65, 0xE8, 0xD5, 0x1F, 0xD2, 0x39, 0x7B, + 0xC3, 0xB9, 0x97, 0x6D, 0x90, 0x76, 0x50, 0x5C}; + const auto expected = LoadDup128(d, expected_lanes); + + alignas(16) uint8_t key_lanes[16]; + for (size_t i = 0; i < 16; ++i) { + key_lanes[i] = expected_mc_lanes[i] ^ expected_lanes[i]; + } + const auto round_key = LoadDup128(d, key_lanes); + + HWY_ASSERT_VEC_EQ(d, expected_mc, AESRound(test, Zero(d))); + HWY_ASSERT_VEC_EQ(d, expected, AESRound(test, round_key)); + HWY_ASSERT_VEC_EQ(d, expected_sr, AESLastRound(test, Zero(d))); + HWY_ASSERT_VEC_EQ(d, Xor(expected_sr, round_key), + AESLastRound(test, round_key)); + + TestSBox(t, d); + } +}; +HWY_NOINLINE void TestAllAES() { ForGEVectors<128, TestAES>()(uint8_t()); } + +#else +HWY_NOINLINE void TestAllAES() {} +#endif // HWY_TARGET != HWY_SCALAR + +struct TestCLMul { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // needs 64 bit lanes and 128-bit result +#if HWY_TARGET != HWY_SCALAR && HWY_HAVE_INTEGER64 + const size_t N = Lanes(d); + if (N == 1) return; + + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + + constexpr size_t kCLMulNum = 512; + // Depends on rng! + static constexpr uint64_t kCLMulLower[kCLMulNum] = { + 0x24511d4ce34d6350ULL, 0x4ca582edde1236bbULL, 0x537e58f72dac25a8ULL, + 0x4e942d5e130b9225ULL, 0x75a906c519257a68ULL, 0x1df9f85126d96c5eULL, + 0x464e7c13f4ad286aULL, 0x138535ee35dabc40ULL, 0xb2f7477b892664ecULL, + 0x01557b077167c25dULL, 0xf32682490ee49624ULL, 0x0025bac603b9e140ULL, + 0xcaa86aca3e3daf40ULL, 0x1fbcfe4af73eb6c4ULL, 0x8ee8064dd0aae5dcULL, + 0x1248cb547858c213ULL, 0x37a55ee5b10fb34cULL, 0x6eb5c97b958f86e2ULL, + 0x4b1ab3eb655ea7cdULL, 0x1d66645a85627520ULL, 0xf8728e96daa36748ULL, + 0x38621043e6ff5e3bULL, 0xd1d28b5da5ffefb4ULL, 0x0a5cd65931546df7ULL, + 0x2a0639be3d844150ULL, 0x0e2d0f18c8d6f045ULL, 0xfacc770b963326c1ULL, + 0x19611b31ca2ef141ULL, 0xabea29510dd87518ULL, 0x18a7dc4b205f2768ULL, + 0x9d3975ea5612dc86ULL, 0x06319c139e374773ULL, 0x6641710400b4c390ULL, + 0x356c29b6001c3670ULL, 0xe9e04d851e040a00ULL, 0x21febe561222d79aULL, + 0xc071eaae6e148090ULL, 0x0eed351a0af94f5bULL, 0x04324eedb3c03688ULL, + 0x39e89b136e0d6ccdULL, 0x07d0fd2777a31600ULL, 0x44b8573827209822ULL, + 0x6d690229ea177d78ULL, 0x1b9749d960ba9f18ULL, 0x190945271c0fbb94ULL, + 0x189aea0e07d2c88eULL, 0xf18eab6b65a6beb2ULL, 0x57744b21c13d0d84ULL, + 0xf63050a613e95c2eULL, 0x12cd20d25f97102fULL, 0x5a5df0678dbcba60ULL, + 0x0b08fb80948bfafcULL, 0x44cf1cbe7c6fc3c8ULL, 0x166a470ef25da288ULL, + 0x2c498a609204e48cULL, 0x261b0a22585697ecULL, 0x737750574af7dde4ULL, + 0x4079959c60b01e0cULL, 0x06ed8aac13f782d6ULL, 0x019d454ba9b5ef20ULL, + 0xea1edbf96d49e858ULL, 0x17c2f3ebde9ac469ULL, 0x5cf72706e3d6f5e4ULL, + 0x16e856aa3c841516ULL, 0x256f7e3cef83368eULL, 0x47e17c8eb2774e77ULL, + 0x9b48ac150a804821ULL, 0x584523f61ccfdf22ULL, 0xedcb6a2a75d9e7f2ULL, + 0x1fe3d1838e537aa7ULL, 0x778872e9f64549caULL, 0x2f1cea6f0d3faf92ULL, + 0x0e8c4b6a9343f326ULL, 0x01902d1ba3048954ULL, 0xc5c1fd5269e91dc0ULL, + 0x0ef8a4707817eb9cULL, 0x1f696f09a5354ca4ULL, 0x369cd9de808b818cULL, + 0xf6917d1dd43fd784ULL, 0x7f4b76bf40dc166fULL, 0x4ce67698724ace12ULL, + 0x02c3bf60e6e9cd92ULL, 0xb8229e45b21458e8ULL, 0x415efd41e91adf49ULL, + 0x5edfcd516bb921cdULL, 0x5ff2c29429fd187eULL, 0x0af666b17103b3e0ULL, + 0x1f5e4ff8f54c9a5bULL, 0x429253d8a5544ba6ULL, 0x19de2fdf9f4d9dcaULL, + 0x29bf3d37ddc19a40ULL, 0x04d4513a879552baULL, 0x5cc7476cf71ee155ULL, + 0x40011f8c238784a5ULL, 0x1a3ae50b0fd2ee2bULL, 0x7db22f432ba462baULL, + 0x417290b0bee2284aULL, 0x055a6bd5bb853db2ULL, 0xaa667daeed8c2a34ULL, + 0x0d6b316bda7f3577ULL, 0x72d35598468e3d5dULL, 0x375b594804bfd33aULL, + 0x16ed3a319b540ae8ULL, 0x093bace4b4695afdULL, 0xc7118754ec2737ceULL, + 0x0fff361f0505c81aULL, 0x996e9e7291321af0ULL, 0x496b1d9b0b89ba8cULL, + 0x65a98b2e9181da9cULL, 0x70759c8dd45575dfULL, 0x3446fe727f5e2cbbULL, + 0x1121ae609d195e74ULL, 0x5ff5d68ce8a21018ULL, 0x0e27eca3825b60d6ULL, + 0x82f628bceca3d1daULL, 0x2756a0914e344047ULL, 0xa460406c1c708d50ULL, + 0x63ce32a0c083e491ULL, 0xc883e5a685c480e0ULL, 0x602c951891e600f9ULL, + 0x02ecb2e3911ca5f8ULL, 0x0d8675f4bb70781aULL, 0x43545cc3c78ea496ULL, + 0x04164b01d6b011c2ULL, 0x3acbb323dcab2c9bULL, 0x31c5ba4e22793082ULL, + 0x5a6484af5f7c2d10ULL, 0x1a929b16194e8078ULL, 0x7a6a75d03b313924ULL, + 0x0553c73a35b1d525ULL, 0xf18628c51142be34ULL, 0x1b51cf80d7efd8f5ULL, + 0x52e0ca4df63ee258ULL, 0x0e977099160650c9ULL, 0x6be1524e92024f70ULL, + 0x0ee2152625438b9dULL, 0xfa32af436f6d8eb4ULL, 0x5ecf49c2154287e5ULL, + 0x6b72f4ae3590569dULL, 0x086c5ee6e87bfb68ULL, 0x737a4f0dc04b6187ULL, + 0x08c3439280edea41ULL, 0x9547944f01636c5cULL, 0x6acfbfc2571cd71fULL, + 0x85d7842972449637ULL, 0x252ea5e5a7fad86aULL, 0x4e41468f99ba1632ULL, + 0x095e0c3ae63b25a2ULL, 0xb005ce88fd1c9425ULL, 0x748e668abbe09f03ULL, + 0xb2cfdf466b187d18ULL, 0x60b11e633d8fe845ULL, 0x07144c4d246db604ULL, + 0x139bcaac55e96125ULL, 0x118679b5a6176327ULL, 0x1cebe90fa4d9f83fULL, + 0x22244f52f0d312acULL, 0x669d4e17c9bfb713ULL, 0x96390e0b834bb0d0ULL, + 0x01f7f0e82ba08071ULL, 0x2dffeee31ca6d284ULL, 0x1f4738745ef039feULL, + 0x4ce0dd2b603b6420ULL, 0x0035fc905910a4d5ULL, 0x07df2b533df6fb04ULL, + 0x1cee2735c9b910ddULL, 0x2bc4af565f7809eaULL, 0x2f876c1f5cb1076cULL, + 0x33e079524099d056ULL, 0x169e0405d2f9efbaULL, 0x018643ab548a358cULL, + 0x1bb6fc4331cffe92ULL, 0x05111d3a04e92faaULL, 0x23c27ecf0d638b73ULL, + 0x1b79071dc1685d68ULL, 0x0662d20aba8e1e0cULL, 0xe7f6440277144c6fULL, + 0x4ca38b64c22196c0ULL, 0x43c05f6d1936fbeeULL, 0x0654199d4d1faf0fULL, + 0xf2014054e71c2d04ULL, 0x0a103e47e96b4c84ULL, 0x7986e691dd35b040ULL, + 0x4e1ebb53c306a341ULL, 0x2775bb3d75d65ba6ULL, 0x0562ab0adeff0f15ULL, + 0x3c2746ad5eba3eacULL, 0x1facdb5765680c60ULL, 0xb802a60027d81d00ULL, + 0x1191d0f6366ae3a9ULL, 0x81a97b5ae0ea5d14ULL, 0x06bee05b6178a770ULL, + 0xc7baeb2fe1d6aeb3ULL, 0x594cb5b867d04fdfULL, 0xf515a80138a4e350ULL, + 0x646417ad8073cf38ULL, 0x4a229a43373fb8d4ULL, 0x10fa6eafff1ca453ULL, + 0x9f060700895cc731ULL, 0x00521133d11d11f4ULL, 0xb940a2bb912a7a5cULL, + 0x3fab180670ad2a3cULL, 0x45a5f0e5b6fdb95dULL, 0x27c1baad6f946b15ULL, + 0x336c6bdbe527cf58ULL, 0x3b83aa602a5baea3ULL, 0xdf749153f9bcc376ULL, + 0x1a05513a6c0b4a90ULL, 0xb81e0b570a075c47ULL, 0x471fabb40bdc27ceULL, + 0x9dec9472f6853f60ULL, 0x361f71b88114193bULL, 0x3b550a8c4feeff00ULL, + 0x0f6cde5a68bc9bc0ULL, 0x3f50121a925703e0ULL, 0x6967ff66d6d343a9ULL, + 0xff6b5bd2ce7bc3ccULL, 0x05474cea08bf6cd8ULL, 0xf76eabbfaf108eb0ULL, + 0x067529be4fc6d981ULL, 0x4d766b137cf8a988ULL, 0x2f09c7395c5cfbbdULL, + 0x388793712da06228ULL, 0x02c9ff342c8f339aULL, 0x152c734139a860a3ULL, + 0x35776eb2b270c04dULL, 0x0f8d8b41f11c4608ULL, 0x0c2071665be6b288ULL, + 0xc034e212b3f71d88ULL, 0x071d961ef3276f99ULL, 0xf98598ee75b60773ULL, + 0x062062c58c6724e4ULL, 0xd156438e2125572cULL, 0x38552d59a7f0f7c8ULL, + 0x1a402178206e413cULL, 0x1f1f996c68293b26ULL, 0x8bce3cafe1730f7eULL, + 0x2d0480a0828f6bf5ULL, 0x6c99cffa171f92f6ULL, 0x0087f842bb0ac681ULL, + 0x11d7ed06e1e7fd3eULL, 0x07cb1186f2385dc6ULL, 0x5d7763ebff1e170fULL, + 0x2dacc870231ac292ULL, 0x8486317a9ffb390cULL, 0x1c3a6dd20c959ac6ULL, + 0x90dc96e3992e06b8ULL, 0x70d60bfa33e72b67ULL, 0x70c9bddd0985ee63ULL, + 0x012c9767b3673093ULL, 0xfcd3bc5580f6a88aULL, 0x0ac80017ef6308c3ULL, + 0xdb67d709ef4bba09ULL, 0x4c63e324f0e247ccULL, 0xa15481d3fe219d60ULL, + 0x094c4279cdccb501ULL, 0x965a28c72575cb82ULL, 0x022869db25e391ebULL, + 0x37f528c146023910ULL, 0x0c1290636917deceULL, 0x9aee25e96251ca9cULL, + 0x728ac5ba853b69c2ULL, 0x9f272c93c4be20c8ULL, 0x06c1aa6319d28124ULL, + 0x4324496b1ca8a4f7ULL, 0x0096ecfe7dfc0189ULL, 0x9e06131b19ae0020ULL, + 0x15278b15902f4597ULL, 0x2a9fece8c13842d8ULL, 0x1d4e6781f0e1355eULL, + 0x6855b712d3dbf7c0ULL, 0x06a07fad99be6f46ULL, 0x3ed9d7957e4d1d7cULL, + 0x0c326f7cbc248bb2ULL, 0xe6363ad2c537cf51ULL, 0x0e12eb1c40723f13ULL, + 0xf5c6ac850afba803ULL, 0x0322a79d615fa9f0ULL, 0x6116696ed97bd5f8ULL, + 0x0d438080fbbdc9f1ULL, 0x2e4dc42c38f1e243ULL, 0x64948e9104f3a5bfULL, + 0x9fd622371bdb5f00ULL, 0x0f12bf082b2a1b6eULL, 0x4b1f8d867d78031cULL, + 0x134392ea9f5ef832ULL, 0xf3d70472321bc23eULL, 0x05fcbe5e9eea268eULL, + 0x136dede7175a22cfULL, 0x1308f8baac2cbcccULL, 0xd691026f0915eb64ULL, + 0x0e49a668345c3a38ULL, 0x24ddbbe8bc96f331ULL, 0x4d2ec9479b640578ULL, + 0x450f0697327b359cULL, 0x32b45360f4488ee0ULL, 0x4f6d9ecec46a105aULL, + 0x5500c63401ae8e80ULL, 0x47dea495cf6f98baULL, 0x13dc9a2dfca80babULL, + 0xe6f8a93f7b24ca92ULL, 0x073f57a6d900a87fULL, 0x9ddb935fd3aa695aULL, + 0x101e98d24b39e8aaULL, 0x6b8d0eb95a507ddcULL, 0x45a908b3903d209bULL, + 0x6c96a3e119e617d4ULL, 0x2442787543d3be48ULL, 0xd3bc055c7544b364ULL, + 0x7693bb042ca8653eULL, 0xb95e3a4ea5d0101eULL, 0x116f0d459bb94a73ULL, + 0x841244b72cdc5e90ULL, 0x1271acced6cb34d3ULL, 0x07d289106524d638ULL, + 0x537c9cf49c01b5bbULL, 0x8a8e16706bb7a5daULL, 0x12e50a9c499dc3a9ULL, + 0x1cade520db2ba830ULL, 0x1add52f000d7db70ULL, 0x12cf15db2ce78e30ULL, + 0x0657eaf606bfc866ULL, 0x4026816d3b05b1d0ULL, 0x1ba0ebdf90128e4aULL, + 0xdfd649375996dd6eULL, 0x0f416e906c23d9aeULL, 0x384273cad0582a24ULL, + 0x2ff27b0378a46189ULL, 0xc4ecd18a2d7a7616ULL, 0x35cef0b5cd51d640ULL, + 0x7d582363643f48b7ULL, 0x0984ad746ad0ab7cULL, 0x2990a999835f9688ULL, + 0x2d4df66a97b19e05ULL, 0x592c79720af99aa2ULL, 0x052863c230602cd3ULL, + 0x5f5e2b15edcf2840ULL, 0x01dff1b694b978b0ULL, 0x14345a48b622025eULL, + 0x028fab3b6407f715ULL, 0x3455d188e6feca50ULL, 0x1d0d40288fb1b5fdULL, + 0x4685c5c2b6a1e5aeULL, 0x3a2077b1e5fe5adeULL, 0x1bc55d611445a0d8ULL, + 0x05480ae95f3f83feULL, 0xbbb59cfcf7e17fb6ULL, 0x13f7f10970bbb990ULL, + 0x6d00ac169425a352ULL, 0x7da0db397ef2d5d3ULL, 0x5b512a247f8d2479ULL, + 0x637eaa6a977c3c32ULL, 0x3720f0ae37cba89cULL, 0x443df6e6aa7f525bULL, + 0x28664c287dcef321ULL, 0x03c267c00cf35e49ULL, 0x690185572d4021deULL, + 0x2707ff2596e321c2ULL, 0xd865f5af7722c380ULL, 0x1ea285658e33aafbULL, + 0xc257c5e88755bef4ULL, 0x066f67275cfcc31eULL, 0xb09931945cc0fed0ULL, + 0x58c1dc38d6e3a03fULL, 0xf99489678fc94ee8ULL, 0x75045bb99be5758aULL, + 0x6c163bc34b40feefULL, 0x0420063ce7bdd3b4ULL, 0xf86ef10582bf2e28ULL, + 0x162c3449ca14858cULL, 0x94106aa61dfe3280ULL, 0x4073ae7a4e7e4941ULL, + 0x32b13fd179c250b4ULL, 0x0178fbb216a7e744ULL, 0xf840ae2f1cf92669ULL, + 0x18fc709acc80243dULL, 0x20ac2ebd69f4d558ULL, 0x6e580ad9c73ad46aULL, + 0x76d2b535b541c19dULL, 0x6c7a3fb9dd0ce0afULL, 0xc3481689b9754f28ULL, + 0x156e813b6557abdbULL, 0x6ee372e31276eb10ULL, 0x19cf37c038c8d381ULL, + 0x00d4d906c9ae3072ULL, 0x09f03cbb6dfbfd40ULL, 0x461ba31c4125f3cfULL, + 0x25b29fc63ad9f05bULL, 0x6808c95c2dddede9ULL, 0x0564224337066d9bULL, + 0xc87eb5f4a4d966f2ULL, 0x66fc66e1701f5847ULL, 0xc553a3559f74da28ULL, + 0x1dfd841be574df43ULL, 0x3ee2f100c3ebc082ULL, 0x1a2c4f9517b56e89ULL, + 0x502f65c4b535c8ffULL, 0x1da5663ab6f96ec0ULL, 0xba1f80b73988152cULL, + 0x364ff12182ac8dc1ULL, 0xe3457a3c4871db31ULL, 0x6ae9cadf92fd7e84ULL, + 0x9621ba3d6ca15186ULL, 0x00ff5af878c144ceULL, 0x918464dc130101a4ULL, + 0x036511e6b187efa6ULL, 0x06667d66550ff260ULL, 0x7fd18913f9b51bc1ULL, + 0x3740e6b27af77aa8ULL, 0x1f546c2fd358ff8aULL, 0x42f1424e3115c891ULL, + 0x03767db4e3a1bb33ULL, 0xa171a1c564345060ULL, 0x0afcf632fd7b1324ULL, + 0xb59508d933ffb7d0ULL, 0x57d766c42071be83ULL, 0x659f0447546114a2ULL, + 0x4070364481c460aeULL, 0xa2b9752280644d52ULL, 0x04ab884bea5771bdULL, + 0x87cd135602a232b4ULL, 0x15e54cd9a8155313ULL, 0x1e8005efaa3e1047ULL, + 0x696b93f4ab15d39fULL, 0x0855a8e540de863aULL, 0x0bb11799e79f9426ULL, + 0xeffa61e5c1b579baULL, 0x1e060a1d11808219ULL, 0x10e219205667c599ULL, + 0x2f7b206091c49498ULL, 0xb48854c820064860ULL, 0x21c4aaa3bfbe4a38ULL, + 0x8f4a032a3fa67e9cULL, 0x3146b3823401e2acULL, 0x3afee26f19d88400ULL, + 0x167087c485791d38ULL, 0xb67a1ed945b0fb4bULL, 0x02436eb17e27f1c0ULL, + 0xe05afce2ce2d2790ULL, 0x49c536fc6224cfebULL, 0x178865b3b862b856ULL, + 0x1ce530de26acde5bULL, 0x87312c0b30a06f38ULL, 0x03e653b578558d76ULL, + 0x4d3663c21d8b3accULL, 0x038003c23626914aULL, 0xd9d5a2c052a09451ULL, + 0x39b5acfe08a49384ULL, 0x40f349956d5800e4ULL, 0x0968b6950b1bd8feULL, + 0xd60b2ca030f3779cULL, 0x7c8bc11a23ce18edULL, 0xcc23374e27630bc2ULL, + 0x2e38fc2a8bb33210ULL, 0xe421357814ee5c44ULL, 0x315fb65ea71ec671ULL, + 0xfb1b0223f70ed290ULL, 0x30556c9f983eaf07ULL, 0x8dd438c3d0cd625aULL, + 0x05a8fd0c7ffde71bULL, 0x764d1313b5aeec7aULL, 0x2036af5de9622f47ULL, + 0x508a5bfadda292feULL, 0x3f77f04ba2830e90ULL, 0x9047cd9c66ca66d2ULL, + 0x1168b5318a54eb21ULL, 0xc93462d221da2e15ULL, 0x4c2c7cc54abc066eULL, + 0x767a56fec478240eULL, 0x095de72546595bd3ULL, 0xc9da535865158558ULL, + 0x1baccf36f33e73fbULL, 0xf3d7dbe64df77f18ULL, 0x1f8ebbb7be4850b8ULL, + 0x043c5ed77bce25a1ULL, 0x07d401041b2a178aULL, 0x9181ebb8bd8d5618ULL, + 0x078b935dc3e4034aULL, 0x7b59c08954214300ULL, 0x03570dc2a4f84421ULL, + 0xdd8715b82f6b4078ULL, 0x2bb49c8bb544163bULL, 0xc9eb125564d59686ULL, + 0x5fdc7a38f80b810aULL, 0x3a4a6d8fff686544ULL, 0x28360e2418627d3aULL, + 0x60874244c95ed992ULL, 0x2115cc1dd9c34ed3ULL, 0xfaa3ef61f55e9efcULL, + 0x27ac9b1ef1adc7e6ULL, 0x95ea00478fec3f54ULL, 0x5aea808b2d99ab43ULL, + 0xc8f79e51fe43a580ULL, 0x5dbccd714236ce25ULL, 0x783fa76ed0753458ULL, + 0x48cb290f19d84655ULL, 0xc86a832f7696099aULL, 0x52f30c6fec0e71d3ULL, + 0x77d4e91e8cdeb886ULL, 0x7169a703c6a79ccdULL, 0x98208145b9596f74ULL, + 0x0945695c761c0796ULL, 0x0be897830d17bae0ULL, 0x033ad3924caeeeb4ULL, + 0xedecb6cfa2d303a8ULL, 0x3f86b074818642e7ULL, 0xeefa7c878a8b03f4ULL, + 0x093c101b80922551ULL, 0xfb3b4e6c26ac0034ULL, 0x162bf87999b94f5eULL, + 0xeaedae76e975b17cULL, 0x1852aa090effe18eULL}; + + static constexpr uint64_t kCLMulUpper[kCLMulNum] = { + 0xbb41199b1d587c69ULL, 0x514d94d55894ee29ULL, 0xebc6cd4d2efd5d16ULL, + 0x042044ad2de477fdULL, 0xb865c8b0fcdf4b15ULL, 0x0724d7e551cc40f3ULL, + 0xb15a16f39edb0bccULL, 0x37d64419ede7a171ULL, 0x2aa01bb80c753401ULL, + 0x06ff3f8a95fdaf4dULL, 0x79898cc0838546deULL, 0x776acbd1b237c60aULL, + 0x4c1753be4f4e0064ULL, 0x0ba9243601206ed3ULL, 0xd567c3b1bf3ec557ULL, + 0x043fac7bcff61fb3ULL, 0x49356232b159fb2fULL, 0x3910c82038102d4dULL, + 0x30592fef753eb300ULL, 0x7b2660e0c92a9e9aULL, 0x8246c9248d671ef0ULL, + 0x5a0dcd95147af5faULL, 0x43fde953909cc0eaULL, 0x06147b972cb96e1bULL, + 0xd84193a6b2411d80ULL, 0x00cd7711b950196fULL, 0x1088f9f4ade7fa64ULL, + 0x05a13096ec113cfbULL, 0x958d816d53b00edcULL, 0x3846154a7cdba9cbULL, + 0x8af516db6b27d1e6ULL, 0x1a1d462ab8a33b13ULL, 0x4040b0ac1b2c754cULL, + 0x05127fe9af2fe1d6ULL, 0x9f96e79374321fa6ULL, 0x06ff64a4d9c326f3ULL, + 0x28709566e158ac15ULL, 0x301701d7111ca51cULL, 0x31e0445d1b9d9544ULL, + 0x0a95aff69bf1d03eULL, 0x7c298c8414ecb879ULL, 0x00801499b4143195ULL, + 0x91521a00dd676a5cULL, 0x2777526a14c2f723ULL, 0xfa26aac6a6357dddULL, + 0x1d265889b0187a4bULL, 0xcd6e70fa8ed283e4ULL, 0x18a815aa50ea92caULL, + 0xc01e082694a263c6ULL, 0x4b40163ba53daf25ULL, 0xbc658caff6501673ULL, + 0x3ba35359586b9652ULL, 0x74f96acc97a4936cULL, 0x3989dfdb0cf1d2cfULL, + 0x358a01eaa50dda32ULL, 0x01109a5ed8f0802bULL, 0x55b84922e63c2958ULL, + 0x55b14843d87551d5ULL, 0x1db8ec61b1b578d8ULL, 0x79a2d49ef8c3658fULL, + 0xa304516816b3fbe0ULL, 0x163ecc09cc7b82f9ULL, 0xab91e8d22aabef00ULL, + 0x0ed6b09262de8354ULL, 0xcfd47d34cf73f6f2ULL, 0x7dbd1db2390bc6c3ULL, + 0x5ae789d3875e7b00ULL, 0x1d60fd0e70fe8fa4ULL, 0x690bc15d5ae4f6f5ULL, + 0x121ef5565104fb44ULL, 0x6e98e89297353b54ULL, 0x42554949249d62edULL, + 0xd6d6d16b12df78d2ULL, 0x320b33549b74975dULL, 0xd2a0618763d22e00ULL, + 0x0808deb93cba2017ULL, 0x01bd3b2302a2cc70ULL, 0x0b7b8dd4d71c8dd6ULL, + 0x34d60a3382a0756cULL, 0x40984584c8219629ULL, 0xf1152cba10093a66ULL, + 0x068001c6b2159ccbULL, 0x3d70f13c6cda0800ULL, 0x0e6b6746a322b956ULL, + 0x83a494319d8c770bULL, 0x0faecf64a8553e9aULL, 0xa34919222c39b1bcULL, + 0x0c63850d89e71c6fULL, 0x585f0bee92e53dc8ULL, 0x10f222b13b4fa5deULL, + 0x61573114f94252f2ULL, 0x09d59c311fba6c27ULL, 0x014effa7da49ed4eULL, + 0x4a400a1bc1c31d26ULL, 0xc9091c047b484972ULL, 0x3989f341ec2230ccULL, + 0xdcb03a98b3aee41eULL, 0x4a54a676a33a95e1ULL, 0xe499b7753951ef7cULL, + 0x2f43b1d1061d8b48ULL, 0xc3313bdc68ceb146ULL, 0x5159f6bc0e99227fULL, + 0x98128e6d9c05efcaULL, 0x15ea32b27f77815bULL, 0xe882c054e2654eecULL, + 0x003d2cdb8faee8c6ULL, 0xb416dd333a9fe1dfULL, 0x73f6746aefcfc98bULL, + 0x93dc114c10a38d70ULL, 0x05055941657845eaULL, 0x2ed7351347349334ULL, + 0x26fb1ee2c69ae690ULL, 0xa4575d10dc5b28e0ULL, 0x3395b11295e485ebULL, + 0xe840f198a224551cULL, 0x78e6e5a431d941d4ULL, 0xa1fee3ceab27f391ULL, + 0x07d35b3c5698d0dcULL, 0x983c67fca9174a29ULL, 0x2bb6bbae72b5144aULL, + 0xa7730b8d13ce58efULL, 0x51b5272883de1998ULL, 0xb334e128bb55e260ULL, + 0x1cacf5fbbe1b9974ULL, 0x71a9df4bb743de60ULL, 0x5176fe545c2d0d7aULL, + 0xbe592ecf1a16d672ULL, 0x27aa8a30c3efe460ULL, 0x4c78a32f47991e06ULL, + 0x383459294312f26aULL, 0x97ba789127f1490cULL, 0x51c9aa8a3abd1ef1ULL, + 0xcc7355188121e50fULL, 0x0ecb3a178ae334c1ULL, 0x84879a5e574b7160ULL, + 0x0765298f6389e8f3ULL, 0x5c6750435539bb22ULL, 0x11a05cf056c937b5ULL, + 0xb5dc2172dbfb7662ULL, 0x3ffc17915d9f40e8ULL, 0xbc7904daf3b431b0ULL, + 0x71f2088490930a7cULL, 0xa89505fd9efb53c4ULL, 0x02e194afd61c5671ULL, + 0x99a97f4abf35fcecULL, 0x26830aad30fae96fULL, 0x4b2abc16b25cf0b0ULL, + 0x07ec6fffa1cafbdbULL, 0xf38188fde97a280cULL, 0x121335701afff64dULL, + 0xea5ef38b4e672a64ULL, 0x477edbcae3eabf03ULL, 0xa32813cc0e0d244dULL, + 0x13346d2af4972eefULL, 0xcbc18357af1cfa9aULL, 0x561b630316e73fa6ULL, + 0xe9dfb53249249305ULL, 0x5d2b9dd1479312eeULL, 0x3458008119b56d04ULL, + 0x50e6790b49801385ULL, 0x5bb9febe2349492bULL, 0x0c2813954299098fULL, + 0xf747b0c890a071d5ULL, 0x417e8f82cc028d77ULL, 0xa134fee611d804f8ULL, + 0x24c99ee9a0408761ULL, 0x3ebb224e727137f3ULL, 0x0686022073ceb846ULL, + 0xa05e901fb82ad7daULL, 0x0ece7dc43ab470fcULL, 0x2d334ecc58f7d6a3ULL, + 0x23166fadacc54e40ULL, 0x9c3a4472f839556eULL, 0x071717ab5267a4adULL, + 0xb6600ac351ba3ea0ULL, 0x30ec748313bb63d4ULL, 0xb5374e39287b23ccULL, + 0x074d75e784238aebULL, 0x77315879243914a4ULL, 0x3bbb1971490865f1ULL, + 0xa355c21f4fbe02d3ULL, 0x0027f4bb38c8f402ULL, 0xeef8708e652bc5f0ULL, + 0x7b9aa56cf9440050ULL, 0x113ac03c16cfc924ULL, 0x395db36d3e4bef9fULL, + 0x5d826fabcaa597aeULL, 0x2a77d3c58786d7e0ULL, 0x85996859a3ba19d4ULL, + 0x01e7e3c904c2d97fULL, 0x34f90b9b98d51fd0ULL, 0x243aa97fd2e99bb7ULL, + 0x40a0cebc4f65c1e8ULL, 0x46d3922ed4a5503eULL, 0x446e7ecaf1f9c0a4ULL, + 0x49dc11558bc2e6aeULL, 0xe7a9f20881793af8ULL, 0x5771cc4bc98103f1ULL, + 0x2446ea6e718fce90ULL, 0x25d14aca7f7da198ULL, 0x4347af186f9af964ULL, + 0x10cb44fc9146363aULL, 0x8a35587afce476b4ULL, 0x575144662fee3d3aULL, + 0x69f41177a6bc7a05ULL, 0x02ff8c38d6b3c898ULL, 0x57c73589a226ca40ULL, + 0x732f6b5baae66683ULL, 0x00c008bbedd4bb34ULL, 0x7412ff09524d6cadULL, + 0xb8fd0b5ad8c145a8ULL, 0x74bd9f94b6cdc7dfULL, 0x68233b317ca6c19cULL, + 0x314b9c2c08b15c54ULL, 0x5bd1ad72072ebd08ULL, 0x6610e6a6c07030e4ULL, + 0xa4fc38e885ead7ceULL, 0x36975d1ca439e034ULL, 0xa358f0fe358ffb1aULL, + 0x38e247ad663acf7dULL, 0x77daed3643b5deb8ULL, 0x5507c2aeae1ec3d0ULL, + 0xfdec226c73acf775ULL, 0x1b87ff5f5033492dULL, 0xa832dee545d9033fULL, + 0x1cee43a61e41783bULL, 0xdff82b2e2d822f69ULL, 0x2bbc9a376cb38cf2ULL, + 0x117b1cdaf765dc02ULL, 0x26a407f5682be270ULL, 0x8eb664cf5634af28ULL, + 0x17cb4513bec68551ULL, 0xb0df6527900cbfd0ULL, 0x335a2dc79c5afdfcULL, + 0xa2f0ca4cd38dca88ULL, 0x1c370713b81a2de1ULL, 0x849d5df654d1adfcULL, + 0x2fd1f7675ae14e44ULL, 0x4ff64dfc02247f7bULL, 0x3a2bcf40e395a48dULL, + 0x436248c821b187c1ULL, 0x29f4337b1c7104c0ULL, 0xfc317c46e6630ec4ULL, + 0x2774bccc4e3264c7ULL, 0x2d03218d9d5bee23ULL, 0x36a0ed04d659058aULL, + 0x452484461573cab6ULL, 0x0708edf87ed6272bULL, 0xf07960a1587446cbULL, + 0x3660167b067d84e0ULL, 0x65990a6993ddf8c4ULL, 0x0b197cd3d0b40b3fULL, + 0x1dcec4ab619f3a05ULL, 0x722ab223a84f9182ULL, 0x0822d61a81e7c38fULL, + 0x3d22ad75da563201ULL, 0x93cef6979fd35e0fULL, 0x05c3c25ae598b14cULL, + 0x1338df97dd496377ULL, 0x15bc324dc9c20acfULL, 0x96397c6127e6e8cfULL, + 0x004d01069ef2050fULL, 0x2fcf2e27893fdcbcULL, 0x072f77c3e44f4a5cULL, + 0x5eb1d80b3fe44918ULL, 0x1f59e7c28cc21f22ULL, 0x3390ce5df055c1f8ULL, + 0x4c0ef11df92cb6bfULL, 0x50f82f9e0848c900ULL, 0x08d0fde3ffc0ae38ULL, + 0xbd8d0089a3fbfb73ULL, 0x118ba5b0f311ef59ULL, 0x9be9a8407b926a61ULL, + 0x4ea04fbb21318f63ULL, 0xa1c8e7bb07b871ffULL, 0x1253a7262d5d3b02ULL, + 0x13e997a0512e5b29ULL, 0x54318460ce9055baULL, 0x4e1d8a4db0054798ULL, + 0x0b235226e2cade32ULL, 0x2588732c1476b315ULL, 0x16a378750ba8ac68ULL, + 0xba0b116c04448731ULL, 0x4dd02bd47694c2f1ULL, 0x16d6797b218b6b25ULL, + 0x769eb3709cfbf936ULL, 0x197746a0ce396f38ULL, 0x7d17ad8465961d6eULL, + 0xfe58f4998ae19bb4ULL, 0x36df24305233ce69ULL, 0xb88a4eb008f4ee72ULL, + 0x302b2eb923334787ULL, 0x15a4e3edbe13d448ULL, 0x39a4bf64dd7730ceULL, + 0xedf25421b31090c4ULL, 0x4d547fc131be3b69ULL, 0x2b316e120ca3b90eULL, + 0x0faf2357bf18a169ULL, 0x71f34b54ee2c1d62ULL, 0x18eaf6e5c93a3824ULL, + 0x7e168ba03c1b4c18ULL, 0x1a534dd586d9e871ULL, 0xa2cccd307f5f8c38ULL, + 0x2999a6fb4dce30f6ULL, 0x8f6d3b02c1d549a6ULL, 0x5cf7f90d817aac5aULL, + 0xd2a4ceefe66c8170ULL, 0x11560edc4ca959feULL, 0x89e517e6f0dc464dULL, + 0x75bb8972dddd2085ULL, 0x13859ed1e459d65aULL, 0x057114653326fa84ULL, + 0xe2e6f465173cc86cULL, 0x0ada4076497d7de4ULL, 0xa856fa10ec6dbf8aULL, + 0x41505d9a7c25d875ULL, 0x3091b6278382eccdULL, 0x055737185b2c3f13ULL, + 0x2f4df8ecd6f9c632ULL, 0x0633e89c33552d98ULL, 0xf7673724d16db440ULL, + 0x7331bd08e636c391ULL, 0x0252f29672fee426ULL, 0x1fc384946b6b9ddeULL, + 0x03460c12c901443aULL, 0x003a0792e10abcdaULL, 0x8dbec31f624e37d0ULL, + 0x667420d5bfe4dcbeULL, 0xfbfa30e874ed7641ULL, 0x46d1ae14db7ecef6ULL, + 0x216bd7e8f5448768ULL, 0x32bcd40d3d69cc88ULL, 0x2e991dbc39b65abeULL, + 0x0e8fb123a502f553ULL, 0x3d2d486b2c7560c0ULL, 0x09aba1db3079fe03ULL, + 0xcb540c59398c9bceULL, 0x363970e5339ed600ULL, 0x2caee457c28af00eULL, + 0x005e7d7ee47f41a0ULL, 0x69fad3eb10f44100ULL, 0x048109388c75beb3ULL, + 0x253dddf96c7a6fb8ULL, 0x4c47f705b9d47d09ULL, 0x6cec894228b5e978ULL, + 0x04044bb9f8ff45c2ULL, 0x079e75704d775caeULL, 0x073bd54d2a9e2c33ULL, + 0xcec7289270a364fbULL, 0x19e7486f19cd9e4eULL, 0xb50ac15b86b76608ULL, + 0x0620cf81f165c812ULL, 0x63eaaf13be7b11d4ULL, 0x0e0cf831948248c2ULL, + 0xf0412df8f46e7957ULL, 0x671c1fe752517e3fULL, 0x8841bfb04dd3f540ULL, + 0x122de4142249f353ULL, 0x40a4959fb0e76870ULL, 0x25cfd3d4b4bbc459ULL, + 0x78a07c82930c60d0ULL, 0x12c2de24d4cbc969ULL, 0x85d44866096ad7f4ULL, + 0x1fd917ca66b2007bULL, 0x01fbbb0751764764ULL, 0x3d2a4953c6fe0fdcULL, + 0xcc1489c5737afd94ULL, 0x1817c5b6a5346f41ULL, 0xe605a6a7e9985644ULL, + 0x3c50412328ff1946ULL, 0xd8c7fd65817f1291ULL, 0x0bd66975ab66339bULL, + 0x2baf8fa1c7d10fa9ULL, 0x24abdf06ddef848dULL, 0x14df0c9b2ea4f6c2ULL, + 0x2be950edfd2cb1f7ULL, 0x21911e21094178b6ULL, 0x0fa54d518a93b379ULL, + 0xb52508e0ac01ab42ULL, 0x0e035b5fd8cb79beULL, 0x1c1c6d1a3b3c8648ULL, + 0x286037b42ea9871cULL, 0xfe67bf311e48a340ULL, 0x02324131e932a472ULL, + 0x2486dc2dd919e2deULL, 0x008aec7f1da1d2ebULL, 0x63269ba0e8d3eb3aULL, + 0x23c0f11154adb62fULL, 0xc6052393ecd4c018ULL, 0x523585b7d2f5b9fcULL, + 0xf7e6f8c1e87564c9ULL, 0x09eb9fe5dd32c1a3ULL, 0x4d4f86886e055472ULL, + 0x67ea17b58a37966bULL, 0x3d3ce8c23b1ed1a8ULL, 0x0df97c5ac48857ceULL, + 0x9b6992623759eb12ULL, 0x275aa9551ae091f2ULL, 0x08855e19ac5e62e5ULL, + 0x1155fffe0ae083ccULL, 0xbc9c78db7c570240ULL, 0x074560c447dd2418ULL, + 0x3bf78d330bcf1e70ULL, 0x49867cd4b7ed134bULL, 0x8e6eee0cb4470accULL, + 0x1dabafdf59233dd6ULL, 0xea3a50d844fc3fb8ULL, 0x4f03f4454764cb87ULL, + 0x1f2f41cc36c9e6ecULL, 0x53cba4df42963441ULL, 0x10883b70a88d91fbULL, + 0x62b1fc77d4eb9481ULL, 0x893d8f2604b362e1ULL, 0x0933b7855368b440ULL, + 0x9351b545703b2fceULL, 0x59c1d489b9bdd3b4ULL, 0xe72a9c4311417b18ULL, + 0x5355df77e88eb226ULL, 0xe802c37aa963d7e1ULL, 0x381c3747bd6c3bc3ULL, + 0x378565573444258cULL, 0x37848b1e52b43c18ULL, 0x5da2cd32bdce12b6ULL, + 0x13166c5da615f6fdULL, 0xa51ef95efcc66ac8ULL, 0x640c95e473f1e541ULL, + 0x6ec68def1f217500ULL, 0x49ce3543c76a4079ULL, 0x5fc6fd3cddc706b5ULL, + 0x05c3c0f0f6a1fb0dULL, 0xe7820c0996ad1bddULL, 0x21f0d752a088f35cULL, + 0x755405b51d6fc4a0ULL, 0x7ec7649ca4b0e351ULL, 0x3d2b6a46a251f790ULL, + 0x23e1176b19f418adULL, 0x06056575efe8ac05ULL, 0x0f75981b6966e477ULL, + 0x06e87ec41ad437e4ULL, 0x43f6c255d5e1cb84ULL, 0xe4e67d1120ceb580ULL, + 0x2cd67b9e12c26d7bULL, 0xcd00b5ff7fd187f1ULL, 0x3f6cd40accdc4106ULL, + 0x3e895c835459b330ULL, 0x0814d53a217c0850ULL, 0xc9111fe78bc3a62dULL, + 0x719967e351473204ULL, 0xe757707d24282aa4ULL, 0x7226b7f5607f98e6ULL, + 0x7b268ffae3c08d96ULL, 0x16d3917c8b86020eULL, 0x5128bca51c49ea64ULL, + 0x345ffea02bb1698dULL, 0x9460f5111fe4fbc8ULL, 0x60dd1aa5762852cbULL, + 0xbb7440ed3c81667cULL, 0x0a4b12affa7f6f5cULL, 0x95cbcb0ae03861b6ULL, + 0x07ab3b0591db6070ULL, 0xc6476a4c3de78982ULL, 0x204e82e8623ad725ULL, + 0x569a5b4e8ac2a5ccULL, 0x425a1d77d72ebae2ULL, 0xcdaad5551ab33830ULL, + 0x0b7c68fd8422939eULL, 0x46d9a01f53ec3020ULL, 0x102871edbb29e852ULL, + 0x7a8e8084039075a5ULL, 0x40eaede8615e376aULL, 0x4dc67d757a1c751fULL, + 0x1176ef33063f9145ULL, 0x4ea230285b1c8156ULL, 0x6b2aa46ce0027392ULL, + 0x32b13230fba1b068ULL, 0x0e69796851bb984fULL, 0xb749f4542db698c0ULL, + 0x19ad0241ffffd49cULL, 0x2f41e92ef6caff52ULL, 0x4d0b068576747439ULL, + 0x14d607aef7463e00ULL, 0x1443d00d85fb440eULL, 0x529b43bf68688780ULL, + 0x21133a6bc3a3e378ULL, 0x865b6436dae0e7e5ULL, 0x6b4fe83dc1d6defcULL, + 0x03a5858a0ca0be46ULL, 0x1e841b187e67f312ULL, 0x61ee22ef40a66940ULL, + 0x0494bd2e9e741ef8ULL, 0x4eb59e323010e72cULL, 0x19f2abcfb749810eULL, + 0xb30f1e4f994ef9bcULL, 0x53cf6cdd51bd2d96ULL, 0x263943036497a514ULL, + 0x0d4b52170aa2edbaULL, 0x0c4758a1c7b4f758ULL, 0x178dadb1b502b51aULL, + 0x1ddbb20a602eb57aULL, 0x1fc2e2564a9f27fdULL, 0xd5f8c50a0e3d6f90ULL, + 0x0081da3bbe72ac09ULL, 0xcf140d002ccdb200ULL, 0x0ae8389f09b017feULL, + 0x17cc9ffdc03f4440ULL, 0x04eb921d704bcdddULL, 0x139a0ce4cdc521abULL, + 0x0bfce00c145cb0f0ULL, 0x99925ff132eff707ULL, 0x063f6e5da50c3d35ULL, + 0xa0c25dea3f0e6e29ULL, 0x0c7a9048cc8e040fULL, + }; + + const size_t padded = RoundUpTo(kCLMulNum, N); + auto expected_lower = AllocateAligned(padded); + auto expected_upper = AllocateAligned(padded); + CopyBytes(kCLMulLower, expected_lower.get()); + CopyBytes(kCLMulUpper, expected_upper.get()); + const size_t padding_size = (padded - kCLMulNum) * sizeof(T); + memset(expected_lower.get() + kCLMulNum, 0, padding_size); + memset(expected_upper.get() + kCLMulNum, 0, padding_size); + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < kCLMulNum / N; ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = Random64(&rng); + in2[i] = Random64(&rng); + } + + const auto a = Load(d, in1.get()); + const auto b = Load(d, in2.get()); +#if HWY_PRINT_CLMUL_GOLDEN + Store(CLMulLower(a, b), d, expected_lower.get() + rep * N); + Store(CLMulUpper(a, b), d, expected_upper.get() + rep * N); +#else + HWY_ASSERT_VEC_EQ(d, expected_lower.get() + rep * N, CLMulLower(a, b)); + HWY_ASSERT_VEC_EQ(d, expected_upper.get() + rep * N, CLMulUpper(a, b)); +#endif + } + +#if HWY_PRINT_CLMUL_GOLDEN + // RVV lacks PRIu64, so print 32-bit halves. + for (size_t i = 0; i < kCLMulNum; ++i) { + printf("0x%08x%08xULL,", static_cast(expected_lower[i] >> 32), + static_cast(expected_lower[i] & 0xFFFFFFFFU)); + } + printf("\n"); + for (size_t i = 0; i < kCLMulNum; ++i) { + printf("0x%08x%08xULL,", static_cast(expected_upper[i] >> 32), + static_cast(expected_upper[i] & 0xFFFFFFFFU)); + } +#endif // HWY_PRINT_CLMUL_GOLDEN +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllCLMul() { ForGEVectors<128, TestCLMul>()(uint64_t()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCryptoTest); +HWY_EXPORT_AND_TEST_P(HwyCryptoTest, TestAllAES); +HWY_EXPORT_AND_TEST_P(HwyCryptoTest, TestAllCLMul); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/demote_test.cc b/media/highway/src/hwy/tests/demote_test.cc new file mode 100644 index 000000000..4339a5437 --- /dev/null +++ b/media/highway/src/hwy/tests/demote_test.cc @@ -0,0 +1,326 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/demote_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Causes build timeout. +#if !HWY_IS_MSAN + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +bool IsFiniteT(T t) { + return std::isfinite(t); +} +// Wrapper avoids calling std::isfinite for integer types (ambiguous). +template +bool IsFiniteT(T /*unused*/) { + return true; +} + +template +struct TestDemoteTo { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(!IsFloat(), "Use TestDemoteToFloat for float output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Narrower range in the wider type, for clamping before we cast + const T min = LimitsMin(); + const T max = LimitsMax(); + + const auto value_ok = [&](T& value) { + if (!IsFiniteT(value)) return false; + return true; + }; + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + CopyBytes(&bits, &from[i]); // not same size + } while (!value_ok(from[i])); + expected[i] = static_cast(HWY_MIN(HWY_MAX(min, from[i]), max)); + } + + const auto in = Load(from_d, from.get()); + HWY_ASSERT_VEC_EQ(to_d, expected.get(), DemoteTo(to_d, in)); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToInt() { + ForDemoteVectors>()(int16_t()); + ForDemoteVectors, 2>()(int32_t()); + + ForDemoteVectors>()(int16_t()); + ForDemoteVectors, 2>()(int32_t()); + + const ForDemoteVectors> to_u16; + to_u16(int32_t()); + + const ForDemoteVectors> to_i16; + to_i16(int32_t()); +} + +HWY_NOINLINE void TestAllDemoteToMixed() { +#if HWY_HAVE_FLOAT64 + const ForDemoteVectors> to_i32; + to_i32(double()); +#endif +} + +template +struct TestDemoteToFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + // For floats, we clamp differently and cannot call LimitsMin. + static_assert(IsFloat(), "Use TestDemoteTo for integer output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + CopyBytes(&bits, &from[i]); // not same size + } while (!IsFiniteT(from[i])); + const T magn = std::abs(from[i]); + const T max_abs = HighestValue(); + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + const T clipped = copysign(HWY_MIN(magn, max_abs), from[i]); + expected[i] = static_cast(clipped); + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + DemoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToFloat() { + // Must test f16 separately because we can only load/store/convert them. + +#if HWY_HAVE_FLOAT64 + const ForDemoteVectors, 1> to_float; + to_float(double()); +#endif +} + +template +AlignedFreeUniquePtr ReorderBF16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // Same as BF16TestCases: + // +/- 1 + 1.0f, + -1.0f, + // +/- 0 + 0.0f, + -0.0f, + // near 0 + 0.25f, + -0.25f, + // +/- integer + 4.0f, + -32.0f, + // positive +/- delta + 2.015625f, + 3.984375f, + // negative +/- delta + -2.015625f, + -3.984375f, + + // No huge values - would interfere with sum. But add more to fill 2 * N: + -2.0f, + -10.0f, + 0.03125f, + 1.03125f, + 1.5f, + 2.0f, + 4.0f, + 5.0f, + 6.0f, + 8.0f, + 10.0f, + 256.0f, + 448.0f, + 2080.0f, + }; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, 2 * N); // allow loading pairs of vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, 0.0f); + return in; +} + +class TestReorderDemote2To { + // In-place N^2 selection sort to avoid dependencies + void Sort(float* p, size_t count) { + for (size_t i = 0; i < count - 1; ++i) { + // Find min_element + size_t idx_min = i; + for (size_t j = i + 1; j < count; j++) { + if (p[j] < p[idx_min]) { + idx_min = j; + } + } + + // Swap with current + const float tmp = p[i]; + p[i] = p[idx_min]; + p[idx_min] = tmp; + } + } + + public: + template + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if HWY_TARGET != HWY_SCALAR + size_t padded; + auto in = ReorderBF16TestCases(d32, padded); + + using TBF16 = bfloat16_t; + const Repartition dbf16; + const Half dbf16_half; + const size_t N = Lanes(d32); + auto temp16 = AllocateAligned(2 * N); + auto expected = AllocateAligned(2 * N); + auto actual = AllocateAligned(2 * N); + + for (size_t i = 0; i < padded; i += 2 * N) { + const auto f0 = Load(d32, &in[i + 0]); + const auto f1 = Load(d32, &in[i + N]); + const auto v16 = ReorderDemote2To(dbf16, f0, f1); + Store(v16, dbf16, temp16.get()); + const auto promoted0 = PromoteTo(d32, Load(dbf16_half, temp16.get() + 0)); + const auto promoted1 = PromoteTo(d32, Load(dbf16_half, temp16.get() + N)); + + // Smoke test: sum should be same (with tolerance for non-associativity) + const auto sum_expected = GetLane(SumOfLanes(d32, Add(f0, f1))); + const auto sum_actual = + GetLane(SumOfLanes(d32, Add(promoted0, promoted1))); + + HWY_ASSERT(sum_expected - 1E-4 <= sum_actual && + sum_actual <= sum_expected + 1E-4); + + // Ensure values are the same after sorting to undo the Reorder + Store(f0, d32, expected.get() + 0); + Store(f1, d32, expected.get() + N); + Store(promoted0, d32, actual.get() + 0); + Store(promoted1, d32, actual.get() + N); + Sort(expected.get(), 2 * N); + Sort(actual.get(), 2 * N); + HWY_ASSERT_VEC_EQ(d32, expected.get() + 0, Load(d32, actual.get() + 0)); + HWY_ASSERT_VEC_EQ(d32, expected.get() + N, Load(d32, actual.get() + N)); + } +#else // HWY_SCALAR + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllReorderDemote2To() { + ForShrinkableVectors()(float()); +} + +struct TestI32F64 { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = int32_t; + const Rebind di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), DemoteTo(di, Iota(df, TF(4.0)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), DemoteTo(di, Iota(df, -TF(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), DemoteTo(di, Iota(df, TF(2.001)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), DemoteTo(di, Iota(df, TF(3.9999)))); + + const TF eps = static_cast(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + DemoteTo(di, Iota(df, -TF(N + 1) + eps))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + DemoteTo(di, Iota(df, -TF(N + 1) - eps))); + + // Huge positive float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMax()), + DemoteTo(di, Set(df, TF(1E12)))); + + // Huge negative float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMin()), + DemoteTo(di, Set(df, TF(-1E12)))); + } +}; + +HWY_NOINLINE void TestAllI32F64() { +#if HWY_HAVE_FLOAT64 + ForDemoteVectors()(double()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // !HWY_IS_MSAN + +#if HWY_ONCE + +namespace hwy { +#if !HWY_IS_MSAN +HWY_BEFORE_TEST(HwyDemoteTest); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToInt); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToMixed); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToFloat); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllReorderDemote2To); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllI32F64); +#endif // !HWY_IS_MSAN +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/float_test.cc b/media/highway/src/hwy/tests/float_test.cc new file mode 100644 index 000000000..05d7b7605 --- /dev/null +++ b/media/highway/src/hwy/tests/float_test.cc @@ -0,0 +1,349 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests some ops specific to floating-point types (Div, Round etc.) + +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/float_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestDiv { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(-2)); + const auto v1 = Set(d, T(1)); + + // Unchanged after division by 1. + HWY_ASSERT_VEC_EQ(d, v, Div(v, v1)); + + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = (T(i) - 2) / T(2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Div(v, Set(d, T(2)))); + } +}; + +HWY_NOINLINE void TestAllDiv() { ForFloatTypes(ForPartialVectors()); } + +struct TestApproximateReciprocal { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(-2)); + const auto nonzero = IfThenElse(Eq(v, Zero(d)), Set(d, T(1)), v); + const size_t N = Lanes(d); + auto input = AllocateAligned(N); + Store(nonzero, d, input.get()); + + auto actual = AllocateAligned(N); + Store(ApproximateReciprocal(nonzero), d, actual.get()); + + double max_l1 = 0.0; + double worst_expected = 0.0; + double worst_actual = 0.0; + for (size_t i = 0; i < N; ++i) { + const double expected = 1.0 / input[i]; + const double l1 = std::abs(expected - actual[i]); + if (l1 > max_l1) { + max_l1 = l1; + worst_expected = expected; + worst_actual = actual[i]; + } + } + const double abs_worst_expected = std::abs(worst_expected); + if (abs_worst_expected > 1E-5) { + const double max_rel = max_l1 / abs_worst_expected; + fprintf(stderr, "max l1 %f rel %f (%f vs %f)\n", max_l1, max_rel, + worst_expected, worst_actual); + HWY_ASSERT(max_rel < 0.004); + } + } +}; + +HWY_NOINLINE void TestAllApproximateReciprocal() { + ForPartialVectors()(float()); +} + +struct TestSquareRoot { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto vi = Iota(d, 0); + HWY_ASSERT_VEC_EQ(d, vi, Sqrt(Mul(vi, vi))); + } +}; + +HWY_NOINLINE void TestAllSquareRoot() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestReciprocalSquareRoot { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Set(d, 123.0f); + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(ApproximateReciprocalSqrt(v), d, lanes.get()); + for (size_t i = 0; i < N; ++i) { + float err = lanes[i] - 0.090166f; + if (err < 0.0f) err = -err; + if (err >= 4E-4f) { + HWY_ABORT("Lane %d (%d): actual %f err %f\n", static_cast(i), + static_cast(N), lanes[i], err); + } + } + } +}; + +HWY_NOINLINE void TestAllReciprocalSquareRoot() { + ForPartialVectors()(float()); +} + +template +AlignedFreeUniquePtr RoundTestCases(T /*unused*/, D d, size_t& padded) { + const T eps = std::numeric_limits::epsilon(); + const T test_cases[] = { + // +/- 1 + T(1), + T(-1), + // +/- 0 + T(0), + T(-0), + // near 0 + T(0.4), + T(-0.4), + // +/- integer + T(4), + T(-32), + // positive near limit + MantissaEnd() - T(1.5), + MantissaEnd() + T(1.5), + // negative near limit + -MantissaEnd() - T(1.5), + -MantissaEnd() + T(1.5), + // positive tiebreak + T(1.5), + T(2.5), + // negative tiebreak + T(-1.5), + T(-2.5), + // positive +/- delta + T(2.0001), + T(3.9999), + // negative +/- delta + T(-999.9999), + T(-998.0001), + // positive +/- epsilon + T(1) + eps, + T(1) - eps, + // negative +/- epsilon + T(-1) + eps, + T(-1) - eps, + // +/- huge (but still fits in float) + T(1E34), + T(-1E35), + // +/- infinity + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + // qNaN + GetLane(NaN(d)) + }; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, T(0)); + return in; +} + +struct TestRound { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + // Avoid [std::]round, which does not round to nearest *even*. + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + expected[i] = static_cast(nearbyint(in[i])); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Round(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllRound() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestNearestInt { + template + HWY_NOINLINE void operator()(TF tf, const DF df) { + using TI = MakeSigned; + const RebindToSigned di; + + size_t padded; + auto in = RoundTestCases(tf, df, padded); + auto expected = AllocateAligned(padded); + + constexpr double max = static_cast(LimitsMax()); + for (size_t i = 0; i < padded; ++i) { + if (std::isnan(in[i])) { + // We replace NaN with 0 below (no_nan) + expected[i] = 0; + } else if (std::isinf(in[i]) || double{std::abs(in[i])} >= max) { + // Avoid undefined result for lrintf + expected[i] = std::signbit(in[i]) ? LimitsMin() : LimitsMax(); + } else { + expected[i] = static_cast(lrintf(in[i])); + } + } + for (size_t i = 0; i < padded; i += Lanes(df)) { + const auto v = Load(df, &in[i]); + const auto no_nan = IfThenElse(Eq(v, v), v, Zero(df)); + HWY_ASSERT_VEC_EQ(di, &expected[i], NearestInt(no_nan)); + } + } +}; + +HWY_NOINLINE void TestAllNearestInt() { + ForPartialVectors()(float()); +} + +struct TestTrunc { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + expected[i] = static_cast(trunc(in[i])); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Trunc(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllTrunc() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestCeil { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + expected[i] = std::ceil(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Ceil(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllCeil() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestFloor { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + expected[i] = std::floor(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Floor(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllFloor() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestAbsDiff { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes_a = AllocateAligned(N); + auto in_lanes_b = AllocateAligned(N); + auto out_lanes = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + in_lanes_a[i] = static_cast((i ^ 1u) << i); + in_lanes_b[i] = static_cast(i << i); + out_lanes[i] = std::abs(in_lanes_a[i] - in_lanes_b[i]); + } + const auto a = Load(d, in_lanes_a.get()); + const auto b = Load(d, in_lanes_b.get()); + const auto expected = Load(d, out_lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected, AbsDiff(a, b)); + HWY_ASSERT_VEC_EQ(d, expected, AbsDiff(b, a)); + } +}; + +HWY_NOINLINE void TestAllAbsDiff() { + ForPartialVectors()(float()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyFloatTest); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllDiv); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllApproximateReciprocal); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllSquareRoot); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllReciprocalSquareRoot); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllRound); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllNearestInt); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllTrunc); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllCeil); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllFloor); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllAbsDiff); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/hwy_gtest.h b/media/highway/src/hwy/tests/hwy_gtest.h new file mode 100644 index 000000000..acecee8e3 --- /dev/null +++ b/media/highway/src/hwy/tests/hwy_gtest.h @@ -0,0 +1,157 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_TESTS_HWY_GTEST_H_ +#define HWY_TESTS_HWY_GTEST_H_ + +// Adapters for GUnit to run tests for all targets. + +#include +#include + +#include +#include // std::tuple + +#include "gtest/gtest.h" +#include "hwy/highway.h" + +namespace hwy { + +// googletest before 1.10 didn't define INSTANTIATE_TEST_SUITE_P() but instead +// used INSTANTIATE_TEST_CASE_P which is now deprecated. +#ifdef INSTANTIATE_TEST_SUITE_P +#define HWY_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_SUITE_P +#else +#define HWY_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P +#endif + +// Helper class to run parametric tests using the hwy target as parameter. To +// use this define the following in your test: +// class MyTestSuite : public TestWithParamTarget { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P(MyTestSuite); +// TEST_P(MyTestSuite, MyTest) { ... } +class TestWithParamTarget : public testing::TestWithParam { + protected: + void SetUp() override { SetSupportedTargetsForTest(GetParam()); } + + void TearDown() override { + // Check that the parametric test calls SupportedTargets() when the source + // was compiled with more than one target. In the single-target case only + // static dispatch will be used anyway. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) != 0 + EXPECT_TRUE(GetChosenTarget().IsInitialized()) + << "This hwy target parametric test doesn't use dynamic-dispatch and " + "doesn't need to be parametric."; +#endif + SetSupportedTargetsForTest(0); + } +}; + +// Function to convert the test parameter of a TestWithParamTarget for +// displaying it in the gtest test name. +static inline std::string TestParamTargetName( + const testing::TestParamInfo& info) { + return TargetName(info.param); +} + +#define HWY_TARGET_INSTANTIATE_TEST_SUITE_P(suite) \ + HWY_GTEST_INSTANTIATE_TEST_SUITE_P( \ + suite##Group, suite, \ + testing::ValuesIn(::hwy::SupportedAndGeneratedTargets()), \ + ::hwy::TestParamTargetName) + +// Helper class similar to TestWithParamTarget to run parametric tests that +// depend on the target and another parametric test. If you need to use multiple +// extra parameters use a std::tuple<> of them and ::testing::Generate(...) as +// the generator. To use this class define the following in your test: +// class MyTestSuite : public TestWithParamTargetT { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(MyTestSuite, ::testing::Range(0, 9)); +// TEST_P(MyTestSuite, MyTest) { ... GetParam() .... } +template +class TestWithParamTargetAndT + : public ::testing::TestWithParam> { + public: + // Expose the parametric type here so it can be used by the + // HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T macro. + using HwyParamType = T; + + protected: + void SetUp() override { + SetSupportedTargetsForTest(std::get<0>( + ::testing::TestWithParam>::GetParam())); + } + + void TearDown() override { + // Check that the parametric test calls SupportedTargets() when the source + // was compiled with more than one target. In the single-target case only + // static dispatch will be used anyway. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) != 0 + EXPECT_TRUE(GetChosenTarget().IsInitialized()) + << "This hwy target parametric test doesn't use dynamic-dispatch and " + "doesn't need to be parametric."; +#endif + SetSupportedTargetsForTest(0); + } + + T GetParam() { + return std::get<1>( + ::testing::TestWithParam>::GetParam()); + } +}; + +template +std::string TestParamTargetNameAndT( + const testing::TestParamInfo>& info) { + return std::string(TargetName(std::get<0>(info.param))) + "_" + + ::testing::PrintToString(std::get<1>(info.param)); +} + +#define HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(suite, generator) \ + HWY_GTEST_INSTANTIATE_TEST_SUITE_P( \ + suite##Group, suite, \ + ::testing::Combine( \ + testing::ValuesIn(::hwy::SupportedAndGeneratedTargets()), \ + generator), \ + ::hwy::TestParamTargetNameAndT) + +// Helper macro to export a function and define a test that tests it. This is +// equivalent to do a HWY_EXPORT of a void(void) function and run it in a test: +// class MyTestSuite : public TestWithParamTarget { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P(MyTestSuite); +// HWY_EXPORT_AND_TEST_P(MyTestSuite, MyTest); +#define HWY_EXPORT_AND_TEST_P(suite, func_name) \ + HWY_EXPORT(func_name); \ + TEST_P(suite, func_name) { HWY_DYNAMIC_DISPATCH(func_name)(); } \ + static_assert(true, "For requiring trailing semicolon") + +#define HWY_EXPORT_AND_TEST_P_T(suite, func_name) \ + HWY_EXPORT(func_name); \ + TEST_P(suite, func_name) { HWY_DYNAMIC_DISPATCH(func_name)(GetParam()); } \ + static_assert(true, "For requiring trailing semicolon") + +#define HWY_BEFORE_TEST(suite) \ + class suite : public hwy::TestWithParamTarget {}; \ + HWY_TARGET_INSTANTIATE_TEST_SUITE_P(suite); \ + static_assert(true, "For requiring trailing semicolon") + +} // namespace hwy + +#endif // HWY_TESTS_HWY_GTEST_H_ diff --git a/media/highway/src/hwy/tests/if_test.cc b/media/highway/src/hwy/tests/if_test.cc new file mode 100644 index 000000000..e44a878a0 --- /dev/null +++ b/media/highway/src/hwy/tests/if_test.cc @@ -0,0 +1,175 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/if_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestIfThenElse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(d); + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + auto bool_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast(Random32(&rng)); + in2[i] = static_cast(Random32(&rng)); + bool_lanes[i] = (Random32(&rng) & 16) ? TI(1) : TI(0); + } + + const auto v1 = Load(d, in1.get()); + const auto v2 = Load(d, in2.get()); + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + + for (size_t i = 0; i < N; ++i) { + expected[i] = bool_lanes[i] ? in1[i] : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenElse(mask, v1, v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = bool_lanes[i] ? in1[i] : T(0); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenElseZero(mask, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = bool_lanes[i] ? T(0) : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenZeroElse(mask, v2)); + } + } +}; + +HWY_NOINLINE void TestAllIfThenElse() { + ForAllTypes(ForPartialVectors()); +} + +struct TestIfVecThenElse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TU = MakeUnsigned; // For all-one mask + const Rebind du; + const size_t N = Lanes(d); + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + auto vec_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast(Random32(&rng)); + in2[i] = static_cast(Random32(&rng)); + vec_lanes[i] = (Random32(&rng) & 16) ? static_cast(~TU(0)) : TU(0); + } + + const auto v1 = Load(d, in1.get()); + const auto v2 = Load(d, in2.get()); + const auto vec = BitCast(d, Load(du, vec_lanes.get())); + + for (size_t i = 0; i < N; ++i) { + expected[i] = vec_lanes[i] ? in1[i] : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfVecThenElse(vec, v1, v2)); + } + } +}; + +HWY_NOINLINE void TestAllIfVecThenElse() { + ForAllTypes(ForPartialVectors()); +} + +struct TestZeroIfNegative { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Iota(d, T(-1E5)); // assumes N < 10^5 + + // Zero and positive remain unchanged + HWY_ASSERT_VEC_EQ(d, v0, ZeroIfNegative(v0)); + HWY_ASSERT_VEC_EQ(d, vp, ZeroIfNegative(vp)); + + // Negative are all replaced with zero + HWY_ASSERT_VEC_EQ(d, v0, ZeroIfNegative(vn)); + } +}; + +HWY_NOINLINE void TestAllZeroIfNegative() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestIfNegative { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Or(vp, SignBit(d)); + + // Zero and positive remain unchanged + HWY_ASSERT_VEC_EQ(d, v0, IfNegativeThenElse(v0, vn, v0)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(v0, v0, vn)); + HWY_ASSERT_VEC_EQ(d, vp, IfNegativeThenElse(vp, vn, vp)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(vp, vp, vn)); + + // Negative are replaced with 2nd arg + HWY_ASSERT_VEC_EQ(d, v0, IfNegativeThenElse(vn, v0, vp)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(vn, vn, v0)); + HWY_ASSERT_VEC_EQ(d, vp, IfNegativeThenElse(vn, vp, vn)); + } +}; + +HWY_NOINLINE void TestAllIfNegative() { + ForFloatTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyIfTest); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllIfThenElse); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllIfVecThenElse); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllZeroIfNegative); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllIfNegative); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/interleaved_test.cc b/media/highway/src/hwy/tests/interleaved_test.cc new file mode 100644 index 000000000..4d1fbd5ac --- /dev/null +++ b/media/highway/src/hwy/tests/interleaved_test.cc @@ -0,0 +1,256 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/interleaved_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLoadStoreInterleaved2 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned(2 * N); + for (size_t i = 0; i < 2 * N; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned(3 * N); + auto actual_aligned = AllocateAligned(3 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[2 * i + 0] = bytes[0 * N + i]; + expected[2 * i + 1] = bytes[1 * N + i]; + // Ensure we do not write more than 2*N bytes + expected[2 * N + i] = actual[2 * N + i] = 0; + } + StoreInterleaved2(in0, in1, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 3 * N * sizeof(T), &pos)) { + Print(d, "in0", in0, pos / 4); + Print(d, "in1", in1, pos / 4); + const size_t i = pos; + fprintf(stderr, "interleaved i=%d %f %f %f %f %f %f %f %f\n", + static_cast(i), static_cast(actual[i]), + static_cast(actual[i + 1]), + static_cast(actual[i + 2]), + static_cast(actual[i + 3]), + static_cast(actual[i + 4]), + static_cast(actual[i + 5]), + static_cast(actual[i + 6]), + static_cast(actual[i + 7])); + HWY_ASSERT(false); + } + + Vec out0, out1; + LoadInterleaved2(d, actual, out0, out1); + HWY_ASSERT_VEC_EQ(d, in0, out0); + HWY_ASSERT_VEC_EQ(d, in1, out1); + } + } +}; + +HWY_NOINLINE void TestAllLoadStoreInterleaved2() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors test; +#else + const ForPartialVectors test; +#endif + ForAllTypes(test); +} + +// Workaround for build timeout on GCC 12 aarch64, see #776 +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && HWY_ARCH_ARM_A64 +#define HWY_BROKEN_LOAD34 1 +#else +#define HWY_BROKEN_LOAD34 0 +#endif + +#if !HWY_BROKEN_LOAD34 + +struct TestLoadStoreInterleaved3 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned(3 * N); + for (size_t i = 0; i < 3 * N; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + const auto in2 = Load(d, &bytes[2 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned(4 * N); + auto actual_aligned = AllocateAligned(4 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[3 * i + 0] = bytes[0 * N + i]; + expected[3 * i + 1] = bytes[1 * N + i]; + expected[3 * i + 2] = bytes[2 * N + i]; + // Ensure we do not write more than 3*N bytes + expected[3 * N + i] = actual[3 * N + i] = 0; + } + StoreInterleaved3(in0, in1, in2, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 4 * N * sizeof(T), &pos)) { + Print(d, "in0", in0, pos / 3, N); + Print(d, "in1", in1, pos / 3, N); + Print(d, "in2", in2, pos / 3, N); + const size_t i = pos; + fprintf(stderr, "interleaved i=%d %f %f %f %f %f %f\n", + static_cast(i), static_cast(actual[i]), + static_cast(actual[i + 1]), + static_cast(actual[i + 2]), + static_cast(actual[i + 3]), + static_cast(actual[i + 4]), + static_cast(actual[i + 5])); + HWY_ASSERT(false); + } + + Vec out0, out1, out2; + LoadInterleaved3(d, actual, out0, out1, out2); + HWY_ASSERT_VEC_EQ(d, in0, out0); + HWY_ASSERT_VEC_EQ(d, in1, out1); + HWY_ASSERT_VEC_EQ(d, in2, out2); + } + } +}; + +HWY_NOINLINE void TestAllLoadStoreInterleaved3() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors test; +#else + const ForPartialVectors test; +#endif + ForAllTypes(test); +} + +struct TestLoadStoreInterleaved4 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned(4 * N); + + for (size_t i = 0; i < 4 * N; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + const auto in2 = Load(d, &bytes[2 * N]); + const auto in3 = Load(d, &bytes[3 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned(5 * N); + auto actual_aligned = AllocateAligned(5 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[4 * i + 0] = bytes[0 * N + i]; + expected[4 * i + 1] = bytes[1 * N + i]; + expected[4 * i + 2] = bytes[2 * N + i]; + expected[4 * i + 3] = bytes[3 * N + i]; + // Ensure we do not write more than 4*N bytes + expected[4 * N + i] = actual[4 * N + i] = 0; + } + StoreInterleaved4(in0, in1, in2, in3, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 5 * N * sizeof(T), &pos)) { + Print(d, "in0", in0, pos / 4); + Print(d, "in1", in1, pos / 4); + Print(d, "in2", in2, pos / 4); + Print(d, "in3", in3, pos / 4); + const size_t i = pos; + fprintf(stderr, "interleaved i=%d %f %f %f %f %f %f %f %f\n", + static_cast(i), static_cast(actual[i]), + static_cast(actual[i + 1]), + static_cast(actual[i + 2]), + static_cast(actual[i + 3]), + static_cast(actual[i + 4]), + static_cast(actual[i + 5]), + static_cast(actual[i + 6]), + static_cast(actual[i + 7])); + HWY_ASSERT(false); + } + + Vec out0, out1, out2, out3; + LoadInterleaved4(d, actual, out0, out1, out2, out3); + HWY_ASSERT_VEC_EQ(d, in0, out0); + HWY_ASSERT_VEC_EQ(d, in1, out1); + HWY_ASSERT_VEC_EQ(d, in2, out2); + HWY_ASSERT_VEC_EQ(d, in3, out3); + } + } +}; + +HWY_NOINLINE void TestAllLoadStoreInterleaved4() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors test; +#else + const ForPartialVectors test; +#endif + ForAllTypes(test); +} + +#endif // !HWY_BROKEN_LOAD34 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyInterleavedTest); +HWY_EXPORT_AND_TEST_P(HwyInterleavedTest, TestAllLoadStoreInterleaved2); +#if !HWY_BROKEN_LOAD34 +HWY_EXPORT_AND_TEST_P(HwyInterleavedTest, TestAllLoadStoreInterleaved3); +HWY_EXPORT_AND_TEST_P(HwyInterleavedTest, TestAllLoadStoreInterleaved4); +#endif +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/list_targets.cc b/media/highway/src/hwy/tests/list_targets.cc new file mode 100644 index 000000000..d09ee4fe8 --- /dev/null +++ b/media/highway/src/hwy/tests/list_targets.cc @@ -0,0 +1,71 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Simple tool to print the list of targets that were compiled in when building +// this tool. + +#include + +#include "hwy/highway.h" + +void PrintTargets(const char* msg, int64_t targets) { + fprintf(stderr, "%s", msg); + // For each bit: + for (int64_t x = targets; x != 0; x = x & (x - 1)) { + // Extract value of least-significant bit. + fprintf(stderr, " %s", hwy::TargetName(x & (~x + 1))); + } + fprintf(stderr, "\n"); +} + +int main() { +#ifdef HWY_COMPILE_ONLY_EMU128 + const int only_emu128 = 1; +#else + const int only_emu128 = 0; +#endif +#ifdef HWY_COMPILE_ONLY_SCALAR + const int only_scalar = 1; +#else + const int only_scalar = 0; +#endif +#ifdef HWY_COMPILE_ONLY_STATIC + const int only_static = 1; +#else + const int only_static = 0; +#endif +#ifdef HWY_COMPILE_ALL_ATTAINABLE + const int all_attain = 1; +#else + const int all_attain = 0; +#endif +#ifdef HWY_IS_TEST + const int is_test = 1; +#else + const int is_test = 0; +#endif + + fprintf(stderr, + "Config: emu128:%d scalar:%d static:%d all_attain:%d is_test:%d\n", + only_emu128, only_scalar, only_static, all_attain, is_test); + PrintTargets("Compiled HWY_TARGETS: ", HWY_TARGETS); + PrintTargets("HWY_ATTAINABLE_TARGETS:", HWY_ATTAINABLE_TARGETS); + PrintTargets("HWY_BASELINE_TARGETS: ", HWY_BASELINE_TARGETS); + PrintTargets("HWY_STATIC_TARGET: ", HWY_STATIC_TARGET); + PrintTargets("HWY_BROKEN_TARGETS: ", HWY_BROKEN_TARGETS); + PrintTargets("HWY_DISABLED_TARGETS: ", HWY_DISABLED_TARGETS); + PrintTargets("Current CPU supports: ", hwy::SupportedTargets()); + return 0; +} diff --git a/media/highway/src/hwy/tests/logical_test.cc b/media/highway/src/hwy/tests/logical_test.cc new file mode 100644 index 000000000..fa2b9b9ad --- /dev/null +++ b/media/highway/src/hwy/tests/logical_test.cc @@ -0,0 +1,270 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memcmp + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/logical_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLogicalInteger { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 0); + const auto ones = VecFromMask(d, Eq(v0, v0)); + const auto v1 = Set(d, 1); + const auto vnot1 = Set(d, T(~T(1))); + + HWY_ASSERT_VEC_EQ(d, v0, Not(ones)); + HWY_ASSERT_VEC_EQ(d, ones, Not(v0)); + HWY_ASSERT_VEC_EQ(d, v1, Not(vnot1)); + HWY_ASSERT_VEC_EQ(d, vnot1, Not(v1)); + + HWY_ASSERT_VEC_EQ(d, v0, And(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, And(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, And(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Or(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Xor(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, AndNot(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, v0, Or3(v0, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(v0, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(v0, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(v0, vi, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, vi, vi)); + + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(v0, vi, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, vi, vi)); + + auto v = vi; + v = And(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = And(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + + v = Or(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = Or(v, v0); + HWY_ASSERT_VEC_EQ(d, vi, v); + + v = Xor(v, vi); + HWY_ASSERT_VEC_EQ(d, v0, v); + v = Xor(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + } +}; + +HWY_NOINLINE void TestAllLogicalInteger() { + ForIntegerTypes(ForPartialVectors()); +} + +struct TestLogicalFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 0); + + HWY_ASSERT_VEC_EQ(d, v0, And(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, And(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, And(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Or(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Xor(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, AndNot(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, vi)); + + auto v = vi; + v = And(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = And(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + + v = Or(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = Or(v, v0); + HWY_ASSERT_VEC_EQ(d, vi, v); + + v = Xor(v, vi); + HWY_ASSERT_VEC_EQ(d, v0, v); + v = Xor(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + } +}; + +HWY_NOINLINE void TestAllLogicalFloat() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestCopySign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Iota(d, T(-1E5)); // assumes N < 10^5 + + // Zero remains zero regardless of sign + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, vp)); + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, vn)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, vp)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, vn)); + + // Positive input, positive sign => unchanged + HWY_ASSERT_VEC_EQ(d, vp, CopySign(vp, vp)); + HWY_ASSERT_VEC_EQ(d, vp, CopySignToAbs(vp, vp)); + + // Positive input, negative sign => negated + HWY_ASSERT_VEC_EQ(d, Neg(vp), CopySign(vp, vn)); + HWY_ASSERT_VEC_EQ(d, Neg(vp), CopySignToAbs(vp, vn)); + + // Negative input, negative sign => unchanged + HWY_ASSERT_VEC_EQ(d, vn, CopySign(vn, vn)); + + // Negative input, positive sign => negated + HWY_ASSERT_VEC_EQ(d, Neg(vn), CopySign(vn, vp)); + } +}; + +HWY_NOINLINE void TestAllCopySign() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestBroadcastSignBit { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto s0 = Zero(d); + const auto s1 = Set(d, -1); // all bit set + const auto vpos = And(Iota(d, 0), Set(d, LimitsMax())); + const auto vneg = Sub(s1, vpos); + + HWY_ASSERT_VEC_EQ(d, s0, BroadcastSignBit(vpos)); + HWY_ASSERT_VEC_EQ(d, s0, BroadcastSignBit(Set(d, LimitsMax()))); + + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(vneg)); + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(Set(d, LimitsMin()))); + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(Set(d, LimitsMin() / 2))); + } +}; + +HWY_NOINLINE void TestAllBroadcastSignBit() { + ForSignedTypes(ForPartialVectors()); +} + +struct TestTestBit { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t kNumBits = sizeof(T) * 8; + for (size_t i = 0; i < kNumBits; ++i) { + const auto bit1 = Set(d, T(1ull << i)); + const auto bit2 = Set(d, T(1ull << ((i + 1) % kNumBits))); + const auto bit3 = Set(d, T(1ull << ((i + 2) % kNumBits))); + const auto bits12 = Or(bit1, bit2); + const auto bits23 = Or(bit2, bit3); + HWY_ASSERT(AllTrue(d, TestBit(bit1, bit1))); + HWY_ASSERT(AllTrue(d, TestBit(bits12, bit1))); + HWY_ASSERT(AllTrue(d, TestBit(bits12, bit2))); + + HWY_ASSERT(AllFalse(d, TestBit(bits12, bit3))); + HWY_ASSERT(AllFalse(d, TestBit(bits23, bit1))); + HWY_ASSERT(AllFalse(d, TestBit(bit1, bit2))); + HWY_ASSERT(AllFalse(d, TestBit(bit2, bit1))); + HWY_ASSERT(AllFalse(d, TestBit(bit1, bit3))); + HWY_ASSERT(AllFalse(d, TestBit(bit3, bit1))); + HWY_ASSERT(AllFalse(d, TestBit(bit2, bit3))); + HWY_ASSERT(AllFalse(d, TestBit(bit3, bit2))); + } + } +}; + +HWY_NOINLINE void TestAllTestBit() { + ForIntegerTypes(ForPartialVectors()); +} + +struct TestPopulationCount { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + size_t N = Lanes(d); + auto data = AllocateAligned(N); + auto popcnt = AllocateAligned(N); + for (size_t i = 0; i < AdjustedReps(1 << 18) / N; i++) { + for (size_t i = 0; i < N; i++) { + data[i] = static_cast(rng()); + popcnt[i] = static_cast(PopCount(data[i])); + } + HWY_ASSERT_VEC_EQ(d, popcnt.get(), PopulationCount(Load(d, data.get()))); + } + } +}; + +HWY_NOINLINE void TestAllPopulationCount() { + ForUnsignedTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyLogicalTest); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalInteger); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalFloat); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllCopySign); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllBroadcastSignBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllPopulationCount); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/mask_mem_test.cc b/media/highway/src/hwy/tests/mask_mem_test.cc new file mode 100644 index 000000000..c44119dcd --- /dev/null +++ b/media/highway/src/hwy/tests/mask_mem_test.cc @@ -0,0 +1,197 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#include +#include +#include // memcmp + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/mask_mem_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestMaskedLoad { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + + auto lanes = AllocateAligned(N); + Store(Iota(d, T{1}), d, lanes.get()); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + const auto expected = IfThenElseZero(mask, Load(d, lanes.get())); + const auto actual = MaskedLoad(mask, d, lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected, actual); + } + } +}; + +HWY_NOINLINE void TestAllMaskedLoad() { + ForAllTypes(ForPartialVectors()); +} + +struct TestBlendedStore { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + + const Vec v = Iota(d, T{1}); + auto actual = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + // Re-initialize to something distinct from v[i]. + actual[i] = static_cast(127 - (i & 127)); + expected[i] = bool_lanes[i] ? static_cast(i + 1) : actual[i]; + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + BlendedStore(v, mask, d, actual.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), Load(d, actual.get())); + } + } +}; + +HWY_NOINLINE void TestAllBlendedStore() { + ForAllTypes(ForPartialVectors()); +} + +class TestStoreMaskBits { + public: + template + HWY_NOINLINE void operator()(T /*t*/, D /*d*/) { + RandomState rng; + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned(N); + + const ScalableTag d_bits; + const size_t expected_num_bytes = (N + 7) / 8; + auto expected = AllocateAligned(expected_num_bytes); + auto actual = AllocateAligned(HWY_MAX(8, expected_num_bytes)); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + // Generate random mask pattern. + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = static_cast((rng() & 1024) ? 1 : 0); + } + const auto bools = Load(di, bool_lanes.get()); + const auto mask = Gt(bools, Zero(di)); + + // Requires at least 8 bytes, ensured above. + const size_t bytes_written = StoreMaskBits(di, mask, actual.get()); + if (bytes_written != expected_num_bytes) { + fprintf(stderr, "%s expected %" PRIu64 " bytes, actual %" PRIu64 "\n", + TypeName(T(), N).c_str(), + static_cast(expected_num_bytes), + static_cast(bytes_written)); + + HWY_ASSERT(false); + } + + // Requires at least 8 bytes, ensured above. + const auto mask2 = LoadMaskBits(di, actual.get()); + HWY_ASSERT_MASK_EQ(di, mask, mask2); + + memset(expected.get(), 0, expected_num_bytes); + for (size_t i = 0; i < N; ++i) { + expected[i / 8] = + static_cast(expected[i / 8] | (bool_lanes[i] << (i % 8))); + } + + size_t i = 0; + // Stored bits must match original mask + for (; i < N; ++i) { + const TI is_set = (actual[i / 8] & (1 << (i % 8))) ? 1 : 0; + if (is_set != bool_lanes[i]) { + fprintf(stderr, "%s lane %" PRIu64 ": expected %d, actual %d\n", + TypeName(T(), N).c_str(), static_cast(i), + static_cast(bool_lanes[i]), static_cast(is_set)); + Print(di, "bools", bools, 0, N); + Print(d_bits, "expected bytes", Load(d_bits, expected.get()), 0, + expected_num_bytes); + Print(d_bits, "actual bytes", Load(d_bits, actual.get()), 0, + expected_num_bytes); + + HWY_ASSERT(false); + } + } + // Any partial bits in the last byte must be zero + for (; i < 8 * bytes_written; ++i) { + const int bit = (actual[i / 8] & (1 << (i % 8))); + if (bit != 0) { + fprintf(stderr, "%s: bit #%" PRIu64 " should be zero\n", + TypeName(T(), N).c_str(), static_cast(i)); + Print(di, "bools", bools, 0, N); + Print(d_bits, "expected bytes", Load(d_bits, expected.get()), 0, + expected_num_bytes); + Print(d_bits, "actual bytes", Load(d_bits, actual.get()), 0, + expected_num_bytes); + + HWY_ASSERT(false); + } + } + } + } +}; + +HWY_NOINLINE void TestAllStoreMaskBits() { + ForAllTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMaskTest); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllMaskedLoad); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllBlendedStore); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllStoreMaskBits); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/mask_test.cc b/media/highway/src/hwy/tests/mask_test.cc new file mode 100644 index 000000000..f48b476be --- /dev/null +++ b/media/highway/src/hwy/tests/mask_test.cc @@ -0,0 +1,293 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memcmp + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/mask_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// All types. +struct TestFromVec { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + + memset(lanes.get(), 0, N * sizeof(T)); + const auto actual_false = MaskFromVec(Load(d, lanes.get())); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), actual_false); + + memset(lanes.get(), 0xFF, N * sizeof(T)); + const auto actual_true = MaskFromVec(Load(d, lanes.get())); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), actual_true); + } +}; + +HWY_NOINLINE void TestAllFromVec() { + ForAllTypes(ForPartialVectors()); +} + +struct TestFirstN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + + using TN = SignedFromSize; + const size_t max_len = static_cast(LimitsMax()); + + const size_t max_lanes = HWY_MIN(2 * N, AdjustedReps(512)); + for (size_t len = 0; len <= HWY_MIN(max_lanes, max_len); ++len) { + // Loop instead of Iota+Lt to avoid wraparound for 8-bit T. + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (i < len) ? T{1} : 0; + } + const auto expected = Eq(Load(d, bool_lanes.get()), Set(d, T{1})); + HWY_ASSERT_MASK_EQ(d, expected, FirstN(d, len)); + } + + // Also ensure huge values yield all-true (unless the vector is actually + // larger than max_len). + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (i < max_len) ? T{1} : 0; + } + const auto expected = Eq(Load(d, bool_lanes.get()), Set(d, T{1})); + HWY_ASSERT_MASK_EQ(d, expected, FirstN(d, max_len)); + } +}; + +HWY_NOINLINE void TestAllFirstN() { + ForAllTypes(ForPartialVectors()); +} + +struct TestMaskVec { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + HWY_ASSERT_MASK_EQ(d, mask, MaskFromVec(VecFromMask(d, mask))); + } + } +}; + +HWY_NOINLINE void TestAllMaskVec() { + const ForPartialVectors test; + + test(uint16_t()); + test(int16_t()); + // TODO(janwas): float16_t - cannot compare yet + + ForUIF3264(test); +} + +struct TestAllTrueFalse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto zero = Zero(d); + auto v = zero; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + + HWY_ASSERT(AllTrue(d, Eq(v, zero))); + HWY_ASSERT(!AllFalse(d, Eq(v, zero))); + + // Single lane implies AllFalse = !AllTrue. Otherwise, there are multiple + // lanes and one is nonzero. + const bool expected_all_false = (N != 1); + + // Set each lane to nonzero and back to zero + for (size_t i = 0; i < N; ++i) { + lanes[i] = T(1); + v = Load(d, lanes.get()); + + HWY_ASSERT(!AllTrue(d, Eq(v, zero))); + + HWY_ASSERT(expected_all_false ^ AllFalse(d, Eq(v, zero))); + + lanes[i] = T(-1); + v = Load(d, lanes.get()); + HWY_ASSERT(!AllTrue(d, Eq(v, zero))); + HWY_ASSERT(expected_all_false ^ AllFalse(d, Eq(v, zero))); + + // Reset to all zero + lanes[i] = T(0); + v = Load(d, lanes.get()); + HWY_ASSERT(AllTrue(d, Eq(v, zero))); + HWY_ASSERT(!AllFalse(d, Eq(v, zero))); + } + } +}; + +HWY_NOINLINE void TestAllAllTrueFalse() { + ForAllTypes(ForPartialVectors()); +} + +struct TestCountTrue { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned(N); + memset(bool_lanes.get(), 0, N * sizeof(TI)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = HWY_MIN(N, size_t(10)); + + for (size_t code = 0; code < (1ull << max_lanes); ++code) { + // Number of zeros written = number of mask lanes that are true. + size_t expected = 0; + for (size_t i = 0; i < max_lanes; ++i) { + const bool is_true = (code & (1ull << i)) != 0; + bool_lanes[i] = is_true ? TI(1) : TI(0); + expected += is_true; + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + const size_t actual = CountTrue(d, mask); + HWY_ASSERT_EQ(expected, actual); + } + } +}; + +HWY_NOINLINE void TestAllCountTrue() { + ForAllTypes(ForPartialVectors()); +} + +struct TestFindFirstTrue { // Also FindKnownFirstTrue + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned(N); + memset(bool_lanes.get(), 0, N * sizeof(TI)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = AdjustedLog2Reps(HWY_MIN(N, size_t(9))); + + HWY_ASSERT_EQ(intptr_t(-1), FindFirstTrue(d, MaskFalse(d))); + HWY_ASSERT_EQ(intptr_t(0), FindFirstTrue(d, MaskTrue(d))); + HWY_ASSERT_EQ(size_t(0), FindKnownFirstTrue(d, MaskTrue(d))); + + for (size_t code = 1; code < (1ull << max_lanes); ++code) { + for (size_t i = 0; i < max_lanes; ++i) { + bool_lanes[i] = (code & (1ull << i)) ? TI(1) : TI(0); + } + + const size_t expected = + Num0BitsBelowLS1Bit_Nonzero32(static_cast(code)); + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + HWY_ASSERT_EQ(static_cast(expected), FindFirstTrue(d, mask)); + HWY_ASSERT_EQ(expected, FindKnownFirstTrue(d, mask)); + } + } +}; + +HWY_NOINLINE void TestAllFindFirstTrue() { + ForAllTypes(ForPartialVectors()); +} + +struct TestLogicalMask { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto m0 = MaskFalse(d); + const auto m_all = MaskTrue(d); + + using TI = MakeSigned; // For mask > 0 comparison + const Rebind di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned(N); + memset(bool_lanes.get(), 0, N * sizeof(TI)); + + HWY_ASSERT_MASK_EQ(d, m0, Not(m_all)); + HWY_ASSERT_MASK_EQ(d, m_all, Not(m0)); + + Print(d, ".", VecFromMask(d, ExclusiveNeither(m0, m0))); + HWY_ASSERT_MASK_EQ(d, m_all, ExclusiveNeither(m0, m0)); + HWY_ASSERT_MASK_EQ(d, m0, ExclusiveNeither(m_all, m0)); + HWY_ASSERT_MASK_EQ(d, m0, ExclusiveNeither(m0, m_all)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = AdjustedLog2Reps(HWY_MIN(N, size_t(6))); + for (size_t code = 0; code < (1ull << max_lanes); ++code) { + for (size_t i = 0; i < max_lanes; ++i) { + bool_lanes[i] = (code & (1ull << i)) ? TI(1) : TI(0); + } + + const auto m = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + + HWY_ASSERT_MASK_EQ(d, m0, Xor(m, m)); + HWY_ASSERT_MASK_EQ(d, m0, AndNot(m, m)); + HWY_ASSERT_MASK_EQ(d, m0, AndNot(m_all, m)); + + HWY_ASSERT_MASK_EQ(d, m, Or(m, m)); + HWY_ASSERT_MASK_EQ(d, m, Or(m0, m)); + HWY_ASSERT_MASK_EQ(d, m, Or(m, m0)); + HWY_ASSERT_MASK_EQ(d, m, Xor(m0, m)); + HWY_ASSERT_MASK_EQ(d, m, Xor(m, m0)); + HWY_ASSERT_MASK_EQ(d, m, And(m, m)); + HWY_ASSERT_MASK_EQ(d, m, And(m_all, m)); + HWY_ASSERT_MASK_EQ(d, m, And(m, m_all)); + HWY_ASSERT_MASK_EQ(d, m, AndNot(m0, m)); + } + } +}; + +HWY_NOINLINE void TestAllLogicalMask() { + ForAllTypes(ForPartialVectors()); +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMaskTest); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFromVec); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFirstN); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllMaskVec); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllAllTrueFalse); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllCountTrue); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFindFirstTrue); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllLogicalMask); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/memory_test.cc b/media/highway/src/hwy/tests/memory_test.cc new file mode 100644 index 000000000..b78be2bce --- /dev/null +++ b/media/highway/src/hwy/tests/memory_test.cc @@ -0,0 +1,341 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Ensure incompabilities with Windows macros (e.g. #define StoreFence) are +// detected. Must come before Highway headers. +#include "hwy/base.h" +#if defined(_WIN32) || defined(_WIN64) +#include +#endif + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/memory_test.cc" +#include "hwy/cache_control.h" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLoadStore { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto hi = Iota(d, static_cast(1 + N)); + const auto lo = Iota(d, 1); + auto lanes = AllocateAligned(2 * N); + Store(hi, d, &lanes[N]); + Store(lo, d, &lanes[0]); + + // Aligned load + const auto lo2 = Load(d, &lanes[0]); + HWY_ASSERT_VEC_EQ(d, lo2, lo); + + // Aligned store + auto lanes2 = AllocateAligned(2 * N); + Store(lo2, d, &lanes2[0]); + Store(hi, d, &lanes2[N]); + for (size_t i = 0; i < 2 * N; ++i) { + HWY_ASSERT_EQ(lanes[i], lanes2[i]); + } + + // Unaligned load + const auto vu = LoadU(d, &lanes[1]); + auto lanes3 = AllocateAligned(N); + Store(vu, d, lanes3.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT_EQ(T(i + 2), lanes3[i]); + } + + // Unaligned store + StoreU(lo2, d, &lanes2[N / 2]); + size_t i = 0; + for (; i < N / 2; ++i) { + HWY_ASSERT_EQ(lanes[i], lanes2[i]); + } + for (; i < 3 * N / 2; ++i) { + HWY_ASSERT_EQ(T(i - N / 2 + 1), lanes2[i]); + } + // Subsequent values remain unchanged. + for (; i < 2 * N; ++i) { + HWY_ASSERT_EQ(T(i + 1), lanes2[i]); + } + } +}; + +HWY_NOINLINE void TestAllLoadStore() { + ForAllTypes(ForPartialVectors()); +} + +struct TestSafeCopyN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto v = Iota(d, 1); + auto from = AllocateAligned(N + 2); + auto to = AllocateAligned(N + 2); + Store(v, d, from.get()); + + // 0: nothing changes + to[0] = T(); + SafeCopyN(0, d, from.get(), to.get()); + HWY_ASSERT_EQ(T(), to[0]); + + // 1: only first changes + to[1] = T(); + SafeCopyN(1, d, from.get(), to.get()); + HWY_ASSERT_EQ(static_cast(1), to[0]); + HWY_ASSERT_EQ(T(), to[1]); + + // N-1: last does not change + to[N - 1] = T(); + SafeCopyN(N - 1, d, from.get(), to.get()); + HWY_ASSERT_EQ(T(), to[N - 1]); + // Also check preceding lanes + to[N - 1] = static_cast(N); + HWY_ASSERT_VEC_EQ(d, to.get(), v); + + // N: all change + to[N] = T(); + SafeCopyN(N, d, from.get(), to.get()); + HWY_ASSERT_VEC_EQ(d, to.get(), v); + HWY_ASSERT_EQ(T(), to[N]); + + // N+1: subsequent lane does not change if using masked store + to[N + 1] = T(); + SafeCopyN(N + 1, d, from.get(), to.get()); + HWY_ASSERT_VEC_EQ(d, to.get(), v); +#if !HWY_MEM_OPS_MIGHT_FAULT + HWY_ASSERT_EQ(T(), to[N + 1]); +#endif + } +}; + +HWY_NOINLINE void TestAllSafeCopyN() { + ForAllTypes(ForPartialVectors()); +} + +struct TestLoadDup128 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define LoadDup128. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + constexpr size_t N128 = 16 / sizeof(T); + alignas(16) T lanes[N128]; + for (size_t i = 0; i < N128; ++i) { + lanes[i] = static_cast(1 + i); + } + + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast(i % N128 + 1); + } + + HWY_ASSERT_VEC_EQ(d, expected.get(), LoadDup128(d, lanes)); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllLoadDup128() { + ForAllTypes(ForGEVectors<128, TestLoadDup128>()); +} + +struct TestStream { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + const size_t affected_bytes = + (Lanes(d) * sizeof(T) + HWY_STREAM_MULTIPLE - 1) & + ~size_t(HWY_STREAM_MULTIPLE - 1); + const size_t affected_lanes = affected_bytes / sizeof(T); + auto out = AllocateAligned(2 * affected_lanes); + std::fill(out.get(), out.get() + 2 * affected_lanes, T(0)); + + Stream(v, d, out.get()); + FlushStream(); + const auto actual = Load(d, out.get()); + HWY_ASSERT_VEC_EQ(d, v, actual); + // Ensure Stream didn't modify more memory than expected + for (size_t i = affected_lanes; i < 2 * affected_lanes; ++i) { + HWY_ASSERT_EQ(T(0), out[i]); + } + } +}; + +HWY_NOINLINE void TestAllStream() { + const ForPartialVectors test; + // No u8,u16. + test(uint32_t()); + test(uint64_t()); + // No i8,i16. + test(int32_t()); + test(int64_t()); + ForFloatTypes(test); +} + +// Assumes little-endian byte order! +struct TestScatter { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Offset = MakeSigned; + + const size_t N = Lanes(d); + const size_t range = 4 * N; // number of items to scatter + const size_t max_bytes = range * sizeof(T); // upper bound on offset + + RandomState rng; + + // Data to be scattered + auto bytes = AllocateAligned(max_bytes); + for (size_t i = 0; i < max_bytes; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto data = Load(d, reinterpret_cast(bytes.get())); + + // Scatter into these regions, ensure vector results match scalar + auto expected = AllocateAligned(range); + auto actual = AllocateAligned(range); + + const Rebind d_offsets; + auto offsets = AllocateAligned(N); // or indices + + for (size_t rep = 0; rep < 100; ++rep) { + // Byte offsets + std::fill(expected.get(), expected.get() + range, T(0)); + std::fill(actual.get(), actual.get() + range, T(0)); + for (size_t i = 0; i < N; ++i) { + // Must be aligned + offsets[i] = static_cast((Random32(&rng) % range) * sizeof(T)); + CopyBytes( + bytes.get() + i * sizeof(T), + reinterpret_cast(expected.get()) + offsets[i]); + } + const auto voffsets = Load(d_offsets, offsets.get()); + ScatterOffset(data, d, actual.get(), voffsets); + if (!BytesEqual(expected.get(), actual.get(), max_bytes)) { + Print(d, "Data", data); + Print(d_offsets, "Offsets", voffsets); + HWY_ASSERT(false); + } + + // Indices + std::fill(expected.get(), expected.get() + range, T(0)); + std::fill(actual.get(), actual.get() + range, T(0)); + for (size_t i = 0; i < N; ++i) { + offsets[i] = static_cast(Random32(&rng) % range); + CopyBytes(bytes.get() + i * sizeof(T), + &expected[size_t(offsets[i])]); + } + const auto vindices = Load(d_offsets, offsets.get()); + ScatterIndex(data, d, actual.get(), vindices); + if (!BytesEqual(expected.get(), actual.get(), max_bytes)) { + Print(d, "Data", data); + Print(d_offsets, "Indices", vindices); + HWY_ASSERT(false); + } + } + } +}; + +HWY_NOINLINE void TestAllScatter() { + ForUIF3264(ForPartialVectors()); +} + +struct TestGather { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Offset = MakeSigned; + + const size_t N = Lanes(d); + const size_t range = 4 * N; // number of items to gather + const size_t max_bytes = range * sizeof(T); // upper bound on offset + + RandomState rng; + + // Data to be gathered from + auto bytes = AllocateAligned(max_bytes); + for (size_t i = 0; i < max_bytes; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + + auto expected = AllocateAligned(N); + auto offsets = AllocateAligned(N); + auto indices = AllocateAligned(N); + + for (size_t rep = 0; rep < 100; ++rep) { + // Offsets + for (size_t i = 0; i < N; ++i) { + // Must be aligned + offsets[i] = static_cast((Random32(&rng) % range) * sizeof(T)); + CopyBytes(bytes.get() + offsets[i], &expected[i]); + } + + const Rebind d_offset; + const T* base = reinterpret_cast(bytes.get()); + auto actual = GatherOffset(d, base, Load(d_offset, offsets.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + + // Indices + for (size_t i = 0; i < N; ++i) { + indices[i] = + static_cast(Random32(&rng) % (max_bytes / sizeof(T))); + CopyBytes(base + indices[i], &expected[i]); + } + actual = GatherIndex(d, base, Load(d_offset, indices.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + } + } +}; + +HWY_NOINLINE void TestAllGather() { + ForUIF3264(ForPartialVectors()); +} + +HWY_NOINLINE void TestAllCache() { + LoadFence(); + FlushStream(); + int test = 0; + Prefetch(&test); + FlushCacheline(&test); + Pause(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMemoryTest); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadStore); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllSafeCopyN); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadDup128); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStream); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllScatter); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllGather); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllCache); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/mul_test.cc b/media/highway/src/hwy/tests/mul_test.cc new file mode 100644 index 000000000..fab4292d4 --- /dev/null +++ b/media/highway/src/hwy/tests/mul_test.cc @@ -0,0 +1,446 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/mul_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +constexpr uint64_t FirstBits() { + return (1ull << kBits) - 1; +} +template <> +constexpr uint64_t FirstBits<64>() { + return ~uint64_t{0}; +} + +struct TestUnsignedMul { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto vi = Iota(d, 1); + const auto vj = Iota(d, 3); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + HWY_ASSERT_VEC_EQ(d, v0, Mul(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Mul(v1, v1)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(v1, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(vi, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((1 + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vi)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((1 + i) * (3 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vj)); + + const T max = LimitsMax(); + const auto vmax = Set(d, max); + HWY_ASSERT_VEC_EQ(d, vmax, Mul(vmax, v1)); + HWY_ASSERT_VEC_EQ(d, vmax, Mul(v1, vmax)); + + constexpr uint64_t kMask = FirstBits(); + const T max2 = (static_cast(max) * max) & kMask; + HWY_ASSERT_VEC_EQ(d, Set(d, max2), Mul(vmax, vmax)); + } +}; + +struct TestSignedMul { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto vi = Iota(d, 1); + const auto vn = Iota(d, -T(N)); // no i8 supported, so no wraparound + HWY_ASSERT_VEC_EQ(d, v0, Mul(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Mul(v1, v1)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(v1, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(vi, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((1 + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vi)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((-T(N) + T(i)) * T(1u + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vn, vi)); + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vn)); + } +}; + +HWY_NOINLINE void TestAllMul() { + const ForPartialVectors test_unsigned; + // No u8. + test_unsigned(uint16_t()); + test_unsigned(uint32_t()); + test_unsigned(uint64_t()); + + const ForPartialVectors test_signed; + // No i8. + test_signed(int16_t()); + test_signed(int32_t()); + test_signed(int64_t()); +} + +struct TestMulHigh { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Wide = MakeWide; + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + auto expected_lanes = AllocateAligned(N); + + const auto vi = Iota(d, 1); + // no i8 supported, so no wraparound + const auto vni = Iota(d, T(static_cast(~N + 1))); + + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(vi, v0)); + + // Large positive squared + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = T(LimitsMax() >> i); + expected_lanes[i] = T((Wide(in_lanes[i]) * in_lanes[i]) >> 16); + } + auto v = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, v)); + + // Large positive * small positive + for (size_t i = 0; i < N; ++i) { + expected_lanes[i] = T((Wide(in_lanes[i]) * T(1u + i)) >> 16); + } + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, vi)); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(vi, v)); + + // Large positive * small negative + for (size_t i = 0; i < N; ++i) { + expected_lanes[i] = T((Wide(in_lanes[i]) * T(i - N)) >> 16); + } + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, vni)); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(vni, v)); + } +}; + +HWY_NOINLINE void TestAllMulHigh() { + ForPartialVectors test; + test(int16_t()); + test(uint16_t()); +} + +struct TestMulFixedPoint15 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, MulFixedPoint15(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, MulFixedPoint15(v0, v0)); + + const size_t N = Lanes(d); + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(10000); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast(Random64(&rng) & 0xFFFF); + in2[i] = static_cast(Random64(&rng) & 0xFFFF); + } + + for (size_t i = 0; i < N; ++i) { + // There are three ways to compute the results. x86 and ARM are defined + // using 32-bit multiplication results: + const int arm = (2 * in1[i] * in2[i] + 0x8000) >> 16; + const int x86 = (((in1[i] * in2[i]) >> 14) + 1) >> 1; + // On other platforms, split the result into upper and lower 16 bits. + const auto v1 = Set(d, in1[i]); + const auto v2 = Set(d, in2[i]); + const int hi = GetLane(MulHigh(v1, v2)); + const int lo = GetLane(Mul(v1, v2)) & 0xFFFF; + const int split = 2 * hi + ((lo + 0x4000) >> 15); + expected[i] = static_cast(arm); + if (in1[i] != -32768 || in2[i] != -32768) { + HWY_ASSERT_EQ(arm, x86); + HWY_ASSERT_EQ(arm, split); + } + } + + const auto a = Load(d, in1.get()); + const auto b = Load(d, in2.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulFixedPoint15(a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMulFixedPoint15() { + ForPartialVectors()(int16_t()); +} + +struct TestMulEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Wide = MakeWide; + const Repartition d2; + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d2, Zero(d2), MulEven(v0, v0)); + + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + auto expected = AllocateAligned(Lanes(d2)); + for (size_t i = 0; i < N; i += 2) { + in_lanes[i + 0] = LimitsMax() >> i; + if (N != 1) { + in_lanes[i + 1] = 1; // unused + } + expected[i / 2] = Wide(in_lanes[i + 0]) * in_lanes[i + 0]; + } + + const auto v = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(d2, expected.get(), MulEven(v, v)); + } +}; + +struct TestMulEvenOdd64 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, Zero(d), MulEven(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Zero(d), MulOdd(v0, v0)); + + const size_t N = Lanes(d); + if (N == 1) return; + + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + auto expected_even = AllocateAligned(N); + auto expected_odd = AllocateAligned(N); + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = Random64(&rng); + in2[i] = Random64(&rng); + } + + for (size_t i = 0; i < N; i += 2) { + expected_even[i] = Mul128(in1[i], in2[i], &expected_even[i + 1]); + expected_odd[i] = Mul128(in1[i + 1], in2[i + 1], &expected_odd[i + 1]); + } + + const auto a = Load(d, in1.get()); + const auto b = Load(d, in2.get()); + HWY_ASSERT_VEC_EQ(d, expected_even.get(), MulEven(a, b)); + HWY_ASSERT_VEC_EQ(d, expected_odd.get(), MulOdd(a, b)); + } +#else + (void)d; +#endif // HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllMulEven() { + ForGEVectors<64, TestMulEven> test; + test(int32_t()); + test(uint32_t()); + + ForGEVectors<128, TestMulEvenOdd64>()(uint64_t()); +} + +#ifndef HWY_NATIVE_FMA +#error "Bug in set_macros-inl.h, did not set HWY_NATIVE_FMA" +#endif + +struct TestMulAdd { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto k0 = Zero(d); + const auto kNeg0 = Set(d, T(-0.0)); + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + HWY_ASSERT_VEC_EQ(d, k0, MulAdd(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, v2, MulAdd(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, MulAdd(v1, k0, v2)); + HWY_ASSERT_VEC_EQ(d, k0, NegMulAdd(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(v1, k0, v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 1) * (i + 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v1, v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(Neg(v2), v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v1, Neg(v2), k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 2) * (i + 2) + (i + 1)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v2, v1)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(Neg(v2), v2, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = + T(-T(i + 2u) * static_cast(i + 2) + static_cast(1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v2, v2, v1)); + + HWY_ASSERT_VEC_EQ(d, k0, MulSub(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, kNeg0, NegMulSub(k0, k0, k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = -T(i + 2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, k0, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(k0), v1, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v1, Neg(k0), v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 1) * (i + 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(v1), v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v2, Neg(v1), k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 2) * (i + 2) - (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v2, v1)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(v2), v2, v1)); + } +}; + +HWY_NOINLINE void TestAllMulAdd() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestReorderWidenMulAccumulate { + template + HWY_NOINLINE void operator()(TN /*unused*/, DN dn) { + using TW = MakeWide; + const RepartitionToWide dw; + const Half dnh; + using VW = Vec; + using VN = Vec; + const size_t NN = Lanes(dn); + + const VW f0 = Zero(dw); + const VW f1 = Set(dw, TW{1}); + const VN bf0 = Zero(dn); + // Cannot Set() bfloat16_t directly. + const VN bf1 = ReorderDemote2To(dn, f1, f1); + + // Any input zero => both outputs zero + VW sum1 = f0; + HWY_ASSERT_VEC_EQ(dw, f0, + ReorderWidenMulAccumulate(dw, bf0, bf0, f0, sum1)); + HWY_ASSERT_VEC_EQ(dw, f0, sum1); + HWY_ASSERT_VEC_EQ(dw, f0, + ReorderWidenMulAccumulate(dw, bf0, bf1, f0, sum1)); + HWY_ASSERT_VEC_EQ(dw, f0, sum1); + HWY_ASSERT_VEC_EQ(dw, f0, + ReorderWidenMulAccumulate(dw, bf1, bf0, f0, sum1)); + HWY_ASSERT_VEC_EQ(dw, f0, sum1); + + // delta[p] := 1, all others zero. For each p: Dot(delta, all-ones) == 1. + auto delta_w = AllocateAligned(NN); + for (size_t i = 0; i < NN; ++i) { + delta_w[i] = TW{0}; + } + for (size_t p = 0; p < NN; ++p) { + delta_w[p] = TW{1}; + const VW delta0 = Load(dw, delta_w.get()); + const VW delta1 = Load(dw, delta_w.get() + NN / 2); + delta_w[p] = TW{0}; + const VN delta = ReorderDemote2To(dn, delta0, delta1); + + { + sum1 = f0; + const VW sum0 = ReorderWidenMulAccumulate(dw, delta, bf1, f0, sum1); + HWY_ASSERT_EQ(TW{1}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + // Swapped arg order + { + sum1 = f0; + const VW sum0 = ReorderWidenMulAccumulate(dw, bf1, delta, f0, sum1); + HWY_ASSERT_EQ(TW{1}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + // Start with nonzero sum0 or sum1 + { + VW sum0 = PromoteTo(dw, LowerHalf(dnh, delta)); + sum1 = PromoteTo(dw, UpperHalf(dnh, delta)); + sum0 = ReorderWidenMulAccumulate(dw, delta, bf1, sum0, sum1); + HWY_ASSERT_EQ(TW{2}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + // Start with nonzero sum0 or sum1, and swap arg order + { + VW sum0 = PromoteTo(dw, LowerHalf(dnh, delta)); + sum1 = PromoteTo(dw, UpperHalf(dnh, delta)); + sum0 = ReorderWidenMulAccumulate(dw, bf1, delta, sum0, sum1); + HWY_ASSERT_EQ(TW{2}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + } + } +}; + +HWY_NOINLINE void TestAllReorderWidenMulAccumulate() { + ForShrinkableVectors()(bfloat16_t()); + ForShrinkableVectors()(int16_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMulTest); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMul); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulHigh); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulFixedPoint15); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulEven); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulAdd); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllReorderWidenMulAccumulate); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/reduction_test.cc b/media/highway/src/hwy/tests/reduction_test.cc new file mode 100644 index 000000000..5e39abc55 --- /dev/null +++ b/media/highway/src/hwy/tests/reduction_test.cc @@ -0,0 +1,227 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/reduction_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestSumOfLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + + // Lane i = bit i, higher lanes 0 + double sum = 0.0; + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast(1ull << i) : 0; + sum += static_cast(in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, T(sum)), + SumOfLanes(d, Load(d, in_lanes.get()))); + + // Lane i = i (iota) to include upper lanes + sum = 0.0; + for (size_t i = 0; i < N; ++i) { + sum += static_cast(i); + } + HWY_ASSERT_VEC_EQ(d, Set(d, T(sum)), SumOfLanes(d, Iota(d, 0))); + } +}; + +HWY_NOINLINE void TestAllSumOfLanes() { + ForUIF3264(ForPartialVectors()); + ForUI16(ForPartialVectors()); +} + +struct TestMinOfLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + + // Lane i = bit i, higher lanes = 2 (not the minimum) + T min = HighestValue(); + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast(1ull << i) : 2; + min = HWY_MIN(min, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(d, Load(d, in_lanes.get()))); + + // Lane i = N - i to include upper lanes + min = HighestValue(); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = static_cast(N - i); // no 8-bit T so no wraparound + min = HWY_MIN(min, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(d, Load(d, in_lanes.get()))); + + // Bug #910: also check negative values + min = HighestValue(); + const T input_copy[] = {static_cast(-1), + static_cast(-2), + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14}; + size_t i = 0; + for (; i < HWY_MIN(N, sizeof(input_copy) / sizeof(T)); ++i) { + in_lanes[i] = input_copy[i]; + min = HWY_MIN(min, input_copy[i]); + } + // Pad with neutral element to full vector (so we can load) + for (; i < N; ++i) { + in_lanes[i] = min; + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(d, Load(d, in_lanes.get()))); + } +}; + +struct TestMaxOfLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + + T max = LowestValue(); + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast(1ull << i) : 0; + max = HWY_MAX(max, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(d, Load(d, in_lanes.get()))); + + // Lane i = i to include upper lanes + max = LowestValue(); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = static_cast(i); // no 8-bit T so no wraparound + max = HWY_MAX(max, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(d, Load(d, in_lanes.get()))); + + // Bug #910: also check negative values + max = LowestValue(); + const T input_copy[] = {static_cast(-1), + static_cast(-2), + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14}; + size_t i = 0; + for (; i < HWY_MIN(N, sizeof(input_copy) / sizeof(T)); ++i) { + in_lanes[i] = input_copy[i]; + max = HWY_MAX(max, in_lanes[i]); + } + // Pad with neutral element to full vector (so we can load) + for (; i < N; ++i) { + in_lanes[i] = max; + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(d, Load(d, in_lanes.get()))); + } +}; + +HWY_NOINLINE void TestAllMinMaxOfLanes() { + const ForPartialVectors test_min; + const ForPartialVectors test_max; + ForUIF3264(test_min); + ForUIF3264(test_max); + ForUI16(test_min); + ForUI16(test_max); +} + +struct TestSumsOf8 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const size_t N = Lanes(d); + if (N < 8) return; + const Repartition du64; + + auto in_lanes = AllocateAligned(N); + auto sum_lanes = AllocateAligned(N / 8); + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = Random64(&rng) & 0xFF; + } + + for (size_t idx_sum = 0; idx_sum < N / 8; ++idx_sum) { + uint64_t sum = 0; + for (size_t i = 0; i < 8; ++i) { + sum += in_lanes[idx_sum * 8 + i]; + } + sum_lanes[idx_sum] = sum; + } + + const Vec in = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(du64, sum_lanes.get(), SumsOf8(in)); + } + } +}; + +HWY_NOINLINE void TestAllSumsOf8() { + ForGEVectors<64, TestSumsOf8>()(uint8_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyReductionTest); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumOfLanes); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf8); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/reverse_test.cc b/media/highway/src/hwy/tests/reverse_test.cc new file mode 100644 index 000000000..fcbcb7fa1 --- /dev/null +++ b/media/highway/src/hwy/tests/reverse_test.cc @@ -0,0 +1,176 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/reverse_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestReverse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[N - 1 - i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse(d, v)); + } +}; + +struct TestReverse2 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 1]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse2(d, v)); + } +}; + +struct TestReverse4 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 3]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse4(d, v)); + } +}; + +struct TestReverse8 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 7]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse8(d, v)); + } +}; + +HWY_NOINLINE void TestAllReverse() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF163264(ForPartialVectors()); +} + +HWY_NOINLINE void TestAllReverse2() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<128, TestReverse2>()); + ForUIF32(ForGEVectors<64, TestReverse2>()); + ForUIF16(ForGEVectors<32, TestReverse2>()); +} + +HWY_NOINLINE void TestAllReverse4() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<256, TestReverse4>()); + ForUIF32(ForGEVectors<128, TestReverse4>()); + ForUIF16(ForGEVectors<64, TestReverse4>()); +} + +HWY_NOINLINE void TestAllReverse8() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<512, TestReverse8>()); + ForUIF32(ForGEVectors<256, TestReverse8>()); + ForUIF16(ForGEVectors<128, TestReverse8>()); +} + +struct TestReverseBlocks { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned(N); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + const size_t num_blocks = N / kLanesPerBlock; + HWY_ASSERT(num_blocks != 0); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / kLanesPerBlock; + const size_t base = (num_blocks - 1 - idx_block) * kLanesPerBlock; + expected[i] = copy[base + (i % kLanesPerBlock)]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ReverseBlocks(d, v)); + } +}; + +HWY_NOINLINE void TestAllReverseBlocks() { + ForAllTypes(ForGEVectors<128, TestReverseBlocks>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyReverseTest); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse2); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse4); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse8); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverseBlocks); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/shift_test.cc b/media/highway/src/hwy/tests/shift_test.cc new file mode 100644 index 000000000..585eba761 --- /dev/null +++ b/media/highway/src/hwy/tests/shift_test.cc @@ -0,0 +1,428 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/shift_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct TestLeftShifts { + template + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestLeftShifts()(t, d); + } + + using TI = MakeSigned; + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + // Values to shift + const auto values = Iota(d, static_cast(kSigned ? -TI(N) : TI(0))); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftLeft<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftLeftSame(values, 0)); + + // 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(T(i) - T(N)) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(T(i) - T(N)) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, kMaxShift)); + } +}; + +template +struct TestVariableLeftShifts { + template + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestVariableLeftShifts()(t, d); + } + + using TI = MakeSigned; + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shl(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, v1)); + + // Same: max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, small_shifts)); + + // Variable: large + for (size_t i = 0; i < N; ++i) { + expected[i] = T(TU(1) << (kMaxShift - (i & kMaxShift))); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(v1, large_shifts)); + } +}; + +struct TestUnsignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto values = Iota(d, 0); + + const T kMax = LimitsMax(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, kMaxShift)); + } +}; + +struct TestRotateRight { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + constexpr size_t kBits = sizeof(T) * 8; + const auto mask_shift = Set(d, T{kBits}); + // Cover as many bit positions as possible to test shifting out + const auto values = Shl(Set(d, T{1}), And(Iota(d, 0), mask_shift)); + + // Rotate by 0 + HWY_ASSERT_VEC_EQ(d, values, RotateRight<0>(values)); + + // Rotate by 1 + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> 1) | (expected[i] << (kBits - 1)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight<1>(values)); + + // Rotate by half + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> (kBits / 2)) | (expected[i] << (kBits / 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight(values)); + + // Rotate by max + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> (kBits - 1)) | (expected[i] << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight(values)); + } +}; + +struct TestVariableUnsignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, 0); + + const T kMax = LimitsMax(); + const auto max = Set(d, kMax); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shr(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, v1)); + + // Same: max + HWY_ASSERT_VEC_EQ(d, v0, Shr(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + expected[i] = T(i) >> (i & kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, small_shifts)); + + // Variable: Large + for (size_t i = 0; i < N; ++i) { + expected[i] = kMax >> (kMaxShift - (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(max, large_shifts)); + } +}; + +template +T RightShiftNegative(T val) { + // C++ shifts are implementation-defined for negative numbers, and we have + // seen divisions replaced with shifts, so resort to bit operations. + using TU = hwy::MakeUnsigned; + TU bits; + CopySameSize(&val, &bits); + + const TU shifted = TU(bits >> kAmount); + + const TU all = TU(~TU(0)); + const size_t num_zero = sizeof(TU) * 8 - 1 - kAmount; + const TU sign_extended = static_cast((all << num_zero) & LimitsMax()); + + bits = shifted | sign_extended; + CopySameSize(&bits, &val); + return val; +} + +class TestSignedRightShifts { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + constexpr T kMin = LimitsMin(); + constexpr T kMax = LimitsMax(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto values = And(Iota(d, 0), Set(d, kMax)); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(values)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(values, kMaxShift)); + + // Even negative value + Test<0>(kMin, d, __LINE__); + Test<1>(kMin, d, __LINE__); + Test<2>(kMin, d, __LINE__); + Test(kMin, d, __LINE__); + + const T odd = static_cast(kMin + 1); + Test<0>(odd, d, __LINE__); + Test<1>(odd, d, __LINE__); + Test<2>(odd, d, __LINE__); + Test(odd, d, __LINE__); + } + + private: + template + void Test(T val, D d, int line) { + const auto expected = Set(d, RightShiftNegative(val)); + const auto in = Set(d, val); + const char* file = __FILE__; + AssertVecEqual(d, expected, ShiftRight(in), file, line); + AssertVecEqual(d, expected, ShiftRightSame(in, kAmount), file, line); + } +}; + +struct TestVariableSignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + constexpr T kMin = LimitsMin(); + constexpr T kMax = LimitsMax(); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto positive = Iota(d, 0) & Set(d, kMax); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, positive, ShiftRight<0>(positive)); + HWY_ASSERT_VEC_EQ(d, positive, ShiftRightSame(positive, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(positive)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(positive, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(positive)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(positive, kMaxShift)); + + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + const auto negative = Iota(d, kMin); + + // Test varying negative to shift + for (size_t i = 0; i < N; ++i) { + expected[i] = RightShiftNegative<1>(static_cast(kMin + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(negative, Set(d, 1))); + + // Shift MSB right by small amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = i & kMaxShift; + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopySameSize(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), small_shifts)); + + // Shift MSB right by large amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = kMaxShift - (i & kMaxShift); + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopySameSize(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), large_shifts)); + } +}; + +HWY_NOINLINE void TestAllShifts() { + ForUnsignedTypes(ForPartialVectors>()); + ForSignedTypes(ForPartialVectors>()); + ForUnsignedTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); +} + +HWY_NOINLINE void TestAllVariableShifts() { + const ForPartialVectors> shl_u; + const ForPartialVectors> shl_s; + const ForPartialVectors shr_u; + const ForPartialVectors shr_s; + + shl_u(uint16_t()); + shr_u(uint16_t()); + + shl_u(uint32_t()); + shr_u(uint32_t()); + + shl_s(int16_t()); + shr_s(int16_t()); + + shl_s(int32_t()); + shr_s(int32_t()); + +#if HWY_HAVE_INTEGER64 + shl_u(uint64_t()); + shr_u(uint64_t()); + + shl_s(int64_t()); + shr_s(int64_t()); +#endif +} + +HWY_NOINLINE void TestAllRotateRight() { + const ForPartialVectors test; + test(uint32_t()); +#if HWY_HAVE_INTEGER64 + test(uint64_t()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyShiftTest); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllShifts); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllVariableShifts); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllRotateRight); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/swizzle_test.cc b/media/highway/src/hwy/tests/swizzle_test.cc new file mode 100644 index 000000000..f447f7a80 --- /dev/null +++ b/media/highway/src/hwy/tests/swizzle_test.cc @@ -0,0 +1,272 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include // memset + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/swizzle_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestGetLane { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + HWY_ASSERT_EQ(T(1), GetLane(v)); + } +}; + +HWY_NOINLINE void TestAllGetLane() { + ForAllTypes(ForPartialVectors()); +} + +struct TestExtractLane { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + for (size_t i = 0; i < Lanes(d); ++i) { + const T actual = ExtractLane(v, i); + HWY_ASSERT_EQ(static_cast(i + 1), actual); + } + } +}; + +HWY_NOINLINE void TestAllExtractLane() { + ForAllTypes(ForPartialVectors()); +} + +struct TestInsertLane { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec; + const V v = Iota(d, T(1)); + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(v, d, lanes.get()); + + for (size_t i = 0; i < Lanes(d); ++i) { + lanes[i] = T{0}; + const V actual = InsertLane(v, i, static_cast(i + 1)); + HWY_ASSERT_VEC_EQ(d, v, actual); + Store(v, d, lanes.get()); // restore lane i + } + } +}; + +HWY_NOINLINE void TestAllInsertLane() { + ForAllTypes(ForPartialVectors()); +} + +struct TestDupEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((static_cast(i) & ~1) + 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), DupEven(Iota(d, 1))); + } +}; + +HWY_NOINLINE void TestAllDupEven() { + ForUIF3264(ForShrinkableVectors()); +} + +struct TestDupOdd { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((static_cast(i) & ~1) + 2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), DupOdd(Iota(d, 1))); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllDupOdd() { + ForUIF3264(ForShrinkableVectors()); +} + +struct TestOddEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto even = Iota(d, 1); + const auto odd = Iota(d, static_cast(1 + N)); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast(1 + i + ((i & 1) ? N : 0)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), OddEven(odd, even)); + } +}; + +HWY_NOINLINE void TestAllOddEven() { + ForAllTypes(ForShrinkableVectors()); +} + +struct TestOddEvenBlocks { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto even = Iota(d, 1); + const auto odd = Iota(d, static_cast(1 + N)); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / (16 / sizeof(T)); + expected[i] = static_cast(1 + i + ((idx_block & 1) ? N : 0)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), OddEvenBlocks(odd, even)); + } +}; + +HWY_NOINLINE void TestAllOddEvenBlocks() { + ForAllTypes(ForGEVectors<128, TestOddEvenBlocks>()); +} + +struct TestSwapAdjacentBlocks { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (N < 2 * kLanesPerBlock) return; + const auto vi = Iota(d, 1); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / kLanesPerBlock; + const size_t base = (idx_block ^ 1) * kLanesPerBlock; + const size_t mod = i % kLanesPerBlock; + expected[i] = static_cast(1 + base + mod); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), SwapAdjacentBlocks(vi)); + } +}; + +HWY_NOINLINE void TestAllSwapAdjacentBlocks() { + ForAllTypes(ForGEVectors<128, TestSwapAdjacentBlocks>()); +} + +struct TestTableLookupLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const RebindToSigned di; + using TI = TFromD; +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + auto idx = AllocateAligned(N); + memset(idx.get(), 0, N * sizeof(TI)); + auto expected = AllocateAligned(N); + const auto v = Iota(d, 1); + + if (N <= 8) { // Test all permutations + for (size_t i0 = 0; i0 < N; ++i0) { + idx[0] = static_cast(i0); + + for (size_t i1 = 0; i1 < N; ++i1) { + if (N >= 2) idx[1] = static_cast(i1); + for (size_t i2 = 0; i2 < N; ++i2) { + if (N >= 4) idx[2] = static_cast(i2); + for (size_t i3 = 0; i3 < N; ++i3) { + if (N >= 4) idx[3] = static_cast(i3); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast(idx[i] + 1); // == v[idx[i]] + } + + const auto opaque1 = IndicesFromVec(d, Load(di, idx.get())); + const auto actual1 = TableLookupLanes(v, opaque1); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual1); + + const auto opaque2 = SetTableIndices(d, idx.get()); + const auto actual2 = TableLookupLanes(v, opaque2); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual2); + } + } + } + } + } else { + // Too many permutations to test exhaustively; choose one with repeated + // and cross-block indices and ensure indices do not exceed #lanes. + // For larger vectors, upper lanes will be zero. + HWY_ALIGN TI idx_source[16] = {1, 3, 2, 2, 8, 1, 7, 6, + 15, 14, 14, 15, 4, 9, 8, 5}; + for (size_t i = 0; i < N; ++i) { + idx[i] = (i < 16) ? idx_source[i] : 0; + // Avoid undefined results / asan error for scalar by capping indices. + if (idx[i] >= static_cast(N)) { + idx[i] = static_cast(N - 1); + } + expected[i] = static_cast(idx[i] + 1); // == v[idx[i]] + } + + const auto opaque1 = IndicesFromVec(d, Load(di, idx.get())); + const auto actual1 = TableLookupLanes(v, opaque1); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual1); + + const auto opaque2 = SetTableIndices(d, idx.get()); + const auto actual2 = TableLookupLanes(v, opaque2); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual2); + } +#else + const TI index = 0; + const auto v = Set(d, 1); + const auto opaque1 = SetTableIndices(d, &index); + HWY_ASSERT_VEC_EQ(d, v, TableLookupLanes(v, opaque1)); + const auto opaque2 = IndicesFromVec(d, Zero(di)); + HWY_ASSERT_VEC_EQ(d, v, TableLookupLanes(v, opaque2)); +#endif + } +}; + +HWY_NOINLINE void TestAllTableLookupLanes() { + ForUIF3264(ForPartialVectors()); +} + + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwySwizzleTest); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllGetLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllExtractLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllInsertLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllDupEven); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllDupOdd); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEven); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEvenBlocks); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllSwapAdjacentBlocks); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllTableLookupLanes); +} // namespace hwy + +#endif diff --git a/media/highway/src/hwy/tests/test_util-inl.h b/media/highway/src/hwy/tests/test_util-inl.h new file mode 100644 index 000000000..d9c1aebc3 --- /dev/null +++ b/media/highway/src/hwy/tests/test_util-inl.h @@ -0,0 +1,665 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-specific helper functions for use by *_test.cc. + +#include + +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util.h" + +// After test_util (also includes highway.h) +#include "hwy/print-inl.h" + +// Per-target include guard +#if defined(HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#undef HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#else +#define HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Compare expected vector to vector. +// HWY_INLINE works around a Clang SVE compiler bug where all but the first +// 128 bits (the NEON register) of actual are zero. +template , class V = Vec> +HWY_INLINE void AssertVecEqual(D d, const T* expected, VecArg actual, + const char* filename, const int line) { + const size_t N = Lanes(d); + auto actual_lanes = AllocateAligned(N); + Store(actual, d, actual_lanes.get()); + + const auto info = hwy::detail::MakeTypeInfo(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected, actual_lanes.get(), N, + target_name, filename, line); +} + +// Compare expected lanes to vector. +// HWY_INLINE works around a Clang SVE compiler bug where all but the first +// 128 bits (the NEON register) of actual are zero. +template , class V = Vec> +HWY_INLINE void AssertVecEqual(D d, VecArg expected, VecArg actual, + const char* filename, int line) { + auto expected_lanes = AllocateAligned(Lanes(d)); + Store(expected, d, expected_lanes.get()); + AssertVecEqual(d, expected_lanes.get(), actual, filename, line); +} + +// Only checks the valid mask elements (those whose index < Lanes(d)). +template +HWY_NOINLINE void AssertMaskEqual(D d, VecArg> a, VecArg> b, + const char* filename, int line) { + // lvalues prevented MSAN failure in farm_sve. + const Vec va = VecFromMask(d, a); + const Vec vb = VecFromMask(d, b); + AssertVecEqual(d, va, vb, filename, line); + + const char* target_name = hwy::TargetName(HWY_TARGET); + AssertEqual(CountTrue(d, a), CountTrue(d, b), target_name, filename, line); + AssertEqual(AllTrue(d, a), AllTrue(d, b), target_name, filename, line); + AssertEqual(AllFalse(d, a), AllFalse(d, b), target_name, filename, line); + + const size_t N = Lanes(d); +#if HWY_TARGET == HWY_SCALAR + const Rebind d8; +#else + const Repartition d8; +#endif + const size_t N8 = Lanes(d8); + auto bits_a = AllocateAligned(HWY_MAX(8, N8)); + auto bits_b = AllocateAligned(HWY_MAX(8, N8)); + memset(bits_a.get(), 0, N8); + memset(bits_b.get(), 0, N8); + const size_t num_bytes_a = StoreMaskBits(d, a, bits_a.get()); + const size_t num_bytes_b = StoreMaskBits(d, b, bits_b.get()); + AssertEqual(num_bytes_a, num_bytes_b, target_name, filename, line); + size_t i = 0; + // First check whole bytes (if that many elements are still valid) + for (; i < N / 8; ++i) { + if (bits_a[i] != bits_b[i]) { + fprintf(stderr, "Mismatch in byte %d: %d != %d\n", static_cast(i), + bits_a[i], bits_b[i]); + Print(d8, "expect", Load(d8, bits_a.get()), 0, N8); + Print(d8, "actual", Load(d8, bits_b.get()), 0, N8); + hwy::Abort(filename, line, "Masks not equal"); + } + } + // Then the valid bit(s) in the last byte. + const size_t remainder = N % 8; + if (remainder != 0) { + const int mask = (1 << remainder) - 1; + const int valid_a = bits_a[i] & mask; + const int valid_b = bits_b[i] & mask; + if (valid_a != valid_b) { + fprintf(stderr, "Mismatch in last byte %d: %d != %d\n", + static_cast(i), valid_a, valid_b); + Print(d8, "expect", Load(d8, bits_a.get()), 0, N8); + Print(d8, "actual", Load(d8, bits_b.get()), 0, N8); + hwy::Abort(filename, line, "Masks not equal"); + } + } +} + +// Only sets valid elements (those whose index < Lanes(d)). This helps catch +// tests that are not masking off the (undefined) upper mask elements. +// +// TODO(janwas): with HWY_NOINLINE GCC zeros the upper half of AVX2 masks. +template +HWY_INLINE Mask MaskTrue(const D d) { + return FirstN(d, Lanes(d)); +} + +template +HWY_INLINE Mask MaskFalse(const D d) { + const auto zero = Zero(RebindToSigned()); + return RebindMask(d, Lt(zero, zero)); +} + +#ifndef HWY_ASSERT_EQ + +#define HWY_ASSERT_EQ(expected, actual) \ + hwy::AssertEqual(expected, actual, hwy::TargetName(HWY_TARGET), __FILE__, \ + __LINE__) + +#define HWY_ASSERT_ARRAY_EQ(expected, actual, count) \ + hwy::AssertArrayEqual(expected, actual, count, hwy::TargetName(HWY_TARGET), \ + __FILE__, __LINE__) + +#define HWY_ASSERT_STRING_EQ(expected, actual) \ + hwy::AssertStringEqual(expected, actual, hwy::TargetName(HWY_TARGET), \ + __FILE__, __LINE__) + +#define HWY_ASSERT_VEC_EQ(d, expected, actual) \ + AssertVecEqual(d, expected, actual, __FILE__, __LINE__) + +#define HWY_ASSERT_MASK_EQ(d, expected, actual) \ + AssertMaskEqual(d, expected, actual, __FILE__, __LINE__) + +#endif // HWY_ASSERT_EQ + +namespace detail { + +// Helpers for instantiating tests with combinations of lane types / counts. + +// Calls Test for each CappedTag where N is in [kMinLanes, kMul * kMinArg] +// and the resulting Lanes() is in [min_lanes, max_lanes]. The upper bound +// is required to ensure capped vectors remain extendable. Implemented by +// recursively halving kMul until it is zero. +template +struct ForeachCappedR { + static void Do(size_t min_lanes, size_t max_lanes) { + const CappedTag d; + + // If we already don't have enough lanes, stop. + const size_t lanes = Lanes(d); + if (lanes < min_lanes) return; + + if (lanes <= max_lanes) { + Test()(T(), d); + } + ForeachCappedR::Do(min_lanes, max_lanes); + } +}; + +// Base case to stop the recursion. +template +struct ForeachCappedR { + static void Do(size_t, size_t) {} +}; + +#if HWY_HAVE_SCALABLE + +template +constexpr int MinPow2() { + // Highway follows RVV LMUL in that the smallest fraction is 1/8th (encoded + // as kPow2 == -3). The fraction also must not result in zero lanes for the + // smallest possible vector size, which is 128 bits even on RISC-V (with the + // application processor profile). + return HWY_MAX(-3, -static_cast(CeilLog2(16 / sizeof(T)))); +} + +// Iterates kPow2 upward through +3. +template +struct ForeachShiftR { + static void Do(size_t min_lanes) { + const ScalableTag d; + + // Precondition: [kPow2, 3] + kAddPow2 is a valid fraction of the minimum + // vector size, so we always have enough lanes, except ForGEVectors. + if (Lanes(d) >= min_lanes) { + Test()(T(), d); + } else { + fprintf(stderr, "%d lanes < %d: T=%d pow=%d\n", + static_cast(Lanes(d)), static_cast(min_lanes), + static_cast(sizeof(T)), kPow2 + kAddPow2); + HWY_ASSERT(min_lanes != 1); + } + + ForeachShiftR::Do(min_lanes); + } +}; + +// Base case to stop the recursion. +template +struct ForeachShiftR { + static void Do(size_t) {} +}; +#else +// ForeachCappedR already handled all possible sizes. +#endif // HWY_HAVE_SCALABLE + +} // namespace detail + +// These 'adapters' call a test for all possible N or kPow2 subject to +// constraints such as "vectors must be extendable" or "vectors >= 128 bits". +// They may be called directly, or via For*Types. Note that for an adapter C, +// `C(T())` does not call the test - the correct invocation is +// `C()(T())`, or preferably `ForAllTypes(C())`. We check at runtime +// that operator() is called to prevent such bugs. Note that this is not +// thread-safe, but that is fine because C are typically local variables. + +// Calls Test for all power of two N in [1, Lanes(d) >> kPow2]. This is for +// ops that widen their input, e.g. Combine (not supported by HWY_SCALAR). +template +class ForExtendableVectors { + mutable bool called_ = false; + + public: + ~ForExtendableVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMaxCapped = HWY_LANES(T); + // Skip CappedTag that are already full vectors. + const size_t max_lanes = Lanes(ScalableTag()) >> kPow2; + (void)kMaxCapped; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + // not supported +#else + detail::ForeachCappedR> kPow2), 1, Test>::Do(1, max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2, 3 - kPow2]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR() + kPow2, -kPow2, Test>::Do(1); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2, 0 - kPow2]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR() + kPow2 + 3, -kPow2 - 3, + Test>::Do(1); +#endif +#endif // HWY_SCALAR + } +}; + +// Calls Test for all power of two N in [1 << kPow2, Lanes(d)]. This is for ops +// that narrow their input, e.g. UpperHalf. +template +class ForShrinkableVectors { + mutable bool called_ = false; + + public: + ~ForShrinkableVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + // not supported +#else + detail::ForeachCappedR> kPow2), kMinLanes, Test>::Do( + kMinLanes, max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR() + kPow2, 0, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR() + kPow2 + 3, -3, Test>::Do( + kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// Calls Test for all supported power of two vectors of at least kMinBits. +// Examples: AES or 64x64 require 128 bits, casts may require 64 bits. +template +class ForGEVectors { + mutable bool called_ = false; + + public: + ~ForGEVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMaxCapped = HWY_LANES(T); + constexpr size_t kMinLanes = kMinBits / 8 / sizeof(T); + // An upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + (void)kMinLanes; // not supported +#else + detail::ForeachCappedR::Do( + kMinLanes, max_lanes); +#if HWY_TARGET == HWY_RVV + // Can be 0 (handled below) if kMinBits > 64. + constexpr size_t kRatio = 128 / kMinBits; + constexpr int kMinPow2 = + kRatio == 0 ? 0 : -static_cast(CeilLog2(kRatio)); + // For each [kMinPow2, 3]; counter is [kMinPow2, 3]. + detail::ForeachShiftR::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // Can be 0 (handled below) if kMinBits > 128. + constexpr size_t kRatio = 128 / kMinBits; + constexpr int kMinPow2 = + kRatio == 0 ? 0 : -static_cast(CeilLog2(kRatio)); + // For each [kMinPow2, 0]; counter is [kMinPow2 + 3, 3]. + detail::ForeachShiftR::Do(kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +template +using ForGE128Vectors = ForGEVectors<128, Test>; + +// Calls Test for all N that can be promoted (not the same as Extendable because +// HWY_SCALAR has one lane). Also used for ZipLower, but not ZipUpper. +template +class ForPromoteVectors { + mutable bool called_ = false; + + public: + ~ForPromoteVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kFactor = size_t{1} << kPow2; + static_assert(kFactor >= 2 && kFactor * sizeof(T) <= sizeof(uint64_t), ""); + constexpr size_t kMaxCapped = HWY_LANES(T); + constexpr size_t kMinLanes = kFactor; + // Skip CappedTag that are already full vectors. + const size_t max_lanes = Lanes(ScalableTag()) >> kPow2; + (void)kMaxCapped; + (void)kMinLanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + detail::ForeachCappedR::Do(1, 1); +#else + // TODO(janwas): call Extendable if kMinLanes check not required? + detail::ForeachCappedR> kPow2), 1, Test>::Do(kMinLanes, + max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2, 3 - kPow2]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR() + kPow2, -kPow2, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2, 0 - kPow2]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR() + kPow2 + 3, -kPow2 - 3, + Test>::Do(kMinLanes); +#endif +#endif // HWY_SCALAR + } +}; + +// Calls Test for all N than can be demoted (not the same as Shrinkable because +// HWY_SCALAR has one lane). +template +class ForDemoteVectors { + mutable bool called_ = false; + + public: + ~ForDemoteVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + detail::ForeachCappedR::Do(1, 1); +#else + detail::ForeachCappedR> kPow2), kMinLanes, Test>::Do( + kMinLanes, max_lanes); + +// TODO(janwas): call Extendable if kMinLanes check not required? +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR() + kPow2, 0, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR() + kPow2 + 3, -3, Test>::Do( + kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// For LowerHalf/Quarter. +template +class ForHalfVectors { + mutable bool called_ = false; + + public: + ~ForHalfVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T /*unused*/) const { + called_ = true; +#if HWY_TARGET == HWY_SCALAR + detail::ForeachCappedR::Do(1, 1); +#else + constexpr size_t kMinLanes = size_t{1} << kPow2; + // For shrinking, an upper limit is unnecessary. + constexpr size_t kMaxCapped = HWY_LANES(T); + detail::ForeachCappedR> kPow2), kMinLanes, Test>::Do( + kMinLanes, kMaxCapped); + +// TODO(janwas): call Extendable if kMinLanes check not required? +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR() + kPow2, 0, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR() + kPow2 + 3, -3, Test>::Do( + kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// Calls Test for all power of two N in [1, Lanes(d)]. This is the default +// for ops that do not narrow nor widen their input, nor require 128 bits. +template +class ForPartialVectors { + mutable bool called_ = false; + + public: + ~ForPartialVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template + void operator()(T t) const { + called_ = true; +#if HWY_TARGET == HWY_SCALAR + (void)t; + detail::ForeachCappedR::Do(1, 1); +#else + ForExtendableVectors()(t); +#endif + } +}; + +// Type lists to shorten call sites: + +template +void ForSignedTypes(const Func& func) { + func(int8_t()); + func(int16_t()); + func(int32_t()); +#if HWY_HAVE_INTEGER64 + func(int64_t()); +#endif +} + +template +void ForUnsignedTypes(const Func& func) { + func(uint8_t()); + func(uint16_t()); + func(uint32_t()); +#if HWY_HAVE_INTEGER64 + func(uint64_t()); +#endif +} + +template +void ForIntegerTypes(const Func& func) { + ForSignedTypes(func); + ForUnsignedTypes(func); +} + +template +void ForFloatTypes(const Func& func) { + func(float()); +#if HWY_HAVE_FLOAT64 + func(double()); +#endif +} + +template +void ForAllTypes(const Func& func) { + ForIntegerTypes(func); + ForFloatTypes(func); +} + +template +void ForUI8(const Func& func) { + func(uint8_t()); + func(int8_t()); +} + +template +void ForUI16(const Func& func) { + func(uint16_t()); + func(int16_t()); +} + +template +void ForUIF16(const Func& func) { + ForUI16(func); +#if HWY_HAVE_FLOAT16 + func(float16_t()); +#endif +} + +template +void ForUI32(const Func& func) { + func(uint32_t()); + func(int32_t()); +} + +template +void ForUIF32(const Func& func) { + ForUI32(func); + func(float()); +} + +template +void ForUI64(const Func& func) { +#if HWY_HAVE_INTEGER64 + func(uint64_t()); + func(int64_t()); +#endif +} + +template +void ForUIF64(const Func& func) { + ForUI64(func); +#if HWY_HAVE_FLOAT64 + func(double()); +#endif +} + +template +void ForUI3264(const Func& func) { + ForUI32(func); + ForUI64(func); +} + +template +void ForUIF3264(const Func& func) { + ForUIF32(func); + ForUIF64(func); +} + +template +void ForUI163264(const Func& func) { + ForUI16(func); + ForUI3264(func); +} + +template +void ForUIF163264(const Func& func) { + ForUIF16(func); + ForUIF3264(func); +} + +// For tests that involve loops, adjust the trip count so that emulated tests +// finish quickly (but always at least 2 iterations to ensure some diversity). +constexpr size_t AdjustedReps(size_t max_reps) { +#if HWY_ARCH_RVV + return HWY_MAX(max_reps / 32, 2); +#elif HWY_IS_DEBUG_BUILD + return HWY_MAX(max_reps / 8, 2); +#elif HWY_ARCH_ARM + return HWY_MAX(max_reps / 4, 2); +#else + return HWY_MAX(max_reps, 2); +#endif +} + +// Same as above, but the loop trip count will be 1 << max_pow2. +constexpr size_t AdjustedLog2Reps(size_t max_pow2) { + // If "negative" (unsigned wraparound), use original. +#if HWY_ARCH_RVV + return HWY_MIN(max_pow2 - 4, max_pow2); +#elif HWY_IS_DEBUG_BUILD + return HWY_MIN(max_pow2 - 1, max_pow2); +#elif HWY_ARCH_ARM + return HWY_MIN(max_pow2 - 1, max_pow2); +#else + return max_pow2; +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/media/highway/src/hwy/tests/test_util.cc b/media/highway/src/hwy/tests/test_util.cc new file mode 100644 index 000000000..a0796b15f --- /dev/null +++ b/media/highway/src/hwy/tests/test_util.cc @@ -0,0 +1,117 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/tests/test_util.h" + +#include +#include + +#include + +#include "hwy/base.h" +#include "hwy/print.h" + +namespace hwy { + +HWY_TEST_DLLEXPORT bool BytesEqual(const void* p1, const void* p2, + const size_t size, size_t* pos) { + const uint8_t* bytes1 = reinterpret_cast(p1); + const uint8_t* bytes2 = reinterpret_cast(p2); + for (size_t i = 0; i < size; ++i) { + if (bytes1[i] != bytes2[i]) { + if (pos != nullptr) { + *pos = i; + } + return false; + } + } + return true; +} + +void AssertStringEqual(const char* expected, const char* actual, + const char* target_name, const char* filename, + int line) { + while (*expected == *actual++) { + if (*expected++ == '\0') return; + } + + Abort(filename, line, "%s string mismatch: expected '%s', got '%s'.\n", + target_name, expected, actual); +} + +namespace detail { + +HWY_TEST_DLLEXPORT bool IsEqual(const TypeInfo& info, const void* expected_ptr, + const void* actual_ptr) { + if (!info.is_float) { + return BytesEqual(expected_ptr, actual_ptr, info.sizeof_t); + } + + if (info.sizeof_t == 4) { + float expected, actual; + CopyBytes<4>(expected_ptr, &expected); + CopyBytes<4>(actual_ptr, &actual); + return ComputeUlpDelta(expected, actual) <= 1; + } else if (info.sizeof_t == 8) { + double expected, actual; + CopyBytes<8>(expected_ptr, &expected); + CopyBytes<8>(actual_ptr, &actual); + return ComputeUlpDelta(expected, actual) <= 1; + } else { + HWY_ABORT("Unexpected float size %d\n", static_cast(info.sizeof_t)); + return false; + } +} + +HWY_TEST_DLLEXPORT HWY_NORETURN void PrintMismatchAndAbort( + const TypeInfo& info, const void* expected_ptr, const void* actual_ptr, + const char* target_name, const char* filename, int line, size_t lane, + size_t num_lanes) { + char type_name[100]; + TypeName(info, 1, type_name); + char expected_str[100]; + ToString(info, expected_ptr, expected_str); + char actual_str[100]; + ToString(info, actual_ptr, actual_str); + Abort(filename, line, + "%s, %sx%d lane %d mismatch: expected '%s', got '%s'.\n", target_name, + type_name, static_cast(num_lanes), static_cast(lane), + expected_str, actual_str); +} + +HWY_TEST_DLLEXPORT void AssertArrayEqual(const TypeInfo& info, + const void* expected_void, + const void* actual_void, size_t N, + const char* target_name, + const char* filename, int line) { + const uint8_t* expected_array = + reinterpret_cast(expected_void); + const uint8_t* actual_array = reinterpret_cast(actual_void); + for (size_t i = 0; i < N; ++i) { + const void* expected_ptr = expected_array + i * info.sizeof_t; + const void* actual_ptr = actual_array + i * info.sizeof_t; + if (!IsEqual(info, expected_ptr, actual_ptr)) { + fprintf(stderr, "\n\n"); + PrintArray(info, "expect", expected_array, N, i); + PrintArray(info, "actual", actual_array, N, i); + + PrintMismatchAndAbort(info, expected_ptr, actual_ptr, target_name, + filename, line, i, N); + } + } +} + +} // namespace detail +} // namespace hwy diff --git a/media/highway/src/hwy/tests/test_util.h b/media/highway/src/hwy/tests/test_util.h new file mode 100644 index 000000000..459de961c --- /dev/null +++ b/media/highway/src/hwy/tests/test_util.h @@ -0,0 +1,172 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_TESTS_TEST_UTIL_H_ +#define HWY_TESTS_TEST_UTIL_H_ + +// Target-independent helper functions for use by *_test.cc. + +#include +#include +#include + +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/highway.h" +#include "hwy/highway_export.h" +#include "hwy/print.h" + +namespace hwy { + +// The maximum vector size used in tests when defining test data. DEPRECATED. +constexpr size_t kTestMaxVectorSize = 64; + +// 64-bit random generator (Xorshift128+). Much smaller state than std::mt19937, +// which triggers a compiler bug. +class RandomState { + public: + explicit RandomState(const uint64_t seed = 0x123456789ull) { + s0_ = SplitMix64(seed + 0x9E3779B97F4A7C15ull); + s1_ = SplitMix64(s0_); + } + + HWY_INLINE uint64_t operator()() { + uint64_t s1 = s0_; + const uint64_t s0 = s1_; + const uint64_t bits = s1 + s0; + s0_ = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s1_ = s1; + return bits; + } + + private: + static uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + uint64_t s0_; + uint64_t s1_; +}; + +static HWY_INLINE uint32_t Random32(RandomState* rng) { + return static_cast((*rng)()); +} + +static HWY_INLINE uint64_t Random64(RandomState* rng) { return (*rng)(); } + +// Prevents the compiler from eliding the computations that led to "output". +// Works by indicating to the compiler that "output" is being read and modified. +// The +r constraint avoids unnecessary writes to memory, but only works for +// built-in types. +template +inline void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC + (void)output; +#else // HWY_COMPILER_MSVC + asm volatile("" : "+r"(output) : : "memory"); +#endif // HWY_COMPILER_MSVC +} + +HWY_TEST_DLLEXPORT bool BytesEqual(const void* p1, const void* p2, + const size_t size, size_t* pos = nullptr); + +void AssertStringEqual(const char* expected, const char* actual, + const char* target_name, const char* filename, int line); + +namespace detail { + +template > +TU ComputeUlpDelta(const T expected, const T actual) { + // Handle -0 == 0 and infinities. + if (expected == actual) return 0; + + // Consider "equal" if both are NaN, so we can verify an expected NaN. + // Needs a special case because there are many possible NaN representations. + if (std::isnan(expected) && std::isnan(actual)) return 0; + + // Compute the difference in units of last place. We do not need to check for + // differing signs; they will result in large differences, which is fine. + TU ux, uy; + CopySameSize(&expected, &ux); + CopySameSize(&actual, &uy); + + // Avoid unsigned->signed cast: 2's complement is only guaranteed by C++20. + const TU ulp = HWY_MAX(ux, uy) - HWY_MIN(ux, uy); + return ulp; +} + +HWY_TEST_DLLEXPORT bool IsEqual(const TypeInfo& info, const void* expected_ptr, + const void* actual_ptr); + +HWY_TEST_DLLEXPORT HWY_NORETURN void PrintMismatchAndAbort( + const TypeInfo& info, const void* expected_ptr, const void* actual_ptr, + const char* target_name, const char* filename, int line, size_t lane = 0, + size_t num_lanes = 1); + +HWY_TEST_DLLEXPORT void AssertArrayEqual(const TypeInfo& info, + const void* expected_void, + const void* actual_void, size_t N, + const char* target_name, + const char* filename, int line); + +} // namespace detail + +// Returns a name for the vector/part/scalar. The type prefix is u/i/f for +// unsigned/signed/floating point, followed by the number of bits per lane; +// then 'x' followed by the number of lanes. Example: u8x16. This is useful for +// understanding which instantiation of a generic test failed. +template +std::string TypeName(T /*unused*/, size_t N) { + char string100[100]; + detail::TypeName(detail::MakeTypeInfo(), N, string100); + return string100; +} + +// Compare non-vector, non-string T. +template +HWY_INLINE bool IsEqual(const T expected, const T actual) { + const auto info = detail::MakeTypeInfo(); + return detail::IsEqual(info, &expected, &actual); +} + +template +HWY_INLINE void AssertEqual(const T expected, const T actual, + const char* target_name, const char* filename, + int line, size_t lane = 0) { + const auto info = detail::MakeTypeInfo(); + if (!detail::IsEqual(info, &expected, &actual)) { + detail::PrintMismatchAndAbort(info, &expected, &actual, target_name, + filename, line, lane); + } +} + +template +HWY_INLINE void AssertArrayEqual(const T* expected, const T* actual, + size_t count, const char* target_name, + const char* filename, int line) { + const auto info = hwy::detail::MakeTypeInfo(); + detail::AssertArrayEqual(info, expected, actual, count, target_name, filename, + line); +} + +} // namespace hwy + +#endif // HWY_TESTS_TEST_UTIL_H_ diff --git a/media/highway/src/hwy/tests/test_util_test.cc b/media/highway/src/hwy/tests/test_util_test.cc new file mode 100644 index 000000000..d55e2e8cb --- /dev/null +++ b/media/highway/src/hwy/tests/test_util_test.cc @@ -0,0 +1,105 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/test_util_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestName { + template + HWY_NOINLINE void operator()(T t, D d) { + char num[10]; + std::string expected = IsFloat() ? "f" : (IsSigned() ? "i" : "u"); + snprintf(num, sizeof(num), "%u" , static_cast(sizeof(T) * 8)); + expected += num; + + const size_t N = Lanes(d); + if (N != 1) { + expected += 'x'; + snprintf(num, sizeof(num), "%u", static_cast(N)); + expected += num; + } + const std::string actual = TypeName(t, N); + if (expected != actual) { + HWY_ABORT("%s mismatch: expected '%s', got '%s'.\n", + hwy::TargetName(HWY_TARGET), expected.c_str(), actual.c_str()); + } + } +}; + +HWY_NOINLINE void TestAllName() { ForAllTypes(ForPartialVectors()); } + +struct TestEqualInteger { + template + HWY_NOINLINE void operator()(T /*t*/) const { + HWY_ASSERT_EQ(T(0), T(0)); + HWY_ASSERT_EQ(T(1), T(1)); + HWY_ASSERT_EQ(T(-1), T(-1)); + HWY_ASSERT_EQ(LimitsMin(), LimitsMin()); + + HWY_ASSERT(!IsEqual(T(0), T(1))); + HWY_ASSERT(!IsEqual(T(1), T(0))); + HWY_ASSERT(!IsEqual(T(1), T(-1))); + HWY_ASSERT(!IsEqual(T(-1), T(1))); + HWY_ASSERT(!IsEqual(LimitsMin(), LimitsMax())); + HWY_ASSERT(!IsEqual(LimitsMax(), LimitsMin())); + } +}; + +struct TestEqualFloat { + template + HWY_NOINLINE void operator()(T /*t*/) const { + HWY_ASSERT(IsEqual(T(0), T(0))); + HWY_ASSERT(IsEqual(T(1), T(1))); + HWY_ASSERT(IsEqual(T(-1), T(-1))); + HWY_ASSERT(IsEqual(MantissaEnd(), MantissaEnd())); + + HWY_ASSERT(!IsEqual(T(0), T(1))); + HWY_ASSERT(!IsEqual(T(1), T(0))); + HWY_ASSERT(!IsEqual(T(1), T(-1))); + HWY_ASSERT(!IsEqual(T(-1), T(1))); + HWY_ASSERT(!IsEqual(LowestValue(), HighestValue())); + HWY_ASSERT(!IsEqual(HighestValue(), LowestValue())); + } +}; + +HWY_NOINLINE void TestAllEqual() { + ForIntegerTypes(TestEqualInteger()); + ForFloatTypes(TestEqualFloat()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(TestUtilTest); +HWY_EXPORT_AND_TEST_P(TestUtilTest, TestAllName); +HWY_EXPORT_AND_TEST_P(TestUtilTest, TestAllEqual); +} // namespace hwy + +#endif diff --git a/media/highway/src/libhwy-contrib.pc.in b/media/highway/src/libhwy-contrib.pc.in new file mode 100644 index 000000000..89c45f5e4 --- /dev/null +++ b/media/highway/src/libhwy-contrib.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libhwy-contrib +Description: Additions to Highway: dot product, image, math, sort +Version: @HWY_LIBRARY_VERSION@ +Libs: -L${libdir} -lhwy_contrib +Cflags: -I${includedir} diff --git a/media/highway/src/libhwy-test.pc.in b/media/highway/src/libhwy-test.pc.in new file mode 100644 index 000000000..0416b10df --- /dev/null +++ b/media/highway/src/libhwy-test.pc.in @@ -0,0 +1,11 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libhwy-test +Description: Efficient and performance-portable SIMD wrapper, test helpers. +Requires: gtest +Version: @HWY_LIBRARY_VERSION@ +Libs: -L${libdir} -lhwy_test +Cflags: -I${includedir} diff --git a/media/highway/src/libhwy.pc.in b/media/highway/src/libhwy.pc.in new file mode 100644 index 000000000..643989275 --- /dev/null +++ b/media/highway/src/libhwy.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libhwy +Description: Efficient and performance-portable SIMD wrapper +Version: @HWY_LIBRARY_VERSION@ +Libs: -L${libdir} -lhwy +Cflags: -I${includedir} -D@DLLEXPORT_TO_DEFINE@ diff --git a/media/highway/src/preamble.js.lds b/media/highway/src/preamble.js.lds new file mode 100644 index 000000000..f484a19d2 --- /dev/null +++ b/media/highway/src/preamble.js.lds @@ -0,0 +1,9 @@ +/* + * Copyright 2019 Google LLC + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* mock crypto module for benchmarks and unit tests or std::random_device fails at runtime */ +var crypto = { getRandomValues: function(array) { for (var i = 0; i < array.length; i++) array[i] = (Math.random()*256)|0 } }; \ No newline at end of file diff --git a/media/highway/src/run_tests.bat b/media/highway/src/run_tests.bat new file mode 100644 index 000000000..26600a2c4 --- /dev/null +++ b/media/highway/src/run_tests.bat @@ -0,0 +1,20 @@ +@echo off +REM Switch directory of this batch file +cd %~dp0 + +if not exist build_win mkdir build_win + +cd build_win +cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON -G Ninja || goto error +ninja || goto error +ctest -j || goto error + +cd .. +echo Success +goto end + +:error +echo Failure +exit /b 1 + +:end diff --git a/media/highway/src/run_tests.sh b/media/highway/src/run_tests.sh new file mode 100644 index 000000000..7f7d3447c --- /dev/null +++ b/media/highway/src/run_tests.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +# Switch to directory of this script +MYDIR=$(dirname $(realpath "$0")) +cd "${MYDIR}" + +# Exit if anything fails +set -e + +####################################### +echo RELEASE +rm -rf build +mkdir build +cd build +cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON +make -j +ctest -j +cd .. +rm -rf build + +####################################### +echo DEBUG Clang 9 +rm -rf build_dbg +mkdir build_dbg +cd build_dbg +CXX=clang++-9 CC=clang-9 cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON -DCMAKE_BUILD_TYPE=Debug +make -j +ctest -j +cd .. +rm -rf build_dbg + +####################################### +echo 32-bit GCC +rm -rf build_32 +mkdir build_32 +cd build_32 +CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 CXX=g++ CC=gcc cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON +make -j +ctest -j +cd .. +rm -rf build_32 + +####################################### +for VER in 10 11 12; do + echo GCC $VER + rm -rf build_g$VER + mkdir build_g$VER + cd build_g$VER + CC=gcc-$VER CXX=g++-$VER cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON + make -j + make test + cd .. + rm -rf build_g$VER +done + +####################################### +echo ARMv7 GCC +export QEMU_LD_PREFIX=/usr/arm-linux-gnueabihf +rm -rf build_arm7 +mkdir build_arm7 +cd build_arm7 +CC=arm-linux-gnueabihf-gcc-11 CXX=arm-linux-gnueabihf-g++-11 cmake .. -DHWY_CMAKE_ARM7:BOOL=ON -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON +make -j8 +ctest +cd .. +rm -rf build_arm7 + +####################################### +echo ARMv8 GCC +export QEMU_LD_PREFIX=/usr/aarch64-linux-gnu +rm -rf build_arm8 +mkdir build_arm8 +cd build_arm8 +CC=aarch64-linux-gnu-gcc-11 CXX=aarch64-linux-gnu-g++-11 cmake .. -DHWY_WARNINGS_ARE_ERRORS:BOOL=ON +make -j8 +ctest +cd .. +rm -rf build_arm8 + +echo Success diff --git a/media/libjxl/README_MCP b/media/libjxl/README_MCP new file mode 100644 index 000000000..aafe24dab --- /dev/null +++ b/media/libjxl/README_MCP @@ -0,0 +1,12 @@ +This directory contains build files for the JPEG-XL image +format reference implementation. + +Any patches or additional configuration to be applied to the +upstream source should be kept here in the media/libjxl +directory. + +The upstream jxl git repository is: + + https://github.com/libjxl/libjxl + +The version used was tagged 0.7.0. diff --git a/media/libjxl/include/jxl/jxl_export.h b/media/libjxl/include/jxl/jxl_export.h new file mode 100644 index 000000000..31834ec9e --- /dev/null +++ b/media/libjxl/include/jxl/jxl_export.h @@ -0,0 +1,15 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef JXL_EXPORT_H +#define JXL_EXPORT_H + +#define JXL_EXPORT + +// MSVC requires [[deprecated]] +#define JXL_DEPRECATED [[deprecated]] + +#endif /* JXL_EXPORT_H */ diff --git a/media/libjxl/include/jxl/jxl_threads_export.h b/media/libjxl/include/jxl/jxl_threads_export.h new file mode 100644 index 000000000..b08aabe76 --- /dev/null +++ b/media/libjxl/include/jxl/jxl_threads_export.h @@ -0,0 +1,12 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef JXL_THREADS_EXPORT_H +#define JXL_THREADS_EXPORT_H + +#define JXL_THREADS_EXPORT + +#endif /* JXL_THREADS_EXPORT_H */ diff --git a/media/libjxl/include/jxl/version.h b/media/libjxl/include/jxl/version.h new file mode 100644 index 000000000..9be7a2f5a --- /dev/null +++ b/media/libjxl/include/jxl/version.h @@ -0,0 +1,39 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file version.h + * @brief libjxl version information + */ + +#ifndef JXL_VERSION_H_ +#define JXL_VERSION_H_ + +#define JPEGXL_MAJOR_VERSION 0 ///< JPEG XL Major version +#define JPEGXL_MINOR_VERSION 7 ///< JPEG XL Minor version +#define JPEGXL_PATCH_VERSION 0 ///< JPEG XL Patch version + +/** Can be used to conditionally compile code for a specific JXL version + * @param[maj] major version + * @param[min] minor version + * + * @code + * #if JPEGXL_NUMERIC_VERSION < JPEGXL_COMPUTE_NUMERIC_VERSION(0,8,0) + * // use old/deprecated api + * #else + * // use current api + * #endif + * @endcode + */ +#define JPEGXL_COMPUTE_NUMERIC_VERSION(major,minor,patch) ((major<<24) | (minor<<16) | (patch<<8) | 0) + +/* Numeric representation of the version */ +#define JPEGXL_NUMERIC_VERSION JPEGXL_COMPUTE_NUMERIC_VERSION(JPEGXL_MAJOR_VERSION,JPEGXL_MINOR_VERSION,JPEGXL_PATCH_VERSION) + +#endif /* JXL_VERSION_H_ */ + +/** @}*/ diff --git a/media/libjxl/moz.build b/media/libjxl/moz.build new file mode 100644 index 000000000..3c55db02c --- /dev/null +++ b/media/libjxl/moz.build @@ -0,0 +1,153 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +LOCAL_INCLUDES += [ + "./include/", + "/media/libjxl/src/", + "/media/libjxl/src/lib/include/", +] + +SOURCES += [ + "/media/libjxl/src/lib/jxl/ac_strategy.cc", + "/media/libjxl/src/lib/jxl/alpha.cc", + "/media/libjxl/src/lib/jxl/ans_common.cc", + "/media/libjxl/src/lib/jxl/aux_out.cc", + "/media/libjxl/src/lib/jxl/base/cache_aligned.cc", + "/media/libjxl/src/lib/jxl/base/data_parallel.cc", + "/media/libjxl/src/lib/jxl/base/padded_bytes.cc", + "/media/libjxl/src/lib/jxl/base/random.cc", + "/media/libjxl/src/lib/jxl/blending.cc", + "/media/libjxl/src/lib/jxl/box_content_decoder.cc", + "/media/libjxl/src/lib/jxl/chroma_from_luma.cc", + "/media/libjxl/src/lib/jxl/coeff_order.cc", + "/media/libjxl/src/lib/jxl/color_encoding_internal.cc", + "/media/libjxl/src/lib/jxl/color_management.cc", + "/media/libjxl/src/lib/jxl/compressed_dc.cc", + "/media/libjxl/src/lib/jxl/convolve_separable5.cc", + "/media/libjxl/src/lib/jxl/convolve_separable7.cc", + "/media/libjxl/src/lib/jxl/convolve_slow.cc", + "/media/libjxl/src/lib/jxl/convolve_symmetric3.cc", + "/media/libjxl/src/lib/jxl/convolve_symmetric5.cc", + "/media/libjxl/src/lib/jxl/dct_scales.cc", + "/media/libjxl/src/lib/jxl/dec_ans.cc", + "/media/libjxl/src/lib/jxl/dec_cache.cc", + "/media/libjxl/src/lib/jxl/dec_context_map.cc", + "/media/libjxl/src/lib/jxl/dec_external_image.cc", + "/media/libjxl/src/lib/jxl/dec_frame.cc", + "/media/libjxl/src/lib/jxl/dec_group.cc", + "/media/libjxl/src/lib/jxl/dec_group_border.cc", + "/media/libjxl/src/lib/jxl/dec_huffman.cc", + "/media/libjxl/src/lib/jxl/dec_modular.cc", + "/media/libjxl/src/lib/jxl/dec_noise.cc", + "/media/libjxl/src/lib/jxl/dec_patch_dictionary.cc", + "/media/libjxl/src/lib/jxl/dec_xyb.cc", + "/media/libjxl/src/lib/jxl/decode.cc", + "/media/libjxl/src/lib/jxl/enc_bit_writer.cc", + "/media/libjxl/src/lib/jxl/entropy_coder.cc", + "/media/libjxl/src/lib/jxl/epf.cc", + "/media/libjxl/src/lib/jxl/fast_dct.cc", + "/media/libjxl/src/lib/jxl/fields.cc", + "/media/libjxl/src/lib/jxl/frame_header.cc", + "/media/libjxl/src/lib/jxl/gauss_blur.cc", + "/media/libjxl/src/lib/jxl/headers.cc", + "/media/libjxl/src/lib/jxl/huffman_table.cc", + "/media/libjxl/src/lib/jxl/icc_codec.cc", + "/media/libjxl/src/lib/jxl/icc_codec_common.cc", + "/media/libjxl/src/lib/jxl/image.cc", + "/media/libjxl/src/lib/jxl/image_bundle.cc", + "/media/libjxl/src/lib/jxl/image_metadata.cc", + "/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.cc", + "/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.cc", + "/media/libjxl/src/lib/jxl/jpeg/jpeg_data.cc", + "/media/libjxl/src/lib/jxl/loop_filter.cc", + "/media/libjxl/src/lib/jxl/luminance.cc", + "/media/libjxl/src/lib/jxl/memory_manager_internal.cc", + "/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.cc", + "/media/libjxl/src/lib/jxl/modular/encoding/encoding.cc", + "/media/libjxl/src/lib/jxl/modular/modular_image.cc", + "/media/libjxl/src/lib/jxl/modular/transform/rct.cc", + "/media/libjxl/src/lib/jxl/modular/transform/squeeze.cc", + "/media/libjxl/src/lib/jxl/modular/transform/transform.cc", + "/media/libjxl/src/lib/jxl/opsin_params.cc", + "/media/libjxl/src/lib/jxl/passes_state.cc", + "/media/libjxl/src/lib/jxl/quant_weights.cc", + "/media/libjxl/src/lib/jxl/quantizer.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_write.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.cc", + "/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.cc", + "/media/libjxl/src/lib/jxl/splines.cc", + "/media/libjxl/src/lib/jxl/toc.cc", +] + +SOURCES += [ + "/media/libjxl/src/lib/threads/thread_parallel_runner.cc", + "/media/libjxl/src/lib/threads/thread_parallel_runner_internal.cc", +] + +EXPORTS.jxl += [ + "./include/jxl/jxl_export.h", + "./include/jxl/jxl_threads_export.h", + "./include/jxl/version.h", + "/media/libjxl/src/lib/include/jxl/butteraugli.h", + "/media/libjxl/src/lib/include/jxl/butteraugli_cxx.h", + "/media/libjxl/src/lib/include/jxl/cms_interface.h", + "/media/libjxl/src/lib/include/jxl/codestream_header.h", + "/media/libjxl/src/lib/include/jxl/color_encoding.h", + "/media/libjxl/src/lib/include/jxl/decode.h", + "/media/libjxl/src/lib/include/jxl/decode_cxx.h", + "/media/libjxl/src/lib/include/jxl/encode.h", + "/media/libjxl/src/lib/include/jxl/encode_cxx.h", + "/media/libjxl/src/lib/include/jxl/memory_manager.h", + "/media/libjxl/src/lib/include/jxl/parallel_runner.h", + "/media/libjxl/src/lib/include/jxl/thread_parallel_runner.h", + "/media/libjxl/src/lib/include/jxl/thread_parallel_runner_cxx.h", + "/media/libjxl/src/lib/include/jxl/types.h", +] + +# DEFINES["JPEGXL_ENABLE_BOXES"] = "0" +DEFINES["JPEGXL_ENABLE_TRANSCODE_JPEG"] = "0" + +FINAL_LIBRARY = "gkmedias" + +# We allow warnings for third-party code that can be updated from upstream. +# XXX: libjxl produces way too many compiler warnings. +# Silence them in the meantime. +ALLOW_COMPILER_WARNINGS = False + +if CONFIG['_MSC_VER']: + CFLAGS += [ + '-wd4646', # function declared with 'noreturn' has non-void return type + '-wd4334', # result of 32-bit shift implicitly converted to 64 bits + '-wd4305', # truncation from 'type' to 'type' + '-wd4146', # unary minus operator applied to unsigned type, result still unsigned + ] + CXXFLAGS += [ + '-wd4646', # function declared with 'noreturn' has non-void return type + '-wd4334', # result of 32-bit shift implicitly converted to 64 bits + '-wd4305', # truncation from 'type' to 'type' + '-wd4146', # unary minus operator applied to unsigned type, result still unsigned + ] + +# Clang 5.0 has a compiler bug that prevents build in c++17 +# See https://gitlab.com/wg1/jpeg-xl/-/issues/227 +# This should be okay since we are using the C API. +if CONFIG["CC_TYPE"] == "clang": + CXXFLAGS += ["-std=c++11"] diff --git a/media/libjxl/src/.clang-format b/media/libjxl/src/.clang-format new file mode 100644 index 000000000..a61b61c56 --- /dev/null +++ b/media/libjxl/src/.clang-format @@ -0,0 +1,4 @@ +BasedOnStyle: Google +IncludeCategories: + - Regex: '^ +# instead . +# - modernize-return-braced-init-list: this often doesn't improve readability. +# - modernize-use-auto: is too aggressive towards using auto. +# - modernize-use-default-member-init: with a mix of constructors and default +# member initialization this can be confusing if enforced. +# - modernize-use-trailing-return-type: does not improve readability when used +# systematically. +# - modernize-use-using: typedefs are ok. +# +# - readability-else-after-return: It doesn't always improve readability. +# - readability-static-accessed-through-instance +# It is often more useful and readable to access a constant of a passed +# variable (like d.N) instead of using the type of the variable that could be +# long and complex. +# - readability-uppercase-literal-suffix: we write 1.0f, not 1.0F. + +Checks: >- + bugprone-*, + clang-*, + -clang-diagnostic-unused-command-line-argument, + google-*, + modernize-*, + performance-*, + readability-*, + -google-readability-todo, + -modernize-deprecated-headers, + -modernize-return-braced-init-list, + -modernize-use-auto, + -modernize-use-default-member-init, + -modernize-use-trailing-return-type, + -modernize-use-using, + -readability-else-after-return, + -readability-function-cognitive-complexity, + -readability-static-accessed-through-instance, + -readability-uppercase-literal-suffix, + + +WarningsAsErrors: >- + bugprone-argument-comment, + bugprone-macro-parentheses, + bugprone-suspicious-string-compare, + bugprone-use-after-move, + clang-*, + clang-analyzer-*, + -clang-diagnostic-unused-command-line-argument, + google-build-using-namespace, + google-explicit-constructor, + google-readability-braces-around-statements, + google-readability-namespace-comments, + modernize-use-override, + readability-inconsistent-declaration-parameter-name + +# We are only interested in the headers from this projects, excluding +# third_party/ and build/. +HeaderFilterRegex: '^.*/(lib|tools)/.*\.h$' + +CheckOptions: + - key: readability-braces-around-statements.ShortStatementLines + value: '2' + - key: google-readability-braces-around-statements.ShortStatementLines + value: '2' + - key: readability-implicit-bool-conversion.AllowPointerConditions + value: '1' + - key: readability-implicit-bool-conversion.AllowIntegerConditions + value: '1' diff --git a/media/libjxl/src/.gitignore b/media/libjxl/src/.gitignore new file mode 100644 index 000000000..58fea2d95 --- /dev/null +++ b/media/libjxl/src/.gitignore @@ -0,0 +1,17 @@ +# Build output directories +/build +/build* +/docker/*.log + +# The downloaded corpora files for benchmark. +/third_party/corpora + +# hdrvdp source code +third_party/hdrvdp-2.2.2 +third_party/hdrvdp-2.2.2.zip +third_party/hdrvdp-2.2.2.zip.tmp + +# Output plots +tools/benchmark/metrics/plots +tools/benchmark/metrics/results.csv +tools/conformance/__pycache__ diff --git a/media/libjxl/src/.gitmodules b/media/libjxl/src/.gitmodules new file mode 100644 index 000000000..bd008a612 --- /dev/null +++ b/media/libjxl/src/.gitmodules @@ -0,0 +1,27 @@ +[submodule "third_party/brotli"] + path = third_party/brotli + url = https://github.com/google/brotli +[submodule "third_party/lcms"] + path = third_party/lcms + url = https://github.com/mm2/Little-CMS +[submodule "third_party/googletest"] + path = third_party/googletest + url = https://github.com/google/googletest +[submodule "third_party/sjpeg"] + path = third_party/sjpeg + url = https://github.com/webmproject/sjpeg.git +[submodule "third_party/skcms"] + path = third_party/skcms + url = https://skia.googlesource.com/skcms +[submodule "third_party/highway"] + path = third_party/highway + url = https://github.com/google/highway +[submodule "third_party/libpng"] + path = third_party/libpng + url = https://github.com/glennrp/libpng.git +[submodule "third_party/zlib"] + path = third_party/zlib + url = https://github.com/madler/zlib.git +[submodule "third_party/testdata"] + path = testdata + url = https://github.com/libjxl/testdata diff --git a/media/libjxl/src/.readthedocs.yaml b/media/libjxl/src/.readthedocs.yaml new file mode 100644 index 000000000..6d714ba1a --- /dev/null +++ b/media/libjxl/src/.readthedocs.yaml @@ -0,0 +1,17 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +# +# readthedocs.io configuration file. See: +# https://docs.readthedocs.io/en/stable/config-file/v2.html + +version: 2 + +sphinx: + configuration: doc/sphinx/conf.py + +python: + version: "3.7" + install: + - requirements: doc/sphinx/requirements.txt diff --git a/media/libjxl/src/AUTHORS b/media/libjxl/src/AUTHORS new file mode 100644 index 000000000..c8522b8fc --- /dev/null +++ b/media/libjxl/src/AUTHORS @@ -0,0 +1,54 @@ +# List of the project authors for copyright purposes. When contributing to the +# project add your name or your organization's name to this list. See +# CONTRIBUTING.md for details. +# +# For organizations: +# Organization +# +# For individuals: +# Name +# +# Please keep each list sorted. If you wish to change your email address please +# send a pull request. + +# Organizations: +Cloudinary Ltd. <*@cloudinary.com> +Google LLC <*@google.com> + +# Individuals: +Alex Xu (Hello71) +Alexander Sago +Andrius Lukas Narbutas +Aous Naman +Artem Selishchev +Biswapriyo Nath +CanadianBaconBoi +Daniel Novomeský +David Burnett +Dirk Lemstra +Don Olmstead +Even Rouault +Heiko Becker +Jon Sneyers +Kai Hollberg +Kleis Auke Wolthuizen +L. E. Segovia +Leo Izen +Lovell Fuller +Maarten DB +Marcin Konicki +Martin Strunz +Mathieu Malaterre +Mikk Leini +Misaki Kasumi +Petr Diblík +Pieter Wuille +roland-rollo +Samuel Leong +Sandro +Stephan T. Lavavej +Thomas Bonfort +Vincent Torri +xiota +Yonatan Nebenzhal +Ziemowit Zabawa diff --git a/media/libjxl/src/CHANGELOG.md b/media/libjxl/src/CHANGELOG.md new file mode 100644 index 000000000..cf6840080 --- /dev/null +++ b/media/libjxl/src/CHANGELOG.md @@ -0,0 +1,261 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Unreleased + +## [0.7] - 2022-07-21 + +### Added + - Export version information in headers. + - decoder API: Ability to decode the content of metadata boxes: + `JXL_DEC_BOX`, `JXL_DEC_BOX_NEED_MORE_OUTPUT`, `JxlDecoderSetBoxBuffer`, + `JxlDecoderGetBoxType`, `JxlDecoderGetBoxSizeRaw` and + `JxlDecoderSetDecompressBoxes`. + - decoder API: ability to mark the input is finished: `JxlDecoderCloseInput`. + - decoder API: ability to request updates on different progressive events using + `JxlDecoderSetProgressiveDetail`; currently supported events are + `kDC`, `kLastPasses` and `kPasses`. + - decoder API: ability to specify desired intensity target using + `JxlDecoderSetDesiredIntensityTarget` + - decoder API: new function `JxlDecoderSetCoalesced` to allow decoding + non-coalesced (unblended) frames, e.g. layers of a composite still image + or the cropped frames of a recompressed GIF/APNG. + - decoder API: new function `JxlDecoderSetUnpremultiplyAlpha` to set + preference for getting an associated alpha channel with premultiplied or + unpremultiplied colors. + - decoder API: field added to `JxlFrameHeader`: a `JxlLayerInfo` struct + that contains crop dimensions and offsets and blending information for + the non-coalesced case. + - decoder API: new function `JxlDecoderGetExtraChannelBlendInfo` to get + the blending information for extra channels in the non-coalesced case. + - decoder API: new function `JxlDecoderSetMultithreadedImageOutCallback`, + allowing output callbacks to receive more information about the number of + threads on which they are running. + - decoder API: new function `JxlDecoderSkipCurrentFrame` to skip processing + the current frame after a progressive detail is reached. + - decoder API: new function `JxlDecoderGetIntendedDownsamplingRatio` to get + the intended downsampling ratio of progressive steps, based on the + information in the frame header. + - decoder API: new function `JxlDecoderSetRenderSpotcolors` to allow disabling + rendering of spot colors. + - decoder/encoder API: add two fields to `JXLBasicInfo`: `intrinsic_xsize` + and `intrinsic_ysize` to signal the intrinsic size. + - encoder API: ability to add metadata boxes, added new functions + `JxlEncoderAddBox`, `JxlEncoderUseBoxes`, `JxlEncoderCloseBoxes` and + `JxlEncoderCloseFrames`. + - encoder API: added ability to set several encoder options / extra fields to + frames using `JxlEncoderSetFrameName`, `JxlEncoderFrameSettingsSetOption`, + `JxlEncoderFrameSettingsSetFloatOption`. + - encoder API: added ability to check required codestream compatibility level + and force specified using `JxlEncoderGetRequiredCodestreamLevel` and + `JxlEncoderSetCodestreamLevel`. + - encoder API: added ability to force emitting box-based container format + using `JxlEncoderUseContainer`. + - encoder API: added ability to store JPEG metadata for lossless reconstruction + using `JxlEncoderStoreJPEGMetadata` + - encoder API: new functions `JxlEncoderSetFrameHeader` and + `JxlEncoderSetExtraChannelBlendInfo` to set animation + and blending parameters of the frame, and `JxlEncoderInitFrameHeader` and + `JxlEncoderInitBlendInfo` to initialize the structs to set. + - encoder API: ability to encode arbitrary extra channels: + `JxlEncoderInitExtraChannelInfo`, `JxlEncoderSetExtraChannelInfo`, + `JxlEncoderSetExtraChannelName` and `JxlEncoderSetExtraChannelBuffer`. + - encoder API: ability to plug custom CMS implementation using + `JxlEncoderSetCms(JxlEncoder* enc, JxlCmsInterface cms)` + - encoder API: added `JxlEncoderGetError` to retrieve last encoder error. + +### Changed +- decoder API: using `JxlDecoderCloseInput` at the end of all input is required + when using JXL_DEC_BOX, and is now also encouraged in other cases, but not + required in those other cases for backwards compatibility. +- encoder API: `JxlEncoderCloseInput` now closes both frames and boxes input. +- CLI: `cjxl` and `djxl` have been reimplemented on the base of public decoder + and encoder API; dropped dependency on `gflags` for argument parsing. + +### Deprecated +- decoder API: `JXL_DEC_EXTENSIONS` event: use `JXL_DEC_BASIC_INFO` +- decoder / encoder API: pixel types `JXL_TYPE_BOOLEAN` and `JXL_TYPE_UINT32`: + consider using `JXL_TYPE_UINT8` and `JXL_TYPE_FLOAT` correspondingly. +- decoder API: pixel format parameter for `JxlDecoderGetColorAsEncodedProfile` + and `JxlDecoderGetICCProfileSize`: pass `NULL`. +- decoder API: `JxlDecoderDefaultPixelFormat` +- encoder API: `JxlEncoderOptions`: use `JxlEncoderFrameSettings` instead. +- encoder API: `JxlEncoderOptionsCreate`: use `JxlEncoderFrameSettingsCreate` + instead. +- encoder API: `JxlEncoderOptionsSetDistance`: use `JxlEncoderSetFrameDistance` + instead. +- encoder API: `JxlEncoderOptionsSetLossless`: use `JxlEncoderSetFrameLossless` + instead. +- encoder API: `JxlEncoderOptionsSetEffort`: use + `JxlEncoderFrameSettingsSetOption(frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, effort)` + instead. +- encoder API: `JxlEncoderOptionsSetDecodingSpeed`: use + `JxlEncoderFrameSettingsSetOption(frame_settings, JXL_ENC_FRAME_SETTING_DECODING_SPEED, tier)` + instead. +- encoder API: deprecated `JXL_ENC_NOT_SUPPORTED`, the encoder returns + `JXL_ENC_ERROR` instead and there is no need to handle + `JXL_ENC_NOT_SUPPORTED`. + +## [0.6.1] - 2021-10-29 +### Changed + - Security: Fix OOB read in splines rendering (#735 - + [CVE-2021-22563](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-22563)) + - Security: Fix OOB copy (read/write) in out-of-order/multi-threaded decoding + (#708 - [CVE-2021-22564](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-22564)) + - Fix segfault in `djxl` tool with `--allow_partial_files` flag (#781). + - Fix border in extra channels when using upsampling (#796) + +## [0.6] - 2021-10-04 +### Added + - API: New functions to decode extra channels: + `JxlDecoderExtraChannelBufferSize` and `JxlDecoderSetExtraChannelBuffer`. + - API: New function `JxlEncoderInitBasicInfo` to initialize `JxlBasicInfo` + (only needed when encoding). NOTE: it is now required to call this function + when using the encoder. Padding was added to the struct for forward + compatibility. + - API: Support for encoding oriented images. + - API: FLOAT16 support in the encoder API. + - Rewrite of the GDK pixbuf loader plugin. Added proper color management and + animation support. + - Rewrite of GIMP plugin. Added compression parameters dialog and switched to + using the public C API. + - Debian packages for GDK pixbuf loader (`libjxl-gdk-pixbuf`) and GIMP + (`libjxl-gimp-plugin`) plugins. + - `cjxl`/`djxl` support for `stdin` and `stdout`. + +### Changed + - API: Renamed the field `alpha_associated` in `JxlExtraChannelInfo` to + `alpha_premultiplied`, to match the corresponding name in `JxlBasicInfo`. + - Improved the 2x2 downscaling method in the encoder for the optional color + channel resampling for low bit rates. + - Fixed: the combination of floating point original data, XYB color encoding, + and Modular mode was broken (in both encoder and decoder). It now works. + NOTE: this can cause the current encoder to write jxl bitstreams that do + not decode with the old decoder. In particular this will happen when using + cjxl with PFM, EXR, or floating point PSD input, and a combination of XYB + and modular mode is used (which caused an encoder error before), e.g. + using options like `-m -q 80` (lossy modular), `-d 4.5` or `--progressive_dc=1` + (modular DC frame), or default lossy encoding on an image where patches + end up being used. There is no problem when using cjxl with PNG, JPEG, GIF, + APNG, PPM, PGM, PGX, or integer (8-bit or 16-bit) PSD input. + - `libjxl` static library now bundles skcms, fixing static linking in + downstream projects when skcms is used. + - Spline rendering performance improvements. + - Butteraugli changes for less visual masking. + +## [0.5] - 2021-08-02 +### Added + - API: New function to decode the image using a callback outputting a part of a + row per call. + - API: 16-bit float output support. + - API: `JxlDecoderRewind` and `JxlDecoderSkipFrames` functions to skip more + efficiently to earlier animation frames. + - API: `JxlDecoderSetPreferredColorProfile` function to choose color profile in + certain circumstances. + - encoder: Adding `center_x` and `center_y` flags for more control of the tile + order. + - New encoder speeds `lightning` (1) and `thunder` (2). + +### Changed + - Re-licensed the project under a BSD 3-Clause license. See the + [LICENSE](LICENSE) and [PATENTS](PATENTS) files for details. + - Full JPEG XL part 1 specification support: Implemented all the spec required + to decode files to pixels, including cases that are not used by the encoder + yet. Part 2 of the spec (container format) is final but not fully implemented + here. + - Butteraugli metric improvements. Exact numbers are different from previous + versions. + - Memory reductions during decoding. + - Reduce the size of the jxl_dec library by removing dependencies. + - A few encoding speedups. + - Clarify the security policy. + - Significant encoding improvements (~5 %) and less ringing. + - Butteraugli metric to have some less masking. + - `cjxl` flag `--speed` is deprecated and replaced by the `--effort` synonym. + +### Removed +- API for returning a downsampled DC was deprecated + (`JxlDecoderDCOutBufferSize` and `JxlDecoderSetDCOutBuffer`) and will be + removed in the next release. + +## [0.3.7] - 2021-03-29 +### Changed + - Fix a rounding issue in 8-bit decoding. + +## [0.3.6] - 2021-03-25 +### Changed + - Fix a bug that could result in the generation of invalid codestreams as + well as failure to decode valid streams. + +## [0.3.5] - 2021-03-23 +### Added + - New encode-time options for faster decoding at the cost of quality. + - Man pages for cjxl and djxl. + +### Changed + - Memory usage improvements. + - Faster decoding to 8-bit output with the C API. + - GIMP plugin: avoid the sRGB conversion dialog for sRGB images, do not show + a console window on Windows. + - Various bug fixes. + +## [0.3.4] - 2021-03-16 +### Changed + - Improved box parsing. + - Improved metadata handling. + - Performance and memory usage improvements. + +## [0.3.3] - 2021-03-05 +### Changed + - Performance improvements for small images. + - Add a (flag-protected) non-high-precision mode with better speed. + - Significantly speed up the PQ EOTF. + - Allow optional HDR tone mapping in djxl (--tone_map, --display_nits). + - Change the behavior of djxl -j to make it consistent with cjxl (#153). + - Improve image quality. + - Improve EXIF handling. + +## [0.3.2] - 2021-02-12 +### Changed + - Fix embedded ICC encoding regression + [#149](https://gitlab.com/wg1/jpeg-xl/-/issues/149). + +## [0.3.1] - 2021-02-10 +### Changed + - New experimental Butteraugli API (`jxl/butteraugli.h`). + - Encoder improvements to low quality settings. + - Bug fixes, including fuzzer-found potential security bug fixes. + - Fixed `-q 100` and `-d 0` not triggering lossless modes. + +## [0.3] - 2021-01-29 +### Changed + - Minor change to the Decoder C API to accommodate future work for other ways + to provide input. + - Future decoder C API changes will be backwards compatible. + - Lots of bug fixes since the previous version. + +## [0.2] - 2020-12-24 +### Added + - JPEG XL bitstream format is frozen. Files encoded with 0.2 will be supported + by future versions. + +### Changed + - Files encoded with previous versions are not supported. + +## [0.1.1] - 2020-12-01 + +## [0.1] - 2020-11-14 +### Added + - Initial release of an encoder (`cjxl`) and decoder (`djxl`) that work + together as well as a benchmark tool for comparison with other codecs + (`benchmark_xl`). + - Note: JPEG XL format is in the final stages of standardization, minor changes + to the codestream format are still possible but we are not expecting any + changes beyond what is required by bug fixing. + - API: new decoder API in C, check the `examples/` directory for its example + usage. The C API is a work in progress and likely to change both in API and + ABI in future releases. diff --git a/media/libjxl/src/CMakeLists.txt b/media/libjxl/src/CMakeLists.txt new file mode 100644 index 000000000..533815d23 --- /dev/null +++ b/media/libjxl/src/CMakeLists.txt @@ -0,0 +1,472 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Ubuntu bionic ships with cmake 3.10. +cmake_minimum_required(VERSION 3.10) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +# Honor VISIBILITY_INLINES_HIDDEN on all types of targets. +if(POLICY CMP0063) + cmake_policy(SET CMP0063 NEW) +endif() +# Pass CMAKE_EXE_LINKER_FLAGS to CC and CXX compilers when testing if they work. +if(POLICY CMP0065) + cmake_policy(SET CMP0065 NEW) +endif() + +# Set PIE flags for POSITION_INDEPENDENT_CODE targets, added in 3.14. +if(POLICY CMP0083) + cmake_policy(SET CMP0083 NEW) +endif() + +project(LIBJXL LANGUAGES C CXX) + +include(CheckCXXSourceCompiles) +check_cxx_source_compiles( + "int main() { + #if !defined(__EMSCRIPTEN__) + static_assert(false, \"__EMSCRIPTEN__ is not defined\"); + #endif + return 0; + }" + JPEGXL_EMSCRIPTEN +) + +message(STATUS "CMAKE_SYSTEM_PROCESSOR is ${CMAKE_SYSTEM_PROCESSOR}") +include(CheckCXXCompilerFlag) +check_cxx_compiler_flag("-fsanitize=fuzzer-no-link" CXX_FUZZERS_SUPPORTED) +check_cxx_compiler_flag("-Xclang -mconstructor-aliases" CXX_CONSTRUCTOR_ALIASES_SUPPORTED) +check_cxx_compiler_flag("-fmacro-prefix-map=OLD=NEW" CXX_MACRO_PREFIX_MAP) +check_cxx_compiler_flag("-fno-rtti" CXX_NO_RTTI_SUPPORTED) + +# Enabled PIE binaries by default if supported. +include(CheckPIESupported OPTIONAL RESULT_VARIABLE CHECK_PIE_SUPPORTED) +if(CHECK_PIE_SUPPORTED) + check_pie_supported(LANGUAGES CXX) + if(CMAKE_CXX_LINK_PIE_SUPPORTED) + set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) + endif() +endif() + +### Project build options: +if(CXX_FUZZERS_SUPPORTED) + # Enabled by default except on arm64, Windows and Apple builds. + set(ENABLE_FUZZERS_DEFAULT true) +endif() +find_package(PkgConfig) +if(NOT APPLE AND NOT WIN32 AND NOT HAIKU AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + pkg_check_modules(TCMallocMinimalVersionCheck QUIET IMPORTED_TARGET + libtcmalloc_minimal) + if(TCMallocMinimalVersionCheck_FOUND AND + NOT TCMallocMinimalVersionCheck_VERSION VERSION_EQUAL 2.8.0) + # Enabled by default except on Windows and Apple builds for + # tcmalloc != 2.8.0. tcmalloc 2.8.1 already has a fix for this issue. + set(ENABLE_TCMALLOC_DEFAULT true) + else() + message(STATUS + "tcmalloc version ${TCMallocMinimalVersionCheck_VERSION} -- " + "tcmalloc 2.8.0 disabled due to " + "https://github.com/gperftools/gperftools/issues/1204") + endif() +endif() + +check_cxx_source_compiles( + "int main() { + #if !defined(HWY_DISABLED_TARGETS) + static_assert(false, \"HWY_DISABLED_TARGETS is not defined\"); + #endif + return 0; + }" + JXL_HWY_DISABLED_TARGETS_FORCED +) + +set(WARNINGS_AS_ERRORS_DEFAULT false) + +if((SANITIZER STREQUAL "msan") OR JPEGXL_EMSCRIPTEN) + set(BUNDLE_LIBPNG_DEFAULT YES) +else() + set(BUNDLE_LIBPNG_DEFAULT NO) +endif() + +# Standard cmake naming for building shared libraries. +get_property(SHARED_LIBS_SUPPORTED GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS) +option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" ${SHARED_LIBS_SUPPORTED}) + +set(JPEGXL_ENABLE_FUZZERS ${ENABLE_FUZZERS_DEFAULT} CACHE BOOL + "Build JPEGXL fuzzer targets.") +set(JPEGXL_ENABLE_DEVTOOLS false CACHE BOOL + "Build JPEGXL developer tools.") +set(JPEGXL_ENABLE_TOOLS true CACHE BOOL + "Build JPEGXL user tools: cjxl and djxl.") +set(JPEGXL_ENABLE_DOXYGEN true CACHE BOOL + "Generate C API documentation using Doxygen.") +set(JPEGXL_ENABLE_MANPAGES true CACHE BOOL + "Build and install man pages for the command-line tools.") +set(JPEGXL_ENABLE_BENCHMARK true CACHE BOOL + "Build JPEGXL benchmark tools.") +set(JPEGXL_ENABLE_EXAMPLES true CACHE BOOL + "Build JPEGXL library usage examples.") +set(JPEGXL_BUNDLE_LIBPNG ${BUNDLE_LIBPNG_DEFAULT} CACHE BOOL + "Build libpng from source and link it statically.") +set(JPEGXL_ENABLE_JNI true CACHE BOOL + "Build JPEGXL JNI Java wrapper, if Java dependencies are installed.") +set(JPEGXL_ENABLE_SJPEG true CACHE BOOL + "Build JPEGXL with support for encoding with sjpeg.") +set(JPEGXL_ENABLE_OPENEXR true CACHE BOOL + "Build JPEGXL with support for OpenEXR if available.") +set(JPEGXL_ENABLE_SKCMS true CACHE BOOL + "Build with skcms instead of lcms2.") +set(JPEGXL_BUNDLE_SKCMS true CACHE BOOL + "When building with skcms, bundle it into libjxl.a.") +set(JPEGXL_ENABLE_VIEWERS false CACHE BOOL + "Build JPEGXL viewer tools for evaluation.") +set(JPEGXL_ENABLE_TCMALLOC ${ENABLE_TCMALLOC_DEFAULT} CACHE BOOL + "Build JPEGXL using gperftools (tcmalloc) allocator.") +set(JPEGXL_ENABLE_PLUGINS false CACHE BOOL + "Build third-party plugins to support JPEG XL in other applications.") +set(JPEGXL_ENABLE_COVERAGE false CACHE BOOL + "Enable code coverage tracking for libjxl. This also enables debug and disables optimizations.") +set(JPEGXL_ENABLE_PROFILER false CACHE BOOL + "Builds in support for profiling (printed by tools if extra flags given)") +set(JPEGXL_ENABLE_SIZELESS_VECTORS false CACHE BOOL + "Builds in support for SVE/RVV vectorization") +set(JPEGXL_ENABLE_TRANSCODE_JPEG true CACHE BOOL + "Builds in support for decoding transcoded JXL files back to JPEG,\ + disabling it makes the decoder reject JXL_DEC_JPEG_RECONSTRUCTION events,\ + (default enabled)") +set(JPEGXL_STATIC false CACHE BOOL + "Build tools as static binaries.") +set(JPEGXL_WARNINGS_AS_ERRORS ${WARNINGS_AS_ERRORS_DEFAULT} CACHE BOOL + "Treat warnings as errors during compilation.") +set(JPEGXL_DEP_LICENSE_DIR "" CACHE STRING + "Directory where to search for system dependencies \"copyright\" files.") +set(JPEGXL_FORCE_NEON false CACHE BOOL + "Set flags to enable NEON in arm if not enabled by your toolchain.") + + +# Force system dependencies. +set(JPEGXL_FORCE_SYSTEM_BROTLI false CACHE BOOL + "Force using system installed brotli instead of third_party/brotli source.") +set(JPEGXL_FORCE_SYSTEM_GTEST false CACHE BOOL + "Force using system installed googletest (gtest/gmock) instead of third_party/googletest source.") +set(JPEGXL_FORCE_SYSTEM_LCMS2 false CACHE BOOL + "Force using system installed lcms2 instead of third_party/lcms source.") +set(JPEGXL_FORCE_SYSTEM_HWY false CACHE BOOL + "Force using system installed highway (libhwy-dev) instead of third_party/highway source.") + +# Check minimum compiler versions. Older compilers are not supported and fail +# with hard to understand errors. +if (NOT CMAKE_C_COMPILER_ID STREQUAL CMAKE_CXX_COMPILER_ID) + message(FATAL_ERROR "Different C/C++ compilers set: " + "${CMAKE_C_COMPILER_ID} vs ${CMAKE_CXX_COMPILER_ID}") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # Android NDK's toolchain.cmake fakes the clang version in + # CMAKE_CXX_COMPILER_VERSION with an incorrect number, so ignore this. + if (NOT CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION MATCHES "clang" + AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 5) + message(FATAL_ERROR + "Minimum Clang version required is Clang 5, please update.") + endif() +elseif (CMAKE_CXX_COMPILER_ID MATCHES "GNU") + if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7) + message(FATAL_ERROR + "Minimum GCC version required is 7, please update.") + endif() +endif() + +message(STATUS + "Compiled IDs C:${CMAKE_C_COMPILER_ID}, C++:${CMAKE_CXX_COMPILER_ID}") + +# CMAKE_EXPORT_COMPILE_COMMANDS is used to generate the compilation database +# used by clang-tidy. +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(JPEGXL_STATIC) + set(BUILD_SHARED_LIBS 0) + # Clang developers say that in case to use "static" we have to build stdlib + # ourselves; for real use case we don't care about stdlib, as it is "granted", + # so just linking all other libraries is fine. + if (NOT MSVC AND NOT APPLE) + set(CMAKE_FIND_LIBRARY_SUFFIXES .a) + set(CMAKE_EXE_LINKER_FLAGS + "${CMAKE_EXE_LINKER_FLAGS} -static -static-libgcc -static-libstdc++") + endif() +endif() # JPEGXL_STATIC + +# Threads +set(THREADS_PREFER_PTHREAD_FLAG YES) +find_package(Threads REQUIRED) + +# These settings are important to drive check_cxx_source_compiles +# See CMP0067 (min cmake version is 3.10 anyway) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +# Atomics +find_package(Atomics REQUIRED) + +if(JPEGXL_STATIC) + if (MINGW) + # In MINGW libstdc++ uses pthreads directly. When building statically a + # program (regardless of whether the source code uses pthread or not) the + # toolchain will add stdc++ and pthread to the linking step but stdc++ will + # be linked statically while pthread will be linked dynamically. + # To avoid this and have pthread statically linked with need to pass it in + # the command line with "-Wl,-Bstatic -lpthread -Wl,-Bdynamic" but the + # linker will discard it if not used by anything else up to that point in + # the linker command line. If the program or any dependency don't use + # pthread directly -lpthread is discarded and libstdc++ (added by the + # toolchain later) will then use the dynamic version. For this we also need + # to pass -lstdc++ explicitly before -lpthread. For pure C programs -lstdc++ + # will be discarded anyway. + # This adds these flags as dependencies for *all* targets. Adding this to + # CMAKE_EXE_LINKER_FLAGS instead would cause them to be included before any + # object files and therefore discarded. This should be set in the + # INTERFACE_LINK_LIBRARIES of Threads::Threads but some third_part targets + # don't depend on it. + link_libraries(-Wl,-Bstatic -lstdc++ -lpthread -Wl,-Bdynamic) + elseif(CMAKE_USE_PTHREADS_INIT) + # "whole-archive" is not supported on OSX. + if (NOT APPLE) + # Set pthreads as a whole-archive, otherwise weak symbols in the static + # libraries will discard pthreads symbols leading to segmentation fault at + # runtime. + message(STATUS "Using -lpthread as --whole-archive") + set_target_properties(Threads::Threads PROPERTIES + INTERFACE_LINK_LIBRARIES + "-Wl,--whole-archive;-lpthread;-Wl,--no-whole-archive") + endif() + endif() +endif() # JPEGXL_STATIC + +if (JPEGXL_EMSCRIPTEN) + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -pthread") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") +endif() + +if (CXX_MACRO_PREFIX_MAP) + add_compile_options(-fmacro-prefix-map=${CMAKE_CURRENT_SOURCE_DIR}=.) +endif() + +if (CXX_NO_RTTI_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti") +endif() + +if (MSVC) +# TODO(janwas): add flags +else () + +# Global compiler flags for all targets here and in subdirectories. +add_definitions( + # Avoid changing the binary based on the current time and date. + -D__DATE__="redacted" + -D__TIMESTAMP__="redacted" + -D__TIME__="redacted" +) + +# Avoid log spam from fopen etc. +if(MSVC) + add_definitions(-D_CRT_SECURE_NO_WARNINGS) +endif() + +# TODO(eustas): JXL currently compiles, but does not pass tests... +if (NOT JXL_HWY_DISABLED_TARGETS_FORCED AND NOT JPEGXL_ENABLE_SIZELESS_VECTORS) + add_definitions(-DHWY_DISABLED_TARGETS=\(HWY_SVE|HWY_SVE2|HWY_SVE_256|HWY_SVE2_128|HWY_RVV\)) + message("Warning: HWY_SVE, HWY_SVE2, HWY_SVE_256, HWY_SVE2_128 and HWY_RVV CPU targets are disabled") +endif() + +# In CMake before 3.12 it is problematic to pass repeated flags like -Xclang. +# For this reason we place them in CMAKE_CXX_FLAGS instead. +# See https://gitlab.kitware.com/cmake/cmake/issues/15826 + +# Machine flags. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funwind-tables") +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xclang -mrelax-all") +endif() +if (CXX_CONSTRUCTOR_ALIASES_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xclang -mconstructor-aliases") +endif() + +if(WIN32) +# Not supported by clang-cl, but frame pointers are default on Windows +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") +endif() + +# CPU flags - remove once we have NEON dynamic dispatch + +# TODO(janwas): this also matches M1, but only ARMv7 is intended/needed. +if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm") +if(JPEGXL_FORCE_NEON) +# GCC requires these flags, otherwise __ARM_NEON is undefined. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \ + -mfpu=neon-vfpv4 -mfloat-abi=hard") +endif() +endif() + +# Force build with optimizations in release mode. +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2") + +add_compile_options( + # Ignore this to allow redefining __DATE__ and others. + -Wno-builtin-macro-redefined + + # Global warning settings. + -Wall +) + +if (JPEGXL_WARNINGS_AS_ERRORS) +add_compile_options(-Werror) +endif () +endif () # !MSVC + +include(GNUInstallDirs) + +# Separately build/configure testing frameworks and other third_party libraries +# to allow disabling tests in those libraries. +include(third_party/testing.cmake) +add_subdirectory(third_party) +# Copy the JXL license file to the output build directory. +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.jpeg-xl COPYONLY) + +# Enable tests regardless of where they are defined. +enable_testing() +include(CTest) +# Specify default location of `testdata`: +if(NOT DEFINED JPEGXL_TEST_DATA_PATH) + set(JPEGXL_TEST_DATA_PATH "${PROJECT_SOURCE_DIR}/testdata") +endif() + +# Libraries. +add_subdirectory(lib) + +if(BUILD_TESTING) +# Script to run tests over the source code in bash. +find_program (BASH_PROGRAM bash) +if(BASH_PROGRAM) + add_test( + NAME bash_test + COMMAND ${BASH_PROGRAM} ${CMAKE_CURRENT_SOURCE_DIR}/bash_test.sh) +endif() +endif() # BUILD_TESTING + +# Documentation generated by Doxygen +if(JPEGXL_ENABLE_DOXYGEN) +find_package(Doxygen) +if(DOXYGEN_FOUND) +set(DOXYGEN_GENERATE_HTML "YES") +set(DOXYGEN_GENERATE_XML "YES") +set(DOXYGEN_STRIP_FROM_PATH "${CMAKE_CURRENT_SOURCE_DIR}/lib/include") +set(DOXYGEN_USE_MDFILE_AS_MAINPAGE "README.md") +if(JPEGXL_WARNINGS_AS_ERRORS) +set(DOXYGEN_WARN_AS_ERROR "YES") +endif() +set(DOXYGEN_QUIET "YES") +doxygen_add_docs(doc + "${CMAKE_CURRENT_SOURCE_DIR}/lib/include" + "${CMAKE_CURRENT_SOURCE_DIR}/doc/api.txt" + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + COMMENT "Generating C API documentation") + +# Add sphinx doc build step for readthedocs.io (requires doxygen too). +find_program(SPHINX_BUILD_PROGRAM sphinx-build) +if(SPHINX_BUILD_PROGRAM) + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/rtd/nonexistent" + COMMENT "Generating readthedocs.io output on ${CMAKE_CURRENT_BINARY_DIR}/rtd" + COMMAND ${SPHINX_BUILD_PROGRAM} -q -W -b html -j auto + ${CMAKE_SOURCE_DIR}/doc/sphinx + ${CMAKE_CURRENT_BINARY_DIR}/rtd + DEPENDS doc + ) + # This command runs the documentation generation every time since the output + # target file doesn't exist. + add_custom_target(rtd-html + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/rtd/nonexistent + ) +else() # SPHINX_BUILD_PROGRAM\ + message(WARNING "sphinx-build not found, skipping rtd documentation") +endif() # SPHINX_BUILD_PROGRAM + +else() +# Create a "doc" target for compatibility since "doc" is not otherwise added to +# the build when doxygen is not installed. +add_custom_target(doc false + COMMENT "Error: Can't generate doc since Doxygen not installed.") +endif() # DOXYGEN_FOUND +endif() # JPEGXL_ENABLE_DOXYGEN + +if(JPEGXL_ENABLE_MANPAGES) +find_program(ASCIIDOC a2x) +if(ASCIIDOC) +file(STRINGS "${ASCIIDOC}" ASCIIDOC_SHEBANG LIMIT_COUNT 1) +if(ASCIIDOC_SHEBANG MATCHES "/sh|/bash") + set(ASCIIDOC_PY_FOUND ON) + # Run the program directly and set ASCIIDOC as empty. + set(ASCIIDOC_PY "${ASCIIDOC}") + set(ASCIIDOC "") +elseif(ASCIIDOC_SHEBANG MATCHES "python2") + find_package(Python2 COMPONENTS Interpreter) + set(ASCIIDOC_PY_FOUND "${Python2_Interpreter_FOUND}") + set(ASCIIDOC_PY Python2::Interpreter) +elseif(ASCIIDOC_SHEBANG MATCHES "python3") + find_package(Python3 COMPONENTS Interpreter) + set(ASCIIDOC_PY_FOUND "${Python3_Interpreter_FOUND}") + set(ASCIIDOC_PY Python3::Interpreter) +else() + find_package(Python COMPONENTS Interpreter QUIET) + if(NOT Python_Interpreter_FOUND) + find_program(ASCIIDOC_PY python) + if(ASCIIDOC_PY) + set(ASCIIDOC_PY_FOUND ON) + endif() + else() + set(ASCIIDOC_PY_FOUND "${Python_Interpreter_FOUND}") + set(ASCIIDOC_PY Python::Interpreter) + endif() +endif() + +if (ASCIIDOC_PY_FOUND) + set(MANPAGE_FILES "") + set(MANPAGES "") + foreach(PAGE IN ITEMS cjxl djxl) + # Invoking the Python interpreter ourselves instead of running the a2x binary + # directly is necessary on MSYS2, otherwise it is run through cmd.exe which + # does not recognize it. + add_custom_command( + OUTPUT "${PAGE}.1" + COMMAND "${ASCIIDOC_PY}" + ARGS ${ASCIIDOC} + --format manpage --destination-dir="${CMAKE_CURRENT_BINARY_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/doc/man/${PAGE}.txt" + MAIN_DEPENDENCY "${CMAKE_CURRENT_SOURCE_DIR}/doc/man/${PAGE}.txt") + list(APPEND MANPAGE_FILES "${CMAKE_CURRENT_BINARY_DIR}/${PAGE}.1") + list(APPEND MANPAGES "${PAGE}.1") + endforeach() + add_custom_target(manpages ALL DEPENDS ${MANPAGES}) + install(FILES ${MANPAGE_FILES} DESTINATION ${CMAKE_INSTALL_MANDIR}/man1) +endif() # ASCIIDOC_PY_FOUND +else() + message(WARNING "asciidoc was not found, the man pages will not be installed.") +endif() # ASCIIDOC +endif() # JPEGXL_ENABLE_MANPAGES + +# Example usage code. +if (JPEGXL_ENABLE_EXAMPLES) +include(examples/examples.cmake) +endif () + +# Plugins for third-party software +if (JPEGXL_ENABLE_PLUGINS) +add_subdirectory(plugins) +endif () + +# Binary tools +add_subdirectory(tools) diff --git a/media/libjxl/src/CODE_OF_CONDUCT.md b/media/libjxl/src/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..b2d81a321 --- /dev/null +++ b/media/libjxl/src/CODE_OF_CONDUCT.md @@ -0,0 +1,93 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of +experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, or to ban temporarily or permanently any +contributor for other behaviors that they deem inappropriate, threatening, +offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when the Project +Steward has a reasonable belief that an individual's behavior may have a +negative impact on the project or its community. + +## Conflict Resolution + +We do not believe that all conflict is bad; healthy debate and disagreement +often yield positive results. However, it is never okay to be disrespectful or +to engage in behavior that violates the project’s code of conduct. + +If you see someone violating the code of conduct, you are encouraged to address +the behavior directly with those involved. Many issues can be resolved quickly +and easily, and this gives people more control over the outcome of their +dispute. If you are unable to resolve the matter for any reason, or if the +behavior is threatening or harassing, report it. We are dedicated to providing +an environment where participants feel welcome and safe. + +Reports should be directed to Jyrki Alakuijala , the +Project Steward(s) for JPEG XL. It is the Project Steward’s duty to +receive and address reported violations of the code of conduct. They will then +work with a committee consisting of representatives from the Open Source +Programs Office and the Google Open Source Strategy team. If for any reason you +are uncomfortable reaching out to the Project Steward, please email +opensource@google.com. + +We will investigate every complaint, but you may not receive a direct response. +We will use our discretion in determining when and how to follow up on reported +incidents, which may range from not taking action to permanent expulsion from +the project and project-sponsored spaces. We will notify the accused of the +report and provide them an opportunity to discuss it before any action is taken. +The identity of the reporter will be omitted from the details of the report +supplied to the accused. In potentially harmful situations, such as ongoing +harassment or threats to anyone's safety, we may take action without notice. + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, +available at +https://www.contributor-covenant.org/version/1/4/code-of-conduct.html diff --git a/media/libjxl/src/CONTRIBUTING.md b/media/libjxl/src/CONTRIBUTING.md new file mode 100644 index 000000000..cb6459797 --- /dev/null +++ b/media/libjxl/src/CONTRIBUTING.md @@ -0,0 +1,132 @@ +# Contributing to libjxl + +## Contributing with bug reports + +For security-related issues please see [SECURITY.md](SECURITY.md). + +We welcome suggestions, feature requests and bug reports. Before opening a new +issue please take a look if there is already an existing one in the following +link: + + * https://github.com/libjxl/libjxl/issues + +## Contributing with patches and Pull Requests + +We'd love to accept your contributions to the JPEG XL Project. Please read +through this section before sending a Pull Request. + +### Contributor License Agreements + +Our project is open source under the terms outlined in the [LICENSE](LICENSE) +and [PATENTS](PATENTS) files. Before we can accept your contributions, even for +small changes, there are just a few small guidelines you need to follow: + +Please fill out either the individual or corporate Contributor License Agreement +(CLA) with Google. JPEG XL Project is an an effort by multiple individuals and +companies, including the initial contributors Cloudinary and Google, but Google +is the legal entity in charge of receiving these CLA and relicensing this +software: + + * If you are an individual writing original source code and you're sure you + own the intellectual property, then you'll need to sign an [individual + CLA](https://code.google.com/legal/individual-cla-v1.0.html). + + * If you work for a company that wants to allow you to contribute your work, + then you'll need to sign a [corporate + CLA](https://code.google.com/legal/corporate-cla-v1.0.html). + +Follow either of the two links above to access the appropriate CLA and +instructions for how to sign and return it. Once we receive it, we'll be able +to accept your pull requests. + +***NOTE***: Only original source code from you and other people that have signed +the CLA can be accepted into the main repository. + +### License + +Contributions are licensed under the project's [LICENSE](LICENSE). Each new +file must include the following header when possible, with comment style adapted +to the language as needed: + +``` +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +``` + +### Code Reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +### Contribution philosophy + + * Prefer small changes, even if they don't implement a complete feature. Small + changes are easier to review and can be submitted faster. Think about what's + the smallest unit you can send that makes sense to review and submit in + isolation. For example, new modules that are not yet used by the tools but + have their own unittests are ok. If you have unrelated changes that + you discovered while working on something else, please send them in a + different Pull Request. If your are refactoring code and changing + functionality try to send the refactor first without any change in + functionality. Reviewers may ask you to split a Pull Request and it is + easier to create a smaller change from the beginning. + + * Describe your commits. Add a meaningful description to your commit message, explain what you are changing if it is not trivially obvious, but more importantly explain *why* you are making those changes. For example "Fix + build" is not a good commit message, describe what build and if it makes sense + why is this fixing it or why was it failing without this. It is very likely + that people far in the future without any context you have right now will be + looking at your commit trying to figure out why was the change introduced. If + related to an issue in this or another repository include a link to it. + + * Code Style: We follow the [Google C++ Coding + Style](https://google.github.io/styleguide/cppguide.html). A + [clang-format](https://clang.llvm.org/docs/ClangFormat.html) configuration + file is available to automatically format your code, you can invoke it with + the `./ci.sh lint` helper tool. + + * Testing: Test your change and explain in the commit message *how* your + commit was tested. For example adding unittests or in some cases just testing + with the existing ones is enough. In any case, mention what testing was + performed so reviewers can evaluate whether that's enough testing. In many + cases, testing that the Continuous Integration workflow passes is enough. + + * Make one commit per Pull Request / review, unless there's a good reason not + to. If you have multiple changes send multiple Pull Requests and each one can + have its own review. + + * When addressing comments from reviewers prefer to squash or fixup your + edits and force-push your commit. When merging changes into the repository we + don't want to include the history of code review back and forth changes or + typos. Reviewers can click on the "force-pushed" automatic comment on a Pull + Request to see the changes between versions. We use "Rebase and merge" policy + to keep a linear git history which is easier to reason about. + + * Your change must pass the build and test workflows. There's a `ci.sh` script + to help building and testing these configurations. See [building and + testing](doc/building_and_testing.md) for more details. + +### Contributing checklist. + + * Sign the CLA (only needed once per user, see above). + + * AUTHORS: If this is your first contribution, add your name or your + company name to the [AUTHORS](AUTHORS) file for copyright tracking purposes. + + * Style guide. Check `./ci.sh lint`. + + * Meaningful commit description: What and *why*, links to issues, testing + procedure. + + * Squashed multiple edits into a single commit. + + * Upload your changes to your fork and [create a Pull + Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request). + +# Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google.com/conduct/). diff --git a/media/libjxl/src/CONTRIBUTORS b/media/libjxl/src/CONTRIBUTORS new file mode 100644 index 000000000..848096f92 --- /dev/null +++ b/media/libjxl/src/CONTRIBUTORS @@ -0,0 +1,23 @@ +# This files lists individuals who made significant contributions to the JPEG XL +# code base, such as design, adding features, performing experiments, ... +# Small changes such as a small bugfix or fixing spelling errors are not +# included. If you'd like to be included in this file thanks to a significant +# contribution, feel free to send a pull request changing this file. +Alex Deymo +Alexander Rhatushnyak +Evgenii Kliuchnikov +Iulia-Maria Comșa +Jan Wassenberg +Jon Sneyers +Jyrki Alakuijala +Krzysztof Potempa +Lode Vandevenne +Luca Versari +Martin Bruse +Moritz Firsching +Renata Khasanova +Robert Obryk +Sami Boukortt +Sebastian Gomez-Gonzalez +Thomas Fischbacher +Zoltan Szabadka diff --git a/media/libjxl/src/LICENSE b/media/libjxl/src/LICENSE new file mode 100644 index 000000000..c66034b10 --- /dev/null +++ b/media/libjxl/src/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) the JPEG XL Project Authors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/media/libjxl/src/PATENTS b/media/libjxl/src/PATENTS new file mode 100644 index 000000000..c95b8f410 --- /dev/null +++ b/media/libjxl/src/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the JPEG XL project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of JPEG XL, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of JPEG XL. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of JPEG XL or any code incorporated within this +implementation of JPEG XL constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of JPEG XL +shall terminate as of the date such litigation is filed. diff --git a/media/libjxl/src/README.Haiku.md b/media/libjxl/src/README.Haiku.md new file mode 100644 index 000000000..20111c570 --- /dev/null +++ b/media/libjxl/src/README.Haiku.md @@ -0,0 +1,20 @@ +## Disclaimer + +Haiku builds are not officially supported, i.e. the build might not work at all, +some tests may fail and some sub-projects are excluded from build. + +This manual outlines Haiku-specific setup. For general building and testing +instructions see "[README](README.md)" and +"[Building and Testing changes](doc/building_and_testing.md)". + +## Dependencies + +```shell +pkgman install llvm9_clang ninja cmake doxygen libjpeg_turbo_devel giflib_devel +``` + +## Building + +```shell +TEST_STACK_LIMIT=none CMAKE_FLAGS="-I/boot/system/develop/tools/lib/gcc/x86_64-unknown-haiku/8.3.0/include/c++ -I/boot/system/develop/tools/lib/gcc/x86_64-unknown-haiku/8.3.0/include/c++/x86_64-unknown-haiku" CMAKE_SHARED_LINKER_FLAGS="-shared -Xlinker -soname=libjpegxl.so -lpthread" ./ci.sh opt +``` diff --git a/media/libjxl/src/README.OSX.md b/media/libjxl/src/README.OSX.md new file mode 100644 index 000000000..8c6dc5a39 --- /dev/null +++ b/media/libjxl/src/README.OSX.md @@ -0,0 +1,41 @@ +## Disclaimer + +OSX builds have "best effort" support, i.e. build might not work at all, some +tests may fail and some sub-projects are excluded from build. + +This manual outlines OSX specific setup. For general building and testing +instructions see "[README](README.md)" and +"[Building and Testing changes](doc/building_and_testing.md)". + +[Homebrew](https://brew.sh/) is a popular package manager. JPEG XL library and +binaries could be installed using it: + +```bash +brew install jpeg-xl +``` + +## Dependencies + +Make sure that `brew doctor` does not report serious problems and up-to-date +version of XCode is installed. + +Installing (actually, building) `clang` might take a couple hours. + +```bash +brew install llvm +``` + +```bash +brew install coreutils cmake giflib jpeg-turbo libpng ninja zlib +``` + +Before building the project check that `which clang` is +`/usr/local/opt/llvm/bin/clang`, not the one provided by XCode. If not, update +`PATH` environment variable. + +Also, setting `CMAKE_PREFIX_PATH` might be necessary for correct include paths +resolving, e.g.: + +```bash +export CMAKE_PREFIX_PATH=`brew --prefix giflib`:`brew --prefix jpeg-turbo`:`brew --prefix libpng`:`brew --prefix zlib` +``` \ No newline at end of file diff --git a/media/libjxl/src/README.md b/media/libjxl/src/README.md new file mode 100644 index 000000000..b0f2e3b50 --- /dev/null +++ b/media/libjxl/src/README.md @@ -0,0 +1,197 @@ +# JPEG XL reference implementation + +[![Build/Test](https://github.com/libjxl/libjxl/actions/workflows/build_test.yml/badge.svg)]( +https://github.com/libjxl/libjxl/actions/workflows/build_test.yml) +[![Build/Test Cross](https://github.com/libjxl/libjxl/actions/workflows/build_test_cross.yml/badge.svg)]( +https://github.com/libjxl/libjxl/actions/workflows/build_test_cross.yml) +[![Conformance](https://github.com/libjxl/libjxl/actions/workflows/conformance.yml/badge.svg)]( +https://github.com/libjxl/libjxl/actions/workflows/conformance.yml) +[![CIFuzz](https://github.com/libjxl/libjxl/actions/workflows/fuzz.yml/badge.svg)]( +https://github.com/libjxl/libjxl/actions/workflows/fuzz.yml) +[![Releases](https://github.com/libjxl/libjxl/actions/workflows/release.yaml/badge.svg)]( +https://github.com/libjxl/libjxl/actions/workflows/release.yaml) +[![Doc](https://readthedocs.org/projects/libjxl/badge/?version=latest)]( +https://libjxl.readthedocs.io/en/latest/?badge=latest) +[![codecov](https://codecov.io/gh/libjxl/libjxl/branch/main/graph/badge.svg)]( +https://codecov.io/gh/libjxl/libjxl) + +JXL logo + +This repository contains a reference implementation of JPEG XL (encoder and +decoder), called `libjxl`. This software library is +[used by many applications that support JPEG XL](doc/software_support.md). + +JPEG XL is in the final stages of standardization and its codestream and file format +are frozen. + +The library API, command line options, and tools in this repository are subject +to change, however files encoded with `cjxl` conform to the JPEG XL format +specification and can be decoded with current and future `djxl` decoders or +`libjxl` decoding library. + +## Quick start guide + +For more details and other workflows see the "Advanced guide" below. + +### Checking out the code + +```bash +git clone https://github.com/libjxl/libjxl.git --recursive --shallow-submodules +``` + +This repository uses git submodules to handle some third party dependencies +under `third_party`, that's why is important to pass `--recursive`. If you +didn't check out with `--recursive`, or any submodule has changed, run: + +```bash +git submodule update --init --recursive --depth 1 --recommend-shallow +``` + +The `--shallow-submodules` and `--depth 1 --recommend-shallow` options create +shallow clones which only downloads the commits requested, and is all that is +needed to build `libjxl`. Should full clones be necessary, you could always run: + +```bash +git submodule foreach git fetch --unshallow +git submodule update --init --recursive +``` + +which pulls the rest of the commits in the submodules. + +Important: If you downloaded a zip file or tarball from the web interface you +won't get the needed submodules and the code will not compile. You can download +these external dependencies from source running `./deps.sh`. The git workflow +described above is recommended instead. + +### Installing dependencies + +Required dependencies for compiling the code, in a Debian/Ubuntu based +distribution run: + +```bash +sudo apt install cmake pkg-config libbrotli-dev +``` + +Optional dependencies for supporting other formats in the `cjxl`/`djxl` tools, +in a Debian/Ubuntu based distribution run: + +```bash +sudo apt install libgif-dev libjpeg-dev libopenexr-dev libpng-dev libwebp-dev +``` + +We recommend using a recent Clang compiler (version 7 or newer), for that +install clang and set `CC` and `CXX` variables. + +```bash +sudo apt install clang +export CC=clang CXX=clang++ +``` + +### Building + +```bash +cd libjxl +mkdir build +cd build +cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF .. +cmake --build . -- -j$(nproc) +``` + +The encoder/decoder tools will be available in the `build/tools` directory. + +### Installing + +```bash +sudo cmake --install . +``` + +### Basic encoder/decoder + +To encode a source image to JPEG XL with default settings: + +```bash +build/tools/cjxl input.png output.jxl +``` + +For more settings run `build/tools/cjxl --help` or for a full list of options +run `build/tools/cjxl -v -v --help`. + +To decode a JPEG XL file run: + +```bash +build/tools/djxl input.jxl output.png +``` + +When possible `cjxl`/`djxl` are able to read/write the following +image formats: .exr, .gif, .jpeg/.jpg, .pfm, .pgm/.ppm, .pgx, .png. + +### Benchmarking + +For speed benchmarks on single images in single or multi-threaded decoding +`djxl` can print decoding speed information. See `djxl --help` for details +on the decoding options and note that the output image is optional for +benchmarking purposes. + +For more comprehensive benchmarking options, see the +[benchmarking guide](doc/benchmarking.md). + +## Advanced guide + +### Building with Docker + +We build a common environment based on Debian/Ubuntu using Docker. Other +systems may have different combinations of versions and dependencies that +have not been tested and may not work. For those cases we recommend using the +Docker container as explained in the +[step by step guide](doc/developing_in_docker.md). + +### Building JPEG XL for developers + +For experienced developers, we provide build instructions for several other environments: + +* [Building on Debian](doc/developing_in_debian.md) +* Building on Windows with [vcpkg](doc/developing_in_windows_vcpkg.md) (Visual Studio 2019) +* Building on Windows with [MSYS2](doc/developing_in_windows_msys.md) +* [Cross Compiling for Windows with Crossroad](doc/developing_with_crossroad.md) + +If you encounter any difficulties, please use Docker instead. + +## License + +This software is available under a 3-clause BSD license which can be found in +the [LICENSE](LICENSE) file, with an "Additional IP Rights Grant" as outlined in +the [PATENTS](PATENTS) file. + +Please note that the PATENTS file only mentions Google since Google is the legal +entity receiving the Contributor License Agreements (CLA) from all contributors +to the JPEG XL Project, including the initial main contributors to the JPEG XL +format: Cloudinary and Google. + +## Additional documentation + +### Codec description + +* [JPEG XL Format Overview](doc/format_overview.md) +* [Introductory paper](https://www.spiedigitallibrary.org/proceedings/Download?fullDOI=10.1117%2F12.2529237) (open-access) +* [XL Overview](doc/xl_overview.md) - a brief introduction to the source code modules +* [JPEG XL white paper](https://ds.jpeg.org/whitepapers/jpeg-xl-whitepaper.pdf) +* [JPEG XL official website](https://jpeg.org/jpegxl) +* [JPEG XL community website](https://jpegxl.info) + +### Development process + +* [More information on testing/build options](doc/building_and_testing.md) +* [Git guide for JPEG XL](doc/developing_in_github.md) - for developers +* [Fuzzing](doc/fuzzing.md) - for developers +* [Building Web Assembly artifacts](doc/building_wasm.md) +* [Test coverage on Codecov.io](https://app.codecov.io/gh/libjxl/libjxl) - for + developers +* [libjxl documentation on readthedocs.io](https://libjxl.readthedocs.io/) + +### Contact + +If you encounter a bug or other issue with the software, please open an Issue here. + +There is a [subreddit about JPEG XL](https://www.reddit.com/r/jpegxl/), and +informal chatting with developers and early adopters of `libjxl` can be done on the +[JPEG XL Discord server](https://discord.gg/DqkQgDRTFu). diff --git a/media/libjxl/src/SECURITY.md b/media/libjxl/src/SECURITY.md new file mode 100644 index 000000000..d03012a63 --- /dev/null +++ b/media/libjxl/src/SECURITY.md @@ -0,0 +1,73 @@ +# Security and Vulnerability Policy for libjxl + +## TL;DR: + +CPE prefix: `cpe:2.3:a:libjxl_project:libjxl` + +To report a security issue, please email libjxl-security@google.com. + +Include in your email a description of the issue, the steps you took to create +the issue, affected versions, and if known, mitigations for the issue. Our +vulnerability management team will acknowledge receiving your email within 3 +working days. + +This project follows a 90 day disclosure timeline. + +For all other bugs, where there are no security implications about disclosing +the unpatched bug, open a [new issue](https://github.com/libjxl/libjxl/issues) +checking first for existing similar issues. If in doubt about the security +impact of a bug you discovered, email first. + +## Policy overview + +libjxl's Security Policy is based on the [Google Open Source program +guidelines](https://github.com/google/oss-vulnerability-guide) for coordinated +vulnerability disclosure. + +Early versions of `libjxl` had a different security policy that didn't provide +security and vulnerability disclosure support. Versions up to and including +0.3.7 are not covered and won't receive any security advisory. + +Only released versions, starting from version 0.5, are covered by this policy. +Development branches, arbitrary commits from `main` branch or even releases with +backported features externally patched on top are not covered. Only those +versions with a release tag in `libjxl`'s repository are covered, starting from +version 0.5. + +## What's a "Security bug" + +A security bug is a bug that can potentially be exploited to let an attacker +gain unauthorized access or privileges such as disclosing information or +arbitrary code execution. Not all fuzzer-found bugs and not all assert() +failures are considered security bugs in libjxl. For a detailed explanation and +examples see our [Security Vulnerabilities Playbook](doc/vuln_playbook.md). + +## What to expect + +To report a security issue, please email libjxl-security@google.com with all the +details about the bug you encountered. + + * Include a description of the issue, steps to reproduce, etc. Compiler + versions, flags, exact version used and even CPU are often relevant given our + usage of SIMD and run-time dispatch of SIMD instructions. + + * A member of our security team will reply to you within 3 business days. Note + that business days are different in different countries. + + * We will evaluate the issue and we may require more input from your side to + reproduce it. + + * If the issue fits in the description of a security bug, we will issue a + CVE, publish a fix and make a new minor or patch release with it. There is + a maximum of 90 day disclosure timeline, we ask you to not publish the + details before the 90 day deadline or the release date (whichever comes + first). + + * In the case that we publish a CVE we will credit the external researcher who + reported the issue. When reporting security issues please let us know if you + need to include specific information while doing so, like for example a + company affiliation. + +Our security team follows the [Security Vulnerabilities +Playbook](doc/vuln_playbook.md). For more details about the process and policies +please take a look at it. diff --git a/media/libjxl/src/bash_test.sh b/media/libjxl/src/bash_test.sh new file mode 100644 index 000000000..675026ab2 --- /dev/null +++ b/media/libjxl/src/bash_test.sh @@ -0,0 +1,314 @@ +#!/bin/bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Tests implemented in bash. These typically will run checks about the source +# code rather than the compiled one. + +MYDIR=$(dirname $(realpath "$0")) + +set -u + +test_includes() { + local ret=0 + local f + for f in $(git ls-files | grep -E '(\.cc|\.cpp|\.h)$'); do + if [ ! -e "$f" ]; then + continue + fi + # Check that the public files (in lib/include/ directory) don't use the full + # path to the public header since users of the library will include the + # library as: #include "jxl/foobar.h". + if [[ "${f#lib/include/}" != "${f}" ]]; then + if grep -i -H -n -E '#include\s*[<"]lib/include/jxl' "$f" >&2; then + echo "Don't add \"include/\" to the include path of public headers." >&2 + ret=1 + fi + fi + + if [[ "${f#third_party/}" == "$f" ]]; then + # $f is not in third_party/ + + # Check that local files don't use the full path to third_party/ + # directory since the installed versions will not have that path. + # Add an exception for third_party/dirent.h. + if grep -v -F 'third_party/dirent.h' "$f" | \ + grep -i -H -n -E '#include\s*[<"]third_party/' >&2 && + [[ $ret -eq 0 ]]; then + cat >&2 <&2 + ret=1 + fi + done + return ${ret} +} + +test_copyright() { + local ret=0 + local f + for f in $( + git ls-files | grep -E \ + '(Dockerfile.*|\.c|\.cc|\.cpp|\.gni|\.h|\.java|\.sh|\.m|\.py|\.ui|\.yml)$'); do + if [ ! -e "$f" ]; then + continue + fi + if [[ "${f#third_party/}" == "$f" ]]; then + # $f is not in third_party/ + if ! head -n 10 "$f" | + grep -F 'Copyright (c) the JPEG XL Project Authors.' >/dev/null ; then + echo "$f: Missing Copyright blob near the top of the file." >&2 + ret=1 + fi + if ! head -n 10 "$f" | + grep -F 'Use of this source code is governed by a BSD-style' \ + >/dev/null ; then + echo "$f: Missing License blob near the top of the file." >&2 + ret=1 + fi + fi + done + return ${ret} +} + +# Check that we don't use "%zu" or "%zd" in format string for size_t. +test_printf_size_t() { + local ret=0 + if grep -n -E '%[0-9]*z[udx]' \ + $(git ls-files | grep -E '(\.c|\.cc|\.cpp|\.h)$'); then + echo "Don't use '%zu' or '%zd' in a format string, instead use " \ + "'%\" PRIuS \"' or '%\" PRIdS \"'." >&2 + ret=1 + fi + + if grep -n -E 'gmock\.h' \ + $(git ls-files | grep -E '(\.c|\.cc|\.cpp|\.h)$' | grep -v -F /test_utils.h); then + echo "Don't include gmock directly, instead include 'test_utils.h'. " >&2 + ret=1 + fi + + local f + for f in $(git ls-files | grep -E "\.cc$" | xargs grep 'PRI[udx]S' | + cut -f 1 -d : | uniq); do + if [ ! -e "$f" ]; then + continue + fi + if ! grep -F printf_macros.h "$f" >/dev/null; then + echo "$f: Add lib/jxl/base/printf_macros.h for PRI.S, or use other " \ + "types for code outside lib/jxl library." >&2 + ret=1 + fi + done + + for f in $(git ls-files | grep -E "\.h$" | grep -v -E '(printf_macros\.h|test_utils\.h)' | + xargs grep -n 'PRI[udx]S'); do + # Having PRIuS / PRIdS in a header file means that printf_macros.h may + # be included before a system header, in particular before gtest headers. + # those may re-define PRIuS unconditionally causing a compile error. + echo "$f: Don't use PRI.S in header files. Sorry." + ret=1 + done + + return ${ret} +} + +# Check that "dec_" code doesn't depend on "enc_" headers. +test_dec_enc_deps() { + local ret=0 + local f + for f in $(git ls-files | grep -E '/dec_'); do + if [ ! -e "$f" ]; then + continue + fi + if [[ "${f#third_party/}" == "$f" ]]; then + # $f is not in third_party/ + if grep -n -H -E "#include.*/enc_" "$f" >&2; then + echo "$f: Don't include \"enc_*\" files from \"dec_*\" files." >&2 + ret=1 + fi + fi + done + return ${ret} +} + +# Check for git merge conflict markers. +test_merge_conflict() { + local ret=0 + TEXT_FILES='(\.cc|\.cpp|\.h|\.sh|\.m|\.py|\.md|\.txt|\.cmake)$' + for f in $(git ls-files | grep -E "${TEXT_FILES}"); do + if [ ! -e "$f" ]; then + continue + fi + if grep -E '^<<<<<<< ' "$f"; then + echo "$f: Found git merge conflict marker. Please resolve." >&2 + ret=1 + fi + done + return ${ret} +} + +# Check that the library and the package have the same version. This prevents +# accidentally having them out of sync. +get_version() { + local varname=$1 + local line=$(grep -F "set(${varname} " lib/CMakeLists.txt | head -n 1) + [[ -n "${line}" ]] + line="${line#set(${varname} }" + line="${line%)}" + echo "${line}" +} + +test_version() { + local major=$(get_version JPEGXL_MAJOR_VERSION) + local minor=$(get_version JPEGXL_MINOR_VERSION) + local patch=$(get_version JPEGXL_PATCH_VERSION) + # Check that the version is not empty + if [[ -z "${major}${minor}${patch}" ]]; then + echo "Couldn't parse version from CMakeLists.txt" >&2 + return 1 + fi + local pkg_version=$(head -n 1 debian/changelog) + # Get only the part between the first "jpeg-xl (" and the following ")". + pkg_version="${pkg_version#jpeg-xl (}" + pkg_version="${pkg_version%%)*}" + if [[ -z "${pkg_version}" ]]; then + echo "Couldn't parse version from debian package" >&2 + return 1 + fi + + local lib_version="${major}.${minor}.${patch}" + lib_version="${lib_version%.0}" + if [[ "${pkg_version}" != "${lib_version}"* ]]; then + echo "Debian package version (${pkg_version}) doesn't match library" \ + "version (${lib_version})." >&2 + return 1 + fi + return 0 +} + +# Check that the SHA versions in deps.sh matches the git submodules. +test_deps_version() { + while IFS= read -r line; do + if [[ "${line:0:10}" != "[submodule" ]]; then + continue + fi + line="${line#[submodule \"}" + line="${line%\"]}" + local varname=$(tr '[:lower:]' '[:upper:]' <<< "${line}") + varname="${varname/\//_}" + if ! grep -F "${varname}=" deps.sh >/dev/null; then + # Ignoring submodule not in deps.sh + continue + fi + local deps_sha=$(grep -F "${varname}=" deps.sh | cut -f 2 -d '"') + [[ -n "${deps_sha}" ]] + local git_sha=$(git ls-tree -r HEAD "${line}" | cut -f 1 | cut -f 3 -d ' ') + if [[ "${deps_sha}" != "${git_sha}" ]]; then + cat >&2 </dev/null; then + cat >&2 <&2; then + echo "Don't use \"%n\"." >&2 + ret=1 + fi + done + return ${ret} +} + +main() { + local ret=0 + cd "${MYDIR}" + + if ! git rev-parse >/dev/null 2>/dev/null; then + echo "Not a git checkout, skipping bash_test" + return 0 + fi + + IFS=$'\n' + for f in $(declare -F); do + local test_name=$(echo "$f" | cut -f 3 -d ' ') + # Runs all the local bash functions that start with "test_". + if [[ "${test_name}" == test_* ]]; then + echo "Test ${test_name}: Start" + if ${test_name}; then + echo "Test ${test_name}: PASS" + else + echo "Test ${test_name}: FAIL" + ret=1 + fi + fi + done + return ${ret} +} + +main "$@" diff --git a/media/libjxl/src/ci.sh b/media/libjxl/src/ci.sh new file mode 100644 index 000000000..45d5218d2 --- /dev/null +++ b/media/libjxl/src/ci.sh @@ -0,0 +1,1519 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Continuous integration helper module. This module is meant to be called from +# the .gitlab-ci.yml file during the continuous integration build, as well as +# from the command line for developers. + +set -eu + +OS=`uname -s` + +MYDIR=$(dirname $(realpath "$0")) + +### Environment parameters: +TEST_STACK_LIMIT="${TEST_STACK_LIMIT:-256}" +CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-RelWithDebInfo} +CMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH:-} +CMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER:-} +CMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER:-} +CMAKE_MAKE_PROGRAM=${CMAKE_MAKE_PROGRAM:-} +SKIP_TEST="${SKIP_TEST:-0}" +TEST_SELECTOR="${TEST_SELECTOR:-}" +BUILD_TARGET="${BUILD_TARGET:-}" +ENABLE_WASM_SIMD="${ENABLE_WASM_SIMD:-0}" +if [[ -n "${BUILD_TARGET}" ]]; then + BUILD_DIR="${BUILD_DIR:-${MYDIR}/build-${BUILD_TARGET%%-*}}" +else + BUILD_DIR="${BUILD_DIR:-${MYDIR}/build}" +fi +# Whether we should post a message in the MR when the build fails. +POST_MESSAGE_ON_ERROR="${POST_MESSAGE_ON_ERROR:-1}" + +# Set default compilers to clang if not already set +export CC=${CC:-clang} +export CXX=${CXX:-clang++} + +# Time limit for the "fuzz" command in seconds (0 means no limit). +FUZZER_MAX_TIME="${FUZZER_MAX_TIME:-0}" + +SANITIZER="none" + + +if [[ "${BUILD_TARGET%%-*}" == "x86_64" || + "${BUILD_TARGET%%-*}" == "i686" ]]; then + # Default to building all targets, even if compiler baseline is SSE4 + HWY_BASELINE_TARGETS=${HWY_BASELINE_TARGETS:-HWY_EMU128} +else + HWY_BASELINE_TARGETS=${HWY_BASELINE_TARGETS:-} +fi + +# Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS +CMAKE_FLAGS=${CMAKE_FLAGS:-} +CMAKE_C_FLAGS="${CMAKE_C_FLAGS:-} ${CMAKE_FLAGS}" +CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS:-} ${CMAKE_FLAGS}" + +CMAKE_CROSSCOMPILING_EMULATOR=${CMAKE_CROSSCOMPILING_EMULATOR:-} +CMAKE_EXE_LINKER_FLAGS=${CMAKE_EXE_LINKER_FLAGS:-} +CMAKE_FIND_ROOT_PATH=${CMAKE_FIND_ROOT_PATH:-} +CMAKE_MODULE_LINKER_FLAGS=${CMAKE_MODULE_LINKER_FLAGS:-} +CMAKE_SHARED_LINKER_FLAGS=${CMAKE_SHARED_LINKER_FLAGS:-} +CMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE:-} + +if [[ "${ENABLE_WASM_SIMD}" -ne "0" ]]; then + CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS} -msimd128" + CMAKE_C_FLAGS="${CMAKE_C_FLAGS} -msimd128" + CMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS} -msimd128" +fi + +if [[ ! -z "${HWY_BASELINE_TARGETS}" ]]; then + CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS} -DHWY_BASELINE_TARGETS=${HWY_BASELINE_TARGETS}" +fi + +# Version inferred from the CI variables. +CI_COMMIT_SHA=${CI_COMMIT_SHA:-${GITHUB_SHA:-}} +JPEGXL_VERSION=${JPEGXL_VERSION:-${CI_COMMIT_SHA:0:8}} + +# Benchmark parameters +STORE_IMAGES=${STORE_IMAGES:-1} +BENCHMARK_CORPORA="${MYDIR}/third_party/corpora" + +# Local flags passed to sanitizers. +UBSAN_FLAGS=( + -fsanitize=alignment + -fsanitize=bool + -fsanitize=bounds + -fsanitize=builtin + -fsanitize=enum + -fsanitize=float-cast-overflow + -fsanitize=float-divide-by-zero + -fsanitize=integer-divide-by-zero + -fsanitize=null + -fsanitize=object-size + -fsanitize=pointer-overflow + -fsanitize=return + -fsanitize=returns-nonnull-attribute + -fsanitize=shift-base + -fsanitize=shift-exponent + -fsanitize=unreachable + -fsanitize=vla-bound + + -fno-sanitize-recover=undefined + # Brunsli uses unaligned accesses to uint32_t, so alignment is just a warning. + -fsanitize-recover=alignment +) +# -fsanitize=function doesn't work on aarch64 and arm. +if [[ "${BUILD_TARGET%%-*}" != "aarch64" && + "${BUILD_TARGET%%-*}" != "arm" ]]; then + UBSAN_FLAGS+=( + -fsanitize=function + ) +fi +if [[ "${BUILD_TARGET%%-*}" != "arm" ]]; then + UBSAN_FLAGS+=( + -fsanitize=signed-integer-overflow + ) +fi + +CLANG_TIDY_BIN=$(which clang-tidy-6.0 clang-tidy-7 clang-tidy-8 clang-tidy | head -n 1) +# Default to "cat" if "colordiff" is not installed or if stdout is not a tty. +if [[ -t 1 ]]; then + COLORDIFF_BIN=$(which colordiff cat | head -n 1) +else + COLORDIFF_BIN="cat" +fi +FIND_BIN=$(which gfind find | head -n 1) +# "false" will disable wine64 when not installed. This won't allow +# cross-compiling. +WINE_BIN=$(which wine64 false | head -n 1) + +CLANG_VERSION="${CLANG_VERSION:-}" +# Detect the clang version suffix and store it in CLANG_VERSION. For example, +# "6.0" for clang 6 or "7" for clang 7. +detect_clang_version() { + if [[ -n "${CLANG_VERSION}" ]]; then + return 0 + fi + local clang_version=$("${CC:-clang}" --version | head -n1) + clang_version=${clang_version#"Debian "} + local llvm_tag + case "${clang_version}" in + "clang version 6."*) + CLANG_VERSION="6.0" + ;; + "clang version "*) + # Any other clang version uses just the major version number. + local suffix="${clang_version#clang version }" + CLANG_VERSION="${suffix%%.*}" + ;; + "emcc"*) + # We can't use asan or msan in the emcc case. + ;; + *) + echo "Unknown clang version: ${clang_version}" >&2 + return 1 + esac +} + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} + +# Executed on exit. +on_exit() { + local retcode="$1" + # Always cleanup the CLEANUP_FILES. + cleanup + + # Post a message in the MR when requested with POST_MESSAGE_ON_ERROR but only + # if the run failed and we are not running from a MR pipeline. + if [[ ${retcode} -ne 0 && -n "${CI_BUILD_NAME:-}" && + -n "${POST_MESSAGE_ON_ERROR}" && -z "${CI_MERGE_REQUEST_ID:-}" && + "${CI_BUILD_REF_NAME}" = "master" ]]; then + load_mr_vars_from_commit + { set +xeu; } 2>/dev/null + local message="**Run ${CI_BUILD_NAME} @ ${CI_COMMIT_SHORT_SHA} failed.** + +Check the output of the job at ${CI_JOB_URL:-} to see if this was your problem. +If it was, please rollback this change or fix the problem ASAP, broken builds +slow down development. Check if the error already existed in the previous build +as well. + +Pipeline: ${CI_PIPELINE_URL} + +Previous build commit: ${CI_COMMIT_BEFORE_SHA} +" + cmd_post_mr_comment "${message}" + fi +} + +trap 'retcode=$?; { set +x; } 2>/dev/null; on_exit ${retcode}' INT TERM EXIT + + +# These variables are populated when calling merge_request_commits(). + +# The current hash at the top of the current branch or merge request branch (if +# running from a merge request pipeline). +MR_HEAD_SHA="" +# The common ancestor between the current commit and the tracked branch, such +# as master. This includes a list +MR_ANCESTOR_SHA="" + +# Populate MR_HEAD_SHA and MR_ANCESTOR_SHA. +merge_request_commits() { + { set +x; } 2>/dev/null + # GITHUB_SHA is the current reference being build in GitHub Actions. + if [[ -n "${GITHUB_SHA:-}" ]]; then + # GitHub normally does a checkout of a merge commit on a shallow repository + # by default. We want to get a bit more of the history to be able to diff + # changes on the Pull Request if needed. This fetches 10 more commits which + # should be enough given that PR normally should have 1 commit. + git -C "${MYDIR}" fetch -q origin "${GITHUB_SHA}" --depth 10 + MR_HEAD_SHA="$(git rev-parse "FETCH_HEAD^2" 2>/dev/null || + echo "${GITHUB_SHA}")" + else + # CI_BUILD_REF is the reference currently being build in the CI workflow. + MR_HEAD_SHA=$(git -C "${MYDIR}" rev-parse -q "${CI_BUILD_REF:-HEAD}") + fi + + if [[ -n "${CI_MERGE_REQUEST_IID:-}" ]]; then + # Merge request pipeline in CI. In this case the upstream is called "origin" + # but it refers to the forked project that's the source of the merge + # request. We need to get the target of the merge request, for which we need + # to query that repository using our CI_JOB_TOKEN. + echo "machine gitlab.com login gitlab-ci-token password ${CI_JOB_TOKEN}" \ + >> "${HOME}/.netrc" + git -C "${MYDIR}" fetch "${CI_MERGE_REQUEST_PROJECT_URL}" \ + "${CI_MERGE_REQUEST_TARGET_BRANCH_NAME}" + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q FETCH_HEAD) + elif [[ -n "${GITHUB_BASE_REF:-}" ]]; then + # Pull request workflow in GitHub Actions. GitHub checkout action uses + # "origin" as the remote for the git checkout. + git -C "${MYDIR}" fetch -q origin "${GITHUB_BASE_REF}" + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q FETCH_HEAD) + else + # We are in a local branch, not a merge request. + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q HEAD@{upstream} || true) + fi + + if [[ -z "${MR_ANCESTOR_SHA}" ]]; then + echo "Warning, not tracking any branch, using the last commit in HEAD.">&2 + # This prints the return value with just HEAD. + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q "${MR_HEAD_SHA}^") + else + # GitHub runs the pipeline on a merge commit, no need to look for the common + # ancestor in that case. + if [[ -z "${GITHUB_BASE_REF:-}" ]]; then + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" merge-base \ + "${MR_ANCESTOR_SHA}" "${MR_HEAD_SHA}") + fi + fi + set -x +} + +# Load the MR iid from the landed commit message when running not from a +# merge request workflow. This is useful to post back results at the merge +# request when running pipelines from master. +load_mr_vars_from_commit() { + { set +x; } 2>/dev/null + if [[ -z "${CI_MERGE_REQUEST_IID:-}" ]]; then + local mr_iid=$(git rev-list --format=%B --max-count=1 HEAD | + grep -F "${CI_PROJECT_URL}" | grep -F "/merge_requests" | head -n 1) + # mr_iid contains a string like this if it matched: + # Part-of: + if [[ -n "${mr_iid}" ]]; then + mr_iid=$(echo "${mr_iid}" | + sed -E 's,^.*merge_requests/([0-9]+)>.*$,\1,') + CI_MERGE_REQUEST_IID="${mr_iid}" + CI_MERGE_REQUEST_PROJECT_ID=${CI_PROJECT_ID} + fi + fi + set -x +} + +# Posts a comment to the current merge request. +cmd_post_mr_comment() { + { set +x; } 2>/dev/null + local comment="$1" + if [[ -n "${BOT_TOKEN:-}" && -n "${CI_MERGE_REQUEST_IID:-}" ]]; then + local url="${CI_API_V4_URL}/projects/${CI_MERGE_REQUEST_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}/notes" + curl -X POST -g \ + -H "PRIVATE-TOKEN: ${BOT_TOKEN}" \ + --data-urlencode "body=${comment}" \ + --output /dev/null \ + "${url}" + fi + set -x +} + +# Set up and export the environment variables needed by the child processes. +export_env() { + if [[ "${BUILD_TARGET}" == *mingw32 ]]; then + # Wine needs to know the paths to the mingw dlls. These should be + # separated by ';'. + WINEPATH=$("${CC:-clang}" -print-search-dirs --target="${BUILD_TARGET}" \ + | grep -F 'libraries: =' | cut -f 2- -d '=' | tr ':' ';') + # We also need our own libraries in the wine path. + local real_build_dir=$(realpath "${BUILD_DIR}") + # Some library .dll dependencies are installed in /bin: + export WINEPATH="${WINEPATH};${real_build_dir};${real_build_dir}/third_party/brotli;/usr/${BUILD_TARGET}/bin" + + local prefix="${BUILD_DIR}/wineprefix" + mkdir -p "${prefix}" + export WINEPREFIX=$(realpath "${prefix}") + fi + # Sanitizers need these variables to print and properly format the stack + # traces: + LLVM_SYMBOLIZER=$("${CC:-clang}" -print-prog-name=llvm-symbolizer || true) + if [[ -n "${LLVM_SYMBOLIZER}" ]]; then + export ASAN_SYMBOLIZER_PATH="${LLVM_SYMBOLIZER}" + export MSAN_SYMBOLIZER_PATH="${LLVM_SYMBOLIZER}" + export UBSAN_SYMBOLIZER_PATH="${LLVM_SYMBOLIZER}" + fi +} + +cmake_configure() { + export_env + + if [[ "${STACK_SIZE:-0}" == 1 ]]; then + # Dump the stack size of each function in the .stack_sizes section for + # analysis. + CMAKE_C_FLAGS+=" -fstack-size-section" + CMAKE_CXX_FLAGS+=" -fstack-size-section" + fi + + local args=( + -B"${BUILD_DIR}" -H"${MYDIR}" + -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" + -G Ninja + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" + -DCMAKE_MODULE_LINKER_FLAGS="${CMAKE_MODULE_LINKER_FLAGS}" + -DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}" + -DJPEGXL_VERSION="${JPEGXL_VERSION}" + -DSANITIZER="${SANITIZER}" + # These are not enabled by default in cmake. + -DJPEGXL_ENABLE_VIEWERS=ON + -DJPEGXL_ENABLE_PLUGINS=ON + -DJPEGXL_ENABLE_DEVTOOLS=ON + # We always use libfuzzer in the ci.sh wrapper. + -DJPEGXL_FUZZER_LINK_FLAGS="-fsanitize=fuzzer" + ) + if [[ "${BUILD_TARGET}" != *mingw32 ]]; then + args+=( + -DJPEGXL_WARNINGS_AS_ERRORS=ON + ) + fi + if [[ -n "${BUILD_TARGET}" ]]; then + local system_name="Linux" + if [[ "${BUILD_TARGET}" == *mingw32 ]]; then + # When cross-compiling with mingw the target must be set to Windows and + # run programs with wine. + system_name="Windows" + args+=( + -DCMAKE_CROSSCOMPILING_EMULATOR="${WINE_BIN}" + # Normally CMake automatically defines MINGW=1 when building with the + # mingw compiler (x86_64-w64-mingw32-gcc) but we are normally compiling + # with clang. + -DMINGW=1 + ) + fi + # EMSCRIPTEN toolchain sets the right values itself + if [[ "${BUILD_TARGET}" != wasm* ]]; then + # If set, BUILD_TARGET must be the target triplet such as + # x86_64-unknown-linux-gnu. + args+=( + -DCMAKE_C_COMPILER_TARGET="${BUILD_TARGET}" + -DCMAKE_CXX_COMPILER_TARGET="${BUILD_TARGET}" + # Only the first element of the target triplet. + -DCMAKE_SYSTEM_PROCESSOR="${BUILD_TARGET%%-*}" + -DCMAKE_SYSTEM_NAME="${system_name}" + -DCMAKE_TOOLCHAIN_FILE="${CMAKE_TOOLCHAIN_FILE}" + ) + else + args+=( + # sjpeg confuses WASM SIMD with SSE. + -DSJPEG_ENABLE_SIMD=OFF + # Building shared libs is not very useful for WASM. + -DBUILD_SHARED_LIBS=OFF + ) + fi + args+=( + # These are needed to make googletest work when cross-compiling. + -DCMAKE_CROSSCOMPILING=1 + -DHAVE_STD_REGEX=0 + -DHAVE_POSIX_REGEX=0 + -DHAVE_GNU_POSIX_REGEX=0 + -DHAVE_STEADY_CLOCK=0 + -DHAVE_THREAD_SAFETY_ATTRIBUTES=0 + ) + if [[ -z "${CMAKE_FIND_ROOT_PATH}" ]]; then + # find_package() will look in this prefix for libraries. + CMAKE_FIND_ROOT_PATH="/usr/${BUILD_TARGET}" + fi + if [[ -z "${CMAKE_PREFIX_PATH}" ]]; then + CMAKE_PREFIX_PATH="/usr/${BUILD_TARGET}" + fi + # Use pkg-config for the target. If there's no pkg-config available for the + # target we can set the PKG_CONFIG_PATH to the appropriate path in most + # linux distributions. + local pkg_config=$(which "${BUILD_TARGET}-pkg-config" || true) + if [[ -z "${pkg_config}" ]]; then + pkg_config=$(which pkg-config) + export PKG_CONFIG_LIBDIR="/usr/${BUILD_TARGET}/lib/pkgconfig" + fi + if [[ -n "${pkg_config}" ]]; then + args+=(-DPKG_CONFIG_EXECUTABLE="${pkg_config}") + fi + fi + if [[ -n "${CMAKE_CROSSCOMPILING_EMULATOR}" ]]; then + args+=( + -DCMAKE_CROSSCOMPILING_EMULATOR="${CMAKE_CROSSCOMPILING_EMULATOR}" + ) + fi + if [[ -n "${CMAKE_FIND_ROOT_PATH}" ]]; then + args+=( + -DCMAKE_FIND_ROOT_PATH="${CMAKE_FIND_ROOT_PATH}" + ) + fi + if [[ -n "${CMAKE_PREFIX_PATH}" ]]; then + args+=( + -DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" + ) + fi + if [[ -n "${CMAKE_C_COMPILER_LAUNCHER}" ]]; then + args+=( + -DCMAKE_C_COMPILER_LAUNCHER="${CMAKE_C_COMPILER_LAUNCHER}" + ) + fi + if [[ -n "${CMAKE_CXX_COMPILER_LAUNCHER}" ]]; then + args+=( + -DCMAKE_CXX_COMPILER_LAUNCHER="${CMAKE_CXX_COMPILER_LAUNCHER}" + ) + fi + if [[ -n "${CMAKE_MAKE_PROGRAM}" ]]; then + args+=( + -DCMAKE_MAKE_PROGRAM="${CMAKE_MAKE_PROGRAM}" + ) + fi + if [[ "${BUILD_TARGET}" == wasm* ]]; then + emcmake cmake "${args[@]}" "$@" + else + cmake "${args[@]}" "$@" + fi +} + +cmake_build_and_test() { + # gtest_discover_tests() runs the test binaries to discover the list of tests + # at build time, which fails under qemu. + ASAN_OPTIONS=detect_leaks=0 cmake --build "${BUILD_DIR}" -- all doc + # Pack test binaries if requested. + if [[ "${PACK_TEST:-}" == "1" ]]; then + (cd "${BUILD_DIR}" + ${FIND_BIN} -name '*.cmake' -a '!' -path '*CMakeFiles*' + # gtest / gmock / gtest_main shared libs + ${FIND_BIN} lib/ -name 'libg*.so*' + ${FIND_BIN} -type d -name tests -a '!' -path '*CMakeFiles*' + ) | tar -C "${BUILD_DIR}" -cf "${BUILD_DIR}/tests.tar.xz" -T - \ + --use-compress-program="xz --threads=$(nproc --all || echo 1) -6" + du -h "${BUILD_DIR}/tests.tar.xz" + # Pack coverage data if also available. + touch "${BUILD_DIR}/gcno.sentinel" + (cd "${BUILD_DIR}"; echo gcno.sentinel; ${FIND_BIN} -name '*gcno') | \ + tar -C "${BUILD_DIR}" -cvf "${BUILD_DIR}/gcno.tar.xz" -T - \ + --use-compress-program="xz --threads=$(nproc --all || echo 1) -6" + fi + + if [[ "${SKIP_TEST}" -ne "1" ]]; then + (cd "${BUILD_DIR}" + export UBSAN_OPTIONS=print_stacktrace=1 + [[ "${TEST_STACK_LIMIT}" == "none" ]] || ulimit -s "${TEST_STACK_LIMIT}" + ctest -j $(nproc --all || echo 1) ${TEST_SELECTOR} --output-on-failure) + fi +} + +# Configure the build to strip unused functions. This considerably reduces the +# output size, specially for tests which only use a small part of the whole +# library. +strip_dead_code() { + # Emscripten does tree shaking without any extra flags. + if [[ "${BUILD_TARGET}" == wasm* ]]; then + return 0 + fi + # -ffunction-sections, -fdata-sections and -Wl,--gc-sections effectively + # discard all unreachable code, reducing the code size. For this to work, we + # need to also pass --no-export-dynamic to prevent it from exporting all the + # internal symbols (like functions) making them all reachable and thus not a + # candidate for removal. + CMAKE_CXX_FLAGS+=" -ffunction-sections -fdata-sections" + CMAKE_C_FLAGS+=" -ffunction-sections -fdata-sections" + if [[ "${OS}" == "Darwin" ]]; then + CMAKE_EXE_LINKER_FLAGS+=" -dead_strip" + CMAKE_SHARED_LINKER_FLAGS+=" -dead_strip" + else + CMAKE_EXE_LINKER_FLAGS+=" -Wl,--gc-sections -Wl,--no-export-dynamic" + CMAKE_SHARED_LINKER_FLAGS+=" -Wl,--gc-sections -Wl,--no-export-dynamic" + fi +} + +### Externally visible commands + +cmd_debug() { + CMAKE_BUILD_TYPE="Debug" + cmake_configure "$@" + cmake_build_and_test +} + +cmd_release() { + CMAKE_BUILD_TYPE="Release" + strip_dead_code + cmake_configure "$@" + cmake_build_and_test +} + +cmd_opt() { + CMAKE_BUILD_TYPE="RelWithDebInfo" + CMAKE_CXX_FLAGS+=" -DJXL_DEBUG_WARNING -DJXL_DEBUG_ON_ERROR" + cmake_configure "$@" + cmake_build_and_test +} + +cmd_coverage() { + # -O0 prohibits stack space reuse -> causes stack-overflow on dozens of tests. + TEST_STACK_LIMIT="none" + + cmd_release -DJPEGXL_ENABLE_COVERAGE=ON "$@" + + if [[ "${SKIP_TEST}" -ne "1" ]]; then + # If we didn't run the test we also don't print a coverage report. + cmd_coverage_report + fi +} + +cmd_coverage_report() { + LLVM_COV=$("${CC:-clang}" -print-prog-name=llvm-cov) + local real_build_dir=$(realpath "${BUILD_DIR}") + local gcovr_args=( + -r "${real_build_dir}" + --gcov-executable "${LLVM_COV} gcov" + # Only print coverage information for the libjxl directories. The rest + # is not part of the code under test. + --filter '.*jxl/.*' + --exclude '.*_test.cc' + --exclude '.*_testonly..*' + --exclude '.*_debug.*' + --exclude '.*test_utils..*' + --object-directory "${real_build_dir}" + ) + + ( + cd "${real_build_dir}" + gcovr "${gcovr_args[@]}" --html --html-details \ + --output="${real_build_dir}/coverage.html" + gcovr "${gcovr_args[@]}" --print-summary | + tee "${real_build_dir}/coverage.txt" + gcovr "${gcovr_args[@]}" --xml --output="${real_build_dir}/coverage.xml" + ) +} + +cmd_test() { + export_env + # Unpack tests if needed. + if [[ -e "${BUILD_DIR}/tests.tar.xz" && ! -d "${BUILD_DIR}/tests" ]]; then + tar -C "${BUILD_DIR}" -Jxvf "${BUILD_DIR}/tests.tar.xz" + fi + if [[ -e "${BUILD_DIR}/gcno.tar.xz" && ! -d "${BUILD_DIR}/gcno.sentinel" ]]; then + tar -C "${BUILD_DIR}" -Jxvf "${BUILD_DIR}/gcno.tar.xz" + fi + (cd "${BUILD_DIR}" + export UBSAN_OPTIONS=print_stacktrace=1 + [[ "${TEST_STACK_LIMIT}" == "none" ]] || ulimit -s "${TEST_STACK_LIMIT}" + ctest -j $(nproc --all || echo 1) --output-on-failure "$@") +} + +cmd_gbench() { + export_env + (cd "${BUILD_DIR}" + export UBSAN_OPTIONS=print_stacktrace=1 + lib/jxl_gbench \ + --benchmark_counters_tabular=true \ + --benchmark_out_format=json \ + --benchmark_out=gbench.json "$@" + ) +} + +cmd_asanfuzz() { + CMAKE_CXX_FLAGS+=" -fsanitize=fuzzer-no-link -DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION=1" + CMAKE_C_FLAGS+=" -fsanitize=fuzzer-no-link -DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION=1" + cmd_asan -DJPEGXL_ENABLE_FUZZERS=ON "$@" +} + +cmd_msanfuzz() { + # Install msan if needed before changing the flags. + detect_clang_version + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + if [[ ! -d "${msan_prefix}" || -e "${msan_prefix}/lib/libc++abi.a" ]]; then + # Install msan libraries for this version if needed or if an older version + # with libc++abi was installed. + cmd_msan_install + fi + + CMAKE_CXX_FLAGS+=" -fsanitize=fuzzer-no-link -DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION=1" + CMAKE_C_FLAGS+=" -fsanitize=fuzzer-no-link -DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION=1" + cmd_msan -DJPEGXL_ENABLE_FUZZERS=ON "$@" +} + +cmd_asan() { + SANITIZER="asan" + CMAKE_C_FLAGS+=" -DJXL_ENABLE_ASSERT=1 -g -DADDRESS_SANITIZER \ + -fsanitize=address ${UBSAN_FLAGS[@]}" + CMAKE_CXX_FLAGS+=" -DJXL_ENABLE_ASSERT=1 -g -DADDRESS_SANITIZER \ + -fsanitize=address ${UBSAN_FLAGS[@]}" + strip_dead_code + cmake_configure "$@" -DJPEGXL_ENABLE_TCMALLOC=OFF + cmake_build_and_test +} + +cmd_tsan() { + SANITIZER="tsan" + local tsan_args=( + -DJXL_ENABLE_ASSERT=1 + -g + -DTHREAD_SANITIZER + ${UBSAN_FLAGS[@]} + -fsanitize=thread + ) + CMAKE_C_FLAGS+=" ${tsan_args[@]}" + CMAKE_CXX_FLAGS+=" ${tsan_args[@]}" + + CMAKE_BUILD_TYPE="RelWithDebInfo" + cmake_configure "$@" -DJPEGXL_ENABLE_TCMALLOC=OFF + cmake_build_and_test +} + +cmd_msan() { + SANITIZER="msan" + detect_clang_version + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + if [[ ! -d "${msan_prefix}" || -e "${msan_prefix}/lib/libc++abi.a" ]]; then + # Install msan libraries for this version if needed or if an older version + # with libc++abi was installed. + cmd_msan_install + fi + + local msan_c_flags=( + -fsanitize=memory + -fno-omit-frame-pointer + -fsanitize-memory-track-origins + + -DJXL_ENABLE_ASSERT=1 + -g + -DMEMORY_SANITIZER + + # Force gtest to not use the cxxbai. + -DGTEST_HAS_CXXABI_H_=0 + ) + local msan_cxx_flags=( + "${msan_c_flags[@]}" + + # Some C++ sources don't use the std at all, so the -stdlib=libc++ is unused + # in those cases. Ignore the warning. + -Wno-unused-command-line-argument + -stdlib=libc++ + + # We include the libc++ from the msan directory instead, so we don't want + # the std includes. + -nostdinc++ + -cxx-isystem"${msan_prefix}/include/c++/v1" + ) + + local msan_linker_flags=( + -L"${msan_prefix}"/lib + -Wl,-rpath -Wl,"${msan_prefix}"/lib/ + ) + + CMAKE_C_FLAGS+=" ${msan_c_flags[@]} ${UBSAN_FLAGS[@]}" + CMAKE_CXX_FLAGS+=" ${msan_cxx_flags[@]} ${UBSAN_FLAGS[@]}" + CMAKE_EXE_LINKER_FLAGS+=" ${msan_linker_flags[@]}" + CMAKE_MODULE_LINKER_FLAGS+=" ${msan_linker_flags[@]}" + CMAKE_SHARED_LINKER_FLAGS+=" ${msan_linker_flags[@]}" + strip_dead_code + cmake_configure "$@" \ + -DCMAKE_CROSSCOMPILING=1 -DRUN_HAVE_STD_REGEX=0 -DRUN_HAVE_POSIX_REGEX=0 \ + -DJPEGXL_ENABLE_TCMALLOC=OFF -DJPEGXL_WARNINGS_AS_ERRORS=OFF \ + -DCMAKE_REQUIRED_LINK_OPTIONS="${msan_linker_flags[@]}" + cmake_build_and_test +} + +# Install libc++ libraries compiled with msan in the msan_prefix for the current +# compiler version. +cmd_msan_install() { + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + # Detect the llvm to install: + export CC="${CC:-clang}" + export CXX="${CXX:-clang++}" + detect_clang_version + # Allow overriding the LLVM checkout. + local llvm_root="${LLVM_ROOT:-}" + if [ -z "${llvm_root}" ]; then + local llvm_tag="llvmorg-${CLANG_VERSION}.0.0" + case "${CLANG_VERSION}" in + "6.0") + llvm_tag="llvmorg-6.0.1" + ;; + "7") + llvm_tag="llvmorg-7.0.1" + ;; + esac + local llvm_targz="${tmpdir}/${llvm_tag}.tar.gz" + curl -L --show-error -o "${llvm_targz}" \ + "https://github.com/llvm/llvm-project/archive/${llvm_tag}.tar.gz" + tar -C "${tmpdir}" -zxf "${llvm_targz}" + llvm_root="${tmpdir}/llvm-project-${llvm_tag}" + fi + + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + rm -rf "${msan_prefix}" + + declare -A CMAKE_EXTRAS + CMAKE_EXTRAS[libcxx]="\ + -DLIBCXX_CXX_ABI=libstdc++ \ + -DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=ON" + + for project in libcxx; do + local proj_build="${tmpdir}/build-${project}" + local proj_dir="${llvm_root}/${project}" + mkdir -p "${proj_build}" + cmake -B"${proj_build}" -H"${proj_dir}" \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_USE_SANITIZER=Memory \ + -DLLVM_PATH="${llvm_root}/llvm" \ + -DLLVM_CONFIG_PATH="$(which llvm-config llvm-config-7 llvm-config-6.0 | \ + head -n1)" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" \ + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \ + -DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}" \ + -DCMAKE_INSTALL_PREFIX="${msan_prefix}" \ + ${CMAKE_EXTRAS[${project}]} + cmake --build "${proj_build}" + ninja -C "${proj_build}" install + done +} + +# Internal build step shared between all cmd_ossfuzz_* commands. +_cmd_ossfuzz() { + local sanitizer="$1" + shift + mkdir -p "${BUILD_DIR}" + local real_build_dir=$(realpath "${BUILD_DIR}") + + # oss-fuzz defines three directories: + # * /work, with the working directory to do re-builds + # * /src, with the source code to build + # * /out, with the output directory where to copy over the built files. + # We use $BUILD_DIR as the /work and the script directory as the /src. The + # /out directory is ignored as developers are used to look for the fuzzers in + # $BUILD_DIR/tools/ directly. + + if [[ "${sanitizer}" = "memory" && ! -d "${BUILD_DIR}/msan" ]]; then + sudo docker run --rm -i \ + --user $(id -u):$(id -g) \ + -v "${real_build_dir}":/work \ + gcr.io/oss-fuzz-base/msan-libs-builder \ + bash -c "cp -r /msan /work" + fi + + # Args passed to ninja. These will be evaluated as a string separated by + # spaces. + local jpegxl_extra_args="$@" + + sudo docker run --rm -i \ + -e JPEGXL_UID=$(id -u) \ + -e JPEGXL_GID=$(id -g) \ + -e FUZZING_ENGINE="${FUZZING_ENGINE:-libfuzzer}" \ + -e SANITIZER="${sanitizer}" \ + -e ARCHITECTURE=x86_64 \ + -e FUZZING_LANGUAGE=c++ \ + -e MSAN_LIBS_PATH="/work/msan" \ + -e JPEGXL_EXTRA_ARGS="${jpegxl_extra_args}" \ + -v "${MYDIR}":/src/libjxl \ + -v "${MYDIR}/tools/ossfuzz-build.sh":/src/build.sh \ + -v "${real_build_dir}":/work \ + gcr.io/oss-fuzz/libjxl +} + +cmd_ossfuzz_asan() { + _cmd_ossfuzz address "$@" +} +cmd_ossfuzz_msan() { + _cmd_ossfuzz memory "$@" +} +cmd_ossfuzz_ubsan() { + _cmd_ossfuzz undefined "$@" +} + +cmd_ossfuzz_ninja() { + [[ -e "${BUILD_DIR}/build.ninja" ]] + local real_build_dir=$(realpath "${BUILD_DIR}") + + if [[ -e "${BUILD_DIR}/msan" ]]; then + echo "ossfuzz_ninja doesn't work with msan builds. Use ossfuzz_msan." >&2 + exit 1 + fi + + sudo docker run --rm -i \ + --user $(id -u):$(id -g) \ + -v "${MYDIR}":/src/libjxl \ + -v "${real_build_dir}":/work \ + gcr.io/oss-fuzz/libjxl \ + ninja -C /work "$@" +} + +cmd_fast_benchmark() { + local small_corpus_tar="${BENCHMARK_CORPORA}/jyrki-full.tar" + mkdir -p "${BENCHMARK_CORPORA}" + curl --show-error -o "${small_corpus_tar}" -z "${small_corpus_tar}" \ + "https://storage.googleapis.com/artifacts.jpegxl.appspot.com/corpora/jyrki-full.tar" + + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + tar -xf "${small_corpus_tar}" -C "${tmpdir}" + + run_benchmark "${tmpdir}" 1048576 +} + +cmd_benchmark() { + local nikon_corpus_tar="${BENCHMARK_CORPORA}/nikon-subset.tar" + mkdir -p "${BENCHMARK_CORPORA}" + curl --show-error -o "${nikon_corpus_tar}" -z "${nikon_corpus_tar}" \ + "https://storage.googleapis.com/artifacts.jpegxl.appspot.com/corpora/nikon-subset.tar" + + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + tar -xvf "${nikon_corpus_tar}" -C "${tmpdir}" + + local sem_id="jpegxl_benchmark-$$" + local nprocs=$(nproc --all || echo 1) + images=() + local filename + while IFS= read -r filename; do + # This removes the './' + filename="${filename:2}" + local mode + if [[ "${filename:0:4}" == "srgb" ]]; then + mode="RGB_D65_SRG_Rel_SRG" + elif [[ "${filename:0:5}" == "adobe" ]]; then + mode="RGB_D65_Ado_Rel_Ado" + else + echo "Unknown image colorspace: ${filename}" >&2 + exit 1 + fi + png_filename="${filename%.ppm}.png" + png_filename=$(echo "${png_filename}" | tr '/' '_') + sem --bg --id "${sem_id}" -j"${nprocs}" -- \ + "${BUILD_DIR}/tools/decode_and_encode" \ + "${tmpdir}/${filename}" "${mode}" "${tmpdir}/${png_filename}" + images+=( "${png_filename}" ) + done < <(cd "${tmpdir}"; ${FIND_BIN} . -name '*.ppm' -type f) + sem --id "${sem_id}" --wait + + # We need about 10 GiB per thread on these images. + run_benchmark "${tmpdir}" 10485760 +} + +get_mem_available() { + if [[ "${OS}" == "Darwin" ]]; then + echo $(vm_stat | grep -F 'Pages free:' | awk '{print $3 * 4}') + else + echo $(grep -F MemAvailable: /proc/meminfo | awk '{print $2}') + fi +} + +run_benchmark() { + local src_img_dir="$1" + local mem_per_thread="${2:-10485760}" + + local output_dir="${BUILD_DIR}/benchmark_results" + mkdir -p "${output_dir}" + + # The memory available at the beginning of the benchmark run in kB. The number + # of threads depends on the available memory, and the passed memory per + # thread. We also add a 2 GiB of constant memory. + local mem_available="$(get_mem_available)" + # Check that we actually have a MemAvailable value. + [[ -n "${mem_available}" ]] + local num_threads=$(( (${mem_available} - 1048576) / ${mem_per_thread} )) + if [[ ${num_threads} -le 0 ]]; then + num_threads=1 + fi + + local benchmark_args=( + --input "${src_img_dir}/*.png" + --codec=jpeg:yuv420:q85,webp:q80,jxl:d1:6,jxl:d1:6:downsampling=8,jxl:d5:6,jxl:d5:6:downsampling=8,jxl:m:d0:2,jxl:m:d0:3,jxl:m:d2:2 + --output_dir "${output_dir}" + --noprofiler --show_progress + --num_threads="${num_threads}" + ) + if [[ "${STORE_IMAGES}" == "1" ]]; then + benchmark_args+=(--save_decompressed --save_compressed) + fi + ( + [[ "${TEST_STACK_LIMIT}" == "none" ]] || ulimit -s "${TEST_STACK_LIMIT}" + "${BUILD_DIR}/tools/benchmark_xl" "${benchmark_args[@]}" | \ + tee "${output_dir}/results.txt" + + # Check error code for benckmark_xl command. This will exit if not. + return ${PIPESTATUS[0]} + ) + + if [[ -n "${CI_BUILD_NAME:-}" ]]; then + { set +x; } 2>/dev/null + local message="Results for ${CI_BUILD_NAME} @ ${CI_COMMIT_SHORT_SHA} (job ${CI_JOB_URL:-}): + +$(cat "${output_dir}/results.txt") +" + cmd_post_mr_comment "${message}" + set -x + fi +} + +# Helper function to wait for the CPU temperature to cool down on ARM. +wait_for_temp() { + { set +x; } 2>/dev/null + local temp_limit=${1:-38000} + if [[ -z "${THERMAL_FILE:-}" ]]; then + echo "Must define the THERMAL_FILE with the thermal_zoneX/temp file" \ + "to read the temperature from. This is normally set in the runner." >&2 + exit 1 + fi + local org_temp=$(cat "${THERMAL_FILE}") + if [[ "${org_temp}" -ge "${temp_limit}" ]]; then + echo -n "Waiting for temp to get down from ${org_temp}... " + fi + local temp="${org_temp}" + local secs=0 + while [[ "${temp}" -ge "${temp_limit}" ]]; do + sleep 1 + temp=$(cat "${THERMAL_FILE}") + echo -n "${temp} " + secs=$((secs + 1)) + if [[ ${secs} -ge 5 ]]; then + break + fi + done + if [[ "${org_temp}" -ge "${temp_limit}" ]]; then + echo "Done, temp=${temp}" + fi + set -x +} + +# Helper function to set the cpuset restriction of the current process. +cmd_cpuset() { + [[ "${SKIP_CPUSET:-}" != "1" ]] || return 0 + local newset="$1" + local mycpuset=$(cat /proc/self/cpuset) + mycpuset="/dev/cpuset${mycpuset}" + # Check that the directory exists: + [[ -d "${mycpuset}" ]] + if [[ -e "${mycpuset}/cpuset.cpus" ]]; then + echo "${newset}" >"${mycpuset}/cpuset.cpus" + else + echo "${newset}" >"${mycpuset}/cpus" + fi +} + +# Return the encoding/decoding speed from the Stats output. +_speed_from_output() { + local speed="$1" + local unit="${2:-MP/s}" + if [[ "${speed}" == *"${unit}"* ]]; then + speed="${speed%% ${unit}*}" + speed="${speed##* }" + echo "${speed}" + fi +} + + +# Run benchmarks on ARM for the big and little CPUs. +cmd_arm_benchmark() { + # Flags used for cjxl encoder with .png inputs + local jxl_png_benchmarks=( + # Lossy options: + "--epf=0 --distance=1.0 --speed=cheetah" + "--epf=2 --distance=1.0 --speed=cheetah" + "--epf=0 --distance=8.0 --speed=cheetah" + "--epf=1 --distance=8.0 --speed=cheetah" + "--epf=2 --distance=8.0 --speed=cheetah" + "--epf=3 --distance=8.0 --speed=cheetah" + "--modular -Q 90" + "--modular -Q 50" + # Lossless options: + "--modular" + "--modular -E 0 -I 0" + "--modular -P 5" + "--modular --responsive=1" + # Near-lossless options: + "--epf=0 --distance=0.3 --speed=fast" + "--modular -Q 97" + ) + + # Flags used for cjxl encoder with .jpg inputs. These should do lossless + # JPEG recompression (of pixels or full jpeg). + local jxl_jpeg_benchmarks=( + "--num_reps=3" + ) + + local images=( + "testdata/jxl/flower/flower.png" + ) + + local jpg_images=( + "testdata/jxl/flower/flower.png.im_q85_420.jpg" + ) + + if [[ "${SKIP_CPUSET:-}" == "1" ]]; then + # Use a single cpu config in this case. + local cpu_confs=("?") + else + # Otherwise the CPU config comes from the environment: + local cpu_confs=( + "${RUNNER_CPU_LITTLE}" + "${RUNNER_CPU_BIG}" + # The CPU description is something like 3-7, so these configurations only + # take the first CPU of the group. + "${RUNNER_CPU_LITTLE%%-*}" + "${RUNNER_CPU_BIG%%-*}" + ) + # Check that RUNNER_CPU_ALL is defined. In the SKIP_CPUSET=1 case this will + # be ignored but still evaluated when calling cmd_cpuset. + [[ -n "${RUNNER_CPU_ALL}" ]] + fi + + local jpg_dirname="third_party/corpora/jpeg" + mkdir -p "${jpg_dirname}" + local jpg_qualities=( 50 80 95 ) + for src_img in "${images[@]}"; do + for q in "${jpg_qualities[@]}"; do + local jpeg_name="${jpg_dirname}/"$(basename "${src_img}" .png)"-q${q}.jpg" + convert -sampling-factor 1x1 -quality "${q}" \ + "${src_img}" "${jpeg_name}" + jpg_images+=("${jpeg_name}") + done + done + + local output_dir="${BUILD_DIR}/benchmark_results" + mkdir -p "${output_dir}" + local runs_file="${output_dir}/runs.txt" + + if [[ ! -e "${runs_file}" ]]; then + echo -e "binary\tflags\tsrc_img\tsrc size\tsrc pixels\tcpuset\tenc size (B)\tenc speed (MP/s)\tdec speed (MP/s)\tJPG dec speed (MP/s)\tJPG dec speed (MB/s)" | + tee -a "${runs_file}" + fi + + mkdir -p "${BUILD_DIR}/arm_benchmark" + local flags + local src_img + for src_img in "${jpg_images[@]}" "${images[@]}"; do + local src_img_hash=$(sha1sum "${src_img}" | cut -f 1 -d ' ') + local enc_binaries=("${BUILD_DIR}/tools/cjxl") + local src_ext="${src_img##*.}" + for enc_binary in "${enc_binaries[@]}"; do + local enc_binary_base=$(basename "${enc_binary}") + + # Select the list of flags to use for the current encoder/image pair. + local img_benchmarks + if [[ "${src_ext}" == "jpg" ]]; then + img_benchmarks=("${jxl_jpeg_benchmarks[@]}") + else + img_benchmarks=("${jxl_png_benchmarks[@]}") + fi + + for flags in "${img_benchmarks[@]}"; do + # Encoding step. + local enc_file_hash="${enc_binary_base} || $flags || ${src_img} || ${src_img_hash}" + enc_file_hash=$(echo "${enc_file_hash}" | sha1sum | cut -f 1 -d ' ') + local enc_file="${BUILD_DIR}/arm_benchmark/${enc_file_hash}.jxl" + + for cpu_conf in "${cpu_confs[@]}"; do + cmd_cpuset "${cpu_conf}" + # nproc returns the number of active CPUs, which is given by the cpuset + # mask. + local num_threads="$(nproc)" + + echo "Encoding with: ${enc_binary_base} img=${src_img} cpus=${cpu_conf} enc_flags=${flags}" + local enc_output + if [[ "${flags}" == *"modular"* ]]; then + # We don't benchmark encoding speed in this case. + if [[ ! -f "${enc_file}" ]]; then + cmd_cpuset "${RUNNER_CPU_ALL:-}" + "${enc_binary}" ${flags} "${src_img}" "${enc_file}.tmp" + mv "${enc_file}.tmp" "${enc_file}" + cmd_cpuset "${cpu_conf}" + fi + enc_output=" ?? MP/s" + else + wait_for_temp + enc_output=$("${enc_binary}" ${flags} "${src_img}" "${enc_file}.tmp" \ + 2>&1 | tee /dev/stderr | grep -F "MP/s [") + mv "${enc_file}.tmp" "${enc_file}" + fi + local enc_speed=$(_speed_from_output "${enc_output}") + local enc_size=$(stat -c "%s" "${enc_file}") + + echo "Decoding with: img=${src_img} cpus=${cpu_conf} enc_flags=${flags}" + + local dec_output + wait_for_temp + dec_output=$("${BUILD_DIR}/tools/djxl" "${enc_file}" \ + --num_reps=5 --num_threads="${num_threads}" 2>&1 | tee /dev/stderr | + grep -E "M[BP]/s \[") + local img_size=$(echo "${dec_output}" | cut -f 1 -d ',') + local img_size_x=$(echo "${img_size}" | cut -f 1 -d ' ') + local img_size_y=$(echo "${img_size}" | cut -f 3 -d ' ') + local img_size_px=$(( ${img_size_x} * ${img_size_y} )) + local dec_speed=$(_speed_from_output "${dec_output}") + + # For JPEG lossless recompression modes (where the original is a JPEG) + # decode to JPG as well. + local jpeg_dec_mps_speed="" + local jpeg_dec_mbs_speed="" + if [[ "${src_ext}" == "jpg" ]]; then + wait_for_temp + local dec_file="${BUILD_DIR}/arm_benchmark/${enc_file_hash}.jpg" + dec_output=$("${BUILD_DIR}/tools/djxl" "${enc_file}" \ + "${dec_file}" --num_reps=5 --num_threads="${num_threads}" 2>&1 | \ + tee /dev/stderr | grep -E "M[BP]/s \[") + local jpeg_dec_mps_speed=$(_speed_from_output "${dec_output}") + local jpeg_dec_mbs_speed=$(_speed_from_output "${dec_output}" MB/s) + if ! cmp --quiet "${src_img}" "${dec_file}"; then + # Add a start at the end to signal that the files are different. + jpeg_dec_mbs_speed+="*" + fi + fi + + # Record entry in a tab-separated file. + local src_img_base=$(basename "${src_img}") + echo -e "${enc_binary_base}\t${flags}\t${src_img_base}\t${img_size}\t${img_size_px}\t${cpu_conf}\t${enc_size}\t${enc_speed}\t${dec_speed}\t${jpeg_dec_mps_speed}\t${jpeg_dec_mbs_speed}" | + tee -a "${runs_file}" + done + done + done + done + cmd_cpuset "${RUNNER_CPU_ALL:-}" + cat "${runs_file}" + + if [[ -n "${CI_BUILD_NAME:-}" ]]; then + load_mr_vars_from_commit + { set +x; } 2>/dev/null + local message="Results for ${CI_BUILD_NAME} @ ${CI_COMMIT_SHORT_SHA} (job ${CI_JOB_URL:-}): + +\`\`\` +$(column -t -s " " "${runs_file}") +\`\`\` +" + cmd_post_mr_comment "${message}" + set -x + fi +} + +# Generate a corpus and run the fuzzer on that corpus. +cmd_fuzz() { + local corpus_dir=$(realpath "${BUILD_DIR}/fuzzer_corpus") + local fuzzer_crash_dir=$(realpath "${BUILD_DIR}/fuzzer_crash") + mkdir -p "${corpus_dir}" "${fuzzer_crash_dir}" + # Generate step. + "${BUILD_DIR}/tools/fuzzer_corpus" "${corpus_dir}" + # Run step: + local nprocs=$(nproc --all || echo 1) + ( + cd "${BUILD_DIR}" + "tools/djxl_fuzzer" "${fuzzer_crash_dir}" "${corpus_dir}" \ + -max_total_time="${FUZZER_MAX_TIME}" -jobs=${nprocs} \ + -artifact_prefix="${fuzzer_crash_dir}/" + ) +} + +# Runs the linter (clang-format) on the pending CLs. +cmd_lint() { + merge_request_commits + { set +x; } 2>/dev/null + local versions=(${1:-6.0 7 8 9 10 11}) + local clang_format_bins=("${versions[@]/#/clang-format-}" clang-format) + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + + local ret=0 + local build_patch="${tmpdir}/build_cleaner.patch" + if ! "${MYDIR}/tools/build_cleaner.py" >"${build_patch}"; then + ret=1 + echo "build_cleaner.py findings:" >&2 + "${COLORDIFF_BIN}" <"${build_patch}" + echo "Run \`tools/build_cleaner.py --update\` to apply them" >&2 + fi + + local installed=() + local clang_patch + local clang_format + for clang_format in "${clang_format_bins[@]}"; do + if ! which "${clang_format}" >/dev/null; then + continue + fi + installed+=("${clang_format}") + local tmppatch="${tmpdir}/${clang_format}.patch" + # We include in this linter all the changes including the uncommitted changes + # to avoid printing changes already applied. + set -x + # Ignoring the error that git-clang-format outputs. + git -C "${MYDIR}" "${clang_format}" --binary "${clang_format}" \ + --style=file --diff "${MR_ANCESTOR_SHA}" -- >"${tmppatch}" || true + { set +x; } 2>/dev/null + if grep -E '^--- ' "${tmppatch}">/dev/null; then + if [[ -n "${LINT_OUTPUT:-}" ]]; then + cp "${tmppatch}" "${LINT_OUTPUT}" + fi + clang_patch="${tmppatch}" + else + echo "clang-format check OK" >&2 + return ${ret} + fi + done + + if [[ ${#installed[@]} -eq 0 ]]; then + echo "You must install clang-format for \"git clang-format\"" >&2 + exit 1 + fi + + # clang-format is installed but found problems. + echo "clang-format findings:" >&2 + "${COLORDIFF_BIN}" < "${clang_patch}" + + echo "clang-format found issues in your patches from ${MR_ANCESTOR_SHA}" \ + "to the current patch. Run \`./ci.sh lint | patch -p1\` from the base" \ + "directory to apply them." >&2 + exit 1 +} + +# Runs clang-tidy on the pending CLs. If the "all" argument is passed it runs +# clang-tidy over all the source files instead. +cmd_tidy() { + local what="${1:-}" + + if [[ -z "${CLANG_TIDY_BIN}" ]]; then + echo "ERROR: You must install clang-tidy-7 or newer to use ci.sh tidy" >&2 + exit 1 + fi + + local git_args=() + if [[ "${what}" == "all" ]]; then + git_args=(ls-files) + shift + else + merge_request_commits + git_args=( + diff-tree --no-commit-id --name-only -r "${MR_ANCESTOR_SHA}" + "${MR_HEAD_SHA}" + ) + fi + + # Clang-tidy needs the compilation database generated by cmake. + if [[ ! -e "${BUILD_DIR}/compile_commands.json" ]]; then + # Generate the build options in debug mode, since we need the debug asserts + # enabled for the clang-tidy analyzer to use them. + CMAKE_BUILD_TYPE="Debug" + cmake_configure + # Build the autogen targets to generate the .h files from the .ui files. + local autogen_targets=( + $(ninja -C "${BUILD_DIR}" -t targets | grep -F _autogen: | + cut -f 1 -d :) + ) + if [[ ${#autogen_targets[@]} != 0 ]]; then + ninja -C "${BUILD_DIR}" "${autogen_targets[@]}" + fi + fi + + cd "${MYDIR}" + local nprocs=$(nproc --all || echo 1) + local ret=0 + if ! parallel -j"${nprocs}" --keep-order -- \ + "${CLANG_TIDY_BIN}" -p "${BUILD_DIR}" -format-style=file -quiet "$@" {} \ + < <(git "${git_args[@]}" | grep -E '(\.cc|\.cpp)$') \ + >"${BUILD_DIR}/clang-tidy.txt"; then + ret=1 + fi + { set +x; } 2>/dev/null + echo "Findings statistics:" >&2 + grep -E ' \[[A-Za-z\.,\-]+\]' -o "${BUILD_DIR}/clang-tidy.txt" | sort \ + | uniq -c >&2 + + if [[ $ret -ne 0 ]]; then + cat >&2 </dev/null + local debsdir="${BUILD_DIR}/debs" + local f + while IFS='' read -r -d '' f; do + echo "=====================================================================" + echo "Package $f:" + dpkg --info $f + dpkg --contents $f + done < <(find "${BUILD_DIR}/debs" -maxdepth 1 -mindepth 1 -type f \ + -name '*.deb' -print0) +} + +build_debian_pkg() { + local srcdir="$1" + local srcpkg="$2" + + local debsdir="${BUILD_DIR}/debs" + local builddir="${debsdir}/${srcpkg}" + + # debuild doesn't have an easy way to build out of tree, so we make a copy + # of with all symlinks on the first level. + mkdir -p "${builddir}" + for f in $(find "${srcdir}" -mindepth 1 -maxdepth 1 -printf '%P\n'); do + if [[ ! -L "${builddir}/$f" ]]; then + rm -f "${builddir}/$f" + ln -s "${srcdir}/$f" "${builddir}/$f" + fi + done + ( + cd "${builddir}" + debuild -b -uc -us + ) +} + +cmd_debian_build() { + local srcpkg="${1:-}" + + case "${srcpkg}" in + jpeg-xl) + build_debian_pkg "${MYDIR}" "jpeg-xl" + ;; + highway) + build_debian_pkg "${MYDIR}/third_party/highway" "highway" + ;; + *) + echo "ERROR: Must pass a valid source package name to build." >&2 + ;; + esac +} + +get_version() { + local varname=$1 + local line=$(grep -F "set(${varname} " lib/CMakeLists.txt | head -n 1) + [[ -n "${line}" ]] + line="${line#set(${varname} }" + line="${line%)}" + echo "${line}" +} + +cmd_bump_version() { + local newver="${1:-}" + + if ! which dch >/dev/null; then + echo "Run:\n sudo apt install debhelper" + exit 1 + fi + + if [[ -z "${newver}" ]]; then + local major=$(get_version JPEGXL_MAJOR_VERSION) + local minor=$(get_version JPEGXL_MINOR_VERSION) + local patch=0 + minor=$(( ${minor} + 1)) + else + local major="${newver%%.*}" + newver="${newver#*.}" + local minor="${newver%%.*}" + newver="${newver#${minor}}" + local patch="${newver#.}" + if [[ -z "${patch}" ]]; then + patch=0 + fi + fi + + newver="${major}.${minor}.${patch}" + + echo "Bumping version to ${newver} (${major}.${minor}.${patch})" + sed -E \ + -e "s/(set\\(JPEGXL_MAJOR_VERSION) [0-9]+\\)/\\1 ${major})/" \ + -e "s/(set\\(JPEGXL_MINOR_VERSION) [0-9]+\\)/\\1 ${minor})/" \ + -e "s/(set\\(JPEGXL_PATCH_VERSION) [0-9]+\\)/\\1 ${patch})/" \ + -i lib/CMakeLists.txt + + # Update lib.gni + tools/build_cleaner.py --update + + # Mark the previous version as "unstable". + DEBCHANGE_RELEASE_HEURISTIC=log dch -M --distribution unstable --release '' + DEBCHANGE_RELEASE_HEURISTIC=log dch -M \ + --newversion "${newver}" \ + "Bump JPEG XL version to ${newver}." +} + +# Check that the AUTHORS file contains the email of the committer. +cmd_authors() { + merge_request_commits + local emails + local names + readarray -t emails < <(git log --format='%ae' "${MR_HEAD_SHA}...${MR_ANCESTOR_SHA}") + readarray -t names < <(git log --format='%an' "${MR_HEAD_SHA}...${MR_ANCESTOR_SHA}") + for i in "${!names[@]}"; do + echo "Checking name '${names[$i]}' with email '${emails[$i]}' ..." + "${MYDIR}"/tools/check_author.py "${emails[$i]}" "${names[$i]}" + done +} + +main() { + local cmd="${1:-}" + if [[ -z "${cmd}" ]]; then + cat >&2 < Build the given source package. + debian_stats Print stats about the built packages. + +oss-fuzz commands: + ossfuzz_asan Build the local source inside oss-fuzz docker with asan. + ossfuzz_msan Build the local source inside oss-fuzz docker with msan. + ossfuzz_ubsan Build the local source inside oss-fuzz docker with ubsan. + ossfuzz_ninja Run ninja on the BUILD_DIR inside the oss-fuzz docker. Extra + parameters are passed to ninja, for example "djxl_fuzzer" will + only build that ninja target. Use for faster build iteration + after one of the ossfuzz_*san commands. + +You can pass some optional environment variables as well: + - BUILD_DIR: The output build directory (by default "$$repo/build") + - BUILD_TARGET: The target triplet used when cross-compiling. + - CMAKE_FLAGS: Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS. + - CMAKE_PREFIX_PATH: Installation prefixes to be searched by the find_package. + - ENABLE_WASM_SIMD=1: enable experimental SIMD in WASM build (only). + - FUZZER_MAX_TIME: "fuzz" command fuzzer running timeout in seconds. + - LINT_OUTPUT: Path to the output patch from the "lint" command. + - SKIP_CPUSET=1: Skip modifying the cpuset in the arm_benchmark. + - SKIP_TEST=1: Skip the test stage. + - STORE_IMAGES=0: Makes the benchmark discard the computed images. + - TEST_STACK_LIMIT: Stack size limit (ulimit -s) during tests, in KiB. + - TEST_SELECTOR: pass additional arguments to ctest, e.g. "-R .Resample.". + - STACK_SIZE=1: Generate binaries with the .stack_sizes sections. + +These optional environment variables are forwarded to the cmake call as +parameters: + - CMAKE_BUILD_TYPE + - CMAKE_C_FLAGS + - CMAKE_CXX_FLAGS + - CMAKE_C_COMPILER_LAUNCHER + - CMAKE_CXX_COMPILER_LAUNCHER + - CMAKE_CROSSCOMPILING_EMULATOR + - CMAKE_FIND_ROOT_PATH + - CMAKE_EXE_LINKER_FLAGS + - CMAKE_MAKE_PROGRAM + - CMAKE_MODULE_LINKER_FLAGS + - CMAKE_SHARED_LINKER_FLAGS + - CMAKE_TOOLCHAIN_FILE + +Example: + BUILD_DIR=/tmp/build $0 opt +EOF + exit 1 + fi + + cmd="cmd_${cmd}" + shift + set -x + "${cmd}" "$@" +} + +main "$@" diff --git a/media/libjxl/src/cmake/FindAtomics.cmake b/media/libjxl/src/cmake/FindAtomics.cmake new file mode 100644 index 000000000..9a6cdc39e --- /dev/null +++ b/media/libjxl/src/cmake/FindAtomics.cmake @@ -0,0 +1,53 @@ +# Original issue: +# * https://gitlab.kitware.com/cmake/cmake/-/issues/23021#note_1098733 +# +# For reference: +# * https://gcc.gnu.org/wiki/Atomic/GCCMM +# +# riscv64 specific: +# * https://lists.debian.org/debian-riscv/2022/01/msg00009.html +# +# ATOMICS_FOUND - system has c++ atomics +# ATOMICS_LIBRARIES - libraries needed to use c++ atomics + +include(CheckCXXSourceCompiles) + +# RISC-V only has 32-bit and 64-bit atomic instructions. GCC is supposed +# to convert smaller atomics to those larger ones via masking and +# shifting like LLVM, but it’s a known bug that it does not. This means +# anything that wants to use atomics on 1-byte or 2-byte types needs +# -latomic, but not 4-byte or 8-byte (though it does no harm). +set(atomic_code + " + #include + #include + std::atomic n8 (0); // riscv64 + std::atomic n64 (0); // armel, mipsel, powerpc + int main() { + ++n8; + ++n64; + return 0; + }") + +check_cxx_source_compiles("${atomic_code}" ATOMICS_LOCK_FREE_INSTRUCTIONS) + +if(ATOMICS_LOCK_FREE_INSTRUCTIONS) + set(ATOMICS_FOUND TRUE) + set(ATOMICS_LIBRARIES) +else() + set(CMAKE_REQUIRED_LIBRARIES "-latomic") + check_cxx_source_compiles("${atomic_code}" ATOMICS_IN_LIBRARY) + set(CMAKE_REQUIRED_LIBRARIES) + if(ATOMICS_IN_LIBRARY) + set(ATOMICS_LIBRARY atomic) + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(Atomics DEFAULT_MSG ATOMICS_LIBRARY) + set(ATOMICS_LIBRARIES ${ATOMICS_LIBRARY}) + unset(ATOMICS_LIBRARY) + else() + if(Atomics_FIND_REQUIRED) + message(FATAL_ERROR "Neither lock free instructions nor -latomic found.") + endif() + endif() +endif() +unset(atomic_code) diff --git a/media/libjxl/src/cmake/FindBrotli.cmake b/media/libjxl/src/cmake/FindBrotli.cmake new file mode 100644 index 000000000..5c6cb0987 --- /dev/null +++ b/media/libjxl/src/cmake/FindBrotli.cmake @@ -0,0 +1,85 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set(brlibs brotlicommon brotlienc brotlidec) + +find_package(PkgConfig QUIET) +if (PkgConfig_FOUND) + foreach(brlib IN ITEMS ${brlibs}) + string(TOUPPER "${brlib}" BRPREFIX) + pkg_check_modules("PC_${BRPREFIX}" lib${brlib}) + endforeach() +endif() + +find_path(BROTLI_INCLUDE_DIR + NAMES brotli/decode.h + HINTS ${PC_BROTLICOMMON_INCLUDEDIR} ${PC_BROTLICOMMON_INCLUDE_DIRS} +) + +foreach(brlib IN ITEMS ${brlibs}) + string(TOUPPER "${brlib}" BRPREFIX) + find_library(${BRPREFIX}_LIBRARY + NAMES ${${BRPREFIX}_NAMES} ${brlib} + HINTS ${PC_${BRPREFIX}_LIBDIR} ${PC_${BRPREFIX}_LIBRARY_DIRS} + ) + + if (${BRPREFIX}_LIBRARY AND NOT TARGET ${brlib}) + if(CMAKE_VERSION VERSION_LESS "3.13.5") + add_library(${brlib} INTERFACE IMPORTED GLOBAL) + set_property(TARGET ${brlib} PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${BROTLI_INCLUDE_DIR}) + target_link_libraries(${brlib} INTERFACE ${${BRPREFIX}_LIBRARY}) + set_property(TARGET ${brlib} PROPERTY INTERFACE_COMPILE_OPTIONS ${PC_${BRPREFIX}_CFLAGS_OTHER}) + + add_library(${brlib}-static INTERFACE IMPORTED GLOBAL) + set_property(TARGET ${brlib}-static PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${BROTLI_INCLUDE_DIR}) + target_link_libraries(${brlib}-static INTERFACE ${${BRPREFIX}_LIBRARY}) + set_property(TARGET ${brlib}-static PROPERTY INTERFACE_COMPILE_OPTIONS ${PC_${BRPREFIX}_CFLAGS_OTHER}) + else() + add_library(${brlib} INTERFACE IMPORTED GLOBAL) + target_include_directories(${brlib} + INTERFACE ${BROTLI_INCLUDE_DIR}) + target_link_libraries(${brlib} + INTERFACE ${${BRPREFIX}_LIBRARY}) + target_link_options(${brlib} + INTERFACE ${PC_${BRPREFIX}_LDFLAGS_OTHER}) + target_compile_options(${brlib} + INTERFACE ${PC_${BRPREFIX}_CFLAGS_OTHER}) + + # TODO(deymo): Remove the -static library versions, this target is + # currently needed by brunsli.cmake. When importing it this way, the + # brotli*-static target is just an alias. + add_library(${brlib}-static ALIAS ${brlib}) + endif() + endif() +endforeach() + +if (BROTLICOMMON_FOUND AND BROTLIENC_FOUND AND BROTLIDEC_FOUND) + set(Brotli_FOUND ON) +else () + set(Brotli_FOUND OFF) +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Brotli + FOUND_VAR Brotli_FOUND + REQUIRED_VARS + BROTLI_INCLUDE_DIR + BROTLICOMMON_LIBRARY + BROTLIENC_LIBRARY + BROTLIDEC_LIBRARY + VERSION_VAR Brotli_VERSION +) + +mark_as_advanced( + BROTLI_INCLUDE_DIR + BROTLICOMMON_LIBRARY + BROTLIENC_LIBRARY + BROTLIDEC_LIBRARY +) + +if (Brotli_FOUND) + set(Brotli_LIBRARIES ${BROTLICOMMON_LIBRARY} ${BROTLIENC_LIBRARY} ${BROTLIDEC_LIBRARY}) + set(Brotli_INCLUDE_DIRS ${BROTLI_INCLUDE_DIR}) +endif() diff --git a/media/libjxl/src/cmake/FindHWY.cmake b/media/libjxl/src/cmake/FindHWY.cmake new file mode 100644 index 000000000..c1deb9b85 --- /dev/null +++ b/media/libjxl/src/cmake/FindHWY.cmake @@ -0,0 +1,66 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +find_package(PkgConfig QUIET) +if (PkgConfig_FOUND) + pkg_check_modules(PC_HWY QUIET libhwy) + set(HWY_VERSION ${PC_HWY_VERSION}) +endif () + +find_path(HWY_INCLUDE_DIR + NAMES hwy/highway.h + HINTS ${PC_HWY_INCLUDEDIR} ${PC_HWY_INCLUDE_DIRS} +) + +find_library(HWY_LIBRARY + NAMES ${HWY_NAMES} hwy + HINTS ${PC_HWY_LIBDIR} ${PC_HWY_LIBRARY_DIRS} +) + +if (HWY_INCLUDE_DIR AND NOT HWY_VERSION) + if (EXISTS "${HWY_INCLUDE_DIR}/hwy/highway.h") + file(READ "${HWY_INCLUDE_DIR}/hwy/highway.h" HWY_VERSION_CONTENT) + + string(REGEX MATCH "#define HWY_MAJOR +([0-9]+)" _dummy "${HWY_VERSION_CONTENT}") + set(HWY_VERSION_MAJOR "${CMAKE_MATCH_1}") + + string(REGEX MATCH "#define +HWY_MINOR +([0-9]+)" _dummy "${HWY_VERSION_CONTENT}") + set(HWY_VERSION_MINOR "${CMAKE_MATCH_1}") + + string(REGEX MATCH "#define +HWY_PATCH +([0-9]+)" _dummy "${HWY_VERSION_CONTENT}") + set(HWY_VERSION_PATCH "${CMAKE_MATCH_1}") + + set(HWY_VERSION "${HWY_VERSION_MAJOR}.${HWY_VERSION_MINOR}.${HWY_VERSION_PATCH}") + endif () +endif () + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(HWY + FOUND_VAR HWY_FOUND + REQUIRED_VARS HWY_LIBRARY HWY_INCLUDE_DIR + VERSION_VAR HWY_VERSION +) + +if (HWY_LIBRARY AND NOT TARGET hwy) + add_library(hwy INTERFACE IMPORTED GLOBAL) + + if(CMAKE_VERSION VERSION_LESS "3.13.5") + set_property(TARGET hwy PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${HWY_INCLUDE_DIR}) + target_link_libraries(hwy INTERFACE ${HWY_LIBRARY}) + set_property(TARGET hwy PROPERTY INTERFACE_COMPILE_OPTIONS ${PC_HWY_CFLAGS_OTHER}) + else() + target_include_directories(hwy INTERFACE ${HWY_INCLUDE_DIR}) + target_link_libraries(hwy INTERFACE ${HWY_LIBRARY}) + target_link_options(hwy INTERFACE ${PC_HWY_LDFLAGS_OTHER}) + target_compile_options(hwy INTERFACE ${PC_HWY_CFLAGS_OTHER}) + endif() +endif() + +mark_as_advanced(HWY_INCLUDE_DIR HWY_LIBRARY) + +if (HWY_FOUND) + set(HWY_LIBRARIES ${HWY_LIBRARY}) + set(HWY_INCLUDE_DIRS ${HWY_INCLUDE_DIR}) +endif () diff --git a/media/libjxl/src/cmake/FindLCMS2.cmake b/media/libjxl/src/cmake/FindLCMS2.cmake new file mode 100644 index 000000000..0a7b54eb9 --- /dev/null +++ b/media/libjxl/src/cmake/FindLCMS2.cmake @@ -0,0 +1,59 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +find_package(PkgConfig QUIET) +if (PkgConfig_FOUND) + pkg_check_modules(PC_LCMS2 QUIET libLCMS2) + set(LCMS2_VERSION ${PC_LCMS2_VERSION}) +endif () + +find_path(LCMS2_INCLUDE_DIR + NAMES lcms2.h + HINTS ${PC_LCMS2_INCLUDEDIR} ${PC_LCMS2_INCLUDE_DIRS} +) + +find_library(LCMS2_LIBRARY + NAMES ${LCMS2_NAMES} lcms2 liblcms2 lcms-2 liblcms-2 + HINTS ${PC_LCMS2_LIBDIR} ${PC_LCMS2_LIBRARY_DIRS} +) + +if (LCMS2_INCLUDE_DIR AND NOT LCMS_VERSION) + file(READ ${LCMS2_INCLUDE_DIR}/lcms2.h LCMS2_VERSION_CONTENT) + string(REGEX MATCH "#define[ \t]+LCMS_VERSION[ \t]+([0-9]+)[ \t]*\n" LCMS2_VERSION_MATCH ${LCMS2_VERSION_CONTENT}) + if (LCMS2_VERSION_MATCH) + string(SUBSTRING ${CMAKE_MATCH_1} 0 1 LCMS2_VERSION_MAJOR) + string(SUBSTRING ${CMAKE_MATCH_1} 1 2 LCMS2_VERSION_MINOR) + set(LCMS2_VERSION "${LCMS2_VERSION_MAJOR}.${LCMS2_VERSION_MINOR}") + endif () +endif () + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(LCMS2 + FOUND_VAR LCMS2_FOUND + REQUIRED_VARS LCMS2_LIBRARY LCMS2_INCLUDE_DIR + VERSION_VAR LCMS2_VERSION +) + +if (LCMS2_LIBRARY AND NOT TARGET lcms2) + add_library(lcms2 INTERFACE IMPORTED GLOBAL) + + if(CMAKE_VERSION VERSION_LESS "3.13.5") + set_property(TARGET lcms2 PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${LCMS2_INCLUDE_DIR}) + target_link_libraries(lcms2 INTERFACE ${LCMS2_LIBRARY}) + set_property(TARGET lcms2 PROPERTY INTERFACE_COMPILE_OPTIONS ${PC_LCMS2_CFLAGS_OTHER}) + else() + target_include_directories(lcms2 INTERFACE ${LCMS2_INCLUDE_DIR}) + target_link_libraries(lcms2 INTERFACE ${LCMS2_LIBRARY}) + target_link_options(lcms2 INTERFACE ${PC_LCMS2_LDFLAGS_OTHER}) + target_compile_options(lcms2 INTERFACE ${PC_LCMS2_CFLAGS_OTHER}) + endif() +endif() + +mark_as_advanced(LCMS2_INCLUDE_DIR LCMS2_LIBRARY) + +if (LCMS2_FOUND) + set(LCMS2_LIBRARIES ${LCMS2_LIBRARY}) + set(LCMS2_INCLUDE_DIRS ${LCMS2_INCLUDE_DIR}) +endif () diff --git a/media/libjxl/src/debian/changelog b/media/libjxl/src/debian/changelog new file mode 100644 index 000000000..a63607e5e --- /dev/null +++ b/media/libjxl/src/debian/changelog @@ -0,0 +1,83 @@ +jpeg-xl (0.7) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.7. + + -- JPEG XL Maintainers Mon, 08 Aug 2022 14:43:58 +0000 + +jpeg-xl (0.6) unstable; urgency=medium + + * Bump JPEG XL version to 0.6. + + -- JPEG XL Maintainers Fri, 10 Sep 2021 16:08:17 +0200 + +jpeg-xl (0.5.0) unstable; urgency=medium + + * Bump JPEG XL version to 0.5.0. + + -- JPEG XL Maintainers Thu, 12 Aug 2021 23:49:40 +0200 + +jpeg-xl (0.3.7) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.7. + + -- Sami Boukortt Mon, 29 Mar 2021 12:14:20 +0200 + +jpeg-xl (0.3.6) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.6. + + -- Sami Boukortt Thu, 25 Mar 2021 17:40:58 +0100 + +jpeg-xl (0.3.5) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.5. + + -- Sami Boukortt Tue, 23 Mar 2021 15:20:44 +0100 + +jpeg-xl (0.3.4) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.4. + + -- Sami Boukortt Tue, 16 Mar 2021 12:13:59 +0100 + +jpeg-xl (0.3.3) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.3. + + -- Sami Boukortt Fri, 5 Mar 2021 19:15:26 +0100 + +jpeg-xl (0.3.2) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.2. + + -- Alex Deymo Fri, 12 Feb 2021 21:00:12 +0100 + +jpeg-xl (0.3.1) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.1. + + -- Alex Deymo Tue, 09 Feb 2021 09:48:43 +0100 + +jpeg-xl (0.3) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3. + + -- Alex Deymo Wed, 27 Jan 2021 22:36:32 +0100 + +jpeg-xl (0.2) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.2. + + -- Alex Deymo Wed, 23 Nov 2020 20:42:10 +0100 + +jpeg-xl (0.1) UNRELEASED; urgency=medium + + * JPEG XL format release candidate. + + -- Alex Deymo Fri, 13 Nov 2020 17:42:24 +0100 + +jpeg-xl (0.0.2-1) UNRELEASED; urgency=medium + + * Initial debian package. + + -- Alex Deymo Tue, 27 Oct 2020 15:27:59 +0100 diff --git a/media/libjxl/src/debian/compat b/media/libjxl/src/debian/compat new file mode 100644 index 000000000..f599e28b8 --- /dev/null +++ b/media/libjxl/src/debian/compat @@ -0,0 +1 @@ +10 diff --git a/media/libjxl/src/debian/control b/media/libjxl/src/debian/control new file mode 100644 index 000000000..7a3c502e0 --- /dev/null +++ b/media/libjxl/src/debian/control @@ -0,0 +1,88 @@ +Source: jpeg-xl +Maintainer: JPEG XL Maintainers +Section: misc +Priority: optional +Standards-Version: 3.9.8 +Build-Depends: + asciidoc, + cmake, + debhelper (>= 9), + libbrotli-dev, + libgdk-pixbuf-2.0-dev | libgdk-pixbuf2.0-dev, + libgif-dev, + libgimp2.0-dev, + libgmock-dev, + libgoogle-perftools-dev, + libgtest-dev, + libhwy-dev (>= 0.15.0), + libjpeg-dev, + libopenexr-dev, + libpng-dev, + libwebp-dev, + pkg-config, + xdg-utils, + xmlto, +Homepage: https://github.com/libjxl/libjxl +Rules-Requires-Root: no + +Package: jxl +Architecture: any +Section: utils +Depends: ${misc:Depends}, ${shlibs:Depends} +Description: JPEG XL Image Coding System - "JXL" (command line utility) + The JPEG XL Image Coding System (ISO/IEC 18181) is a lossy and + lossless image compression format. It has a rich feature set and is + particularly optimized for responsive web environments, so that + content renders well on a wide range of devices. Moreover, it includes + several features that help transition from the legacy JPEG format. + . + This package installs the command line utilities. + +Package: libjxl-dev +Architecture: any +Section: libdevel +Depends: libjxl (= ${binary:Version}), ${misc:Depends} + libhwy-dev, +Description: JPEG XL Image Coding System - "JXL" (development files) + The JPEG XL Image Coding System (ISO/IEC 18181) is a lossy and + lossless image compression format. It has a rich feature set and is + particularly optimized for responsive web environments, so that + content renders well on a wide range of devices. Moreover, it includes + several features that help transition from the legacy JPEG format. + . + This package installs development files. + +Package: libjxl +Architecture: any +Multi-Arch: same +Section: libs +Depends: ${shlibs:Depends}, ${misc:Depends} +Pre-Depends: ${misc:Pre-Depends} +Description: JPEG XL Image Coding System - "JXL" (shared libraries) + The JPEG XL Image Coding System (ISO/IEC 18181) is a lossy and + lossless image compression format. It has a rich feature set and is + particularly optimized for responsive web environments, so that + content renders well on a wide range of devices. Moreover, it includes + several features that help transition from the legacy JPEG format. + . + This package installs shared libraries. + +Package: libjxl-gdk-pixbuf +Architecture: any +Multi-Arch: same +Section: libs +Depends: ${shlibs:Depends}, ${misc:Depends} +Pre-Depends: ${misc:Pre-Depends} +Description: JPEG XL Plugin for gdk-pixbuf + This package installs the required files for reading JPEG XL files in + GTK applications. + +Package: libjxl-gimp-plugin +Architecture: any +Multi-Arch: same +Section: graphics +Depends: ${shlibs:Depends}, ${misc:Depends} +Pre-Depends: ${misc:Pre-Depends} +Enhances: gimp +Description: JPEG XL Import and Export Plugin for GIMP + This is a plugin for GIMP version 2.10.x to import and export JPEG XL images. diff --git a/media/libjxl/src/debian/copyright b/media/libjxl/src/debian/copyright new file mode 100644 index 000000000..20225a920 --- /dev/null +++ b/media/libjxl/src/debian/copyright @@ -0,0 +1,194 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ +Upstream-Name: jpeg-xl + +Files: * +Copyright: 2020 the JPEG XL Project +License: BSD-3-clause + +Files: third_party/sjpeg/* +Copyright: 2017 Google, Inc +License: Apache-2.0 + +Files: third_party/skcms/* +Copyright: 2018 Google Inc. +License: BSD-3-clause + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + . + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of Google Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + . + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Files: testdata/external/pngsuite/* +Copyright: Willem van Schaik, 1996, 2011 +License: PngSuite License + See http://www.schaik.com/pngsuite/ for details. + . + Permission to use, copy, modify and distribute these images for any + purpose and without fee is hereby granted. + +Files: testdata/external/raw.pixls/* +Copyright: their respective owners listed in https://raw.pixls.us/ +License: CC0-1.0 + +Files: testdata/external/wesaturate/* +Copyright: their respective owners listed in https://www.wesaturate.com/ +License: CC0-1.0 + +Files: testdata/external/wide-gamut-tests/ +Copyright: github.com/codelogic/wide-gamut-tests authors. +License: Apache-2.0 + +License: Apache-2.0 + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + http://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + . + On Debian systems, the complete text of the Apache License, Version 2 + can be found in "/usr/share/common-licenses/Apache-2.0". + +License: CC0 + Creative Commons Zero v1.0 Universal + . + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL + SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN ATTORNEY-CLIENT + RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" + BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE USE OF THIS + DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER, AND DISCLAIMS + LIABILITY FOR DAMAGES RESULTING FROM THE USE OF THIS DOCUMENT OR THE + INFORMATION OR WORKS PROVIDED HEREUNDER. + . + Statement of Purpose + . + The laws of most jurisdictions throughout the world automatically confer + exclusive Copyright and Related Rights (defined below) upon the creator and + subsequent owner(s) (each and all, an "owner") of an original work of + authorship and/or a database (each, a "Work"). + . + Certain owners wish to permanently relinquish those rights to a Work for the + purpose of contributing to a commons of creative, cultural and scientific + works ("Commons") that the public can reliably and without fear of later + claims of infringement build upon, modify, incorporate in other works, reuse + and redistribute as freely as possible in any form whatsoever and for any + purposes, including without limitation commercial purposes. These owners may + contribute to the Commons to promote the ideal of a free culture and the + further production of creative, cultural and scientific works, or to gain + reputation or greater distribution for their Work in part through the use + and efforts of others. + . + For these and/or other purposes and motivations, and without any expectation + of additional consideration or compensation, the person associating CC0 with + a Work (the "Affirmer"), to the extent that he or she is an owner of + Copyright and Related Rights in the Work, voluntarily elects to apply CC0 to + the Work and publicly distribute the Work under its terms, with knowledge of + his or her Copyright and Related Rights in the Work and the meaning and + intended legal effect of CC0 on those rights. + . + 1. Copyright and Related Rights. A Work made available under CC0 may be + protected by copyright and related or neighboring rights ("Copyright and + Related Rights"). Copyright and Related Rights include, but are not limited + to, the following: + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); + iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation thereof, + including any amended or successor version of such directive); and + vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national implementations + thereof. + . + 2. Waiver. To the greatest extent permitted by, but not in contravention of, + applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and + unconditionally waives, abandons, and surrenders all of Affirmer's Copyright + and Related Rights and associated claims and causes of action, whether now + known or unknown (including existing as well as future claims and causes of + action), in the Work (i) in all territories worldwide, (ii) for the maximum + duration provided by applicable law or treaty (including future time + extensions), (iii) in any current or future medium and for any number of + copies, and (iv) for any purpose whatsoever, including without limitation + commercial, advertising or promotional purposes (the "Waiver"). Affirmer + makes the Waiver for the benefit of each member of the public at large and + to the detriment of Affirmer's heirs and successors, fully intending that + such Waiver shall not be subject to revocation, rescission, cancellation, + termination, or any other legal or equitable action to disrupt the quiet + enjoyment of the Work by the public as contemplated by Affirmer's express + Statement of Purpose. + . + 3. Public License Fallback. Should any part of the Waiver for any reason be + judged legally invalid or ineffective under applicable law, then the Waiver + shall be preserved to the maximum extent permitted taking into account + Affirmer's express Statement of Purpose. In addition, to the extent the + Waiver is so judged Affirmer hereby grants to each affected person a + royalty-free, non transferable, non sublicensable, non exclusive, + irrevocable and unconditional license to exercise Affirmer's Copyright and + Related Rights in the Work (i) in all territories worldwide, (ii) for the + maximum duration provided by applicable law or treaty (including future time + extensions), (iii) in any current or future medium and for any number of + copies, and (iv) for any purpose whatsoever, including without limitation + commercial, advertising or promotional purposes (the "License"). The License + shall be deemed effective as of the date CC0 was applied by Affirmer to the + Work. Should any part of the License for any reason be judged legally + invalid or ineffective under applicable law, such partial invalidity or + ineffectiveness shall not invalidate the remainder of the License, and in + such case Affirmer hereby affirms that he or she will not (i) exercise any + of his or her remaining Copyright and Related Rights in the Work or (ii) + assert any associated claims and causes of action with respect to the Work, + in either case contrary to Affirmer's express Statement of Purpose. + . + 4. Limitations and Disclaimers. + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, statutory or + otherwise, including without limitation warranties of title, + merchantability, fitness for a particular purpose, non infringement, or the + absence of latent or other defects, accuracy, or the present or absence of + errors, whether or not discoverable, all to the greatest extent permissible + under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without limitation + any person's Copyright and Related Rights in the Work. Further, Affirmer + disclaims responsibility for obtaining any necessary consents, permissions + or other rights required for any use of the Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to this + CC0 or use of the Work. + . + For more information, please see: + http://creativecommons.org/publicdomain/zero/1.0/> + diff --git a/media/libjxl/src/debian/jxl.install b/media/libjxl/src/debian/jxl.install new file mode 100644 index 000000000..c3bae3ed1 --- /dev/null +++ b/media/libjxl/src/debian/jxl.install @@ -0,0 +1,3 @@ +usr/bin/* +usr/share/man/man1/cjxl.1 +usr/share/man/man1/djxl.1 diff --git a/media/libjxl/src/debian/libjxl-dev.install b/media/libjxl/src/debian/libjxl-dev.install new file mode 100644 index 000000000..b735ec2c2 --- /dev/null +++ b/media/libjxl/src/debian/libjxl-dev.install @@ -0,0 +1,4 @@ +usr/include/jxl/*.h +usr/lib/*/*.a +usr/lib/*/*.so +usr/lib/*/pkgconfig/*.pc diff --git a/media/libjxl/src/debian/libjxl-gdk-pixbuf.install b/media/libjxl/src/debian/libjxl-gdk-pixbuf.install new file mode 100644 index 000000000..12d2ab250 --- /dev/null +++ b/media/libjxl/src/debian/libjxl-gdk-pixbuf.install @@ -0,0 +1,3 @@ +usr/lib/*/gdk-pixbuf-*/*/loaders/* +usr/share/mime/packages/image-jxl.xml +usr/share/thumbnailers/jxl.thumbnailer diff --git a/media/libjxl/src/debian/libjxl-gimp-plugin.install b/media/libjxl/src/debian/libjxl-gimp-plugin.install new file mode 100644 index 000000000..353431dba --- /dev/null +++ b/media/libjxl/src/debian/libjxl-gimp-plugin.install @@ -0,0 +1 @@ +usr/lib/gimp diff --git a/media/libjxl/src/debian/libjxl.install b/media/libjxl/src/debian/libjxl.install new file mode 100644 index 000000000..cd157a7a5 --- /dev/null +++ b/media/libjxl/src/debian/libjxl.install @@ -0,0 +1 @@ +usr/lib/*/libjxl*.so.* diff --git a/media/libjxl/src/debian/rules b/media/libjxl/src/debian/rules new file mode 100644 index 000000000..efed75d50 --- /dev/null +++ b/media/libjxl/src/debian/rules @@ -0,0 +1,17 @@ +#!/usr/bin/make -f + +include /usr/share/dpkg/pkg-info.mk + +%: + dh $@ --buildsystem=cmake + +override_dh_auto_configure: + # TODO(deymo): Remove the DCMAKE_BUILD_TYPE once builds without NDEBUG + # are as useful as Release builds. + dh_auto_configure -- \ + -DJPEGXL_VERSION=$(DEB_VERSION) \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DJPEGXL_FORCE_SYSTEM_GTEST=ON \ + -DJPEGXL_FORCE_SYSTEM_BROTLI=ON \ + -DJPEGXL_FORCE_SYSTEM_HWY=ON \ + -DJPEGXL_ENABLE_PLUGINS=ON diff --git a/media/libjxl/src/debian/source/format b/media/libjxl/src/debian/source/format new file mode 100644 index 000000000..163aaf8d8 --- /dev/null +++ b/media/libjxl/src/debian/source/format @@ -0,0 +1 @@ +3.0 (quilt) diff --git a/media/libjxl/src/deps.sh b/media/libjxl/src/deps.sh new file mode 100644 index 000000000..9aaabba2e --- /dev/null +++ b/media/libjxl/src/deps.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# This file downloads the dependencies needed to build JPEG XL into third_party. +# These dependencies are normally pulled by gtest. + +set -eu + +MYDIR=$(dirname $(realpath "$0")) + +# Git revisions we use for the given submodules. Update these whenever you +# update a git submodule. +THIRD_PARTY_BROTLI="35ef5c554d888bef217d449346067de05e269b30" +THIRD_PARTY_HIGHWAY="22e3d7276f4157d4a47586ba9fd91dd6303f441a" +THIRD_PARTY_SKCMS="64374756e03700d649f897dbd98c95e78c30c7da" +THIRD_PARTY_SJPEG="868ab558fad70fcbe8863ba4e85179eeb81cc840" +THIRD_PARTY_ZLIB="cacf7f1d4e3d44d871b605da3b647f07d718623f" +THIRD_PARTY_LIBPNG="a40189cf881e9f0db80511c382292a5604c3c3d1" + +# Download the target revision from GitHub. +download_github() { + local path="$1" + local project="$2" + + local varname=`echo "$path" | tr '[:lower:]' '[:upper:]'` + varname="${varname/\//_}" + local sha + eval "sha=\${${varname}}" + + local down_dir="${MYDIR}/downloads" + local local_fn="${down_dir}/${sha}.tar.gz" + if [[ -e "${local_fn}" && -d "${MYDIR}/${path}" ]]; then + echo "${path} already up to date." >&2 + return 0 + fi + + local url + local strip_components=0 + if [[ "${project:0:4}" == "http" ]]; then + # "project" is a googlesource.com base url. + url="${project}${sha}.tar.gz" + else + # GitHub files have a top-level directory + strip_components=1 + url="https://github.com/${project}/tarball/${sha}" + fi + + echo "Downloading ${path} version ${sha}..." >&2 + mkdir -p "${down_dir}" + curl -L --show-error -o "${local_fn}.tmp" "${url}" + mkdir -p "${MYDIR}/${path}" + tar -zxf "${local_fn}.tmp" -C "${MYDIR}/${path}" \ + --strip-components="${strip_components}" + mv "${local_fn}.tmp" "${local_fn}" +} + + +main() { + if git -C "${MYDIR}" rev-parse; then + cat >&2 < + +## Why + +The vast majority of web images are still sRGB. However, wide-gamut material is +increasingly being produced (photography, cinema, 4K). Screens covering most of +the Adobe RGB gamut are readily available and some also cover most of DCI P3 +(iPhone, Pixel2) or even BT.2020. + +Currently, after a camera records a very saturated red pixel, most raw +processors would clip it to the rather small sRGB gamut before saving as JPEG. +In keeping with our high-quality goal, we prevent such loss by allowing wider +input color spaces. + +## Which color space + +Even wide gamuts could be expressed relative to the sRGB primaries, but the +resulting coordinates may be outside the valid 0..1 range. Surprisingly, such +'unbounded' coordinates can be passed through color transforms provided the +transfer functions are expressed as parametric functions (not lookup tables). +However, most image file formats (including PNG and PNM) lack min/max metadata +and thus do not support unbounded coordinates. + +Instead, we need a larger working gamut to ensure most pixel coordinates are +within bounds and thus not clipped. However, larger gamuts result in lower +precision/resolution when using <= 16 bit encodings (as opposed to 32-bit float +in PFM). BT.2100 or P3 DCI appear to be good compromises. + +## CMS library + +Transforms with unbounded pixels are desirable because they reduce round-trip +error in tests. This requires parametric curves, which are only supported for +the common sRGB case in ICC v4 profiles. ArgyllCMS does not support v4. The +other popular open-source CMS is LittleCMS. It is also used by color-managed +editors (Krita/darktable), which increases the chances of interoperability. +However, LCMS has race conditions and overflow issues that prevent fuzzing. We +will later switch to the newer skcms. Note that this library does not intend to +support multiProcessElements, so HDR transfer functions cannot be represented +accurately. Thus in the long term, we will probably migrate away from ICC +profiles entirely. + +## Which viewer + +On Linux, Krita and darktable support loading our PNG output images and their +ICC profile. + +## How to compress/decompress + +### Embedded ICC profile + +- Create an 8-bit or 16-bit PNG with an iCCP chunk, e.g. using darktable. +- Pass it to `cjxl`, then `djxl` with no special arguments. The decoded output + will have the same bit depth (can override with `--output_bit_depth`) and + color space. + +### Images without metadata (e.g. HDR) + +- Create a PGM/PPM/PFM file in a known color space. +- Invoke `cjxl` with `-x color_space=RGB_D65_202_Rel_Lin` (linear 2020). For + details/possible values, see color_encoding.cc `Description`. +- Invoke `djxl` as above with no special arguments. diff --git a/media/libjxl/src/doc/developing_in_debian.md b/media/libjxl/src/doc/developing_in_debian.md new file mode 100644 index 000000000..a88b682ff --- /dev/null +++ b/media/libjxl/src/doc/developing_in_debian.md @@ -0,0 +1,57 @@ +# Developing in Debian + +These instructions assume an up-to-date Debian/Ubuntu system. +For other platforms, please instead use the following: + +* [Developing in Docker](developing_in_docker.md). +* [Cross Compiling for Windows with Crossroad](developing_with_crossroad.md). + +## Minimum build dependencies + +Apart from the dependencies in `third_party`, some of the tools use external +dependencies that need to be installed on your system first: + +```bash +sudo apt install cmake clang doxygen g++ extra-cmake-modules \ + libgif-dev libjpeg-dev ninja-build libgoogle-perftools-dev +``` + +Make sure your default `clang` compiler is at least version 6 by running + +```bash +clang --version +``` + +If it still shows an old version despite having, for example, `clang-7` installed, you need +to update the default `clang` compiler. On Debian-based systems run: + +```bash +sudo update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-7 100 +sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-7 100 +``` + +Optionally, to compile some of the extra tool support and tests you can install +the following packages: + +```bash +sudo apt install qtbase5-dev libqt5x11extras5-dev libwebp-dev libgimp2.0-dev \ + libopenexr-dev libgtest-dev libgmock-dev libbenchmark-dev libbenchmark-tools +``` + +For the lint/coverage commands, you will also need additional packages: + +```bash +sudo apt install clang-format clang-tidy curl parallel gcovr +``` + +## Building + +The `libjxl` project uses CMake to build. We provide a script that simplifies the +invocation. To build and test the project, run + +```bash +./ci.sh opt +``` + +This writes binaries to `build/tools` and runs unit tests. More information +on [build modes and testing](building_and_testing.md) is available. diff --git a/media/libjxl/src/doc/developing_in_docker.md b/media/libjxl/src/doc/developing_in_docker.md new file mode 100644 index 000000000..f104a44b2 --- /dev/null +++ b/media/libjxl/src/doc/developing_in_docker.md @@ -0,0 +1,114 @@ +# Developing in Docker + +Docker allows software to be run in a packaged container, isolated from the +host system. This allows code to be run in a standard environment instead +of dealing with different build environments during development. It also +simplifies resolving external dependencies by including them in the automated +setup of the container environment. + +## Set up the container + +You can read installation instructions and download Docker for your +operating system at [Get Docker](https://docs.docker.com/get-docker/). + +The image used by our builders is an Ubuntu Bionic image with all the +required dependencies and build tools installed. You can pull this image +from `gcr.io/jpegxl/jpegxl-builder` using the following command: + +```bash +sudo docker pull gcr.io/jpegxl/jpegxl-builder +``` + +To use the Docker image you can run the following command: + +```bash +sudo docker run -it --rm \ + --user $(id -u):$(id -g) \ + -v $HOME/jpeg-xl:/jpeg-xl -w /jpeg-xl \ + gcr.io/jpegxl/jpegxl-builder bash +``` + +This creates and runs a container that will be deleted after you exit the +terminal (`--rm` flag). + +The `-v` flag is to map the directory containing your jpeg-xl checkout in your +host (assumed to be at `$HOME/jpeg-xl`) to a directory inside the container at +/jpeg-xl. Since the container is accessing the host folder directly, +changes made on the host will will be seen immediately in the container, +and vice versa. + +On OSX, the path must be one of those shared and whitelisted with Docker. $HOME +(which is a subdirectory of /Users/) is known to work with the factory-default +settings of Docker. + +On OSX, you may ignore the warning that Docker "cannot find name for group ID". +This warning may also appear on some Linux computers. + +On Windows, you can run the following from the jpeg-xl directory obtained from +Gitlab: + +```bash +docker run -u root:root -it --rm -v %cd%:/jpeg-xl -w /jpeg-xl \ + gcr.io/jpegxl/jpegxl-builder +``` + +## Basic building + +Inside the Docker container, you can compile everything and run unit tests. +We need to specify `clang-7` because the default `clang` compiler is +not installed on the image. + +```bash +CC=clang-7 CXX=clang++-7 ./ci.sh opt +``` + +This writes binaries to `/jpeg-xl/build/tools` and runs unit tests. +More information on [build modes and testing](building_and_testing.md) is +available. + +If a `build` directory already exists and was configured for a different +compiler, `cmake` will complain. This can be avoided by renaming or removing +the existing `build` directory or setting the `BUILD_DIR` environment variable. + +## Cross-compiling environments (optional) + +We have installed the required cross-compiling tools in the main Docker image +`jpegxl-builder`. This allows compiling for other architectures, such as arm. +Tests will be emulated under `qemu`. + +The Docker container has several `qemu-*-static` binaries (such as +`qemu-aarch64-static`) that emulate other architectures on x86_64. These +binaries are automatically used when running foreign architecture programs +in the container only if `binfmt` is installed and configured on the *host* +to use binaries from `/usr/bin` . This is the default location on Ubuntu/Debian. + +You need to install both `binfmt-support` and `qemu-user-static` on the host, +since `binfmt-support` configures only `binfmt` signatures of architectures +that are installed. If these are configured elsewhere on other distributions, +you can symlink them to `/usr/bin/qemu-*-static` inside the Docker container. + +To install binfmt support in your Ubuntu host run *outside* the container: + +```bash +sudo apt install binfmt-support qemu-user-static +``` + +Then to cross-compile and run unit tests execute the following commands: + +```bash +export BUILD_TARGET=aarch64-linux-gnu CC=clang-7 CXX=clang++-7 +./ci.sh release +``` + +The `BUILD_TARGET=aarch64-linux-gnu` environment variable tells the `ci.sh` +script to cross-compile for that target. This also changes the default +`BUILD_DIR` to `build-aarch64` since you never want to mix them with the `build` +of your host. You can also explicitly set a `BUILD_DIR` environment variable +that will be used instead. The list of supported `BUILD_TARGET` values for this +container is: + +* *the empty string* (for native x86_64 support) +* aarch64-linux-gnu +* arm-linux-gnueabihf +* i686-linux-gnu +* x86_64-w64-mingw32 (for Windows builds) diff --git a/media/libjxl/src/doc/developing_in_github.md b/media/libjxl/src/doc/developing_in_github.md new file mode 100644 index 000000000..ecda64fc8 --- /dev/null +++ b/media/libjxl/src/doc/developing_in_github.md @@ -0,0 +1,357 @@ +# Developing in GitHub + +This document describes the development steps related to handling the git +repository. + +If you are new to GitHub, there's a nice [quickstart +guide](https://docs.github.com/en/github/getting-started-with-github/quickstart) +on GitHub explaining the basics. + +## Initial setup + +You need to perform this set up at least once if you haven't use GitHub before. +Read through the quickstart guide [Set up +Git](https://docs.github.com/en/github/getting-started-with-github/set-up-git) +page to get your git up and running. You will need to Fork a repository next. +After that "Life of a Pull Request" describes the common everyday workflows. + +### Configure your SSH access + +The easiest way to configure access to your Github repository is to use SSH +keys. For that you need an SSH private and public key, ideally a strong one. You +can use different keys for different sites if you want. In this example, we will +create one for using in GitHub only. + +Create the `~/.ssh/id_rsa_github` file executing the following. (Here and +elsewhere, {{X}} are placeholders for your email/username) + +```bash +ssh-keygen -t rsa -b 4096 -C "{{EMAIL}}" -f ~/.ssh/id_rsa_github +``` + +Go to your [SSH and GPG keys](https://github.com/settings/keys) settings and +paste the contents of your *public key* (the one ending in `.pub`), that would +be the output of this command: + +```bash +cat ~/.ssh/id_rsa_github.pub +``` + +To use a specific key when SSHing to the github.com domain, you can add this +snippet of config to your .ssh/config file executing the following. + +```bash +cat >> ~/.ssh/config <github.com/*{{USERNAME}}*/libjxl + +where {{USERNAME}} denotes your GitHub username. + +### Checkout the JPEG XL code from GitHub + +To get the source code on your computer you need to "clone" it. There are two +repositories at play here, the upstream repository (`libjxl/lbjxl`) and your +fork (`{{USERNAME}}/libjxl`). You will be normally fetching new changes from +the upstream repository and push changes to your fork. Getting your changes from +your fork to the upstream repository is done through the Web interface, via Pull +Requests. + +The [Fork a +repo](https://docs.github.com/en/github/getting-started-with-github/fork-a-repo) +goes in great detail, but uses the git remote names `upstream` for the shared +upstream repository and `origin` for your work. This guide proposes an +alternative naming scheme, used in the examples below. + +In this guide `origin` is the upstream shared repository and `myfork` is your +fork. You can use any other name for your fork if you want. Use the following +commands to set things up, replacing `{{USERNAME}}` with your GitHub username: + +```bash +git clone git https://github.com/libjxl/libjxl --recursive +cd libjxl +git remote set-url --push origin git@github.com:{{USERNAME}}/libjxl.git +git remote add myfork git@github.com:{{USERNAME}}/libjxl.git +git remote -vv +``` + +These commands did three things: + + * Created the repository with `origin` as the upstream remote, + * Changed the "push" URL to point to your fork, and + * Create a new remote pointing to your fork. + +The last step is optional. Since the "fetch" URL of `origin` points to the +shared repository and the "push" URL points to your fork, fetching from `origin` +always gets the latest changes from the upstream repository regardless of the +contents of your fork. + +Having a second origin called `myfork` is only useful if you need to download +pending changes from your fork from a different computer. For example, if you +work on multiple computers, each one with this setup, you can push to your +fork from one, and then fetch from `myfork` from another computer to get those. + +# Life of a Pull Request + +The general [GitHub flow +guide](https://docs.github.com/en/github/getting-started-with-github/github-flow) +applies to sending Pull Requests to this project. + +All the commands here assume you are in a git checkout as setup here. + +### Sync to the latest version + +```bash +git fetch origin +``` + +The last upstream version is now on `origin/main` and none of your local +branches have been modified by this command. + +### Start a new branch + +To start a new change you need a local branch. Each branch will represent a list +of individual commits which can then be requested to be merged as a single merge +request. So in general one branch is one code review, but each branch can have +multiple individual commits in it. + +```bash +git checkout origin/main -b mybranch +``` + +This will create a new branch `mybranch` tracking `origin/main`. A branch can +track any remove or local branch, which is used by some tools. Running `git +branch -vv` will show all the branches you have have, what are they tracking and +how many commits are ahead or behind. If you create a branch without tracking +any other, you can add or change the tracking branch of the current branch +running `git branch --set-upstream-to=...`. + +### Add changes to your branch + +Follow any of the many online tutorials, for example +[The basics](https://git-scm.com/book/en/v2/Git-Basics-Getting-a-Git-Repository) +chapter from the https://git-scm.com/doc website is a good starting guide. +Create, change or delete files and do a git commit with a message. + +The commit message is required. A commit message should follow the 50/72 rule: + +* First line is 50 characters or less. +* Then a blank line. +* Remaining text should be wrapped at 72 characters. + +The first line should identify your commit, since that's what most tools will +show to the user. First lines like "Some fixes" are not useful. Explain what the +commit contains and why. + +We follow the [Google C++ Coding +Style](https://google.github.io/styleguide/cppguide.html). A +[clang-format](https://clang.llvm.org/docs/ClangFormat.html) configuration +file is available to automatically format your code, you can invoke it with +the `./ci.sh lint` helper tool. + +Read the [CONTRIBUTING.md](../CONTRIBUTING.md) file for more information about +contributing to libjxl. + +### Upload your changes for review + +The first step is a local review of your changes to see what will you be sending +for review. `gitg` is a nice Gtk UI for reviewing your local changes, or `tig` +for similar ncurses console-based interface. Otherwise, from the terminal you +can run: + +```bash +git branch -vv +``` + +To show the current status of your local branches. In particular, since your +branch is tracking origin/main (as seen in the output) git will tell you that +you are one commit ahead of the tracking branch. + +``` +* mybranch e74ae1a [origin/main: ahead 1] Improved decoding speed by 40% +``` + +It is a good idea before uploading to sync again with upstream (`git fetch +origin`) and then run `git branch -vv` to check whether there are new changes +upstream. If that is the case, you will see a "behind" flag in the output: + +``` +* mybranch e74ae1a [origin/main: ahead 1, behind 2] Improved decoding speed by 40% +``` + +To sync your changes on top of the latest changes in upstream you need to +rebase: + +```bash +git rebase +``` + +This will by default rebase your current branch changes on top of the tracking +branch. In this case, this will try to apply the current commit on top of the +latest origin/main (which has 2 more commits than the ones we have in our +branch) and your branch will now include that. There could be conflicts that you +have to deal with. A shortcut to do both fetch and rebase is to run `git pull +-r`, where the `-r` stands for "rebase" and will rebase the local commits on top +of the remote ones. + +Before uploading a patch, make sure your patch conforms to the +[contributing guidelines](../CONTRIBUTING.md) and it +[builds and passes tests](building_and_testing.md). + +Once you are ready to send your branch for review, upload it to *your* fork: + +```bash +git push origin mybranch +``` + +This will push your local branch "mybranch" to a remote in your fork called +"mybranch". The name can be anything, but keep in mind that it is public. A link +to the URL to create a merge request will be displayed. + +``` +Enumerating objects: 627, done. +Counting objects: 100% (627/627), done. +Delta compression using up to 56 threads +Compressing objects: 100% (388/388), done. +Writing objects: 100% (389/389), 10.71 MiB | 8.34 MiB/s, done. +Total 389 (delta 236), reused 0 (delta 0) +emote: +remote: Create a pull request for 'mybranch' on GitHub by visiting: +remote: https://github.com/{{USERNAME}}/libjxl/pull/new/mybranch +remote: +To github.com:{{USERNAME}}/libjxl.git + * [new branch] mybranch -> mybranch +``` + +### Updating submodules + +The repository uses submodules for external library dependencies in +third_party. Each submodule points to a particular external commit of the +external repository by the hash code of that external commit. Just like +regular source code files, this hash code is part of the current branch and +jpeg xl commit you have checked out. + +When changing branches or when doing `git rebase`, git will unfortunately +*not* automatically set those hashes to the ones of the branch or jpeg xl +commit you changed to nor set the source files of the third_party submodules +to the new state. That is, even though git will have updated the jpeg xl +source code files on your disk to the new ones, it will leave the submodule +hashes and the files in third_party in your workspace to the ones they were +before you changed branches. This will show up in a git diff because this +is seen as a change compared to the branch you switched to. The git diff shows +the difference in hash codes (as if you are changing to the old ones), it does +not show changes in files inside the third_party directory. + +This mismatch can cause at least two problems: + +*) the jpeg xl codebase may not compile due to third_party library version +mismatch if e.g. API changed or a submodule was added/removed. + +*) when using `commit -a` your commit, which may be a technical change +unrelated to submodule changes, will unintentionally contain a change to the +submodules hash code, which is undesired unless you actually want to change +the version of third_party libraries. + +To resolve this, the submodules must be updated manually with +the following command after those actions (at least when the submodules +changed): + +``` +git submodule update --init --recursive +``` + +Here, the init flag ensures new modules get added when encessary and the +recursive flag is required for the submodules depending on other submodules. + +If you checkout a different branch, you can spot that submodules changed +when it shows a message similar to this: + +``` +M third_party/brotli +M third_party/lcms +``` + +If you do a rebase you may end up in a harder to solve situation, where +`git submodule update --init --recursive` itself fails with errors such as: + +``` +Unable to checkout '35ef5c554d888bef217d449346067de05e269b30' in submodule path 'third_party/brotli' +``` + +In that case, you can use the force flag: + +``` +git submodule update --init --recursive --force +``` + +### Iterating changes in your merge request + +To address reviewer changes you need to amend the local changes in your branch +first. Make the changes you need in your commit locally by running `git commit +--amend file1 file2 file3 ...` or `git commit --amend -a` to amend all the +changes from all the staged files. + +Once you have the new version of the "mybranch" branch to re-upload, you need to +force push it to the same branch in your fork. Since you are pushing a different +version of the same commit (as opposed to another commit on top of the existing +ones), you need to force the operation to replace the old version. + +```bash +git push origin mybranch --force +``` + +The merge request should now be updated with the new changes. + +### Merging your changes + +We use "rebase" as a merge policy, which means that there a no "merge" commits +(commits with more than one parent) but instead only a linear history of +changes. + +It is possible that other changes where added to the main branch since the last +time you rebased your changes. These changes could create a conflict with your +Pull Request, if so you need to `git fetch`, `git rebase` and push again your +changes which need to go through the continuous integration workflow again to +verify that all the tests pass again after including the latest changes. + +### Trying locally a pending Pull Request + +If you want to review in your computer a pending pull request proposed by +another user you can fetch the merge request commit with the following command, +replacing `NNNN` with the pull request number: + +```bash +git fetch origin refs/pull/NNNN/head +git checkout FETCH_HEAD +``` + +The first command will add to your local git repository the remote commit for +the pending pull request and store a temporary reference called `FETCH_HEAD`. +The second command then checks out that reference. From this point you can +review the files in your computer, create a local branch for this FETCH_HEAD or +build on top of it. diff --git a/media/libjxl/src/doc/developing_in_windows_msys.md b/media/libjxl/src/doc/developing_in_windows_msys.md new file mode 100644 index 000000000..3e86d5dd8 --- /dev/null +++ b/media/libjxl/src/doc/developing_in_windows_msys.md @@ -0,0 +1,168 @@ +# Developing for Windows with MSYS2 + +[MSYS2](https://www.msys2.org/) ("minimal system 2") is a software distribution and a development platform based on MinGW and Cygwin. It provides a Unix-like environment to build code on Windows. These instructions were written with a 64-bit instance of Windows 10 running on a VM. They may also work on native instances of Windows and other versions of Windows. + +## Build Environments + +MSYS2 provides multiple development [environments](https://www.msys2.org/docs/environments/). By convention, they are referred to in uppercase. They target slightly different platforms, runtime libraries, and compiler toolchains. For example, to build for 32-bit Windows, use the MINGW32 environment. For interoperability with Visual Studio projects, use the UCRT64 environment. + +Since all of the build environments are built on top of the MSYS environment, **all updates and package installation must be done from within the MSYS environment**. After making any package changes, `exit` all MSYS2 terminals and restart the desired build-environment. This reminder is repeated multiple times throughout this guide. + +* **MINGW32:** To compile for 32-bit Windows (on 64-bit Windows), use packages from the `mingw32` group. Package names are prefixed with `mingw-w64-i686`. The naming scheme may be different on the 32-bit version of MSYS2. + +* **MINGW64:** This is the primary environment to building for 64-bit Windows. It uses the older MSVCRT runtime, which is widely available across Windows systems. Package names are prefixed with `mingw-w64-x86_64`. + +* **UCRT64:** The Universal C Runtime (UCRT) is used by recent versions of Microsoft Visual Studio. It ships by default with Windows 10. For older versions of Windows, it must be provided with the application or installed by the user. Package names are prefixed with `mingw-w64-ucrt-x86_64`. + +* **CLANG64:** Unfortunately, the `gimp` packages are not available for the CLANG64 environment. However, `libjxl` will otherwise build in this environment if the appropriate packages are installed. Packages are prefixed with `mingw-w64-clang-x86_64`. + +## Install and Upgrade MSYS2 + +Download MSYS2 from the homepage. Install at a location without any spaces on a drive with ample free space. After installing the packages used in this guide, MSYS2 used about 15GB of space. + +Toward the end of installation, select the option to run MSYS2 now. A command-line window will open. Run the following command, and answer the prompts to update the repository and close the terminal. + +```bash +pacman -Syu +``` + +Now restart the MSYS environment and run the following command to complete updates: + +```bash +pacman -Su +``` + +## Package Management + +Packages are organized in groups, which share the build environment name, but in lower case. Then they have name prefixes that indicate which group they belong to. Consider this package search: `pacman -Ss cmake` + +``` +mingw32/mingw-w64-i686-cmake +mingw64/mingw-w64-x86_64-cmake +ucrt64/mingw-w64-ucrt-x86_64-cmake +clang64/mingw-w64-clang-x86_64-cmake +msys/cmake +``` + +We can see the organization `group/prefix-name`. When installing packages, the group name is optional. + +```bash +pacman -S mingw-w64-x86_64-cmake +``` + +For tools that need to be aware of the compiler to function, install the package that corresponds with the specific build-environment you plan to use. For `cmake`, install the `mingw64` version. The generic `msys/cmake` will not function correctly because it will not find the compiler. For other tools, the generic `msys` version is adequate, like `msys/git`. + +To remove packages, use: + +```bash +pacman -Rsc [package-name] +``` + +## Worst-Case Scenario... + +If packages management is done within a build environment other than MSYS, the environment structure will be disrupted and compilation will likely fail. If this happens, it may be necessary to reinstall MSYS2. + +1. Rename the `msys64` folder to `msys64.bak`. + +2. Use the installer to reinstall MSYS2 to `msys64`. + +3. Copy packages from `msys64.bak/var/cache/pacman/pkg/` to the new installation to save download time and bandwidth. + +4. Use `pacman` from within the MSYS environment to install and update packages. + +5. After successfully building a project, it is safe to delete `msys64.bak` + +## The MING64 Environment + +Next set up the MING64 environment. The following commands should be run within the MSYS environment. `pacman -S` is used to install packages. The `--needed` argument prevents packages from being reinstalled. + +```bash +pacman -S --needed base-devel mingw-w64-x86_64-toolchain +pacman -S git mingw-w64-x86_64-cmake mingw-w64-x86_64-ninja \ + mingw-w64-x86_64-gtest mingw-w64-x86_64-giflib \ + mingw-w64-x86_64-libpng mingw-w64-x86_64-libjpeg-turbo +``` + +## Build `libjxl` + +Download the source from the libjxl [releases](https://github.com/libjxl/libjxl/releases) page. Alternatively, you may obtain the latest development version with `git`. Run `./deps.sh` to ensure additional third-party dependencies are downloaded. + +Start the MINGW64 environment, create a build directory within the source directory, and configure with `cmake`. + +```bash +mkdir build +cd build +cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=OFF \ + -DJPEGXL_ENABLE_BENCHMARK=OFF -DJPEGXL_ENABLE_PLUGINS=ON \ + -DJPEGXL_ENABLE_MANPAGES=OFF -DJPEGXL_FORCE_SYSTEM_BROTLI=ON \ + -DJPEGXL_FORCE_SYSTEM_GTEST=ON .. +``` + +Check the output to see if any dependencies were missed and need to be installed. Adding `-G Ninja` may be helpful, but on my computer, Ninja was selected by default. Remember that package changes must be done from the MSYS environment. Then exit all MSYS2 terminals and restart the build environment. + +If all went well, you may now run `cmake` to build `libjxl`: + +```bash +cmake --build . +``` + +Do not be alarmed by the compiler warnings. They are a caused by differences between gcc/g++ and clang. The build should complete successfully. Then `cjxl`, `djxl`, `jxlinfo`, and others can be run from within the build environment. Moving them into the native Windows environment requires resolving `dll` issues that are beyond the scope of this document. + +## The `clang` Compiler + +To use the `clang` compiler, install the packages that correspond with the environment you wish to use. Remember to make package changes from within the MSYS environment. + +``` +mingw-w64-i686-clang +mingw-w64-i686-clang-tools-extra +mingw-w64-i686-clang-compiler-rt + +mingw-w64-x86_64-clang +mingw-w64-x86_64-clang-tools-extra +mingw-w64-x86_64-clang-compiler-rt + +mingw-w64-ucrt64-x86_64-clang +mingw-w64-ucrt64-x86_64-clang-tools-extra +mingw-w64-ucrt64-x86_64-clang-compiler-rt +``` + +After the `clang` compiler is installed, 'libjxl' can be built with the `./ci.sh` script. + +```bash +./ci.sh opt -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=OFF \ + -DJPEGXL_ENABLE_BENCHMARK=OFF -DJPEGXL_ENABLE_MANPAGES=OFF \ + -DJPEGXL_FORCE_SYSTEM_BROTLI=ON -DJPEGXL_FORCE_SYSTEM_GTEST=ON +``` + +On my computer, `doxygen` packages needed to be installed to proceed with building. Use `pacman -Ss doxygen` to find the packages to install. + +## The GIMP Plugin + +To build the GIMP plugin, install the relevant `gimp` package. This will also install dependencies. Again, perform package management tasks from only the MSYS environment. Then restart the build environment. + +```bash +pacman -S mingw-w64-i686-gimp +pacman -S mingw-w64-x86_64-gimp +pacman -S mingw-w64-ucrt-x86_64-gimp +``` + +If `clang` is installed, you can use the `./ci.sh` script to build. Otherwise, navigate to the build directory to reconfigure and build with `cmake`. + +```bash +cd build +rm -r CM* +cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=OFF \ + -DJPEGXL_ENABLE_BENCHMARK=OFF -DJPEGXL_ENABLE_MANPAGES=OFF \ + -DJPEGXL_ENABLE_PLUGINS=ON -DJPEGXL_FORCE_SYSTEM_BROTLI=ON \ + -DJPEGXL_FORCE_SYSTEM_GTEST=ON .. +``` + +The plugin is built statically, so there should be no need to install `dll` files. To try out the plugin: + +1. [Download](https://www.gimp.org/downloads/) and install the stable version of GIMP (currently 2.10.24). + +2. Create a new folder: `C:\Program Files\GIMP 2\lib\gimp\2.0\plug-ins\file-jxl` + +3. Copy `build/plugins/gimp/file-jxl.exe` to the new folder. diff --git a/media/libjxl/src/doc/developing_in_windows_vcpkg.md b/media/libjxl/src/doc/developing_in_windows_vcpkg.md new file mode 100644 index 000000000..332e4513e --- /dev/null +++ b/media/libjxl/src/doc/developing_in_windows_vcpkg.md @@ -0,0 +1,91 @@ +# Developing on Windows with Visual Studio 2019 + +These instructions assume an up-to-date Windows 10 (e.g. build 19041.928) with +**Microsoft Visual Studio 2019** (e.g. Version 16.9.0 Preview 4.0) installed. If +unavailable, please use another build environment: + +* [Docker container](developing_in_docker.md) +* [MSYS2 on Windows](developing_in_windows_msys.md) +* [Crossroad on Linux](developing_with_crossroad.md) (cross compilation for Windows) + +## Minimum build dependencies + +Apart from the dependencies in third_party, some of the tools use external +dependencies that need to be installed in your system first. + +Please install [vcpkg](https://vcpkg.readthedocs.io/en/latest/examples/installing-and-using-packages/) +(tested with version 2019.07.18), and use it to install the following libraries: + +``` +vcpkg install gtest:x64-windows +vcpkg install giflib:x64-windows +vcpkg install libjpeg-turbo:x64-windows +vcpkg install libpng:x64-windows +vcpkg install zlib:x64-windows +``` + +## Building + +From Visual Studio, open the CMakeLists.txt in the JPEG XL root directory. +Right-click the CMakeLists.txt entry in the Folder View of the Solution +Explorer. In the context menu, select CMake Settings. Click on the green plus +to add an x64-Clang configuration and the red minus to remove any non-Clang +configuration (the MSVC compiler is currently not supported). Click on the blue +hyperlink marked "CMakeSettings.json" and an editor will open. Insert the +following text after replacing $VCPKG with the directory where you installed +vcpkg above. + +``` +{ + "configurations": [ + { + "name": "x64-Clang-Release", + "generator": "Ninja", + "configurationType": "MinSizeRel", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "-DCMAKE_TOOLCHAIN_FILE=$VCPKG/scripts/buildsystems/vcpkg.cmake", + "buildCommandArgs": "-v", + "ctestCommandArgs": "", + "inheritEnvironments": [ "clang_cl_x64" ], + "variables": [ + { + "name": "VCPKG_TARGET_TRIPLET", + "value": "x64-windows", + "type": "STRING" + }, + { + "name": "JPEGXL_ENABLE_TCMALLOC", + "value": "False", + "type": "BOOL" + }, + { + "name": "BUILD_GMOCK", + "value": "True", + "type": "BOOL" + }, + { + "name": "gtest_force_shared_crt", + "value": "True", + "type": "BOOL" + }, + { + "name": "JPEGXL_ENABLE_FUZZERS", + "value": "False", + "type": "BOOL" + }, + { + "name": "JPEGXL_ENABLE_VIEWERS", + "value": "False", + "type": "BOOL" + } + ] + } + ] +} +``` + +The project is now ready for use. To build, simply press F7 (or choose +Build All from the Build menu). This writes binaries to +`out/build/x64-Clang-Release/tools`. The main [README.md](../README.md) explains +how to use the encoder/decoder and benchmark binaries. diff --git a/media/libjxl/src/doc/developing_with_crossroad.md b/media/libjxl/src/doc/developing_with_crossroad.md new file mode 100644 index 000000000..e7c2f23f9 --- /dev/null +++ b/media/libjxl/src/doc/developing_with_crossroad.md @@ -0,0 +1,116 @@ +# Cross Compiling for Windows with Crossroad + +[Crossroad](https://pypi.org/project/crossroad/) is a tool to set up cross-compilation environments on GNU/Linux distributions. These instructions assume a Debian/Ubuntu system. However, they can likely be adapted to other Linux environments. Since Ubuntu can be run on Windows through WSL, these instruction may be useful for developing directly on Windows. + +## Install Crossroad + +Crossroad requires tools included with `python3-docutils` and `mingw-w64`. They may be installed using: + +```bash +sudo aptitude install python3-docutils mingw-w64 +``` + +The `zstandard` python package is also required, but is not available in the repositories. It may be installed using `pip`. + +```bash +pip3 install zstandard +``` + +After the dependencies are installed, crossroad itself maybe installed with `pip`. + +```bash +pip3 install crossroad +``` + +If there are errors while running crossroad, it may need to be downloaded and installed directly using `setup.py`. Instructions are on the crossroad homepage. + +## Update Debian Alternatives + +Since `libjxl` uses C++ features that require posix threads, the symlinks used by the Debian alternative system need to be updated: + +```bash +sudo update-alternatives --config x86_64-w64-mingw32-g++ +``` + +Select the option that indicates `posix` usage. Repeat for `gcc` and `i686`: + +```bash +sudo update-alternatives --config x86_64-w64-mingw32-gcc +sudo update-alternatives --config i686-w64-mingw32-gcc +sudo update-alternatives --config i686-w64-mingw32-g++ +``` + +## Create a New Crossroad Project + +Crossroad supports the following platforms: + +``` +native Native platform (x86_64 GNU/Linux) +android-x86 Generic Android/Bionic on x86 +android-mips64 Generic Android/Bionic on MIPS64 +android-x86-64 Generic Android/Bionic on x86-64 +w64 Windows 64-bit +w32 Windows 32-bit +android-arm64 Generic Android/Bionic on ARM64 +android-mips Generic Android/Bionic on MIPS +android-arm Generic Android/Bionic on ARM +``` + +To begin cross compiling for Windows, a new project needs to be created: + +```bash +crossroad w64 [project-name] +``` + +## Install Dependencies + +Since the `gimp` development package is required to build the GIMP plugin and also includes most of the packages required by `libjxl`, install it first. + +```bash +crossroad install gimp +``` + +`gtest` and `brotli` are also required. + +```bash +crossroad install gtest brotli +``` + +If any packages are later found to be missing, you may search for them using: + +```bash +crossroad search [...] +``` + +## Build `libjxl` + +Download the source from the libjxl [releases](https://github.com/libjxl/libjxl/releases) page. Alternatively, you may obtain the latest development version with `git`. Run `./deps.sh` to ensure additional third-party dependencies are downloaded. Unfortunately, the script `./ci.sh` does not work with Crossroad, so `cmake` will need to be called directly. + +Create a build directory within the source directory. If you haven't already, start your crossroad project and run `cmake`: + +```bash +mkdir build +cd build +crossroad w64 libjxl +crossroad cmake -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=OFF \ + -DJPEGXL_ENABLE_BENCHMARK=OFF -DJPEGXL_ENABLE_MANPAGES=OFF \ + -DJPEGXL_ENABLE_PLUGINS=ON -DJPEGXL_FORCE_SYSTEM_BROTLI=ON \ + -DJPEGXL_FORCE_SYSTEM_GTEST=ON .. +``` + +Check the output to see if any dependencies were missed and need to be installed. If all went well, you may now run `cmake` to build `libjxl`: + +```bash +cmake --build . +``` + +## Try out the GIMP Plugin + +The plugin is built statically, so there should be no need to install `dll` files. To try out the plugin: + +1. [Download](https://www.gimp.org/downloads/) and install the stable version of GIMP (currently 2.10.24). + +2. Create a new folder: `C:\Program Files\GIMP 2\lib\gimp\2.0\plug-ins\file-jxl` + +3. Copy `build/plugins/gimp/file-jxl.exe` to the new folder. diff --git a/media/libjxl/src/doc/format_overview.md b/media/libjxl/src/doc/format_overview.md new file mode 100644 index 000000000..4614df550 --- /dev/null +++ b/media/libjxl/src/doc/format_overview.md @@ -0,0 +1,284 @@ +# JPEG XL Format Overview + +This document gives an overview of the JPEG XL file format and codestream, +its features, and the underlying design rationale. +The aim of this document is to provide general insight into the +format capabilities and design, thus helping developers +better understand how to use the `libjxl` API. + +## Codestream and File Format + +The JPEG XL format is defined in ISO/IEC 18181. This standard consists of +four parts: + +* 18181-1: Core codestream +* 18181-2: File format +* 18181-3: Conformance testing +* 18181-4: Reference implementation + +### Core codestream + +The core codestream contains all the data necessary to decode and display +still image or animation data. This includes basic metadata like image dimensions, +the pixel data itself, colorspace information, orientation, upsampling, etc. + +### File format + +The JPEG XL file format can take two forms: + +* A 'naked' codestream. In this case, only the image/animation data itself is +stored, and no additional metadata can be included. Such a file starts with the +bytes `0xFF0A` (the JPEG marker for "start of JPEG XL codestream"). +* An ISOBMFF-based container. This is a box-based container that includes a +JPEG XL codestream box (`jxlc`), and can optionally include other boxes with +additional information, such as Exif metadata. In this case, the file starts with +the bytes `0x0000000C 4A584C20 0D0A870A`. + +### Conformance testing + +This part of the standard defines precision bounds and test cases for conforming +decoders, to verify that they implement all coding tools correctly and accurately. + +### Reference implementation + +The `libjxl` software is the reference implementation of JPEG XL. + + +## Metadata versus Image Data + +JPEG XL makes a clear separation between metadata and image data. +Everything that is needed to correctly display an image is +considered to be image data, and is part of the core codestream. This includes +elements that have traditionally been considered 'metadata', such as ICC profiles +and Exif orientation. The goal is to reduce the ambiguity and potential for +incorrect implementations that can be caused by having a 'black box' codestream +that only contains numerical pixel data, requiring applications to figure out how +to correctly interpret the data (i.e. apply color transforms, upsampling, +orientation, blending, cropping, etc.). By including this functionality in the +codestream itself, the decoder can provide output in a normalized way +(e.g. in RGBA, orientation already applied, frames blended and coalesced), +simplifying things and making it less error-prone for applications. + +The remaining metadata, e.g. Exif or XMP, can be stored in the container format, +but it does not influence image rendering. In the case of Exif orientation, +this field has to be ignored by applications, since the orientation in the +codestream always takes precedence (and will already have been applied +transparently by the decoder). This means that stripping metadata can be done +without affecting the displayed image. + + +## Codestream Features + +### Color Management + +In JPEG XL, images always have a fully defined colorspace, i.e. it is always +unambiguous how to interpret the pixel values. There are two options: + +* Pixel data is in a specified (non-XYB) colorspace, and the decoder will produce +a pixel buffer in this colorspace plus an ICC profile that describes that +colorspace. Mathematically lossless encoding can only use this option. +* Pixel data is in the XYB colorspace, which is an absolute colorspace. +In this case, the decoder can produce a pixel buffer directly in a desired +display space like sRGB, Display-P3 or Rec.2100 PQ. + +The image header always contains a colorspace; however, its meaning depends on +which of the above two options were used: + +* In the first case (non-XYB), the signaled colorspace defines the +interpretation of the pixel data. +* In the second case (XYB), the signaled colorspace is merely a _suggestion_ +of a target colorspace to represent the image in, i.e. it is the colorspace +the original image was in, that has a sufficiently wide gamut and a +suitable transfer curve to represent the image data with high fidelity +using a limited bit depth representation. + +Colorspaces can be signaled in two ways in JPEG XL: + +* CICP-style Enum values: This is a very compact representation that +covers most or all of the common colorspaces. The decoder can convert +XYB to any of these colorspaces without requiring an external color management +library. +* ICC profiles: Arbitrary ICC profiles can also be used, including +CMYK ones. The ICC profile data gets compressed. In this case, external +color management software (e.g. lcms2 or skcms) has to be used for color +conversions. + +### Frames + +A JPEG XL codestream contains one or more frames. In the case of animation, +these frames have a duration and can be looped (infinitely or a number of times). +Zero-duration frames are possible and represent different layers of the image. + +Frames can have a blendmode (Replace, Add, Alpha-blend, Multiply, etc.) and +they can use any previous frame as a base. +They can be smaller than the image canvas, in which case the pixels outside the +crop are copied from the base frame. They can be positioned at an arbitrary +offset from the image canvas; this offset can also be negative and frames can +also be larger than the image canvas, in which case parts of the frame will +be invisible and only the intersection with the image canvas will be shown. + +By default, the decoder will blend and coalesce frames, producing only a single +output frame when there are subsequent zero-duration frames, and all output frames +are of the same size (the size of the image canvas) and have either no duration +(in case of a still image) or a non-zero duration (in case of animation). + +### Pixel Data + +Every frame contains pixel data encoded in one of two modes: + +* VarDCT mode: In this mode, variable-sized DCT transforms are applied +and the image data is encoded in the form of DCT coefficients. This mode is +always lossy, but it can also be used to losslessly represent an existing +(already lossy) JPEG image, in which case only the DCT8x8 is used. +* Modular mode: In this mode, only integer arithmetic is used, which +enables lossless compression. However, this mode can also be used for lossy +compression. Multiple transformations can be used to improve compression or to +obtain other desirable effects: reversible color transforms (RCTs), +(delta) palette transforms, and a modified non-linear Haar transform +called Squeeze, which facilitates (but does not require) lossy compression +and enables progressive decoding. + +Internally, the VarDCT mode uses Modular sub-bitstreams to encode +various auxiliary images, such as the "LF image" (a 1:8 downscaled version +of the image that contains the DC coefficients of DCT8x8 and low-frequency +coefficients of the larger DCT transforms), extra channels besides the +three color channels (e.g. alpha), and weights for adaptive quantization. + +In addition, both modes can separately encode additional 'image features' that +are rendered on top of the decoded image: + +* Patches: rectangles from a previously decoded frame (which can be a +'hidden' frame that is not displayed but only stored to be referenced later) +can be blended using one of the blendmodes on top of the current frame. +This allows the encoder to identify repeating patterns (such as letters of +text) and encode them only once, using patches to insert the pattern in +multiple spots. These patterns are encoded in a previous frame, making +it possible to add Modular-encoded pixels to a VarDCT-encoded frame or +vice versa. +* Splines: centripetal Catmull-Rom splines can be encoded, with a color +and a thickness that can vary along the arclength of the curve. +Although the current encoder does not use this bitstream feature yet, we +anticipate that it can be useful to complement DCT-encoded data, since +thin lines are hard to represent faithfully using the DCT. +* Noise: luma-modulated synthetic noise can be added to an image, e.g. +to emulate photon noise, in a way that avoids poor compression due to +high frequency DCT coefficients. + +Finally, both modes can also optionally apply two filtering methods to +the decoded image, which both have the goal of reducing block artifacts +and ringing: + +* Gabor-like transform ('Gaborish'): a small (3x3) blur that gets +applied across block and group boundaries, reducing blockiness. The +encoder applies the inverse sharpening transform before encoding, +effectively getting the benefits of lapped transforms without the +disadvantages. +* Edge-preserving filter ('EPF'): similar to a bilateral filter, +this smoothing filter avoids blurring edges while reducing ringing. +The strength of this filter is signaled and can locally be adapted. + +### Groups + +In both modes (Modular and VarDCT), the frame data is signaled as +a sequence of groups. These groups can be decoded independently, +and the frame header contains a table of contents (TOC) with bitstream +offsets for the start of each group. This enables parallel decoding, +and also partial decoding of a region of interest or a progressive preview. + +In VarDCT mode, all groups have dimensions 256x256 (or smaller at the +right and bottom borders). First the LF image is encoded, also in +256x256 groups (corresponding to 2048x2048 pixels, since this data +corresponds to the 1:8 image). This means there is always a basic +progressive preview available in VarDCT mode. +Optionally, the LF image can be encoded separately in a (hidden) +LF frame, which can itself recursively be encoded in VarDCT mode +and have its own LF frame. This makes it possible to represent huge +images while still having an overall preview that can be efficiently +decoded. +Then the HF groups are encoded, corresponding to the remaining AC +coefficients. The HF groups can be encoded in multiple passes for +more progressive refinement steps; the coefficients of all passes +are added. Unlike JPEG progressive scan scripts, JPEG XL allows +signaling any amount of detail in any part of the image in any pass. + +In Modular mode, groups can have dimensions 128x128, 256x256, 512x512 +or 1024x1024. If the Squeeze transform was used, the data will +be split in three parts: the Global groups (the top of the Laplacian +pyramid that fits in a single group), the LF groups (the middle part +of the Laplacian pyramid that corresponds to the data needed to +reconstruct the 1:8 image) and the HF groups (the base of the Laplacian +pyramid), where the HF groups are again possibly encoded in multiple +passes (up to three: one for the 1:4 image, one for the 1:2 image, +and one for the 1:1 image). + +In case of a VarDCT image with extra channels (e.g. alpha), the +VarDCT groups and the Modular groups are interleaved in order to +allow progressive previews of all the channels. + +The default group order is to encode the LF and HF groups in +scanline order (top to bottom, left to right), but this order +can be permuted arbitrarily. This allows, for example, a center-first +ordering or a saliency-based ordering, causing the bitstream +to prioritize progressive refinements in a different way. + + +## File Format Features + +Besides the image data itself (stored in the `jxlc` codestream box), +the optional container format allows storing additional information. + +## Metadata + +Three types of metadata can be included in a JPEG XL container: + +* Exif (`Exif`) +* XMP (`xml `) +* JUMBF (`jumb`) + +This metadata can contain information about the image, such as copyright +notices, GPS coordinates, camera settings, etc. +If it contains rendering-impacting information (such as Exif orientation), +the information in the codestream takes precedence. + +## Compressed Metadata + +The container allows the above metadata to be stored either uncompressed +(e.g. plaintext XML in the case of XMP) or by Brotli-compression. +In the latter case, the box type is `brob` (Brotli-compressed Box) and +the first four bytes of the box contents define the actual box type +(e.g. `xml `) it represents. + +## JPEG Bitstream Reconstruction Data + +JPEG XL can losslessly recompress existing JPEG files. +The general design philosophy still applies in this case: +all the image data is stored in the codestream box, including the DCT +coefficients of the original JPEG image and possibly an ICC profile or +Exif orientation. + +In order to allow bit-identical reconstruction of the original JPEG file +(not just the image but the actual file), additional information is needed, +since the same image data can be encoded in multiple ways as a JPEG file. +The `jbrd` box (JPEG Bitstream Reconstruction Data) contains this information. +Typically it is relatively small. Using the image data from the codestream, +the JPEG bitstream reconstruction data, and possibly other metadata boxes +that were present in the JPEG file (Exif/XMP/JUMBF), the exact original +JPEG file can be reconstructed. + +This box is not needed to display a recompressed JPEG image; it is only +needed to reconstruct the original JPEG file. + +## Frame Index + +The container can optionally store a `jxli` box, which contains an index +of offsets to keyframes of a JPEG XL animation. It is not needed to display +the animation, but it does facilitate efficient seeking. + +## Partial Codestream + +The codestream can optionally be split into multiple `jxlp` boxes; +conceptually, this is equivalent to a single `jxlc` box that contains the +concatenation of all partial codestream boxes. +This makes it possible to create a file that starts with +the data needed for a progressive preview of the image, followed by +metadata, followed by the remaining image data. diff --git a/media/libjxl/src/doc/fuzzing.md b/media/libjxl/src/doc/fuzzing.md new file mode 100644 index 000000000..af926596f --- /dev/null +++ b/media/libjxl/src/doc/fuzzing.md @@ -0,0 +1,184 @@ +# Fuzzing + +Fuzzing is a technique to find potential bugs by providing randomly generated +invalid inputs. To detect potential bugs such as programming errors we use +fuzzing in combination with ASan (Address Sanitizer), MSan (Memory Sanitizer), +UBSan (Undefined Behavior Sanitizer) and asserts in the code. An invalid input +will likely produce a decoding error (some API function returning error), which +is absolutely not a problem, but what it should not do is access memory out of +bounds, use uninitialized memory or hit a false assert condition. + +## Automated Fuzzing with oss-fuzz + +libjxl fuzzing is integrated into [oss-fuzz](https://github.com/google/oss-fuzz) +as the project `libjxl`. oss-fuzz regularly runs the fuzzers on the `main` +branch and reports bugs into their bug tracker which remains private until the +bugs are fixed in main. + +## Fuzzer targets + +There are several fuzzer executable targets defined in the `tools/` directory +to fuzz different parts of the code. The main one is `djxl_fuzzer`, which uses +the public C decoder API to attempt to decode an image. The fuzzer input is not +directly the .jxl file, the last few bytes of the fuzzer input are used to +decide *how* will the API be used (if preview is requested, the pixel format +requested, if the .jxl input data is provided altogether, etc) and the rest of +the fuzzer input is provided as the .jxl file to the decoder. Some bugs might +reproduce only if the .jxl input is decoded in certain way. + +The remaining fuzzer targets execute a specific portion the codec that might be +easier to fuzz independently from the whole codec. + +## Reproducing fuzzer bugs + +A fuzzer target, like `djxl_fuzzer` accepts as a parameter one or more files +that will be used as inputs. This runs the fuzzer program in test-only mode +where no new inputs are generated and only the provided files are tested. This +is the easiest way to reproduce a bug found by the fuzzer using the generated +test case from the bug report. + +oss-fuzz uses a specific compiler version and flags, and it is built using +Docker. Different compiler versions will have different support for detecting +certain actions as errors, so we want to reproduce the build from oss-fuzz as +close as possible. To reproduce the build as generated by oss-fuzz there are a +few helper commands in `ci.sh` as explained below. + +### Generate the gcr.io/oss-fuzz/libjxl image + +First you need the ossfuzz libjxl builder image. This is the base oss-fuzz +builder image with a few dependencies installed. To generate it you need to +check out the oss-fuzz project and build it: + +```bash +git clone https://github.com/google/oss-fuzz.git ~/oss-fuzz +cd ~/oss-fuzz +sudo infra/helper.py build_image libjxl +``` + +This will create the `gcr.io/oss-fuzz/libjxl` docker image. You can check if it +was created verifying that it is listed in the output of the `sudo docker image +ls` command. + +### Build the fuzzer targets with oss-fuzz + +To build the fuzzer targets from the current libjxl source checkout, use the +`./ci.sh ossfuzz_msan` command for MSan, `./ci.sh ossfuzz_asan` command for ASan +or `./ci.sh ossfuzz_ubsan` command for UBSan. All the `JXL_ASSERT` and +`JXL_DASSERT` calls are enabled in all the three modes. These ci.sh helpers will +reproduce the oss-fuzz docker call to build libjxl mounting the current source +directory into the Docker container. Ideally you will run this command in a +different build directory separated from your regular builds. + +For example, for MSan builds run: + +```bash +BUILD_DIR=build-fuzzmsan ./ci.sh ossfuzz_msan +``` + +After this, the fuzzer program will be generated in the build directory like +for other build modes: `build-fuzzmsan/tools/djxl_fuzzer`. + +### Iterating changes with oss-fuzz builds + +After modifying the source code to fix the fuzzer-found bug, or to include more +debug information, you can rebuild only a specific fuzzer target to save on +rebuilding time and immediately run the test case again. For example, for +rebuilding and testing only `djxl_fuzzer` in MSan mode we can run: + +```bash +BUILD_DIR=build-fuzzmsan ./ci.sh ossfuzz_msan djxl_fuzzer && build-fuzzmsan/tools/djxl_fuzzer path/to/testcase.bin +``` + +When MSan and ASan fuzzers fail they will print a stack trace at the point where +the error occurred, and some related information. To make these these stack +traces useful we need to convert the addresses to function names and source file +names and lines, which is done with the "symbolizer". For UBSan to print a stack +trace we need to set the `UBSAN_OPTIONS` environment variables when running the +fuzzer. + +Set the following environment variables when testing the fuzzer binaries. Here +`clang` should match the compiler version used by the container, you can pass a +different compiler version in the following example by first installing the +clang package for that version outside the container and using `clang-NN` +(for example `clang-11`) instead of `clang` in the following commands: + +```bash +symbolizer=$($(realpath $(which clang)) -print-prog-name=llvm-symbolizer) +export MSAN_SYMBOLIZER_PATH="${symbolizer}" +export UBSAN_SYMBOLIZER_PATH="${symbolizer}" +export ASAN_SYMBOLIZER_PATH="${symbolizer}" +export ASAN_OPTIONS=detect_leaks=1 +export UBSAN_OPTIONS=print_stacktrace=1 +``` + +Note: The symbolizer binary must be a program called `llvm-symbolizer`, any +other file name will fail. There are normally symlinks already installed with +the right name which the `-print-prog-name` would print. + +## Running the fuzzers locally + +Running the fuzzer targets in fuzzing mode can be achieved by running them with +no parameters, or better with a parameter with the path to a *directory* +containing a seed of files to use as a starting point. Note that passing a +directory is considered a corpus to use for fuzzing while passing a file is +considered an input to evaluate. Multi-process fuzzing is also supported. For +details about all the fuzzing options run: + +```bash +build-fuzzmsan/tools/djxl_fuzzer -help=1 +``` + +## Writing fuzzer-friendly code + +Fuzzing on itself can't find programming bugs unless an input makes the program +perform an invalid operation (read/write out of bounds, perform an undefined +behavior operation, etc). You can help the fuzzer find invalid situations by +adding asserts: + + * `JXL_ASSERT()` is enabled in Release mode by default. It can be disabled + with `-DJXL_ENABLE_ASSERT=0` but the intention is that it will run for all + the users in released code. If performance of the check is not an issue (like + checks done once per image, once per channel, once per group, etc) a + JXL_ASSERT is appropriate. A failed assert is preferable to an out of bounds + write. + + * `JXL_DASSERT()` is only enabled in Debug builds, which includes all the ASan, + MSan and UBSan builds. Performance of these checks is not an issue if kept + within reasonable limits (automated msan/asan test should finish withing 1 + hour for example). Fuzzing is more effective when the given input runs + faster, so keep that in mind when adding a complex DASSERT that runs multiple + times per output pixel. + + * For MSan builds it is also possible to specify that certain values must be + initialized. This is automatic for values that are used to make decisions + (like when used in an `if` statement or in the ternary operator condition) + but those checks can be made explicit for image data using the + `JXL_CHECK_IMAGE_INITIALIZED(image, rect)` macro. This helps document and + check (only in MSan builds) that a given portion of the image is expected to + be initialized, allowing to catch errors earlier in the process. + +## Dealing with use-of-uninitialized memory + +In MSan builds it is considered an error to *use* uninitialized memory. Using +the memory normally requires something like a decision / branch based on the +uninitialized value, just running `memcpy()` or simple arithmetic over +uninitialized memory is not a problem. Notably, computing `DemoteTo()`, +`NearestInt()` or similar expressions that create a branch based on the value of +the uninitialized memory will trigger an MSan error. + +In libjxl we often run vectorized operations over a series of values, rounding +up to the next multiple of a vector size, thus operating over uninitialized +values past the end of the requested region. These values are part of the image +padding but are not initialized. This behavior would not create an MSan error +unless the processing includes operations like `NearestInt()`. For such cases +the preferred solution is to use `msan::UnpoisonMemory` over the portion of +memory of the last SIMD vector before processing, and then running +`msan::PoisonMemory` over the corresponding value in the output side. A note +including why this is safe to do must be added, for example if the processing +doesn't involve any cross-lane computation. + +Initializing padding memory in MSan builds is discouraged because it may hide +bugs in functions that weren't supposed to read from the padding. Initializing +padding memory in all builds, including Release builds, would mitigate the +MSan potential security issue but it would hide the logic bug for a longer time +and potentially incur in a performance hit. diff --git a/media/libjxl/src/doc/jxl.svg b/media/libjxl/src/doc/jxl.svg new file mode 100644 index 000000000..a80778b0b --- /dev/null +++ b/media/libjxl/src/doc/jxl.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/media/libjxl/src/doc/man/cjxl.txt b/media/libjxl/src/doc/man/cjxl.txt new file mode 100644 index 000000000..261742a68 --- /dev/null +++ b/media/libjxl/src/doc/man/cjxl.txt @@ -0,0 +1,102 @@ +cjxl(1) +======= +:doctype: manpage + +Name +---- + +cjxl - compress images to JPEG XL + +Synopsis +-------- + +*cjxl* ['options'...] 'input' ['output.jxl'] + +Description +----------- + +`cjxl` compresses an image or animation to the JPEG XL format. It is intended to +spare users the trouble of determining a set of optimal parameters for each +individual image. Instead, for a given target quality, it should provide +consistent visual results across various kinds of images. The defaults have been +chosen to be sensible, so that the following commands should give satisfactory +results in most cases: + +---- +cjxl input.png output.jxl +cjxl input.jpg output.jxl +cjxl input.gif output.jxl +---- + +Options +------- + +-h:: +--help:: + Displays the options that `cjxl` supports. On its own, it will only show + basic options. It can be combined with `-v` or `-v -v` to show increasingly + advanced options as well. + +-v:: +--verbose:: + Increases verbosity. Can be repeated to increase it further, and also + applies to `--help`. + +-d 'distance':: +--distance='distance':: + The preferred way to specify quality. It is specified in multiples of a + just-noticeable difference. That is, `-d 0` is mathematically lossless, + `-d 1` should be visually lossless, and higher distances yield denser and + denser files with lower and lower fidelity. Lossy sources such as JPEG and + GIF files are compressed losslessly by default, and in the case of JPEG + files specifically, the original JPEG can then be reconstructed bit-for-bit. + For lossless sources, `-d 1` is the default. + +-q 'quality':: +--quality='quality':: + Alternative way to indicate the desired quality. 100 is lossless and lower + values yield smaller files. There is no lower bound to this quality + parameter, but positive values should approximately match the quality + setting of libjpeg. + +-e 'effort':: +--effort='effort':: + Controls the amount of effort that goes into producing an ``optimal'' file + in terms of quality/size. That is to say, all other parameters being equal, + a higher effort should yield a file that is at least as dense and possibly + denser, and with at least as high and possibly higher quality. ++ +Recognized effort settings, from fastest to slowest, are: ++ +- 1 or ``lightning'' +- 2 or ``thunder'' +- 3 or ``falcon'' +- 4 or ``cheetah'' +- 5 or ``hare'' +- 6 or ``wombat'' +- 7 or ``squirrel'' (default) +- 8 or ``kitten'' +- 9 or ``tortoise'' + +Examples +-------- + +---- +# Compress a PNG file to a high-quality JPEG XL version. +$ cjxl input.png output.jxl + +# Compress it at a slightly lower quality, appropriate for web use. +$ cjxl -d 2 input.png output.jxl + +# Compress it losslessly. These are equivalent. +$ cjxl -d 0 input.png lossless.jxl +$ cjxl -q 100 input.png lossless.jxl + +# Compress a JPEG file losslessly. +$ cjxl input.jpeg lossless-jpeg.jxl +---- + +See also +-------- + +*djxl*(1) diff --git a/media/libjxl/src/doc/man/djxl.txt b/media/libjxl/src/doc/man/djxl.txt new file mode 100644 index 000000000..bd57b4420 --- /dev/null +++ b/media/libjxl/src/doc/man/djxl.txt @@ -0,0 +1,61 @@ +djxl(1) +======= +:doctype: manpage + +Name +---- + +djxl - decompress JPEG XL images + +Synopsis +-------- + +*djxl* ['options'...] 'input.jxl' ['output'] + +Description +----------- + +`djxl` decompresses a JPEG XL image or animation. The output format is determined +by the extension of the output file, which can be `.png`, `.jpg`, `.ppm`, `.pfm`. +If the JPEG XL input file contains an animation, multiple output files will be +produced, with names of the form "'output'-*framenumber*.ext". + + +Options +------- + +-h:: +--help:: + Displays the options that `djxl` supports. + +-j:: +--pixels_to_jpeg:: + By default, if the input JPEG XL contains a recompressed JPEG file, + djxl reconstructs the exact original JPEG file if the output file has the + `.jpg` (or `.jpeg`) filename extension. + This flag causes the decoder to instead decode the image to pixels and + encode a new (lossy) JPEG in this case. + + +-q 'quality':: +--jpeg_quality='quality':: + When decoding to `.jpg`, use this output quality. This option implicitly + enables the --pixels_to_jpeg option. + + +Examples +-------- + +---- +# Decompress a JPEG XL file to PNG +$ djxl input.jxl output.png + +# Reconstruct a losslessly-recompressed JPEG file +$ djxl lossless-jpeg.jxl reconstructed.jpeg +---- + + +See also +-------- + +*cjxl*(1) diff --git a/media/libjxl/src/doc/release.md b/media/libjxl/src/doc/release.md new file mode 100644 index 000000000..70f1278a4 --- /dev/null +++ b/media/libjxl/src/doc/release.md @@ -0,0 +1,267 @@ +# libjxl release process + +This guide documents the release process for the libjxl project. + +libjxl follows the [semantic versioning](https://semver.org/spec/v2.0.0.html) +specification for released versions. Releases are distributed as tags in the git +repository with the semantic version prefixed by the letter "v". For example, +release version "0.3.7" will have a git tag "v0.3.7". + +The public API is explicitly defined as C headers in the `lib/include` +directory, normally installed in your include path. All other headers are +internal API and are not covered by the versioning rules. + +## Development and release workflow + +New code development is performed on the `main` branch of the git repository. +Pre-submit checks enforce minimum build and test requirements for new patches +that balance impact and test latency, but not all checks are performed before +pull requests are merged. Several slower checks only run *after* the code has +been merged to `main`, resulting in some errors being detected hours after the +code is merged or even days after in the case of fuzzer-detected bugs. + +Release tags are cut from *release branches*. Each MAJOR.MINOR version has its +own release branch, for example releases `0.7.0`, `0.7.1`, `0.7.2`, ... would +have tags `v0.7.0`, `v0.7.1`, `v0.7.2`, ... on commits from the `v0.7.x` branch. +`v0.7.x` is a branch name, not a tag name, and doesn't represent a released +version since semantic versioning requires that the PATCH is a non-negative +number. Released tags don't each one have their own release branch, all releases +from the same MAJOR.MINOR version will share the same branch. The first commit +after the branch-off points between the main branch and the release branch +should be tagged with the suffix `-snapshot` and the name of the next +MAJOR.MINOR version, in order to get meaningful ouput for `git --describe`. + +The main purpose of the release branch is to stabilize the code before a +release. This involves including fixes to existing bugs but **not** including +new features. New features often come with new bugs which take time to fix, so +having a release branch allows us to cherry-pick *bug fixes* from the `main` +branch into the release branch without including the new *features* from `main`. +For this reason it is important to make small commits in `main` and separate bug +fixes from new features. + +After the initial minor release (`MAJOR.MINOR.PATCH`, for example `0.5.0`) the +release branch is used to continue to cherry-pick fixes to be included in a +patch release, for example a version `0.5.1` release. Patch fixes are only meant +to fix security bugs or other critical bugs that can't wait until the next major +or minor release. + +Release branches *may* continue to be maintained even after the next minor or +major version has been released to support users that can't update to a newer +minor release. In that case, the same process applies to all the maintained +release branches. + +A release branch with specific cherry-picks from `main` means that the release +code is actually a version of the code that never existed in the `main` branch, +so it needs to be tested independently. Pre-submit and post-submit tests run on +release branches (branches matching `v*.*.x`) but extra manual checks should be +performed before a release, specially if multiple bug fixes interact with each +other. Take this into account when selecting which commits to include in a +release. The objective is to have a stable version that can be used without +problems for months. Having the latest improvements at the time the release tag +is created is a non-goal. + +## Creating a release branch + +A new release branch is needed before creating a new major or minor release, +that is, a new release where the MAJOR or MINOR numbers are increased. Patch +releases, where only the PATCH number is increased, reuse the branch from the +previous release of the same MAJOR and MINOR numbers. + +The following instructions assume that you followed the recommended [libjxl git +setup](developing_in_github.md) where `origin` points to the upstream +libjxl/libjxl project, otherwise use the name of your upstream remote repository +instead of `origin`. + +The release branch is normally created from the latest work in `main` at the +time the branch is created, but it is possible to create the branch from an +older commit if the current `main` is particularly unstable or includes commits +that were not intended to be included in the release. The following example +creates the branch `v0.5.x` from the latest commit in main (`origin/main`), if a +different commit is to be used then replace `origin/main` with the SHA of that +commit. Change the `v0.5.x` branch name to the one you are creating. + +```bash +git fetch origin main +git push git@github.com:libjxl/libjxl.git origin/main:refs/heads/v0.5.x +``` + +Here we use the SSH URL explicitly since you are pushing to the `libjxl/libjxl` +project directly to a branch there. If you followed the guide `origin` will have +the HTTPS URL which wouldn't normally let you push since you wouldn't be +authenticated. The `v*.*.x` branches are [GitHub protected +branches](https://docs.github.com/en/github/administering-a-repository/defining-the-mergeability-of-pull-requests/about-protected-branches) +in our repository, however you can push to a protected branch when *creating* it +but you can't directly push to it after it is created. To include more changes +in the release branch see the "Cherry-picking fixes to a release" section below. + +## Creating a merge label + +We use GitHub labels in Pull Requests to keep track of the changes that should +be merged into a given release branch. For this purpose create a new label for +each new MAJOR.MINOR release branch called `merge-MAJOR.MINOR`, for example, +`merge-0.5`. + +In the [edit labels](https://github.com/libjxl/libjxl/issues/labels) page, click +on "New label" and create the label. Pick your favorite color. + +Labels are a GitHub-only concept and are not represented in git. You can add the +label to a Pull Request even after it was merged, whenever it is decided that +the Pull Request should be included in the given release branch. Adding the +label doesn't automatically merge it to the release branch. + +## Update the versioning number + +The version number (as returned by `JxlDecoderVersion`) in the source code in +`main` must match the semantic versioning of a release. After the release +branch is created the code in `main` will only be included in the next major +or minor release. Right after a release branch update the version targeting the +next release. Artifacts from `main` should include the new (unreleased) version, +so it is important to update it. For example, after the `v0.5.x` branch is +created from main, you should update the version on `main` to `0.6.0`. + +To help update it, run this helper command (in a Debian-based system): + +```bash +./ci.sh bump_version 0.6.0 +``` + +This will update the version in the following files: + + * `lib/CMakeLists.txt` + * `lib/lib.gni`, automatically updated with `tools/build_cleaner.py --update`. + * `debian/changelog` to create the Debian package release with the new version. + Debian changelog shouldn't repeat the library changelog, instead it should + include changes to the packaging scripts. + +If there were incompatible API/ABI changes, make sure to also adapt the +corresponding section in +[CMakeLists.txt](https://github.com/libjxl/libjxl/blob/main/lib/CMakeLists.txt#L12). + +## Cherry-pick fixes to a release + +After a Pull Request that should be included in a release branch has been merged +to `main` it can be cherry-picked to the release branch. Before cherry-picking a +change to a release branch it is important to check that it doesn't introduce +more problems, in particular it should run for some time in `main` to make sure +post-submit tests and the fuzzers run on it. Waiting for a day is a good idea. + +Most of the testing is done on the `main` branch, so be careful with what +commits are cherry-picked to a branch. Refactoring code is often not a good +candidate to cherry-pick. + +To cherry-pick a single commit to a release branch (in this example to `v0.5.x`) +you can run: + +```bash +git fetch origin +git checkout origin/v0.5.x -b merge_to_release +git cherry-pick -x SHA_OF_MAIN_COMMIT +# -x will annotate the cherry-pick with the original SHA_OF_MAIN_COMMIT value. +# If not already mentioned in the original commit, add the original PR number to +# the commit, for example add "(cherry picked from PR #NNNN)". +git commit --amend +``` + +The `SHA_OF_MAIN_COMMIT` is the hash of the commit as it landed in main. Use +`git log origin/main` to list the recent main commits and their hashes. + +Making sure that the commit message on the cherry-picked commit contains a +reference to the original pull request (like `#NNNN`) is important. It creates +an automatic comment in the original pull request notifying that it was +mentioned in another commit, helping keep track of the merged pull requests. If +the original commit was merged with the "Squash and merge" policy it will +automatically contain the pull request number on the first line, if this is not +the case you can amend the commit message of the cherry-pick to include a +reference. + +Multiple commits can be cherry-picked and tested at once to save time. Continue +running `git cherry-pick` and `git commit --amend` multiple times for all the +commits you need to cherry-pick, ideally in the same order they were merged on +the `main` branch. At the end you will have a local branch with multiple commits +on top of the release branch. + +Finally, upload your changes to *your fork* like normal, except that when +creating a pull request select the desired release branch as a target: + +```bash +git push myfork merge_to_release +``` + +If you used the [guide](developing_in_github.md) `myfork` would be `origin` in +that example. Click on the URL displayed, which will be something like + + `https://github.com/mygithubusername/libjxl/pull/new/merge_to_release` + +In the "Open a pull request" page, change the drop-down base branch from +"base: main" (the default) to the release branch you are targeting. + +The pull request approval and pre-submit rules apply as with normal pull +requests to the `main` branch. + +**Important:** When merging multiple cherry-picks use "Rebase and merge" policy, +not the squash one since otherwise you would discard the individual commit +message references from the git history in the release branch. + +## Publishing a release + +Once a release tag is created it must not be modified, so you need to prepare +the changes before creating the release. Make sure you checked the following: + + * The semantic version number in the release branch (see `lib/CMakeLists.txt`) + matches the number you intend to release, all three MAJOR, MINOR and PATCH + should match. Otherwise send a pull request to the release branch to + update them. + + * The GitHub Actions checks pass on the release branch. Look for the green + tick next to the last commit on the release branch. This should be visible + on the branch page, for example: https://github.com/libjxl/libjxl/tree/v0.5.x + + * There no open fuzzer-found bugs for the release branch. The most effective + way is to [run the fuzzer](fuzzing.md) on the release branch for a while. You + can seed the fuzzer with corpus generated by oss-fuzz by [downloading + it](https://google.github.io/oss-fuzz/advanced-topics/corpora/#downloading-the-corpus), + for example `djxl_fuzzer` with libFuzzer will use: + gs://libjxl-corpus.clusterfuzz-external.appspot.com/libFuzzer/libjxl_djxl_fuzzer + + * Manually check that images encode/decode ok. + + * Manually check that downstream projects compile with our code. Sometimes + bugs on build scripts are only detected when other projects try to use our + library. For example, test compiling + [imagemagick](https://github.com/ImageMagick/ImageMagick) and Chrome. + +A [GitHub +"release"](https://docs.github.com/en/github/administering-a-repository/releasing-projects-on-github/about-releases) +consists of two different concepts: + + * a git "tag": this is a name (`v` plus the semantic version number) with a + commit hash associated, defined in the git repository. Most external projects + will use git tags or HTTP URLs to these tags to fetch the code. + + * a GitHub "release": this is a GitHub-only concept and is not represented in + git other than by having a git tag associated with the release. A GitHub + release has a given source code commit SHA associated (through the tag) but + it *also* contains release notes and optional binary files attached to the + release. + +Releases from the older GitLab repository only have a git tag in GitHub, while +newer releases have both a git tag and a release entry in GitHub. + +To publish a release open the [New Release +page](https://github.com/libjxl/libjxl/releases/new) and follow these +instructions: + + * Set the "Tag version" as "v" plus the semantic version number. + + * Select the "Target" as your release branch. For example for a "v0.7.1" + release tag you should use the "v0.7.x" branch. + + * Use the version number as the release title. + + * Copy-paste the relevant section of the [CHANGELOG.md](../CHANGELOG.md) to the + release notes into the release notes. Add any other information pertaining + the release itself that are not included in the CHANGELOG.md, although prefer + to include those in the CHANGELOG.md file. You can switch to the Preview tab + to see the results. + + * Finally click "Publish release" and go celebrate with the team. 🎉 diff --git a/media/libjxl/src/doc/software_support.md b/media/libjxl/src/doc/software_support.md new file mode 100644 index 000000000..9dddad7de --- /dev/null +++ b/media/libjxl/src/doc/software_support.md @@ -0,0 +1,65 @@ +# JPEG XL software support + +This document attempts to keep track of software that is using libjxl to support JPEG XL. +This list serves several purposes: + +- thank/acknowledge other projects for integrating jxl support +- point end-users to software that can read/write jxl +- keep track of the adoption status of jxl +- in case of a (security) bug in libjxl, it's easier to see who might be affected and check if they are updated (in case they use static linking) + +Please add missing software to this list. + +## Browsers + +- Chromium: behind a flag since version 91, [tracking bug](https://bugs.chromium.org/p/chromium/issues/detail?id=1178058) +- Firefox: behind a flag since version 90, [tracking bug](https://bugzilla.mozilla.org/show_bug.cgi?id=1539075) +- Safari: not supported, [tracking bug](https://bugs.webkit.org/show_bug.cgi?id=208235) +- Edge: behind a flag since version 91, start with `.\msedge.exe --enable-features=JXL` +- Opera: behind a flag since version 77. +- For all browsers and to track browsers progress see [Can I Use](https://caniuse.com/jpegxl). + +## Image libraries + +- [ImageMagick](https://imagemagick.org/): supported since 7.0.10-54 +- [libvips](https://libvips.github.io/libvips/): supported since 8.11 +- [Imlib2](https://github.com/alistair7/imlib2-jxl) +- [FFmpeg](https://github.com/FFmpeg/FFmpeg/search?q=jpeg-xl&type=commits) +- [GDAL](https://gdal.org/drivers/raster/jpegxl.html): supported since 3.4.0 as a TIFF codec, and 3.6.0 as standalone format + +## OS-level support / UI frameworks / file browser plugins + +- Qt / KDE: [plugin available](https://github.com/novomesk/qt-jpegxl-image-plugin) +- GDK-pixbuf: plugin available in libjxl repo +- [gThumb](https://ubuntuhandbook.org/index.php/2021/04/gthumb-3-11-3-adds-jpeg-xl-support/) +- [MacOS viewer/QuickLook plugin](https://github.com/yllan/JXLook) +- [Windows Imaging Component](https://github.com/mirillis/jpegxl-wic) +- [Windows thumbnail handler](https://github.com/saschanaz/jxl-winthumb) +- [OpenMandriva Lx (since 4.3 RC)](https://www.openmandriva.org/en/news/article/openmandriva-lx-4-3-rc-available-for-testing) +- [KaOS (since 2021.06)](https://news.itsfoss.com/kaos-2021-06-release/) +- [EFL (since 1.27, no external plugin needed)](https://www.enlightenment.org) + +## Image editors + +- [GIMP (since 2.99.8)](https://www.gimp.org/news/2021/10/20/gimp-2-99-8-released/); plugin for older versions available in libjxl repo +- [Krita](https://invent.kde.org/graphics/krita/-/commit/13e5d2e5b9f0eac5c8064b7767f0b62264a0797b) +- Photoshop: no plugin available yet, no official support yet + +## Image viewers + +- [XnView](https://www.xnview.com/en/) +- [ImageGlass](https://imageglass.org/) +- [IrfanView](https://www.irfanview.com/); supported since 4.59 - requires a [plugin](https://www.irfanview.com/plugins.htm) to be downloaded and enabled. +- [Tachiyomi](https://github.com/tachiyomiorg/tachiyomi/releases/tag/v0.12.1) +- Any viewer based on Qt, KDE, GDK-pixbuf, EFL, ImageMagick, libvips or imlib2 (see above) + - Qt viewers: gwenview, digiKam, KolourPaint, KPhotoAlbum, LXImage-Qt, qimgv, qView, nomacs, VookiImageViewer, PhotoQt + - GTK viewers: Eye of Gnome (eog), gThumb, Geeqie + - EFL viewers: entice, ephoto +- [Swayimg](https://github.com/artemsen/swayimg) + +## Online tools + +- [Squoosh](https://squoosh.app/) +- [Cloudinary](https://cloudinary.com/blog/cloudinary_supports_jpeg_xl) +- [MConverter](https://mconverter.eu/) +- [jpegxl.io](https://jpegxl.io/) diff --git a/media/libjxl/src/doc/sphinx/api.rst b/media/libjxl/src/doc/sphinx/api.rst new file mode 100644 index 000000000..56fca09e2 --- /dev/null +++ b/media/libjxl/src/doc/sphinx/api.rst @@ -0,0 +1,15 @@ +API reference +============= + +``libjxl`` exposes a C API for encoding and decoding JPEG XL files with some +C++ header-only helpers for C++ users. + +.. toctree:: + :caption: API REFERENCE + :maxdepth: 2 + + api_decoder + api_encoder + api_common + api_butteraugli + api_threads diff --git a/media/libjxl/src/doc/sphinx/api_butteraugli.rst b/media/libjxl/src/doc/sphinx/api_butteraugli.rst new file mode 100644 index 000000000..4aae44a99 --- /dev/null +++ b/media/libjxl/src/doc/sphinx/api_butteraugli.rst @@ -0,0 +1,6 @@ +Butteraugli API - ``jxl/butteraugli.h`` +======================================= + +.. doxygengroup:: libjxl_butteraugli + :members: + :private-members: diff --git a/media/libjxl/src/doc/sphinx/api_common.rst b/media/libjxl/src/doc/sphinx/api_common.rst new file mode 100644 index 000000000..7114b51cd --- /dev/null +++ b/media/libjxl/src/doc/sphinx/api_common.rst @@ -0,0 +1,6 @@ +Common API concepts +=================== + +.. doxygengroup:: libjxl_common + :members: + :private-members: diff --git a/media/libjxl/src/doc/sphinx/api_decoder.rst b/media/libjxl/src/doc/sphinx/api_decoder.rst new file mode 100644 index 000000000..3f8db228d --- /dev/null +++ b/media/libjxl/src/doc/sphinx/api_decoder.rst @@ -0,0 +1,6 @@ +Decoder API - ``jxl/decode.h`` +============================== + +.. doxygengroup:: libjxl_decoder + :members: + :private-members: diff --git a/media/libjxl/src/doc/sphinx/api_encoder.rst b/media/libjxl/src/doc/sphinx/api_encoder.rst new file mode 100644 index 000000000..0c76cc889 --- /dev/null +++ b/media/libjxl/src/doc/sphinx/api_encoder.rst @@ -0,0 +1,6 @@ +Encoder API - ``jxl/encode.h`` +============================== + +.. doxygengroup:: libjxl_encoder + :members: + :private-members: diff --git a/media/libjxl/src/doc/sphinx/api_threads.rst b/media/libjxl/src/doc/sphinx/api_threads.rst new file mode 100644 index 000000000..78dba657d --- /dev/null +++ b/media/libjxl/src/doc/sphinx/api_threads.rst @@ -0,0 +1,6 @@ +Multi-threaded Encoder/Decoder +============================== + +.. doxygengroup:: libjxl_threads + :members: + :private-members: diff --git a/media/libjxl/src/doc/sphinx/conf.py b/media/libjxl/src/doc/sphinx/conf.py new file mode 100644 index 000000000..1591aefc7 --- /dev/null +++ b/media/libjxl/src/doc/sphinx/conf.py @@ -0,0 +1,110 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Configuration file for the Sphinx documentation builder. +# +# See https://www.sphinx-doc.org/en/master/usage/configuration.html + +import os +import re +import subprocess + +def GetVersion(): + """Function to get the version of the current code.""" + with open(os.path.join( + os.path.dirname(__file__), '../../lib/CMakeLists.txt'), 'r') as f: + cmakevars = {} + for line in f: + m = re.match(r'set\(JPEGXL_([A-Z]+)_VERSION ([^\)]+)\)', line) + if m: + cmakevars[m.group(1)] = m.group(2) + return '%s.%s.%s' % (cmakevars['MAJOR'], cmakevars['MINOR'], cmakevars['PATCH']) + +def ConfigProject(app, config): + # Configure the doxygen xml directory as the "xml" directory next to the + # sphinx output directory. Doxygen generates by default the xml files in a + # "xml" sub-directory of the OUTPUT_DIRECTORY. + build_dir = os.path.dirname(app.outdir) + xml_dir = os.path.join(build_dir, 'xml') + config.breathe_projects['libjxl'] = xml_dir + + # Read the docs build environment doesn't run our cmake script so instead we + # need to run doxygen manually here. + if os.environ.get('READTHEDOCS', None) != 'True': + return + root_dir = os.path.realpath(os.path.join(app.srcdir, '../../')) + doxyfile = os.path.join(build_dir, 'Doxyfile-rtd.doc') + with open(doxyfile, 'w') as f: + f.write(f""" +FILE_PATTERNS = *.c *.h +GENERATE_HTML = NO +GENERATE_LATEX = NO +GENERATE_XML = YES +INPUT = lib/include doc/api.txt +OUTPUT_DIRECTORY = {build_dir} +PROJECT_NAME = LIBJXL +QUIET = YES +RECURSIVE = YES +STRIP_FROM_PATH = lib/include +WARN_AS_ERROR = YES +""") + subprocess.check_call(['doxygen', doxyfile], cwd=root_dir) + +def setup(app): + # Generate doxygen XML on init when running from Read the docs. + app.connect("config-inited", ConfigProject) + +### Project information + +project = 'libjxl' +project_copyright = 'JPEG XL Project Authors' +author = 'JPEG XL Project Authors' +version = GetVersion() + +### General configuration + +extensions = [ + # For integration with doxygen documentation. + 'breathe', + # sphinx readthedocs theme. + 'sphinx_rtd_theme', + # Do we use it? + 'sphinx.ext.graphviz', +] + +breathe_default_project = 'libjxl' +breathe_projects = {} + + +# All the API is in C, except those files that end with cxx.h. +breathe_domain_by_extension = {'h': 'cpp'} +breathe_domain_by_file_pattern = { + '*cxx.h': 'cpp', +} +breathe_implementation_filename_extensions = ['.cc'] + +# These are defined at build time by cmake. +c_id_attributes = [ + 'JXL_EXPORT', + 'JXL_DEPRECATED', + 'JXL_THREADS_EXPORT', +] +cpp_id_attributes = c_id_attributes + + +breathe_projects_source = { + 'libjxl' : ('../../', [ + 'doc/api.txt', + 'lib/include/jxl', + ]) +} + +# Recognized suffixes. +source_suffix = ['.rst', '.md'] + +### Options for HTML output + +# Use the readthedocs.io theme when generating the HTML output. +html_theme = 'sphinx_rtd_theme' diff --git a/media/libjxl/src/doc/sphinx/index.rst b/media/libjxl/src/doc/sphinx/index.rst new file mode 100644 index 000000000..9a57074b0 --- /dev/null +++ b/media/libjxl/src/doc/sphinx/index.rst @@ -0,0 +1,18 @@ +.. libjxl sphinx documentation entrypoint + +JPEG XL image format reference implementation +============================================= + +.. toctree:: + :maxdepth: 3 + :caption: Contents: + + api + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`search` + diff --git a/media/libjxl/src/doc/sphinx/requirements.txt b/media/libjxl/src/doc/sphinx/requirements.txt new file mode 100644 index 000000000..28179eafa --- /dev/null +++ b/media/libjxl/src/doc/sphinx/requirements.txt @@ -0,0 +1,3 @@ +breathe +sphinx +sphinx-rtd-theme diff --git a/media/libjxl/src/doc/tables/adobe.md b/media/libjxl/src/doc/tables/adobe.md new file mode 100644 index 000000000..f3beef75a --- /dev/null +++ b/media/libjxl/src/doc/tables/adobe.md @@ -0,0 +1,6 @@ +#### Table M.8 – "Adobe" marker template + +``` +0xEE, 0x00, 0x0E, 0x41, 0x64, 0x6F, 0x62, 0x65, 0x00, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01 +``` + diff --git a/media/libjxl/src/doc/tables/all_tables.pdf b/media/libjxl/src/doc/tables/all_tables.pdf new file mode 100644 index 000000000..09a85170c --- /dev/null +++ b/media/libjxl/src/doc/tables/all_tables.pdf @@ -0,0 +1,1078 @@ +Electronic Insert I.1 – DCT-II / DCT-III code generator + + ####################################################################### + # DCT-II / DCT-III generator + # + # Based on: + # "A low multiplicative complexity fast recursive DCT-2 algorithm" + # by Maxim Vashkevich and Alexander Petrovsky / arXiv / 20 Jul 2012 + ####################################################################### + + import math + import sys + N=8 + + ####################################################################### + # Base transforms / generators + ####################################################################### + + CNTR = 0 + def makeTmp(): + + global CNTR + result = "t{:02d}".format(CNTR) + CNTR = CNTR + 1 + return result + + def makeVar(i): + return "i{:02d}".format(i) + + def add(x, y): + tmp = makeTmp() + print(tmp + " = " + x + " + " + y + ";") + return tmp + + def sub(x, y): + tmp = makeTmp() + print(tmp + " = " + x + " - " + y + ";") + return tmp + + def mul(x, c): + tmp = makeTmp() + print(tmp + " = " + x + " * " + c + ";") + return tmp + + # 2.0 * math.cos((a + 0.0) / (b + 0.0) * math.pi) + def C2(a, b): + return "c_c2_" + str(a) + "_" + str(b) + +# 1.0 / C2(a, b) +def iC2(a, b): + + return "c_ic2_" + str(a) + "_" + str(b) + +####################################################################### +# Utilities +####################################################################### + +# Generate identity matrix. Usually this matrix is passed to +# DCT algorithm to generate "basis" vectors of the transform. +def makeVars(): + + return [makeVar(i) for i in range(N)] + +# Split list of variables info halves. +def split(x): + + m = len(x) + m2 = m // 2 + return (x[0 : m2], x[m2 : m]) + +# Make a list of variables in a reverse order. +def reverse(varz): + + m = len(varz) + result = [0] * m + for i in range(m): + + result[i] = varz[m - 1 - i] + return result + +# Apply permutation +def permute(x, p): + + return [x[p[i]] for i in range(len(p))] + +def transposePermutation(p): + n = len(p) + result = [0] * n + for i in range(n): + result[p[i]] = i + return result + +# See paper. Split even-odd elements. +def P(n): + + if n == 1: + return [0] + n2 = n // 2 + return [2 * i for i in range(n2)] + [2 * i + 1 for i in range(n2)] + +# See paper. Interleave first and second half. +def Pt(n): + + return transposePermutation(P(n)) + +####################################################################### +# Scheme +####################################################################### + +def B2(x): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + bottom = reverse(bottom) + t = [add(top[i], bottom[i]) for i in range(n2)] + b = [sub(top[i], bottom[i]) for i in range(n2)] + return t + b + +def iB2(x): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + t = [add(top[i], bottom[i]) for i in range(n2)] + b = [sub(top[i], bottom[i]) for i in range(n2)] + return t + reverse(b) + +def B4(x, rn): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + rbottom = reverse(bottom) + t = [sub(top[i], rbottom[i]) for i in range(n2)] + b = [mul(bottom[i], C2(rn, 2 * N)) for i in range(n2)] + top = [add(t[i], b[i]) for i in range(n2)] + bottom = [sub(t[i], b[i]) for i in range(n2)] + return top + bottom + def iB4(x, rn): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + t = [add(top[i], bottom[i]) for i in range(n2)] + b = [sub(top[i], bottom[i]) for i in range(n2)] + bottom = [mul(b[i], iC2(rn, 2 * N)) for i in range(n2)] + rbottom = reverse(bottom) + top = [add(t[i], rbottom[i]) for i in range(n2)] + return top + bottom + +def P4(n): + if n == 1: + return [0] + if n == 2: + return [0, 1] + n2 = n // 2 + result = [0] * n + tc = 0 + bc = 0 + i=0 + result[i] = tc; tc = tc + 1; i = i + 1 + turn = True + while i < n - 1: + if turn: + result[i] = n2 + bc; bc = bc + 1; i = i + 1 + result[i] = n2 + bc; bc = bc + 1; i = i + 1 + else: + result[i] = tc; tc = tc + 1; i = i + 1 + result[i] = tc; tc = tc + 1; i = i + 1 + turn = not turn + result[i] = tc; tc = tc + 1; i = i + 1 + return result + +def iP4(n): + return transposePermutation(P4(n)) + +def d2n(x): + n = len(x) + if n == 1: + return x + y = B2(x) + (top, bottom) = split(y) + return permute(d2n(top) + d4n(bottom, N // 2), Pt(n)) + +def id2n(x): + n = len(x) + if n == 1: + return x + (top, bottom) = split(permute(x, P(n))) + return iB2(id2n(top) + id4n(bottom, N // 2)) + +def d4n(x, rn): + n = len(x) + if n == 1: + return x + y = B4(x, rn) + (top, bottom) = split(y) + rn2 = rn // 2 + return permute(d4n(top, rn2) + d4n(bottom, N - rn2), P4(n)) + +def id4n(x, rn): + n = len(x) + if n == 1: + return x + (top, bottom) = split(permute(x, iP4(n))) + rn2 = rn // 2 + y = id4n(top, rn2) + id4n(bottom, N -rn2) + return iB4(y, rn) + +####################################################################### +# Main. +####################################################################### + +def help(): + print("Usage: %s [N [T]]" % sys.argv[0]) + print(" N should be the power of 2, default is 8") + print(" T is one of {2, 3}, default is 2") + sys.exit() + +def parseInt(s): + try: + return int(s) + except ValueError: + help() + +if __name__ == "__main__": + if len(sys.argv) < 1 or len(sys.argv) > 3: help() + if len(sys.argv) >= 2: + N = parseInt(sys.argv[1]) + if (N & (N - 1)) != 0: help() + + type = 0 + if len(sys.argv) >= 3: + + typeOption = sys.argv[2] + if len(typeOption) != 1: help() + type = "23".index(typeOption) + if type == -1: help() + if type == 0: + vars = d2n(makeVars()) + else: # type == 1 + vars = id2n(makeVars()) + print("Output vector: " + str(vars)) + +Table M.1 – is_zero_base table + + 228, 216, 216, 195, 192, 189, 182, 184, 179, 176, 171, 168, 166, 159, + 156, 151, 151, 150, 150, 146, 144, 138, 138, 137, 135, 131, 127, 126, + 124, 123, 124, 123, 122, 121, 118, 117, 114, 115, 116, 116, 115, 115, + 114, 111, 111, 111, 112, 111, 110, 110, 110, 111, 111, 114, 110, 111, + 112, 113, 116, 120, 126, 131, 147, 160 + +Table M.2 – num_nonzeros_base table + + 251, 252, 117, 249, 161, 136, 83, 238, 184, 126, 137, 129, 140, 119, + 70, 213, 160, 175, 174, 130, 166, 134, 122, 125, 131, 144, 136, 133, + + 139, 123, 79, 216, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128 + + 254, 252, 174, 232, 189, 155, 122, 177, 204, 173, 146, 149, 141, 133, + 103, 109, 167, 187, 168, 142, 154, 147, 125, 139, 144, 138, 138, 153, + 141, 133, 90, 121, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128 + + 251, 240, 197, 176, 184, 177, 114, 89, 194, 165, 153, 161, 158, 136, + 92, 95, 123, 171, 160, 140, 148, 136, 129, 139, 145, 136, 143, 134, + + 138, 124, 92, 154, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128 + 247, 220, 201, 110, 194, 176, 147, 59, 175, 171, 156, 157, 152, 146, +115, 114, 88, 151, 164, 141, 153, 135, 141, 131, 146, 139, 140, 145, +138, 137, 112, 184, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + +238, 179, 203, 63, 194, 173, 149, 71, 139, 169, 154, 159, 150, 146, +117, 143, 78, 122, 152, 137, 149, 138, 138, 133, 134, 142, 142, 142, +148, 128, 118, 199, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + +227, 127, 200, 44, 192, 170, 148, 100, 102, 161, 156, 153, 148, 149, +124, 160, 88, 101, 134, 132, 149, 145, 134, 134, 136, 141, 138, 142, +144, 137, 116, 208, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + +214, 86, 195, 44, 187, 163, 148, 126, 81, 147, 156, 152, 150, 144, +121, 172, 96, 95, 117, 122, 145, 152, 136, 133, 135, 135, 131, 142, +141, 135, 114, 217, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + +198, 56, 191, 54, 171, 162, 147, 144, 74, 128, 152, 149, 150, 142, +119, 177, 101, 100, 106, 111, 135, 154, 136, 137, 136, 132, 133, 142, +144, 130, 117, 222, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + +176, 40, 189, 73, 147, 159, 148, 152, 79, 106, 147, 149, 151, 139, +123, 188, 108, 110, 106, 97, 125, 151, 137, 138, 135, 135, 134, 136, +140, 131, 116, 221, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + +148, 33, 185, 88, 117, 158, 145, 163, 95, 91, 137, 146, 150, 140, +120, 197, 115, 116, 114, 92, 114, 144, 130, 133, 132, 133, 129, 140, +138, 130, 111, 224, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + 117, 31, 180, 104, 93, 150, 143, 166, 99, 85, 124, 139, 148, 142, +118, 201, 105, 120, 120, 90, 107, 135, 127, 130, 131, 131, 132, 140, +142, 133, 114, 229, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 87, 35, 170, 110, 78, 141, 144, 176, 106, 90, 112, 132, 143, 138, +119, 204, 111, 121, 125, 90, 105, 131, 124, 122, 129, 128, 129, 137, +138, 133, 114, 227, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 63, 42, 159, 123, 73, 127, 142, 191, 105, 91, 105, 123, 139, 137, +120, 209, 117, 110, 122, 98, 110, 125, 115, 123, 122, 126, 128, 134, +141, 129, 113, 229, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 45, 53, 146, 135, 71, 114, 138, 193, 100, 98, 98, 113, 133, 135, +118, 222, 113, 111, 139, 103, 107, 126, 111, 119, 121, 122, 127, 135, +141, 128, 114, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 33, 60, 132, 138, 75, 100, 134, 203, 112, 99, 98, 105, 126, 131, +115, 229, 107, 93, 121, 106, 108, 122, 106, 109, 114, 116, 127, 133, +143, 128, 110, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 24, 70, 118, 134, 76, 87, 130, 201, 110, 96, 99, 97, 119, 130, +111, 229, 97, 104, 125, 102, 112, 125, 101, 109, 113, 114, 125, 129, +142, 127, 112, 241, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 17, 65, 100, 121, 80, 75, 124, 174, 117, 100, 94, 93, 114, 128, +110, 216, 103, 94, 113, 122, 118, 126, 113, 108, 105, 108, 122, 128, +141, 125, 113, 238, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + 12, 70, 82, 132, 78, 65, 118, 155, 136, 103, 97, 89, 106, 124, +111, 215, 115, 123, 129, 99, 104, 127, 110, 108, 101, 109, 118, 126, +136, 123, 110, 233, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 8, 66, 61, 117, 91, 59, 108, 195, 101, 112, 99, 99, 99, 116, +106, 230, 127, 99, 144, 101, 118, 137, 117, 111, 106, 104, 116, 121, +134, 122, 110, 223, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 6, 78, 42, 146, 101, 54, 94, 201, 116, 102, 110, 94, 92, 108, +103, 214, 108, 111, 127, 102, 121, 132, 120, 121, 95, 98, 110, 121, +129, 117, 107, 235, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 5, 93, 29, 145, 102, 52, 77, 216, 108, 115, 108, 102, 89, 97, + 94, 229, 89, 103, 139, 120, 103, 151, 102, 100, 97, 96, 99, 111, +125, 116, 104, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 4, 105, 21, 145, 100, 54, 64, 217, 100, 122, 128, 87, 88, 91, + 87, 230, 112, 80, 148, 95, 146, 123, 96, 140, 90, 91, 98, 106, +122, 111, 100, 249, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 4, 130, 14, 142, 104, 56, 51, 208, 116, 135, 100, 89, 82, 84, + 75, 239, 85, 85, 122, 125, 94, 144, 151, 136, 92, 97, 104, 109, +113, 110, 91, 246, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 3, 126, 9, 172, 105, 57, 39, 219, 95, 120, 118, 96, 93, 75, + 66, 241, 102, 134, 96, 156, 146, 162, 130, 112, 82, 89, 97, 101, +116, 103, 82, 254, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + 3, 149, 7, 182, 122, 54, 29, 224, 103, 100, 113, 96, 90, 74, + 55, 250, 127, 94, 118, 93, 135, 160, 113, 130, 95, 117, 106, 96, +111, 97, 77, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 3, 150, 4, 170, 138, 59, 20, 229, 91, 150, 107, 98, 92, 68, + 48, 245, 113, 64, 114, 111, 134, 127, 102, 104, 85, 118, 103, 107, +102, 91, 72, 245, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 3, 171, 3, 165, 137, 62, 14, 211, 96, 127, 132, 121, 95, 62, + 37, 248, 102, 57, 144, 85, 127, 191, 102, 97, 127, 104, 91, 102, +107, 81, 64, 254, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 2, 166, 2, 196, 122, 65, 10, 243, 102, 93, 117, 92, 96, 63, + 29, 251, 169, 159, 149, 96, 91, 139, 157, 40, 100, 89, 120, 92, +109, 79, 58, 247, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 2, 176, 2, 189, 118, 48, 7, 219, 68, 43, 109, 96, 129, 75, + 19, 254, 2, 3, 185, 6, 102, 127, 127, 127, 1, 131, 83, 99, +107, 80, 45, 254, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 1, 205, 2, 208, 64, 89, 4, 223, 29, 169, 29, 123, 118, 76, + 11, 240, 202, 243, 65, 6, 12, 243, 96, 55, 102, 102, 114, 102, +107, 74, 31, 247, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + + 1, 216, 1, 214, 127, 94, 2, 234, 145, 3, 127, 106, 155, 80, + 4, 247, 4, 65, 86, 127, 127, 127, 127, 102, 127, 143, 143, 108, +113, 80, 16, 216, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 + 2, 199, 1, 222, 93, 94, 1, 232, 2, 65, 74, 139, 201, 48, + 2, 254, 169, 127, 52, 243, 251, 249, 102, 86, 202, 153, 65, 65, + 146, 69, 8, 238, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128 + +Table M.3 – Protocol Buffer descriptor of top-level structure of losslessly compressed JPEG stream + + message Header { + optional uint64 width = 1; + optional uint64 height = 2; + required uint64 version_and_component_count_code = 3; + optional uint64 subsampling_code = 4; + + } + + message Jpeg { + required bytes signature = 1; + required Header header = 2; + optional bytes meta_data = 3; + optional bytes jpeg1_internals = 4; + optional bytes quant_data = 5; + optional bytes histogram_data = 6; + optional bytes dc_data = 7; + optional bytes ac_data = 8; + optional bytes original_jpg = 9; + + } + +Table M.4 – APP0 template + + 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, + 0x01, 0x00, 0x01, 0x00, 0x00 + +Table M.6 – common ICC profile template + + 0xE2, 0x0C, 0x58, 0x49, 0x43, 0x43, 0x5F, 0x50, 0x52, 0x4F, 0x46, 0x49, + 0x4C, 0x45, 0x00, 0x01, 0x01, 0x00, 0x00, 0x0C, 0x48, 0x4C, 0x69, 0x6E, + 0x6F, 0x02, 0x10, 0x00, 0x00, 0x6D, 0x6E, 0x74, 0x72, 0x52, 0x47, 0x42, + 0x20, 0x58, 0x59, 0x5A, 0x20, 0x07, 0xCE, 0x00, 0x02, 0x00, 0x09, 0x00, + 0x06, 0x00, 0x31, 0x00, 0x00, 0x61, 0x63, 0x73, 0x70, 0x4D, 0x53, 0x46, + 0x54, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x43, 0x20, 0x73, 0x52, 0x47, + 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0xF6, 0xD6, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xD3, + 0x2D, 0x48, 0x50, 0x20, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x11, 0x63, 0x70, 0x72, 0x74, 0x00, 0x00, 0x01, +0x50, 0x00, 0x00, 0x00, 0x33, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x01, +0x84, 0x00, 0x00, 0x00, 0x6C, 0x77, 0x74, 0x70, 0x74, 0x00, 0x00, 0x01, +0xF0, 0x00, 0x00, 0x00, 0x14, 0x62, 0x6B, 0x70, 0x74, 0x00, 0x00, 0x02, +0x04, 0x00, 0x00, 0x00, 0x14, 0x72, 0x58, 0x59, 0x5A, 0x00, 0x00, 0x02, +0x18, 0x00, 0x00, 0x00, 0x14, 0x67, 0x58, 0x59, 0x5A, 0x00, 0x00, 0x02, +0x2C, 0x00, 0x00, 0x00, 0x14, 0x62, 0x58, 0x59, 0x5A, 0x00, 0x00, 0x02, +0x40, 0x00, 0x00, 0x00, 0x14, 0x64, 0x6D, 0x6E, 0x64, 0x00, 0x00, 0x02, +0x54, 0x00, 0x00, 0x00, 0x70, 0x64, 0x6D, 0x64, 0x64, 0x00, 0x00, 0x02, +0xC4, 0x00, 0x00, 0x00, 0x88, 0x76, 0x75, 0x65, 0x64, 0x00, 0x00, 0x03, +0x4C, 0x00, 0x00, 0x00, 0x86, 0x76, 0x69, 0x65, 0x77, 0x00, 0x00, 0x03, +0xD4, 0x00, 0x00, 0x00, 0x24, 0x6C, 0x75, 0x6D, 0x69, 0x00, 0x00, 0x03, +0xF8, 0x00, 0x00, 0x00, 0x14, 0x6D, 0x65, 0x61, 0x73, 0x00, 0x00, 0x04, +0x0C, 0x00, 0x00, 0x00, 0x24, 0x74, 0x65, 0x63, 0x68, 0x00, 0x00, 0x04, +0x30, 0x00, 0x00, 0x00, 0x0C, 0x72, 0x54, 0x52, 0x43, 0x00, 0x00, 0x04, +0x3C, 0x00, 0x00, 0x08, 0x0C, 0x67, 0x54, 0x52, 0x43, 0x00, 0x00, 0x04, +0x3C, 0x00, 0x00, 0x08, 0x0C, 0x62, 0x54, 0x52, 0x43, 0x00, 0x00, 0x04, +0x3C, 0x00, 0x00, 0x08, 0x0C, 0x74, 0x65, 0x78, 0x74, 0x00, 0x00, 0x00, +0x00, 0x43, 0x6F, 0x70, 0x79, 0x72, 0x69, 0x67, 0x68, 0x74, 0x20, 0x28, +0x63, 0x29, 0x20, 0x31, 0x39, 0x39, 0x38, 0x20, 0x48, 0x65, 0x77, 0x6C, +0x65, 0x74, 0x74, 0x2D, 0x50, 0x61, 0x63, 0x6B, 0x61, 0x72, 0x64, 0x20, +0x43, 0x6F, 0x6D, 0x70, 0x61, 0x6E, 0x79, 0x00, 0x00, 0x64, 0x65, 0x73, +0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x73, 0x52, 0x47, +0x42, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, +0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x12, 0x73, 0x52, 0x47, 0x42, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, +0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x58, 0x59, 0x5A, +0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF3, 0x51, 0x00, 0x01, 0x00, +0x00, 0x00, 0x01, 0x16, 0xCC, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6F, +0xA2, 0x00, 0x00, 0x38, 0xF5, 0x00, 0x00, 0x03, 0x90, 0x58, 0x59, 0x5A, +0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0x99, 0x00, 0x00, 0xB7, +0x85, 0x00, 0x00, 0x18, 0xDA, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x24, 0xA0, 0x00, 0x00, 0x0F, 0x84, 0x00, 0x00, 0xB6, +0xCF, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x16, 0x49, 0x45, 0x43, 0x20, 0x68, 0x74, 0x74, 0x70, 0x3A, 0x2F, 0x2F, + 0x77, 0x77, 0x77, 0x2E, 0x69, 0x65, 0x63, 0x2E, 0x63, 0x68, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x49, 0x45, +0x43, 0x20, 0x68, 0x74, 0x74, 0x70, 0x3A, 0x2F, 0x2F, 0x77, 0x77, 0x77, +0x2E, 0x69, 0x65, 0x63, 0x2E, 0x63, 0x68, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x2E, 0x49, 0x45, 0x43, 0x20, 0x36, 0x31, 0x39, +0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x20, 0x44, 0x65, 0x66, 0x61, 0x75, +0x6C, 0x74, 0x20, 0x52, 0x47, 0x42, 0x20, 0x63, 0x6F, 0x6C, 0x6F, 0x75, +0x72, 0x20, 0x73, 0x70, 0x61, 0x63, 0x65, 0x20, 0x2D, 0x20, 0x73, 0x52, +0x47, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x2E, 0x49, 0x45, 0x43, 0x20, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, +0x32, 0x2E, 0x31, 0x20, 0x44, 0x65, 0x66, 0x61, 0x75, 0x6C, 0x74, 0x20, +0x52, 0x47, 0x42, 0x20, 0x63, 0x6F, 0x6C, 0x6F, 0x75, 0x72, 0x20, 0x73, +0x70, 0x61, 0x63, 0x65, 0x20, 0x2D, 0x20, 0x73, 0x52, 0x47, 0x42, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x65, 0x73, +0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2C, 0x52, 0x65, 0x66, +0x65, 0x72, 0x65, 0x6E, 0x63, 0x65, 0x20, 0x56, 0x69, 0x65, 0x77, 0x69, +0x6E, 0x67, 0x20, 0x43, 0x6F, 0x6E, 0x64, 0x69, 0x74, 0x69, 0x6F, 0x6E, +0x20, 0x69, 0x6E, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, +0x2D, 0x32, 0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x2C, 0x52, 0x65, 0x66, 0x65, 0x72, 0x65, 0x6E, 0x63, +0x65, 0x20, 0x56, 0x69, 0x65, 0x77, 0x69, 0x6E, 0x67, 0x20, 0x43, 0x6F, +0x6E, 0x64, 0x69, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x69, 0x6E, 0x20, 0x49, +0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x76, 0x69, 0x65, 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0xA4, +0xFE, 0x00, 0x14, 0x5F, 0x2E, 0x00, 0x10, 0xCF, 0x14, 0x00, 0x03, 0xED, +0xCC, 0x00, 0x04, 0x13, 0x0B, 0x00, 0x03, 0x5C, 0x9E, 0x00, 0x00, 0x00, +0x01, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4C, 0x09, +0x56, 0x00, 0x50, 0x00, 0x00, 0x00, 0x57, 0x1F, 0xE7, 0x6D, 0x65, 0x61, +0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +0x00, 0x00, 0x00, 0x02, 0x8F, 0x00, 0x00, 0x00, 0x02, 0x73, 0x69, 0x67, +0x20, 0x00, 0x00, 0x00, 0x00, 0x43, 0x52, 0x54, 0x20, 0x63, 0x75, 0x72, +0x76, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, +0x05, 0x00, 0x0A, 0x00, 0x0F, 0x00, 0x14, 0x00, 0x19, 0x00, 0x1E, 0x00, +0x23, 0x00, 0x28, 0x00, 0x2D, 0x00, 0x32, 0x00, 0x37, 0x00, 0x3B, 0x00, +0x40, 0x00, 0x45, 0x00, 0x4A, 0x00, 0x4F, 0x00, 0x54, 0x00, 0x59, 0x00, +0x5E, 0x00, 0x63, 0x00, 0x68, 0x00, 0x6D, 0x00, 0x72, 0x00, 0x77, 0x00, +0x7C, 0x00, 0x81, 0x00, 0x86, 0x00, 0x8B, 0x00, 0x90, 0x00, 0x95, 0x00, + 0x9A, 0x00, 0x9F, 0x00, 0xA4, 0x00, 0xA9, 0x00, 0xAE, 0x00, 0xB2, 0x00, +0xB7, 0x00, 0xBC, 0x00, 0xC1, 0x00, 0xC6, 0x00, 0xCB, 0x00, 0xD0, 0x00, +0xD5, 0x00, 0xDB, 0x00, 0xE0, 0x00, 0xE5, 0x00, 0xEB, 0x00, 0xF0, 0x00, +0xF6, 0x00, 0xFB, 0x01, 0x01, 0x01, 0x07, 0x01, 0x0D, 0x01, 0x13, 0x01, +0x19, 0x01, 0x1F, 0x01, 0x25, 0x01, 0x2B, 0x01, 0x32, 0x01, 0x38, 0x01, +0x3E, 0x01, 0x45, 0x01, 0x4C, 0x01, 0x52, 0x01, 0x59, 0x01, 0x60, 0x01, +0x67, 0x01, 0x6E, 0x01, 0x75, 0x01, 0x7C, 0x01, 0x83, 0x01, 0x8B, 0x01, +0x92, 0x01, 0x9A, 0x01, 0xA1, 0x01, 0xA9, 0x01, 0xB1, 0x01, 0xB9, 0x01, +0xC1, 0x01, 0xC9, 0x01, 0xD1, 0x01, 0xD9, 0x01, 0xE1, 0x01, 0xE9, 0x01, +0xF2, 0x01, 0xFA, 0x02, 0x03, 0x02, 0x0C, 0x02, 0x14, 0x02, 0x1D, 0x02, +0x26, 0x02, 0x2F, 0x02, 0x38, 0x02, 0x41, 0x02, 0x4B, 0x02, 0x54, 0x02, +0x5D, 0x02, 0x67, 0x02, 0x71, 0x02, 0x7A, 0x02, 0x84, 0x02, 0x8E, 0x02, +0x98, 0x02, 0xA2, 0x02, 0xAC, 0x02, 0xB6, 0x02, 0xC1, 0x02, 0xCB, 0x02, +0xD5, 0x02, 0xE0, 0x02, 0xEB, 0x02, 0xF5, 0x03, 0x00, 0x03, 0x0B, 0x03, +0x16, 0x03, 0x21, 0x03, 0x2D, 0x03, 0x38, 0x03, 0x43, 0x03, 0x4F, 0x03, +0x5A, 0x03, 0x66, 0x03, 0x72, 0x03, 0x7E, 0x03, 0x8A, 0x03, 0x96, 0x03, +0xA2, 0x03, 0xAE, 0x03, 0xBA, 0x03, 0xC7, 0x03, 0xD3, 0x03, 0xE0, 0x03, +0xEC, 0x03, 0xF9, 0x04, 0x06, 0x04, 0x13, 0x04, 0x20, 0x04, 0x2D, 0x04, +0x3B, 0x04, 0x48, 0x04, 0x55, 0x04, 0x63, 0x04, 0x71, 0x04, 0x7E, 0x04, +0x8C, 0x04, 0x9A, 0x04, 0xA8, 0x04, 0xB6, 0x04, 0xC4, 0x04, 0xD3, 0x04, +0xE1, 0x04, 0xF0, 0x04, 0xFE, 0x05, 0x0D, 0x05, 0x1C, 0x05, 0x2B, 0x05, +0x3A, 0x05, 0x49, 0x05, 0x58, 0x05, 0x67, 0x05, 0x77, 0x05, 0x86, 0x05, +0x96, 0x05, 0xA6, 0x05, 0xB5, 0x05, 0xC5, 0x05, 0xD5, 0x05, 0xE5, 0x05, +0xF6, 0x06, 0x06, 0x06, 0x16, 0x06, 0x27, 0x06, 0x37, 0x06, 0x48, 0x06, +0x59, 0x06, 0x6A, 0x06, 0x7B, 0x06, 0x8C, 0x06, 0x9D, 0x06, 0xAF, 0x06, +0xC0, 0x06, 0xD1, 0x06, 0xE3, 0x06, 0xF5, 0x07, 0x07, 0x07, 0x19, 0x07, +0x2B, 0x07, 0x3D, 0x07, 0x4F, 0x07, 0x61, 0x07, 0x74, 0x07, 0x86, 0x07, +0x99, 0x07, 0xAC, 0x07, 0xBF, 0x07, 0xD2, 0x07, 0xE5, 0x07, 0xF8, 0x08, +0x0B, 0x08, 0x1F, 0x08, 0x32, 0x08, 0x46, 0x08, 0x5A, 0x08, 0x6E, 0x08, +0x82, 0x08, 0x96, 0x08, 0xAA, 0x08, 0xBE, 0x08, 0xD2, 0x08, 0xE7, 0x08, +0xFB, 0x09, 0x10, 0x09, 0x25, 0x09, 0x3A, 0x09, 0x4F, 0x09, 0x64, 0x09, +0x79, 0x09, 0x8F, 0x09, 0xA4, 0x09, 0xBA, 0x09, 0xCF, 0x09, 0xE5, 0x09, +0xFB, 0x0A, 0x11, 0x0A, 0x27, 0x0A, 0x3D, 0x0A, 0x54, 0x0A, 0x6A, 0x0A, +0x81, 0x0A, 0x98, 0x0A, 0xAE, 0x0A, 0xC5, 0x0A, 0xDC, 0x0A, 0xF3, 0x0B, +0x0B, 0x0B, 0x22, 0x0B, 0x39, 0x0B, 0x51, 0x0B, 0x69, 0x0B, 0x80, 0x0B, +0x98, 0x0B, 0xB0, 0x0B, 0xC8, 0x0B, 0xE1, 0x0B, 0xF9, 0x0C, 0x12, 0x0C, +0x2A, 0x0C, 0x43, 0x0C, 0x5C, 0x0C, 0x75, 0x0C, 0x8E, 0x0C, 0xA7, 0x0C, +0xC0, 0x0C, 0xD9, 0x0C, 0xF3, 0x0D, 0x0D, 0x0D, 0x26, 0x0D, 0x40, 0x0D, +0x5A, 0x0D, 0x74, 0x0D, 0x8E, 0x0D, 0xA9, 0x0D, 0xC3, 0x0D, 0xDE, 0x0D, +0xF8, 0x0E, 0x13, 0x0E, 0x2E, 0x0E, 0x49, 0x0E, 0x64, 0x0E, 0x7F, 0x0E, +0x9B, 0x0E, 0xB6, 0x0E, 0xD2, 0x0E, 0xEE, 0x0F, 0x09, 0x0F, 0x25, 0x0F, +0x41, 0x0F, 0x5E, 0x0F, 0x7A, 0x0F, 0x96, 0x0F, 0xB3, 0x0F, 0xCF, 0x0F, +0xEC, 0x10, 0x09, 0x10, 0x26, 0x10, 0x43, 0x10, 0x61, 0x10, 0x7E, 0x10, +0x9B, 0x10, 0xB9, 0x10, 0xD7, 0x10, 0xF5, 0x11, 0x13, 0x11, 0x31, 0x11, +0x4F, 0x11, 0x6D, 0x11, 0x8C, 0x11, 0xAA, 0x11, 0xC9, 0x11, 0xE8, 0x12, + 0x07, 0x12, 0x26, 0x12, 0x45, 0x12, 0x64, 0x12, 0x84, 0x12, 0xA3, 0x12, +0xC3, 0x12, 0xE3, 0x13, 0x03, 0x13, 0x23, 0x13, 0x43, 0x13, 0x63, 0x13, +0x83, 0x13, 0xA4, 0x13, 0xC5, 0x13, 0xE5, 0x14, 0x06, 0x14, 0x27, 0x14, +0x49, 0x14, 0x6A, 0x14, 0x8B, 0x14, 0xAD, 0x14, 0xCE, 0x14, 0xF0, 0x15, +0x12, 0x15, 0x34, 0x15, 0x56, 0x15, 0x78, 0x15, 0x9B, 0x15, 0xBD, 0x15, +0xE0, 0x16, 0x03, 0x16, 0x26, 0x16, 0x49, 0x16, 0x6C, 0x16, 0x8F, 0x16, +0xB2, 0x16, 0xD6, 0x16, 0xFA, 0x17, 0x1D, 0x17, 0x41, 0x17, 0x65, 0x17, +0x89, 0x17, 0xAE, 0x17, 0xD2, 0x17, 0xF7, 0x18, 0x1B, 0x18, 0x40, 0x18, +0x65, 0x18, 0x8A, 0x18, 0xAF, 0x18, 0xD5, 0x18, 0xFA, 0x19, 0x20, 0x19, +0x45, 0x19, 0x6B, 0x19, 0x91, 0x19, 0xB7, 0x19, 0xDD, 0x1A, 0x04, 0x1A, +0x2A, 0x1A, 0x51, 0x1A, 0x77, 0x1A, 0x9E, 0x1A, 0xC5, 0x1A, 0xEC, 0x1B, +0x14, 0x1B, 0x3B, 0x1B, 0x63, 0x1B, 0x8A, 0x1B, 0xB2, 0x1B, 0xDA, 0x1C, +0x02, 0x1C, 0x2A, 0x1C, 0x52, 0x1C, 0x7B, 0x1C, 0xA3, 0x1C, 0xCC, 0x1C, +0xF5, 0x1D, 0x1E, 0x1D, 0x47, 0x1D, 0x70, 0x1D, 0x99, 0x1D, 0xC3, 0x1D, +0xEC, 0x1E, 0x16, 0x1E, 0x40, 0x1E, 0x6A, 0x1E, 0x94, 0x1E, 0xBE, 0x1E, +0xE9, 0x1F, 0x13, 0x1F, 0x3E, 0x1F, 0x69, 0x1F, 0x94, 0x1F, 0xBF, 0x1F, +0xEA, 0x20, 0x15, 0x20, 0x41, 0x20, 0x6C, 0x20, 0x98, 0x20, 0xC4, 0x20, +0xF0, 0x21, 0x1C, 0x21, 0x48, 0x21, 0x75, 0x21, 0xA1, 0x21, 0xCE, 0x21, +0xFB, 0x22, 0x27, 0x22, 0x55, 0x22, 0x82, 0x22, 0xAF, 0x22, 0xDD, 0x23, +0x0A, 0x23, 0x38, 0x23, 0x66, 0x23, 0x94, 0x23, 0xC2, 0x23, 0xF0, 0x24, +0x1F, 0x24, 0x4D, 0x24, 0x7C, 0x24, 0xAB, 0x24, 0xDA, 0x25, 0x09, 0x25, +0x38, 0x25, 0x68, 0x25, 0x97, 0x25, 0xC7, 0x25, 0xF7, 0x26, 0x27, 0x26, +0x57, 0x26, 0x87, 0x26, 0xB7, 0x26, 0xE8, 0x27, 0x18, 0x27, 0x49, 0x27, +0x7A, 0x27, 0xAB, 0x27, 0xDC, 0x28, 0x0D, 0x28, 0x3F, 0x28, 0x71, 0x28, +0xA2, 0x28, 0xD4, 0x29, 0x06, 0x29, 0x38, 0x29, 0x6B, 0x29, 0x9D, 0x29, +0xD0, 0x2A, 0x02, 0x2A, 0x35, 0x2A, 0x68, 0x2A, 0x9B, 0x2A, 0xCF, 0x2B, +0x02, 0x2B, 0x36, 0x2B, 0x69, 0x2B, 0x9D, 0x2B, 0xD1, 0x2C, 0x05, 0x2C, +0x39, 0x2C, 0x6E, 0x2C, 0xA2, 0x2C, 0xD7, 0x2D, 0x0C, 0x2D, 0x41, 0x2D, +0x76, 0x2D, 0xAB, 0x2D, 0xE1, 0x2E, 0x16, 0x2E, 0x4C, 0x2E, 0x82, 0x2E, +0xB7, 0x2E, 0xEE, 0x2F, 0x24, 0x2F, 0x5A, 0x2F, 0x91, 0x2F, 0xC7, 0x2F, +0xFE, 0x30, 0x35, 0x30, 0x6C, 0x30, 0xA4, 0x30, 0xDB, 0x31, 0x12, 0x31, +0x4A, 0x31, 0x82, 0x31, 0xBA, 0x31, 0xF2, 0x32, 0x2A, 0x32, 0x63, 0x32, +0x9B, 0x32, 0xD4, 0x33, 0x0D, 0x33, 0x46, 0x33, 0x7F, 0x33, 0xB8, 0x33, +0xF1, 0x34, 0x2B, 0x34, 0x65, 0x34, 0x9E, 0x34, 0xD8, 0x35, 0x13, 0x35, +0x4D, 0x35, 0x87, 0x35, 0xC2, 0x35, 0xFD, 0x36, 0x37, 0x36, 0x72, 0x36, +0xAE, 0x36, 0xE9, 0x37, 0x24, 0x37, 0x60, 0x37, 0x9C, 0x37, 0xD7, 0x38, +0x14, 0x38, 0x50, 0x38, 0x8C, 0x38, 0xC8, 0x39, 0x05, 0x39, 0x42, 0x39, +0x7F, 0x39, 0xBC, 0x39, 0xF9, 0x3A, 0x36, 0x3A, 0x74, 0x3A, 0xB2, 0x3A, +0xEF, 0x3B, 0x2D, 0x3B, 0x6B, 0x3B, 0xAA, 0x3B, 0xE8, 0x3C, 0x27, 0x3C, +0x65, 0x3C, 0xA4, 0x3C, 0xE3, 0x3D, 0x22, 0x3D, 0x61, 0x3D, 0xA1, 0x3D, +0xE0, 0x3E, 0x20, 0x3E, 0x60, 0x3E, 0xA0, 0x3E, 0xE0, 0x3F, 0x21, 0x3F, +0x61, 0x3F, 0xA2, 0x3F, 0xE2, 0x40, 0x23, 0x40, 0x64, 0x40, 0xA6, 0x40, +0xE7, 0x41, 0x29, 0x41, 0x6A, 0x41, 0xAC, 0x41, 0xEE, 0x42, 0x30, 0x42, +0x72, 0x42, 0xB5, 0x42, 0xF7, 0x43, 0x3A, 0x43, 0x7D, 0x43, 0xC0, 0x44, +0x03, 0x44, 0x47, 0x44, 0x8A, 0x44, 0xCE, 0x45, 0x12, 0x45, 0x55, 0x45, + 0x9A, 0x45, 0xDE, 0x46, 0x22, 0x46, 0x67, 0x46, 0xAB, 0x46, 0xF0, 0x47, +0x35, 0x47, 0x7B, 0x47, 0xC0, 0x48, 0x05, 0x48, 0x4B, 0x48, 0x91, 0x48, +0xD7, 0x49, 0x1D, 0x49, 0x63, 0x49, 0xA9, 0x49, 0xF0, 0x4A, 0x37, 0x4A, +0x7D, 0x4A, 0xC4, 0x4B, 0x0C, 0x4B, 0x53, 0x4B, 0x9A, 0x4B, 0xE2, 0x4C, +0x2A, 0x4C, 0x72, 0x4C, 0xBA, 0x4D, 0x02, 0x4D, 0x4A, 0x4D, 0x93, 0x4D, +0xDC, 0x4E, 0x25, 0x4E, 0x6E, 0x4E, 0xB7, 0x4F, 0x00, 0x4F, 0x49, 0x4F, +0x93, 0x4F, 0xDD, 0x50, 0x27, 0x50, 0x71, 0x50, 0xBB, 0x51, 0x06, 0x51, +0x50, 0x51, 0x9B, 0x51, 0xE6, 0x52, 0x31, 0x52, 0x7C, 0x52, 0xC7, 0x53, +0x13, 0x53, 0x5F, 0x53, 0xAA, 0x53, 0xF6, 0x54, 0x42, 0x54, 0x8F, 0x54, +0xDB, 0x55, 0x28, 0x55, 0x75, 0x55, 0xC2, 0x56, 0x0F, 0x56, 0x5C, 0x56, +0xA9, 0x56, 0xF7, 0x57, 0x44, 0x57, 0x92, 0x57, 0xE0, 0x58, 0x2F, 0x58, +0x7D, 0x58, 0xCB, 0x59, 0x1A, 0x59, 0x69, 0x59, 0xB8, 0x5A, 0x07, 0x5A, +0x56, 0x5A, 0xA6, 0x5A, 0xF5, 0x5B, 0x45, 0x5B, 0x95, 0x5B, 0xE5, 0x5C, +0x35, 0x5C, 0x86, 0x5C, 0xD6, 0x5D, 0x27, 0x5D, 0x78, 0x5D, 0xC9, 0x5E, +0x1A, 0x5E, 0x6C, 0x5E, 0xBD, 0x5F, 0x0F, 0x5F, 0x61, 0x5F, 0xB3, 0x60, +0x05, 0x60, 0x57, 0x60, 0xAA, 0x60, 0xFC, 0x61, 0x4F, 0x61, 0xA2, 0x61, +0xF5, 0x62, 0x49, 0x62, 0x9C, 0x62, 0xF0, 0x63, 0x43, 0x63, 0x97, 0x63, +0xEB, 0x64, 0x40, 0x64, 0x94, 0x64, 0xE9, 0x65, 0x3D, 0x65, 0x92, 0x65, +0xE7, 0x66, 0x3D, 0x66, 0x92, 0x66, 0xE8, 0x67, 0x3D, 0x67, 0x93, 0x67, +0xE9, 0x68, 0x3F, 0x68, 0x96, 0x68, 0xEC, 0x69, 0x43, 0x69, 0x9A, 0x69, +0xF1, 0x6A, 0x48, 0x6A, 0x9F, 0x6A, 0xF7, 0x6B, 0x4F, 0x6B, 0xA7, 0x6B, +0xFF, 0x6C, 0x57, 0x6C, 0xAF, 0x6D, 0x08, 0x6D, 0x60, 0x6D, 0xB9, 0x6E, +0x12, 0x6E, 0x6B, 0x6E, 0xC4, 0x6F, 0x1E, 0x6F, 0x78, 0x6F, 0xD1, 0x70, +0x2B, 0x70, 0x86, 0x70, 0xE0, 0x71, 0x3A, 0x71, 0x95, 0x71, 0xF0, 0x72, +0x4B, 0x72, 0xA6, 0x73, 0x01, 0x73, 0x5D, 0x73, 0xB8, 0x74, 0x14, 0x74, +0x70, 0x74, 0xCC, 0x75, 0x28, 0x75, 0x85, 0x75, 0xE1, 0x76, 0x3E, 0x76, +0x9B, 0x76, 0xF8, 0x77, 0x56, 0x77, 0xB3, 0x78, 0x11, 0x78, 0x6E, 0x78, +0xCC, 0x79, 0x2A, 0x79, 0x89, 0x79, 0xE7, 0x7A, 0x46, 0x7A, 0xA5, 0x7B, +0x04, 0x7B, 0x63, 0x7B, 0xC2, 0x7C, 0x21, 0x7C, 0x81, 0x7C, 0xE1, 0x7D, +0x41, 0x7D, 0xA1, 0x7E, 0x01, 0x7E, 0x62, 0x7E, 0xC2, 0x7F, 0x23, 0x7F, +0x84, 0x7F, 0xE5, 0x80, 0x47, 0x80, 0xA8, 0x81, 0x0A, 0x81, 0x6B, 0x81, +0xCD, 0x82, 0x30, 0x82, 0x92, 0x82, 0xF4, 0x83, 0x57, 0x83, 0xBA, 0x84, +0x1D, 0x84, 0x80, 0x84, 0xE3, 0x85, 0x47, 0x85, 0xAB, 0x86, 0x0E, 0x86, +0x72, 0x86, 0xD7, 0x87, 0x3B, 0x87, 0x9F, 0x88, 0x04, 0x88, 0x69, 0x88, +0xCE, 0x89, 0x33, 0x89, 0x99, 0x89, 0xFE, 0x8A, 0x64, 0x8A, 0xCA, 0x8B, +0x30, 0x8B, 0x96, 0x8B, 0xFC, 0x8C, 0x63, 0x8C, 0xCA, 0x8D, 0x31, 0x8D, +0x98, 0x8D, 0xFF, 0x8E, 0x66, 0x8E, 0xCE, 0x8F, 0x36, 0x8F, 0x9E, 0x90, +0x06, 0x90, 0x6E, 0x90, 0xD6, 0x91, 0x3F, 0x91, 0xA8, 0x92, 0x11, 0x92, +0x7A, 0x92, 0xE3, 0x93, 0x4D, 0x93, 0xB6, 0x94, 0x20, 0x94, 0x8A, 0x94, +0xF4, 0x95, 0x5F, 0x95, 0xC9, 0x96, 0x34, 0x96, 0x9F, 0x97, 0x0A, 0x97, +0x75, 0x97, 0xE0, 0x98, 0x4C, 0x98, 0xB8, 0x99, 0x24, 0x99, 0x90, 0x99, +0xFC, 0x9A, 0x68, 0x9A, 0xD5, 0x9B, 0x42, 0x9B, 0xAF, 0x9C, 0x1C, 0x9C, +0x89, 0x9C, 0xF7, 0x9D, 0x64, 0x9D, 0xD2, 0x9E, 0x40, 0x9E, 0xAE, 0x9F, +0x1D, 0x9F, 0x8B, 0x9F, 0xFA, 0xA0, 0x69, 0xA0, 0xD8, 0xA1, 0x47, 0xA1, +0xB6, 0xA2, 0x26, 0xA2, 0x96, 0xA3, 0x06, 0xA3, 0x76, 0xA3, 0xE6, 0xA4, + 0x56, 0xA4, 0xC7, 0xA5, 0x38, 0xA5, 0xA9, 0xA6, 0x1A, 0xA6, 0x8B, 0xA6, + 0xFD, 0xA7, 0x6E, 0xA7, 0xE0, 0xA8, 0x52, 0xA8, 0xC4, 0xA9, 0x37, 0xA9, + 0xA9, 0xAA, 0x1C, 0xAA, 0x8F, 0xAB, 0x02, 0xAB, 0x75, 0xAB, 0xE9, 0xAC, + 0x5C, 0xAC, 0xD0, 0xAD, 0x44, 0xAD, 0xB8, 0xAE, 0x2D, 0xAE, 0xA1, 0xAF, + 0x16, 0xAF, 0x8B, 0xB0, 0x00, 0xB0, 0x75, 0xB0, 0xEA, 0xB1, 0x60, 0xB1, + 0xD6, 0xB2, 0x4B, 0xB2, 0xC2, 0xB3, 0x38, 0xB3, 0xAE, 0xB4, 0x25, 0xB4, + 0x9C, 0xB5, 0x13, 0xB5, 0x8A, 0xB6, 0x01, 0xB6, 0x79, 0xB6, 0xF0, 0xB7, + 0x68, 0xB7, 0xE0, 0xB8, 0x59, 0xB8, 0xD1, 0xB9, 0x4A, 0xB9, 0xC2, 0xBA, + 0x3B, 0xBA, 0xB5, 0xBB, 0x2E, 0xBB, 0xA7, 0xBC, 0x21, 0xBC, 0x9B, 0xBD, + 0x15, 0xBD, 0x8F, 0xBE, 0x0A, 0xBE, 0x84, 0xBE, 0xFF, 0xBF, 0x7A, 0xBF, + 0xF5, 0xC0, 0x70, 0xC0, 0xEC, 0xC1, 0x67, 0xC1, 0xE3, 0xC2, 0x5F, 0xC2, + 0xDB, 0xC3, 0x58, 0xC3, 0xD4, 0xC4, 0x51, 0xC4, 0xCE, 0xC5, 0x4B, 0xC5, + 0xC8, 0xC6, 0x46, 0xC6, 0xC3, 0xC7, 0x41, 0xC7, 0xBF, 0xC8, 0x3D, 0xC8, + 0xBC, 0xC9, 0x3A, 0xC9, 0xB9, 0xCA, 0x38, 0xCA, 0xB7, 0xCB, 0x36, 0xCB, + 0xB6, 0xCC, 0x35, 0xCC, 0xB5, 0xCD, 0x35, 0xCD, 0xB5, 0xCE, 0x36, 0xCE, + 0xB6, 0xCF, 0x37, 0xCF, 0xB8, 0xD0, 0x39, 0xD0, 0xBA, 0xD1, 0x3C, 0xD1, + 0xBE, 0xD2, 0x3F, 0xD2, 0xC1, 0xD3, 0x44, 0xD3, 0xC6, 0xD4, 0x49, 0xD4, + 0xCB, 0xD5, 0x4E, 0xD5, 0xD1, 0xD6, 0x55, 0xD6, 0xD8, 0xD7, 0x5C, 0xD7, + 0xE0, 0xD8, 0x64, 0xD8, 0xE8, 0xD9, 0x6C, 0xD9, 0xF1, 0xDA, 0x76, 0xDA, + 0xFB, 0xDB, 0x80, 0xDC, 0x05, 0xDC, 0x8A, 0xDD, 0x10, 0xDD, 0x96, 0xDE, + 0x1C, 0xDE, 0xA2, 0xDF, 0x29, 0xDF, 0xAF, 0xE0, 0x36, 0xE0, 0xBD, 0xE1, + 0x44, 0xE1, 0xCC, 0xE2, 0x53, 0xE2, 0xDB, 0xE3, 0x63, 0xE3, 0xEB, 0xE4, + 0x73, 0xE4, 0xFC, 0xE5, 0x84, 0xE6, 0x0D, 0xE6, 0x96, 0xE7, 0x1F, 0xE7, + 0xA9, 0xE8, 0x32, 0xE8, 0xBC, 0xE9, 0x46, 0xE9, 0xD0, 0xEA, 0x5B, 0xEA, + 0xE5, 0xEB, 0x70, 0xEB, 0xFB, 0xEC, 0x86, 0xED, 0x11, 0xED, 0x9C, 0xEE, + 0x28, 0xEE, 0xB4, 0xEF, 0x40, 0xEF, 0xCC, 0xF0, 0x58, 0xF0, 0xE5, 0xF1, + 0x72, 0xF1, 0xFF, 0xF2, 0x8C, 0xF3, 0x19, 0xF3, 0xA7, 0xF4, 0x34, 0xF4, + 0xC2, 0xF5, 0x50, 0xF5, 0xDE, 0xF6, 0x6D, 0xF6, 0xFB, 0xF7, 0x8A, 0xF8, + 0x19, 0xF8, 0xA8, 0xF9, 0x38, 0xF9, 0xC7, 0xFA, 0x57, 0xFA, 0xE7, 0xFB, + 0x77, 0xFC, 0x07, 0xFC, 0x98, 0xFD, 0x29, 0xFD, 0xBA, 0xFE, 0x4B, 0xFE, + 0xDC, 0xFF, 0x6D, 0xFF, 0xFF + +Table M.7 – "Ducky" marker template + + 0xEC, 0x00, 0x11, 0x44, 0x75, 0x63, 0x6B, 0x79, 0x00, 0x01, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x64, 0x00, 0x00 + +Table M.8 – "Adobe" marker template + + 0xEE, 0x00, 0x0E, 0x41, 0x64, 0x6F, 0x62, 0x65, 0x00, 0x64, 0x00, 0x00, + 0x00, 0x00, 0x01 + +Table M.9 – stock counts arrays + is_ac == 0, stock_index == 0: + + 0, 0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0 + +is_ac == 0, stock_index == 1: + + 0, 0, 1, 5, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0 + +is_ac == 1, stock_index == 0: + + 0, 0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 126 + +is_ac == 1, stock_index == 1: + + 0, 0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 120 + +Table M.10 – stock values arrays + +is_ac == 0, stock_index == 0: + + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 256 + +is_ac == 0, stock_index == 1: + + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 256 + +is_ac == 1, stock_index == 0: + + 1, 2, 3, 0, 4, 17, 5, 18, 33, 49, 65, 6, 19, 81, + 97, 7, 34, 113, 20, 50, 129, 145, 161, 8, 35, 66, 177, 193, + 21, 82, 209, 240, 36, 51, 98, 114, 130, 9, 10, 22, 23, 24, + 25, 26, 37, 38, 39, 40, 41, 42, 52, 53, 54, 55, 56, 57, + 58, 67, 68, 69, 70, 71, 72, 73, 74, 83, 84, 85, 86, 87, + 88, 89, 90, 99, 100, 101, 102, 103, 104, 105, 106, 115, 116, 117, + 118, 119, 120, 121, 122, 131, 132, 133, 134, 135, 136, 137, 138, 146, + 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, 166, 167, + 168, 169, 170, 178, 179, 180, 181, 182, 183, 184, 185, 186, 194, 195, + 196, 197, 198, 199, 200, 201, 202, 210, 211, 212, 213, 214, 215, 216, + 217, 218, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 241, 242, + 243, 244, 245, 246, 247, 248, 249, 250, 256 + +is_ac == 1, stock_index == 1: + 0, 1, 2, 3, 17, 4, 5, 33, 49, 6, 18, 65, 81, 7, + 97, 113, 19, 34, 50, 129, 8, 20, 66, 145, 161, 177, 193, 9, + 35, 51, 82, 240, 21, 98, 114, 209, 10, 22, 36, 52, 225, 37, + 241, 23, 24, 25, 26, 38, 39, 40, 41, 42, 53, 54, 55, 56, + 57, 58, 67, 68, 69, 70, 71, 72, 73, 74, 83, 84, 85, 86, + 87, 88, 89, 90, 99, 100, 101, 102, 103, 104, 105, 106, 115, 116, + 117, 118, 119, 120, 121, 122, 130, 131, 132, 133, 134, 135, 136, 137, + 138, 146, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 178, 179, 180, 181, 182, 183, 184, 185, 186, + 194, 195, 196, 197, 198, 199, 200, 201, 202, 210, 211, 212, 213, 214, + 215, 216, 217, 218, 226, 227, 228, 229, 230, 231, 232, 233, 234, 242, + 243, 244, 245, 246, 247, 248, 249, 250, 256 + +Table M.11 – predefined symbol order + +is_ac == 0: + + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +is_ac == 1: + + 1, 0, 2, 3, 17, 4, 5, 33, 18, 49, 65, 6, 81, 19, + 97, 7, 34, 113, 50, 129, 20, 145, 161, 8, 35, 66, 177, 193, + 21, 82, 209, 240, 36, 51, 98, 114, 9, 130, 10, 22, 52, 225, + 23, 37, 241, 24, 25, 26, 38, 39, 40, 41, 42, 53, 54, 55, + 56, 57, 58, 67, 68, 69, 70, 71, 72, 73, 74, 83, 84, 85, + 86, 87, 88, 89, 90, 99, 100, 101, 102, 103, 104, 105, 106, 115, + 116, 117, 118, 119, 120, 121, 122, 131, 132, 133, 134, 135, 136, 137, + 138, 146, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 178, 179, 180, 181, 182, 183, 184, 185, 186, + 194, 195, 196, 197, 198, 199, 200, 201, 202, 210, 211, 212, 213, 214, + 215, 216, 217, 218, 226, 227, 228, 229, 230, 231, 232, 233, 234, 242, + 243, 244, 245, 246, 247, 248, 249, 250, 16, 32, 48, 64, 80, 96, + 112, 128, 144, 160, 176, 192, 208, 11, 12, 13, 14, 15, 27, 28, + 29, 30, 31, 43, 44, 45, 46, 47, 59, 60, 61, 62, 63, 75, + 76, 77, 78, 79, 91, 92, 93, 94, 95, 107, 108, 109, 110, 111, + 123, 124, 125, 126, 127, 139, 140, 141, 142, 143, 155, 156, 157, 158, + 159, 171, 172, 173, 174, 175, 187, 188, 189, 190, 191, 203, 204, 205, + 206, 207, 219, 220, 221, 222, 223, 224, 235, 236, 237, 238, 239, 251, + 252, 253, 254, 255 + +Table M.12 – stock quant tables + is_luma == true, stock_index == 0: + + 3, 2, 2, 3, 5, 8, 10, 12, 2, 2, 3, 4, 5, 12, 12, 11, 3, 3, + 3, 5, 8, 11, 14, 11, 3, 3, 4, 6, 10, 17, 16, 12, 4, 4, 7, 11, + 14, 22, 21, 15, 5, 7, 11, 13, 16, 21, 23, 18, 10, 13, 16, 17, 21, 24, + 24, 20, 14, 18, 19, 20, 22, 20, 21, 20 + +is_luma == true, stock_index == 1: + + 8, 6, 5, 8, 12, 20, 26, 31, 6, 6, 7, 10, 13, 29, 30, 28, 7, 7, + 8, 12, 20, 29, 35, 28, 7, 9, 11, 15, 26, 44, 40, 31, 9, 11, 19, 28, + 34, 55, 52, 39, 12, 18, 28, 32, 41, 52, 57, 46, 25, 32, 39, 44, 52, 61, + 60, 51, 36, 46, 48, 49, 56, 50, 52, 50 + +is_luma == true, stock_index == 2: + + 6, 4, 4, 6, 10, 16, 20, 24, 5, 5, 6, 8, 10, 23, 24, 22, 6, 5, + 6, 10, 16, 23, 28, 22, 6, 7, 9, 12, 20, 35, 32, 25, 7, 9, 15, 22, + 27, 44, 41, 31, 10, 14, 22, 26, 32, 42, 45, 37, 20, 26, 31, 35, 41, 48, + 48, 40, 29, 37, 38, 39, 45, 40, 41, 40 + +is_luma == true, stock_index == 3: + + 5, 3, 3, 5, 7, 12, 15, 18, 4, 4, 4, 6, 8, 17, 18, 17, 4, 4, + 5, 7, 12, 17, 21, 17, 4, 5, 7, 9, 15, 26, 24, 19, 5, 7, 11, 17, + 20, 33, 31, 23, 7, 11, 17, 19, 24, 31, 34, 28, 15, 19, 23, 26, 31, 36, + 36, 30, 22, 28, 29, 29, 34, 30, 31, 30 + +is_luma == true, stock_index == 4: + + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + +is_luma == true, stock_index == 5: + + 2, 1, 1, 2, 2, 4, 5, 6, 1, 1, 1, 2, 3, 6, 6, 6, 1, 1, + 2, 2, 4, 6, 7, 6, 1, 2, 2, 3, 5, 9, 8, 6, 2, 2, 4, 6, + 7, 11, 10, 8, 2, 4, 6, 6, 8, 10, 11, 9, 5, 6, 8, 9, 10, 12, + 12, 10, 7, 9, 10, 10, 11, 10, 10, 10 + +is_luma == true, stock_index == 6: + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, + 1, 2, 2, 3, 1, 1, 1, 1, 2, 2, 3, 3, 1, 1, 1, 2, 2, 3, + 3, 3, 1, 1, 2, 2, 3, 3, 3, 3 + +is_luma == true, stock_index == 7: + + 10, 7, 6, 10, 14, 24, 31, 37, 7, 7, 8, 11, 16, 35, 36, 33, 8, 8, + 10, 14, 24, 34, 41, 34, 8, 10, 13, 17, 31, 52, 48, 37, 11, 13, 22, 34, + 41, 65, 62, 46, 14, 21, 33, 38, 49, 62, 68, 55, 29, 38, 47, 52, 62, 73, + 72, 61, 43, 55, 57, 59, 67, 60, 62, 59 + +is_luma == false, stock_index == 0: + + 9, 9, 9, 12, 11, 12, 24, 13, 13, 24, 50, 33, 28, 33, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50 + +is_luma == false, stock_index == 1: + + 3, 4, 5, 9, 20, 20, 20, 20, 4, 4, 5, 13, 20, 20, 20, 20, 5, 5, + 11, 20, 20, 20, 20, 20, 9, 13, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20 + +is_luma == false, stock_index == 2: + + 9, 9, 12, 24, 50, 50, 50, 50, 9, 11, 13, 33, 50, 50, 50, 50, 12, 13, + 28, 50, 50, 50, 50, 50, 24, 33, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50 + +is_luma == false, stock_index == 3: + + 5, 5, 7, 14, 30, 30, 30, 30, 5, 6, 8, 20, 30, 30, 30, 30, 7, 8, + 17, 30, 30, 30, 30, 30, 14, 20, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, + 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, + 30, 30, 30, 30, 30, 30, 30, 30, 30, 30 + +is_luma == false, stock_index == 4: + 7, 7, 10, 19, 40, 40, 40, 40, 7, 8, 10, 26, 40, 40, 40, 40, 10, 10, + 22, 40, 40, 40, 40, 40, 19, 26, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 40, 40, 40, 40, 40, 40, 40, 40, 40, 40 + +is_luma == false, stock_index == 5: + + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + +is_luma == false, stock_index == 6: + + 2, 2, 2, 5, 10, 10, 10, 10, 2, 2, 3, 7, 10, 10, 10, 10, 2, 3, + 6, 10, 10, 10, 10, 10, 5, 7, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10 + +is_luma == false, stock_index == 7: + + 10, 11, 14, 28, 59, 59, 59, 59, 11, 13, 16, 40, 59, 59, 59, 59, 14, 16, + 34, 59, 59, 59, 59, 59, 28, 40, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, + 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, + 59, 59, 59, 59, 59, 59, 59, 59, 59, 59 + +Table M.13 – template quant tables +is_luma == true: + + 16, 11, 10, 16, 24, 40, 51, 61, 12, 12, 14, 19, 26, 58, 60, + 55, 14, 13, 16, 24, 40, 57, 69, 56, 14, 17, 22, 29, 51, 87, + 80, 62, 18, 22, 37, 56, 68, 109, 103, 77, 24, 35, 55, 64, 81, +104, 113, 92, 49, 64, 78, 87, 103, 121, 120, 101, 72, 92, 95, 98, +112, 100, 103, 99 + +is_luma == false: + +17, 18, 24, 47, 99, 99, 99, 99, 18, 21, 26, 66, 99, 99, 99, 99, 24, 26, +56, 99, 99, 99, 99, 99, 47, 66, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, +99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, +99, 99, 99, 99, 99, 99, 99, 99, 99, 99 + Table M.15 – freq_context + +scheme == 0: + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + +scheme == 1: + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0 + +scheme == 2: + + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1 + +scheme == 3: + + 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 2, 2, 2 + +scheme == 4: + + 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 8, 8, 9, 9, + 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, + 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15 + +scheme == 5: + + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 24, 24, + 25, 25, 25, 25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, + 29, 29, 30, 30, 30, 30, 31, 31, 31, 31 + +scheme == 6: + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 + +Table M.16 – num_nonzero_context + +scheme == 0: + + 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 + +scheme == 1: + + 0, 2, 2, 4, 4, 4, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 10, 10, + 10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14 + +scheme == 2: + + 0, 4, 4, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 20, 20, + 20, 20, 20, 20, 20, 20, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28 + +scheme == 3: + + 0, 8, 8, 16, 16, 16, 24, 24, 24, 24, 32, 32, 32, 32, 32, 32, 40, 40, + 40, 40, 40, 40, 40, 40, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, + 48, 48, 48, 48, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, + 55, 55, 55, 55, 55, 55, 55, 55, 55, 55 + +scheme == 4: + + 0, 16, 16, 32, 32, 32, 48, 48, 48, 48, 64, 64, 64, 64, + 64, 64, 80, 80, 80, 80, 80, 80, 80, 80, 95, 95, 95, 95, + 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 109, 109, + 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, + 109, 109, 109, 109, 109, 109, 109, 109 + +scheme == 5: + 0, 32, 32, 64, 64, 64, 96, 96, 96, 96, 127, 127, 127, 127, + 127, 127, 157, 157, 157, 157, 157, 157, 157, 157, 185, 185, 185, 185, + 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 211, 211, + 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, + 211, 211, 211, 211, 211, 211, 211, 211 + +scheme == 6: + + 0, 64, 64, 127, 127, 127, 188, 188, 188, 188, 246, 246, 246, 246, + 246, 246, 300, 300, 300, 300, 300, 300, 300, 300, 348, 348, 348, 348, + 348, 348, 348, 348, 348, 348, 348, 348, 348, 348, 348, 348, 388, 388, + 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, + 388, 388, 388, 388, 388, 388, 388, 388 + +Table M.17 – nonzero_buckets + + 0, 1, 2, 3, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, + 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10 + +Table M.29 – context_modes table + + 0, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, + 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0 + + 0, 1, 1, 1, 1, 0, 0, 0, 2, 3, 1, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, + 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + \ No newline at end of file diff --git a/media/libjxl/src/doc/tables/all_tables.sh b/media/libjxl/src/doc/tables/all_tables.sh new file mode 100644 index 000000000..6fc98eb8c --- /dev/null +++ b/media/libjxl/src/doc/tables/all_tables.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +cat dct_gen.md \ + is_zero_base.md num_nonzeros_base.md brn_proto.md app0.md icc.md ducky.md \ + adobe.md stock_counts.md stock_values.md symbol_order.md stock_quant.md \ + quant.md freq_context.md num_nonzero_context.md nonzero_buckets.md \ + context_modes.md > all_tables.md diff --git a/media/libjxl/src/doc/tables/app0.md b/media/libjxl/src/doc/tables/app0.md new file mode 100644 index 000000000..266f21047 --- /dev/null +++ b/media/libjxl/src/doc/tables/app0.md @@ -0,0 +1,6 @@ +#### Table M.4 – APP0 template + +``` +0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00 +``` + diff --git a/media/libjxl/src/doc/tables/brn_proto.md b/media/libjxl/src/doc/tables/brn_proto.md new file mode 100644 index 000000000..b5f80b6a7 --- /dev/null +++ b/media/libjxl/src/doc/tables/brn_proto.md @@ -0,0 +1,23 @@ +#### Table M.3 – Protocol Buffer descriptor of top-level structure of losslessly compressed JPEG stream + +```protobuf +message Header { + optional uint64 width = 1; + optional uint64 height = 2; + required uint64 version_and_component_count_code = 3; + optional uint64 subsampling_code = 4; +} + +message Jpeg { + required bytes signature = 1; + required Header header = 2; + optional bytes meta_data = 3; + optional bytes jpeg1_internals = 4; + optional bytes quant_data = 5; + optional bytes histogram_data = 6; + optional bytes dc_data = 7; + optional bytes ac_data = 8; + optional bytes original_jpg = 9; +} +``` + diff --git a/media/libjxl/src/doc/tables/context_modes.md b/media/libjxl/src/doc/tables/context_modes.md new file mode 100644 index 000000000..59bff3649 --- /dev/null +++ b/media/libjxl/src/doc/tables/context_modes.md @@ -0,0 +1,13 @@ +#### Table M.29 – context_modes table + +``` +0, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, +0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, +0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0 +``` + +``` +0, 1, 1, 1, 1, 0, 0, 0, 2, 3, 1, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, +0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +``` diff --git a/media/libjxl/src/doc/tables/dct_gen.md b/media/libjxl/src/doc/tables/dct_gen.md new file mode 100644 index 000000000..f3b59e248 --- /dev/null +++ b/media/libjxl/src/doc/tables/dct_gen.md @@ -0,0 +1,241 @@ +#### Electronic Insert I.1 – DCT-II / DCT-III code generator + +```python +####################################################################### +# DCT-II / DCT-III generator +# +# Based on: +# "A low multiplicative complexity fast recursive DCT-2 algorithm" +# by Maxim Vashkevich and Alexander Petrovsky / arXiv / 20 Jul 2012 +####################################################################### + +import math +import sys +N = 8 + +####################################################################### +# Base transforms / generators +####################################################################### + +CNTR = 0 +def makeTmp(): + global CNTR + result = "t{:02d}".format(CNTR) + CNTR = CNTR + 1 + return result + +def makeVar(i): + return "i{:02d}".format(i) + +def add(x, y): + tmp = makeTmp() + print(tmp + " = " + x + " + " + y + ";") + return tmp + +def sub(x, y): + tmp = makeTmp() + print(tmp + " = " + x + " - " + y + ";") + return tmp + +def mul(x, c): + tmp = makeTmp() + print(tmp + " = " + x + " * " + c + ";") + return tmp + +# 2.0 * math.cos((a + 0.0) / (b + 0.0) * math.pi) +def C2(a, b): + return "c_c2_" + str(a) + "_" + str(b) + +# 1.0 / C2(a, b) +def iC2(a, b): + return "c_ic2_" + str(a) + "_" + str(b) + +####################################################################### +# Utilities +####################################################################### + +# Generate identity matrix. Usually this matrix is passed to +# DCT algorithm to generate "basis" vectors of the transform. +def makeVars(): + return [makeVar(i) for i in range(N)] + +# Split list of variables info halves. +def split(x): + m = len(x) + m2 = m // 2 + return (x[0 : m2], x[m2 : m]) + +# Make a list of variables in a reverse order. +def reverse(varz): + m = len(varz) + result = [0] * m + for i in range(m): + result[i] = varz[m - 1 - i] + return result + +# Apply permutation +def permute(x, p): + return [x[p[i]] for i in range(len(p))] + +def transposePermutation(p): + n = len(p) + result = [0] * n + for i in range(n): + result[p[i]] = i + return result + +# See paper. Split even-odd elements. +def P(n): + if n == 1: + return [0] + n2 = n // 2 + return [2 * i for i in range(n2)] + [2 * i + 1 for i in range(n2)] + +# See paper. Interleave first and second half. +def Pt(n): + return transposePermutation(P(n)) + +####################################################################### +# Scheme +####################################################################### + +def B2(x): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + bottom = reverse(bottom) + t = [add(top[i], bottom[i]) for i in range(n2)] + b = [sub(top[i], bottom[i]) for i in range(n2)] + return t + b + +def iB2(x): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + t = [add(top[i], bottom[i]) for i in range(n2)] + b = [sub(top[i], bottom[i]) for i in range(n2)] + return t + reverse(b) + +def B4(x, rn): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + rbottom = reverse(bottom) + t = [sub(top[i], rbottom[i]) for i in range(n2)] + b = [mul(bottom[i], C2(rn, 2 * N)) for i in range(n2)] + top = [add(t[i], b[i]) for i in range(n2)] + bottom = [sub(t[i], b[i]) for i in range(n2)] + return top + bottom + +def iB4(x, rn): + n = len(x) + n2 = n // 2 + if n == 1: + raise "ooops" + (top, bottom) = split(x) + t = [add(top[i], bottom[i]) for i in range(n2)] + b = [sub(top[i], bottom[i]) for i in range(n2)] + bottom = [mul(b[i], iC2(rn, 2 * N)) for i in range(n2)] + rbottom = reverse(bottom) + top = [add(t[i], rbottom[i]) for i in range(n2)] + return top + bottom + +def P4(n): + if n == 1: + return [0] + if n == 2: + return [0, 1] + n2 = n // 2 + result = [0] * n + tc = 0 + bc = 0 + i = 0 + result[i] = tc; tc = tc + 1; i = i + 1 + turn = True + while i < n - 1: + if turn: + result[i] = n2 + bc; bc = bc + 1; i = i + 1 + result[i] = n2 + bc; bc = bc + 1; i = i + 1 + else: + result[i] = tc; tc = tc + 1; i = i + 1 + result[i] = tc; tc = tc + 1; i = i + 1 + turn = not turn + result[i] = tc; tc = tc + 1; i = i + 1 + return result + +def iP4(n): + return transposePermutation(P4(n)) + +def d2n(x): + n = len(x) + if n == 1: + return x + y = B2(x) + (top, bottom) = split(y) + return permute(d2n(top) + d4n(bottom, N // 2), Pt(n)) + +def id2n(x): + n = len(x) + if n == 1: + return x + (top, bottom) = split(permute(x, P(n))) + return iB2(id2n(top) + id4n(bottom, N // 2)) + +def d4n(x, rn): + n = len(x) + if n == 1: + return x + y = B4(x, rn) + (top, bottom) = split(y) + rn2 = rn // 2 + return permute(d4n(top, rn2) + d4n(bottom, N - rn2), P4(n)) + +def id4n(x, rn): + n = len(x) + if n == 1: + return x + (top, bottom) = split(permute(x, iP4(n))) + rn2 = rn // 2 + y = id4n(top, rn2) + id4n(bottom, N -rn2) + return iB4(y, rn) + +####################################################################### +# Main. +####################################################################### + +def help(): + print("Usage: %s [N [T]]" % sys.argv[0]) + print(" N should be the power of 2, default is 8") + print(" T is one of {2, 3}, default is 2") + sys.exit() + +def parseInt(s): + try: + return int(s) + except ValueError: + help() + +if __name__ == "__main__": + if len(sys.argv) < 1 or len(sys.argv) > 3: help() + if len(sys.argv) >= 2: + N = parseInt(sys.argv[1]) + if (N & (N - 1)) != 0: help() + type = 0 + if len(sys.argv) >= 3: + typeOption = sys.argv[2] + if len(typeOption) != 1: help() + type = "23".index(typeOption) + if type == -1: help() + if type == 0: + vars = d2n(makeVars()) + else: # type == 1 + vars = id2n(makeVars()) + print("Output vector: " + str(vars)) +``` + diff --git a/media/libjxl/src/doc/tables/ducky.md b/media/libjxl/src/doc/tables/ducky.md new file mode 100644 index 000000000..307f68804 --- /dev/null +++ b/media/libjxl/src/doc/tables/ducky.md @@ -0,0 +1,6 @@ +#### Table M.7 – "Ducky" marker template + +``` +0xEC, 0x00, 0x11, 0x44, 0x75, 0x63, 0x6B, 0x79, 0x00, 0x01, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00 +``` + diff --git a/media/libjxl/src/doc/tables/freq_context.md b/media/libjxl/src/doc/tables/freq_context.md new file mode 100644 index 000000000..3e218fb1b --- /dev/null +++ b/media/libjxl/src/doc/tables/freq_context.md @@ -0,0 +1,54 @@ +#### Table M.15 – freq_context + +`scheme == 0`: +``` +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +``` + +`scheme == 1`: +``` +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0 +``` + +`scheme == 2`: +``` +0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, +2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, +3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1 +``` + +`scheme == 3`: +``` +0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, +6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, +7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 2, 2, 2 +``` + +`scheme == 4`: +``` + 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 8, 8, 9, 9, + 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, +13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, +15, 15, 15, 15, 15, 15, 15, 15, 15, 15 +``` + +`scheme == 5`: +``` + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, +17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 24, 24, +25, 25, 25, 25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, +29, 29, 30, 30, 30, 30, 31, 31, 31, 31 +``` + +`scheme == 6`: +``` + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, +18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, +36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, +54, 55, 56, 57, 58, 59, 60, 61, 62, 63 +``` + diff --git a/media/libjxl/src/doc/tables/icc.md b/media/libjxl/src/doc/tables/icc.md new file mode 100644 index 000000000..1f3b4cde3 --- /dev/null +++ b/media/libjxl/src/doc/tables/icc.md @@ -0,0 +1,6 @@ +#### Table M.6 – common ICC profile template + +``` +0xE2, 0x0C, 0x58, 0x49, 0x43, 0x43, 0x5F, 0x50, 0x52, 0x4F, 0x46, 0x49, 0x4C, 0x45, 0x00, 0x01, 0x01, 0x00, 0x00, 0x0C, 0x48, 0x4C, 0x69, 0x6E, 0x6F, 0x02, 0x10, 0x00, 0x00, 0x6D, 0x6E, 0x74, 0x72, 0x52, 0x47, 0x42, 0x20, 0x58, 0x59, 0x5A, 0x20, 0x07, 0xCE, 0x00, 0x02, 0x00, 0x09, 0x00, 0x06, 0x00, 0x31, 0x00, 0x00, 0x61, 0x63, 0x73, 0x70, 0x4D, 0x53, 0x46, 0x54, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x43, 0x20, 0x73, 0x52, 0x47, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xF6, 0xD6, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xD3, 0x2D, 0x48, 0x50, 0x20, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x63, 0x70, 0x72, 0x74, 0x00, 0x00, 0x01, 0x50, 0x00, 0x00, 0x00, 0x33, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x01, 0x84, 0x00, 0x00, 0x00, 0x6C, 0x77, 0x74, 0x70, 0x74, 0x00, 0x00, 0x01, 0xF0, 0x00, 0x00, 0x00, 0x14, 0x62, 0x6B, 0x70, 0x74, 0x00, 0x00, 0x02, 0x04, 0x00, 0x00, 0x00, 0x14, 0x72, 0x58, 0x59, 0x5A, 0x00, 0x00, 0x02, 0x18, 0x00, 0x00, 0x00, 0x14, 0x67, 0x58, 0x59, 0x5A, 0x00, 0x00, 0x02, 0x2C, 0x00, 0x00, 0x00, 0x14, 0x62, 0x58, 0x59, 0x5A, 0x00, 0x00, 0x02, 0x40, 0x00, 0x00, 0x00, 0x14, 0x64, 0x6D, 0x6E, 0x64, 0x00, 0x00, 0x02, 0x54, 0x00, 0x00, 0x00, 0x70, 0x64, 0x6D, 0x64, 0x64, 0x00, 0x00, 0x02, 0xC4, 0x00, 0x00, 0x00, 0x88, 0x76, 0x75, 0x65, 0x64, 0x00, 0x00, 0x03, 0x4C, 0x00, 0x00, 0x00, 0x86, 0x76, 0x69, 0x65, 0x77, 0x00, 0x00, 0x03, 0xD4, 0x00, 0x00, 0x00, 0x24, 0x6C, 0x75, 0x6D, 0x69, 0x00, 0x00, 0x03, 0xF8, 0x00, 0x00, 0x00, 0x14, 0x6D, 0x65, 0x61, 0x73, 0x00, 0x00, 0x04, 0x0C, 0x00, 0x00, 0x00, 0x24, 0x74, 0x65, 0x63, 0x68, 0x00, 0x00, 0x04, 0x30, 0x00, 0x00, 0x00, 0x0C, 0x72, 0x54, 0x52, 0x43, 0x00, 0x00, 0x04, 0x3C, 0x00, 0x00, 0x08, 0x0C, 0x67, 0x54, 0x52, 0x43, 0x00, 0x00, 0x04, 0x3C, 0x00, 0x00, 0x08, 0x0C, 0x62, 0x54, 0x52, 0x43, 0x00, 0x00, 0x04, 0x3C, 0x00, 0x00, 0x08, 0x0C, 0x74, 0x65, 0x78, 0x74, 0x00, 0x00, 0x00, 0x00, 0x43, 0x6F, 0x70, 0x79, 0x72, 0x69, 0x67, 0x68, 0x74, 0x20, 0x28, 0x63, 0x29, 0x20, 0x31, 0x39, 0x39, 0x38, 0x20, 0x48, 0x65, 0x77, 0x6C, 0x65, 0x74, 0x74, 0x2D, 0x50, 0x61, 0x63, 0x6B, 0x61, 0x72, 0x64, 0x20, 0x43, 0x6F, 0x6D, 0x70, 0x61, 0x6E, 0x79, 0x00, 0x00, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x73, 0x52, 0x47, 0x42, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x73, 0x52, 0x47, 0x42, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF3, 0x51, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x16, 0xCC, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6F, 0xA2, 0x00, 0x00, 0x38, 0xF5, 0x00, 0x00, 0x03, 0x90, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0x99, 0x00, 0x00, 0xB7, 0x85, 0x00, 0x00, 0x18, 0xDA, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x24, 0xA0, 0x00, 0x00, 0x0F, 0x84, 0x00, 0x00, 0xB6, 0xCF, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x49, 0x45, 0x43, 0x20, 0x68, 0x74, 0x74, 0x70, 0x3A, 0x2F, 0x2F, 0x77, 0x77, 0x77, 0x2E, 0x69, 0x65, 0x63, 0x2E, 0x63, 0x68, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x49, 0x45, 0x43, 0x20, 0x68, 0x74, 0x74, 0x70, 0x3A, 0x2F, 0x2F, 0x77, 0x77, 0x77, 0x2E, 0x69, 0x65, 0x63, 0x2E, 0x63, 0x68, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2E, 0x49, 0x45, 0x43, 0x20, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x20, 0x44, 0x65, 0x66, 0x61, 0x75, 0x6C, 0x74, 0x20, 0x52, 0x47, 0x42, 0x20, 0x63, 0x6F, 0x6C, 0x6F, 0x75, 0x72, 0x20, 0x73, 0x70, 0x61, 0x63, 0x65, 0x20, 0x2D, 0x20, 0x73, 0x52, 0x47, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2E, 0x49, 0x45, 0x43, 0x20, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x20, 0x44, 0x65, 0x66, 0x61, 0x75, 0x6C, 0x74, 0x20, 0x52, 0x47, 0x42, 0x20, 0x63, 0x6F, 0x6C, 0x6F, 0x75, 0x72, 0x20, 0x73, 0x70, 0x61, 0x63, 0x65, 0x20, 0x2D, 0x20, 0x73, 0x52, 0x47, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2C, 0x52, 0x65, 0x66, 0x65, 0x72, 0x65, 0x6E, 0x63, 0x65, 0x20, 0x56, 0x69, 0x65, 0x77, 0x69, 0x6E, 0x67, 0x20, 0x43, 0x6F, 0x6E, 0x64, 0x69, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x69, 0x6E, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2C, 0x52, 0x65, 0x66, 0x65, 0x72, 0x65, 0x6E, 0x63, 0x65, 0x20, 0x56, 0x69, 0x65, 0x77, 0x69, 0x6E, 0x67, 0x20, 0x43, 0x6F, 0x6E, 0x64, 0x69, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x69, 0x6E, 0x20, 0x49, 0x45, 0x43, 0x36, 0x31, 0x39, 0x36, 0x36, 0x2D, 0x32, 0x2E, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x76, 0x69, 0x65, 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0xA4, 0xFE, 0x00, 0x14, 0x5F, 0x2E, 0x00, 0x10, 0xCF, 0x14, 0x00, 0x03, 0xED, 0xCC, 0x00, 0x04, 0x13, 0x0B, 0x00, 0x03, 0x5C, 0x9E, 0x00, 0x00, 0x00, 0x01, 0x58, 0x59, 0x5A, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4C, 0x09, 0x56, 0x00, 0x50, 0x00, 0x00, 0x00, 0x57, 0x1F, 0xE7, 0x6D, 0x65, 0x61, 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x8F, 0x00, 0x00, 0x00, 0x02, 0x73, 0x69, 0x67, 0x20, 0x00, 0x00, 0x00, 0x00, 0x43, 0x52, 0x54, 0x20, 0x63, 0x75, 0x72, 0x76, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x0A, 0x00, 0x0F, 0x00, 0x14, 0x00, 0x19, 0x00, 0x1E, 0x00, 0x23, 0x00, 0x28, 0x00, 0x2D, 0x00, 0x32, 0x00, 0x37, 0x00, 0x3B, 0x00, 0x40, 0x00, 0x45, 0x00, 0x4A, 0x00, 0x4F, 0x00, 0x54, 0x00, 0x59, 0x00, 0x5E, 0x00, 0x63, 0x00, 0x68, 0x00, 0x6D, 0x00, 0x72, 0x00, 0x77, 0x00, 0x7C, 0x00, 0x81, 0x00, 0x86, 0x00, 0x8B, 0x00, 0x90, 0x00, 0x95, 0x00, 0x9A, 0x00, 0x9F, 0x00, 0xA4, 0x00, 0xA9, 0x00, 0xAE, 0x00, 0xB2, 0x00, 0xB7, 0x00, 0xBC, 0x00, 0xC1, 0x00, 0xC6, 0x00, 0xCB, 0x00, 0xD0, 0x00, 0xD5, 0x00, 0xDB, 0x00, 0xE0, 0x00, 0xE5, 0x00, 0xEB, 0x00, 0xF0, 0x00, 0xF6, 0x00, 0xFB, 0x01, 0x01, 0x01, 0x07, 0x01, 0x0D, 0x01, 0x13, 0x01, 0x19, 0x01, 0x1F, 0x01, 0x25, 0x01, 0x2B, 0x01, 0x32, 0x01, 0x38, 0x01, 0x3E, 0x01, 0x45, 0x01, 0x4C, 0x01, 0x52, 0x01, 0x59, 0x01, 0x60, 0x01, 0x67, 0x01, 0x6E, 0x01, 0x75, 0x01, 0x7C, 0x01, 0x83, 0x01, 0x8B, 0x01, 0x92, 0x01, 0x9A, 0x01, 0xA1, 0x01, 0xA9, 0x01, 0xB1, 0x01, 0xB9, 0x01, 0xC1, 0x01, 0xC9, 0x01, 0xD1, 0x01, 0xD9, 0x01, 0xE1, 0x01, 0xE9, 0x01, 0xF2, 0x01, 0xFA, 0x02, 0x03, 0x02, 0x0C, 0x02, 0x14, 0x02, 0x1D, 0x02, 0x26, 0x02, 0x2F, 0x02, 0x38, 0x02, 0x41, 0x02, 0x4B, 0x02, 0x54, 0x02, 0x5D, 0x02, 0x67, 0x02, 0x71, 0x02, 0x7A, 0x02, 0x84, 0x02, 0x8E, 0x02, 0x98, 0x02, 0xA2, 0x02, 0xAC, 0x02, 0xB6, 0x02, 0xC1, 0x02, 0xCB, 0x02, 0xD5, 0x02, 0xE0, 0x02, 0xEB, 0x02, 0xF5, 0x03, 0x00, 0x03, 0x0B, 0x03, 0x16, 0x03, 0x21, 0x03, 0x2D, 0x03, 0x38, 0x03, 0x43, 0x03, 0x4F, 0x03, 0x5A, 0x03, 0x66, 0x03, 0x72, 0x03, 0x7E, 0x03, 0x8A, 0x03, 0x96, 0x03, 0xA2, 0x03, 0xAE, 0x03, 0xBA, 0x03, 0xC7, 0x03, 0xD3, 0x03, 0xE0, 0x03, 0xEC, 0x03, 0xF9, 0x04, 0x06, 0x04, 0x13, 0x04, 0x20, 0x04, 0x2D, 0x04, 0x3B, 0x04, 0x48, 0x04, 0x55, 0x04, 0x63, 0x04, 0x71, 0x04, 0x7E, 0x04, 0x8C, 0x04, 0x9A, 0x04, 0xA8, 0x04, 0xB6, 0x04, 0xC4, 0x04, 0xD3, 0x04, 0xE1, 0x04, 0xF0, 0x04, 0xFE, 0x05, 0x0D, 0x05, 0x1C, 0x05, 0x2B, 0x05, 0x3A, 0x05, 0x49, 0x05, 0x58, 0x05, 0x67, 0x05, 0x77, 0x05, 0x86, 0x05, 0x96, 0x05, 0xA6, 0x05, 0xB5, 0x05, 0xC5, 0x05, 0xD5, 0x05, 0xE5, 0x05, 0xF6, 0x06, 0x06, 0x06, 0x16, 0x06, 0x27, 0x06, 0x37, 0x06, 0x48, 0x06, 0x59, 0x06, 0x6A, 0x06, 0x7B, 0x06, 0x8C, 0x06, 0x9D, 0x06, 0xAF, 0x06, 0xC0, 0x06, 0xD1, 0x06, 0xE3, 0x06, 0xF5, 0x07, 0x07, 0x07, 0x19, 0x07, 0x2B, 0x07, 0x3D, 0x07, 0x4F, 0x07, 0x61, 0x07, 0x74, 0x07, 0x86, 0x07, 0x99, 0x07, 0xAC, 0x07, 0xBF, 0x07, 0xD2, 0x07, 0xE5, 0x07, 0xF8, 0x08, 0x0B, 0x08, 0x1F, 0x08, 0x32, 0x08, 0x46, 0x08, 0x5A, 0x08, 0x6E, 0x08, 0x82, 0x08, 0x96, 0x08, 0xAA, 0x08, 0xBE, 0x08, 0xD2, 0x08, 0xE7, 0x08, 0xFB, 0x09, 0x10, 0x09, 0x25, 0x09, 0x3A, 0x09, 0x4F, 0x09, 0x64, 0x09, 0x79, 0x09, 0x8F, 0x09, 0xA4, 0x09, 0xBA, 0x09, 0xCF, 0x09, 0xE5, 0x09, 0xFB, 0x0A, 0x11, 0x0A, 0x27, 0x0A, 0x3D, 0x0A, 0x54, 0x0A, 0x6A, 0x0A, 0x81, 0x0A, 0x98, 0x0A, 0xAE, 0x0A, 0xC5, 0x0A, 0xDC, 0x0A, 0xF3, 0x0B, 0x0B, 0x0B, 0x22, 0x0B, 0x39, 0x0B, 0x51, 0x0B, 0x69, 0x0B, 0x80, 0x0B, 0x98, 0x0B, 0xB0, 0x0B, 0xC8, 0x0B, 0xE1, 0x0B, 0xF9, 0x0C, 0x12, 0x0C, 0x2A, 0x0C, 0x43, 0x0C, 0x5C, 0x0C, 0x75, 0x0C, 0x8E, 0x0C, 0xA7, 0x0C, 0xC0, 0x0C, 0xD9, 0x0C, 0xF3, 0x0D, 0x0D, 0x0D, 0x26, 0x0D, 0x40, 0x0D, 0x5A, 0x0D, 0x74, 0x0D, 0x8E, 0x0D, 0xA9, 0x0D, 0xC3, 0x0D, 0xDE, 0x0D, 0xF8, 0x0E, 0x13, 0x0E, 0x2E, 0x0E, 0x49, 0x0E, 0x64, 0x0E, 0x7F, 0x0E, 0x9B, 0x0E, 0xB6, 0x0E, 0xD2, 0x0E, 0xEE, 0x0F, 0x09, 0x0F, 0x25, 0x0F, 0x41, 0x0F, 0x5E, 0x0F, 0x7A, 0x0F, 0x96, 0x0F, 0xB3, 0x0F, 0xCF, 0x0F, 0xEC, 0x10, 0x09, 0x10, 0x26, 0x10, 0x43, 0x10, 0x61, 0x10, 0x7E, 0x10, 0x9B, 0x10, 0xB9, 0x10, 0xD7, 0x10, 0xF5, 0x11, 0x13, 0x11, 0x31, 0x11, 0x4F, 0x11, 0x6D, 0x11, 0x8C, 0x11, 0xAA, 0x11, 0xC9, 0x11, 0xE8, 0x12, 0x07, 0x12, 0x26, 0x12, 0x45, 0x12, 0x64, 0x12, 0x84, 0x12, 0xA3, 0x12, 0xC3, 0x12, 0xE3, 0x13, 0x03, 0x13, 0x23, 0x13, 0x43, 0x13, 0x63, 0x13, 0x83, 0x13, 0xA4, 0x13, 0xC5, 0x13, 0xE5, 0x14, 0x06, 0x14, 0x27, 0x14, 0x49, 0x14, 0x6A, 0x14, 0x8B, 0x14, 0xAD, 0x14, 0xCE, 0x14, 0xF0, 0x15, 0x12, 0x15, 0x34, 0x15, 0x56, 0x15, 0x78, 0x15, 0x9B, 0x15, 0xBD, 0x15, 0xE0, 0x16, 0x03, 0x16, 0x26, 0x16, 0x49, 0x16, 0x6C, 0x16, 0x8F, 0x16, 0xB2, 0x16, 0xD6, 0x16, 0xFA, 0x17, 0x1D, 0x17, 0x41, 0x17, 0x65, 0x17, 0x89, 0x17, 0xAE, 0x17, 0xD2, 0x17, 0xF7, 0x18, 0x1B, 0x18, 0x40, 0x18, 0x65, 0x18, 0x8A, 0x18, 0xAF, 0x18, 0xD5, 0x18, 0xFA, 0x19, 0x20, 0x19, 0x45, 0x19, 0x6B, 0x19, 0x91, 0x19, 0xB7, 0x19, 0xDD, 0x1A, 0x04, 0x1A, 0x2A, 0x1A, 0x51, 0x1A, 0x77, 0x1A, 0x9E, 0x1A, 0xC5, 0x1A, 0xEC, 0x1B, 0x14, 0x1B, 0x3B, 0x1B, 0x63, 0x1B, 0x8A, 0x1B, 0xB2, 0x1B, 0xDA, 0x1C, 0x02, 0x1C, 0x2A, 0x1C, 0x52, 0x1C, 0x7B, 0x1C, 0xA3, 0x1C, 0xCC, 0x1C, 0xF5, 0x1D, 0x1E, 0x1D, 0x47, 0x1D, 0x70, 0x1D, 0x99, 0x1D, 0xC3, 0x1D, 0xEC, 0x1E, 0x16, 0x1E, 0x40, 0x1E, 0x6A, 0x1E, 0x94, 0x1E, 0xBE, 0x1E, 0xE9, 0x1F, 0x13, 0x1F, 0x3E, 0x1F, 0x69, 0x1F, 0x94, 0x1F, 0xBF, 0x1F, 0xEA, 0x20, 0x15, 0x20, 0x41, 0x20, 0x6C, 0x20, 0x98, 0x20, 0xC4, 0x20, 0xF0, 0x21, 0x1C, 0x21, 0x48, 0x21, 0x75, 0x21, 0xA1, 0x21, 0xCE, 0x21, 0xFB, 0x22, 0x27, 0x22, 0x55, 0x22, 0x82, 0x22, 0xAF, 0x22, 0xDD, 0x23, 0x0A, 0x23, 0x38, 0x23, 0x66, 0x23, 0x94, 0x23, 0xC2, 0x23, 0xF0, 0x24, 0x1F, 0x24, 0x4D, 0x24, 0x7C, 0x24, 0xAB, 0x24, 0xDA, 0x25, 0x09, 0x25, 0x38, 0x25, 0x68, 0x25, 0x97, 0x25, 0xC7, 0x25, 0xF7, 0x26, 0x27, 0x26, 0x57, 0x26, 0x87, 0x26, 0xB7, 0x26, 0xE8, 0x27, 0x18, 0x27, 0x49, 0x27, 0x7A, 0x27, 0xAB, 0x27, 0xDC, 0x28, 0x0D, 0x28, 0x3F, 0x28, 0x71, 0x28, 0xA2, 0x28, 0xD4, 0x29, 0x06, 0x29, 0x38, 0x29, 0x6B, 0x29, 0x9D, 0x29, 0xD0, 0x2A, 0x02, 0x2A, 0x35, 0x2A, 0x68, 0x2A, 0x9B, 0x2A, 0xCF, 0x2B, 0x02, 0x2B, 0x36, 0x2B, 0x69, 0x2B, 0x9D, 0x2B, 0xD1, 0x2C, 0x05, 0x2C, 0x39, 0x2C, 0x6E, 0x2C, 0xA2, 0x2C, 0xD7, 0x2D, 0x0C, 0x2D, 0x41, 0x2D, 0x76, 0x2D, 0xAB, 0x2D, 0xE1, 0x2E, 0x16, 0x2E, 0x4C, 0x2E, 0x82, 0x2E, 0xB7, 0x2E, 0xEE, 0x2F, 0x24, 0x2F, 0x5A, 0x2F, 0x91, 0x2F, 0xC7, 0x2F, 0xFE, 0x30, 0x35, 0x30, 0x6C, 0x30, 0xA4, 0x30, 0xDB, 0x31, 0x12, 0x31, 0x4A, 0x31, 0x82, 0x31, 0xBA, 0x31, 0xF2, 0x32, 0x2A, 0x32, 0x63, 0x32, 0x9B, 0x32, 0xD4, 0x33, 0x0D, 0x33, 0x46, 0x33, 0x7F, 0x33, 0xB8, 0x33, 0xF1, 0x34, 0x2B, 0x34, 0x65, 0x34, 0x9E, 0x34, 0xD8, 0x35, 0x13, 0x35, 0x4D, 0x35, 0x87, 0x35, 0xC2, 0x35, 0xFD, 0x36, 0x37, 0x36, 0x72, 0x36, 0xAE, 0x36, 0xE9, 0x37, 0x24, 0x37, 0x60, 0x37, 0x9C, 0x37, 0xD7, 0x38, 0x14, 0x38, 0x50, 0x38, 0x8C, 0x38, 0xC8, 0x39, 0x05, 0x39, 0x42, 0x39, 0x7F, 0x39, 0xBC, 0x39, 0xF9, 0x3A, 0x36, 0x3A, 0x74, 0x3A, 0xB2, 0x3A, 0xEF, 0x3B, 0x2D, 0x3B, 0x6B, 0x3B, 0xAA, 0x3B, 0xE8, 0x3C, 0x27, 0x3C, 0x65, 0x3C, 0xA4, 0x3C, 0xE3, 0x3D, 0x22, 0x3D, 0x61, 0x3D, 0xA1, 0x3D, 0xE0, 0x3E, 0x20, 0x3E, 0x60, 0x3E, 0xA0, 0x3E, 0xE0, 0x3F, 0x21, 0x3F, 0x61, 0x3F, 0xA2, 0x3F, 0xE2, 0x40, 0x23, 0x40, 0x64, 0x40, 0xA6, 0x40, 0xE7, 0x41, 0x29, 0x41, 0x6A, 0x41, 0xAC, 0x41, 0xEE, 0x42, 0x30, 0x42, 0x72, 0x42, 0xB5, 0x42, 0xF7, 0x43, 0x3A, 0x43, 0x7D, 0x43, 0xC0, 0x44, 0x03, 0x44, 0x47, 0x44, 0x8A, 0x44, 0xCE, 0x45, 0x12, 0x45, 0x55, 0x45, 0x9A, 0x45, 0xDE, 0x46, 0x22, 0x46, 0x67, 0x46, 0xAB, 0x46, 0xF0, 0x47, 0x35, 0x47, 0x7B, 0x47, 0xC0, 0x48, 0x05, 0x48, 0x4B, 0x48, 0x91, 0x48, 0xD7, 0x49, 0x1D, 0x49, 0x63, 0x49, 0xA9, 0x49, 0xF0, 0x4A, 0x37, 0x4A, 0x7D, 0x4A, 0xC4, 0x4B, 0x0C, 0x4B, 0x53, 0x4B, 0x9A, 0x4B, 0xE2, 0x4C, 0x2A, 0x4C, 0x72, 0x4C, 0xBA, 0x4D, 0x02, 0x4D, 0x4A, 0x4D, 0x93, 0x4D, 0xDC, 0x4E, 0x25, 0x4E, 0x6E, 0x4E, 0xB7, 0x4F, 0x00, 0x4F, 0x49, 0x4F, 0x93, 0x4F, 0xDD, 0x50, 0x27, 0x50, 0x71, 0x50, 0xBB, 0x51, 0x06, 0x51, 0x50, 0x51, 0x9B, 0x51, 0xE6, 0x52, 0x31, 0x52, 0x7C, 0x52, 0xC7, 0x53, 0x13, 0x53, 0x5F, 0x53, 0xAA, 0x53, 0xF6, 0x54, 0x42, 0x54, 0x8F, 0x54, 0xDB, 0x55, 0x28, 0x55, 0x75, 0x55, 0xC2, 0x56, 0x0F, 0x56, 0x5C, 0x56, 0xA9, 0x56, 0xF7, 0x57, 0x44, 0x57, 0x92, 0x57, 0xE0, 0x58, 0x2F, 0x58, 0x7D, 0x58, 0xCB, 0x59, 0x1A, 0x59, 0x69, 0x59, 0xB8, 0x5A, 0x07, 0x5A, 0x56, 0x5A, 0xA6, 0x5A, 0xF5, 0x5B, 0x45, 0x5B, 0x95, 0x5B, 0xE5, 0x5C, 0x35, 0x5C, 0x86, 0x5C, 0xD6, 0x5D, 0x27, 0x5D, 0x78, 0x5D, 0xC9, 0x5E, 0x1A, 0x5E, 0x6C, 0x5E, 0xBD, 0x5F, 0x0F, 0x5F, 0x61, 0x5F, 0xB3, 0x60, 0x05, 0x60, 0x57, 0x60, 0xAA, 0x60, 0xFC, 0x61, 0x4F, 0x61, 0xA2, 0x61, 0xF5, 0x62, 0x49, 0x62, 0x9C, 0x62, 0xF0, 0x63, 0x43, 0x63, 0x97, 0x63, 0xEB, 0x64, 0x40, 0x64, 0x94, 0x64, 0xE9, 0x65, 0x3D, 0x65, 0x92, 0x65, 0xE7, 0x66, 0x3D, 0x66, 0x92, 0x66, 0xE8, 0x67, 0x3D, 0x67, 0x93, 0x67, 0xE9, 0x68, 0x3F, 0x68, 0x96, 0x68, 0xEC, 0x69, 0x43, 0x69, 0x9A, 0x69, 0xF1, 0x6A, 0x48, 0x6A, 0x9F, 0x6A, 0xF7, 0x6B, 0x4F, 0x6B, 0xA7, 0x6B, 0xFF, 0x6C, 0x57, 0x6C, 0xAF, 0x6D, 0x08, 0x6D, 0x60, 0x6D, 0xB9, 0x6E, 0x12, 0x6E, 0x6B, 0x6E, 0xC4, 0x6F, 0x1E, 0x6F, 0x78, 0x6F, 0xD1, 0x70, 0x2B, 0x70, 0x86, 0x70, 0xE0, 0x71, 0x3A, 0x71, 0x95, 0x71, 0xF0, 0x72, 0x4B, 0x72, 0xA6, 0x73, 0x01, 0x73, 0x5D, 0x73, 0xB8, 0x74, 0x14, 0x74, 0x70, 0x74, 0xCC, 0x75, 0x28, 0x75, 0x85, 0x75, 0xE1, 0x76, 0x3E, 0x76, 0x9B, 0x76, 0xF8, 0x77, 0x56, 0x77, 0xB3, 0x78, 0x11, 0x78, 0x6E, 0x78, 0xCC, 0x79, 0x2A, 0x79, 0x89, 0x79, 0xE7, 0x7A, 0x46, 0x7A, 0xA5, 0x7B, 0x04, 0x7B, 0x63, 0x7B, 0xC2, 0x7C, 0x21, 0x7C, 0x81, 0x7C, 0xE1, 0x7D, 0x41, 0x7D, 0xA1, 0x7E, 0x01, 0x7E, 0x62, 0x7E, 0xC2, 0x7F, 0x23, 0x7F, 0x84, 0x7F, 0xE5, 0x80, 0x47, 0x80, 0xA8, 0x81, 0x0A, 0x81, 0x6B, 0x81, 0xCD, 0x82, 0x30, 0x82, 0x92, 0x82, 0xF4, 0x83, 0x57, 0x83, 0xBA, 0x84, 0x1D, 0x84, 0x80, 0x84, 0xE3, 0x85, 0x47, 0x85, 0xAB, 0x86, 0x0E, 0x86, 0x72, 0x86, 0xD7, 0x87, 0x3B, 0x87, 0x9F, 0x88, 0x04, 0x88, 0x69, 0x88, 0xCE, 0x89, 0x33, 0x89, 0x99, 0x89, 0xFE, 0x8A, 0x64, 0x8A, 0xCA, 0x8B, 0x30, 0x8B, 0x96, 0x8B, 0xFC, 0x8C, 0x63, 0x8C, 0xCA, 0x8D, 0x31, 0x8D, 0x98, 0x8D, 0xFF, 0x8E, 0x66, 0x8E, 0xCE, 0x8F, 0x36, 0x8F, 0x9E, 0x90, 0x06, 0x90, 0x6E, 0x90, 0xD6, 0x91, 0x3F, 0x91, 0xA8, 0x92, 0x11, 0x92, 0x7A, 0x92, 0xE3, 0x93, 0x4D, 0x93, 0xB6, 0x94, 0x20, 0x94, 0x8A, 0x94, 0xF4, 0x95, 0x5F, 0x95, 0xC9, 0x96, 0x34, 0x96, 0x9F, 0x97, 0x0A, 0x97, 0x75, 0x97, 0xE0, 0x98, 0x4C, 0x98, 0xB8, 0x99, 0x24, 0x99, 0x90, 0x99, 0xFC, 0x9A, 0x68, 0x9A, 0xD5, 0x9B, 0x42, 0x9B, 0xAF, 0x9C, 0x1C, 0x9C, 0x89, 0x9C, 0xF7, 0x9D, 0x64, 0x9D, 0xD2, 0x9E, 0x40, 0x9E, 0xAE, 0x9F, 0x1D, 0x9F, 0x8B, 0x9F, 0xFA, 0xA0, 0x69, 0xA0, 0xD8, 0xA1, 0x47, 0xA1, 0xB6, 0xA2, 0x26, 0xA2, 0x96, 0xA3, 0x06, 0xA3, 0x76, 0xA3, 0xE6, 0xA4, 0x56, 0xA4, 0xC7, 0xA5, 0x38, 0xA5, 0xA9, 0xA6, 0x1A, 0xA6, 0x8B, 0xA6, 0xFD, 0xA7, 0x6E, 0xA7, 0xE0, 0xA8, 0x52, 0xA8, 0xC4, 0xA9, 0x37, 0xA9, 0xA9, 0xAA, 0x1C, 0xAA, 0x8F, 0xAB, 0x02, 0xAB, 0x75, 0xAB, 0xE9, 0xAC, 0x5C, 0xAC, 0xD0, 0xAD, 0x44, 0xAD, 0xB8, 0xAE, 0x2D, 0xAE, 0xA1, 0xAF, 0x16, 0xAF, 0x8B, 0xB0, 0x00, 0xB0, 0x75, 0xB0, 0xEA, 0xB1, 0x60, 0xB1, 0xD6, 0xB2, 0x4B, 0xB2, 0xC2, 0xB3, 0x38, 0xB3, 0xAE, 0xB4, 0x25, 0xB4, 0x9C, 0xB5, 0x13, 0xB5, 0x8A, 0xB6, 0x01, 0xB6, 0x79, 0xB6, 0xF0, 0xB7, 0x68, 0xB7, 0xE0, 0xB8, 0x59, 0xB8, 0xD1, 0xB9, 0x4A, 0xB9, 0xC2, 0xBA, 0x3B, 0xBA, 0xB5, 0xBB, 0x2E, 0xBB, 0xA7, 0xBC, 0x21, 0xBC, 0x9B, 0xBD, 0x15, 0xBD, 0x8F, 0xBE, 0x0A, 0xBE, 0x84, 0xBE, 0xFF, 0xBF, 0x7A, 0xBF, 0xF5, 0xC0, 0x70, 0xC0, 0xEC, 0xC1, 0x67, 0xC1, 0xE3, 0xC2, 0x5F, 0xC2, 0xDB, 0xC3, 0x58, 0xC3, 0xD4, 0xC4, 0x51, 0xC4, 0xCE, 0xC5, 0x4B, 0xC5, 0xC8, 0xC6, 0x46, 0xC6, 0xC3, 0xC7, 0x41, 0xC7, 0xBF, 0xC8, 0x3D, 0xC8, 0xBC, 0xC9, 0x3A, 0xC9, 0xB9, 0xCA, 0x38, 0xCA, 0xB7, 0xCB, 0x36, 0xCB, 0xB6, 0xCC, 0x35, 0xCC, 0xB5, 0xCD, 0x35, 0xCD, 0xB5, 0xCE, 0x36, 0xCE, 0xB6, 0xCF, 0x37, 0xCF, 0xB8, 0xD0, 0x39, 0xD0, 0xBA, 0xD1, 0x3C, 0xD1, 0xBE, 0xD2, 0x3F, 0xD2, 0xC1, 0xD3, 0x44, 0xD3, 0xC6, 0xD4, 0x49, 0xD4, 0xCB, 0xD5, 0x4E, 0xD5, 0xD1, 0xD6, 0x55, 0xD6, 0xD8, 0xD7, 0x5C, 0xD7, 0xE0, 0xD8, 0x64, 0xD8, 0xE8, 0xD9, 0x6C, 0xD9, 0xF1, 0xDA, 0x76, 0xDA, 0xFB, 0xDB, 0x80, 0xDC, 0x05, 0xDC, 0x8A, 0xDD, 0x10, 0xDD, 0x96, 0xDE, 0x1C, 0xDE, 0xA2, 0xDF, 0x29, 0xDF, 0xAF, 0xE0, 0x36, 0xE0, 0xBD, 0xE1, 0x44, 0xE1, 0xCC, 0xE2, 0x53, 0xE2, 0xDB, 0xE3, 0x63, 0xE3, 0xEB, 0xE4, 0x73, 0xE4, 0xFC, 0xE5, 0x84, 0xE6, 0x0D, 0xE6, 0x96, 0xE7, 0x1F, 0xE7, 0xA9, 0xE8, 0x32, 0xE8, 0xBC, 0xE9, 0x46, 0xE9, 0xD0, 0xEA, 0x5B, 0xEA, 0xE5, 0xEB, 0x70, 0xEB, 0xFB, 0xEC, 0x86, 0xED, 0x11, 0xED, 0x9C, 0xEE, 0x28, 0xEE, 0xB4, 0xEF, 0x40, 0xEF, 0xCC, 0xF0, 0x58, 0xF0, 0xE5, 0xF1, 0x72, 0xF1, 0xFF, 0xF2, 0x8C, 0xF3, 0x19, 0xF3, 0xA7, 0xF4, 0x34, 0xF4, 0xC2, 0xF5, 0x50, 0xF5, 0xDE, 0xF6, 0x6D, 0xF6, 0xFB, 0xF7, 0x8A, 0xF8, 0x19, 0xF8, 0xA8, 0xF9, 0x38, 0xF9, 0xC7, 0xFA, 0x57, 0xFA, 0xE7, 0xFB, 0x77, 0xFC, 0x07, 0xFC, 0x98, 0xFD, 0x29, 0xFD, 0xBA, 0xFE, 0x4B, 0xFE, 0xDC, 0xFF, 0x6D, 0xFF, 0xFF +``` + diff --git a/media/libjxl/src/doc/tables/is_zero_base.md b/media/libjxl/src/doc/tables/is_zero_base.md new file mode 100644 index 000000000..7e2d081f3 --- /dev/null +++ b/media/libjxl/src/doc/tables/is_zero_base.md @@ -0,0 +1,9 @@ +#### Table M.1 – is_zero_base table + +``` +228, 216, 216, 195, 192, 189, 182, 184, 179, 176, 171, 168, 166, 159, +156, 151, 151, 150, 150, 146, 144, 138, 138, 137, 135, 131, 127, 126, +124, 123, 124, 123, 122, 121, 118, 117, 114, 115, 116, 116, 115, 115, +114, 111, 111, 111, 112, 111, 110, 110, 110, 111, 111, 114, 110, 111, +112, 113, 116, 120, 126, 131, 147, 160 +``` diff --git a/media/libjxl/src/doc/tables/markdown-pdf.css b/media/libjxl/src/doc/tables/markdown-pdf.css new file mode 100644 index 000000000..c1efc1cc2 --- /dev/null +++ b/media/libjxl/src/doc/tables/markdown-pdf.css @@ -0,0 +1,36 @@ +/* + settings.json: + "markdown-pdf.styles": ["markdown-pdf.css",], + "markdown-pdf.format": "Letter", + "markdown-pdf.margin.top": "1in", + "markdown-pdf.margin.bottom": "1in", + "markdown-pdf.margin.left": "1in", + "markdown-pdf.margin.right": "1in", + "markdown-pdf.stylesRelativePathFile" : true, + "markdown-pdf.displayHeaderFooter": false, + */ + +body { + font-family: "Times"; + font-size: 10pt; + padding: 0; +} + +h4 { + font-family: "Times New Roman"; + font-size: 10pt; + font-weight: bold; +} + +code { + font-family: Consolas, "Source Code Pro"; + font-size: 10pt; +} + +pre.hljs code > div { + padding: 0px; +} + +:not(pre):not(.hljs) > code { + color: #4d4d4c; +} diff --git a/media/libjxl/src/doc/tables/nonzero_buckets.md b/media/libjxl/src/doc/tables/nonzero_buckets.md new file mode 100644 index 000000000..77a5a396e --- /dev/null +++ b/media/libjxl/src/doc/tables/nonzero_buckets.md @@ -0,0 +1,9 @@ +#### Table M.17 – nonzero_buckets + +``` + 0, 1, 2, 3, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, + 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10 +``` + diff --git a/media/libjxl/src/doc/tables/num_nonzero_context.md b/media/libjxl/src/doc/tables/num_nonzero_context.md new file mode 100644 index 000000000..b73d48c38 --- /dev/null +++ b/media/libjxl/src/doc/tables/num_nonzero_context.md @@ -0,0 +1,60 @@ +#### Table M.16 – num_nonzero_context + +`scheme == 0`: +``` +0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, +6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, +7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 +``` + +`scheme == 1`: +``` + 0, 2, 2, 4, 4, 4, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 10, 10, +10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, +12, 12, 12, 12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, +14, 14, 14, 14, 14, 14, 14, 14, 14, 14 +``` + +`scheme == 2`: +``` + 0, 4, 4, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 20, 20, +20, 20, 20, 20, 20, 20, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, +24, 24, 24, 24, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, +28, 28, 28, 28, 28, 28, 28, 28, 28, 28 +``` + +`scheme == 3`: +``` + 0, 8, 8, 16, 16, 16, 24, 24, 24, 24, 32, 32, 32, 32, 32, 32, 40, 40, +40, 40, 40, 40, 40, 40, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, +48, 48, 48, 48, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, +55, 55, 55, 55, 55, 55, 55, 55, 55, 55 +``` + +`scheme == 4`: +``` + 0, 16, 16, 32, 32, 32, 48, 48, 48, 48, 64, 64, 64, 64, + 64, 64, 80, 80, 80, 80, 80, 80, 80, 80, 95, 95, 95, 95, + 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 109, 109, +109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, 109, +109, 109, 109, 109, 109, 109, 109, 109 +``` + +`scheme == 5`: +``` + 0, 32, 32, 64, 64, 64, 96, 96, 96, 96, 127, 127, 127, 127, +127, 127, 157, 157, 157, 157, 157, 157, 157, 157, 185, 185, 185, 185, +185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 211, 211, +211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, +211, 211, 211, 211, 211, 211, 211, 211 +``` + +`scheme == 6`: +``` + 0, 64, 64, 127, 127, 127, 188, 188, 188, 188, 246, 246, 246, 246, +246, 246, 300, 300, 300, 300, 300, 300, 300, 300, 348, 348, 348, 348, +348, 348, 348, 348, 348, 348, 348, 348, 348, 348, 348, 348, 388, 388, +388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, 388, +388, 388, 388, 388, 388, 388, 388, 388 +``` + diff --git a/media/libjxl/src/doc/tables/num_nonzeros_base.md b/media/libjxl/src/doc/tables/num_nonzeros_base.md new file mode 100644 index 000000000..165c7389c --- /dev/null +++ b/media/libjxl/src/doc/tables/num_nonzeros_base.md @@ -0,0 +1,258 @@ +#### Table M.2 – num_nonzeros_base table + +``` +251, 252, 117, 249, 161, 136, 83, 238, 184, 126, 137, 129, 140, 119, + 70, 213, 160, 175, 174, 130, 166, 134, 122, 125, 131, 144, 136, 133, +139, 123, 79, 216, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +254, 252, 174, 232, 189, 155, 122, 177, 204, 173, 146, 149, 141, 133, +103, 109, 167, 187, 168, 142, 154, 147, 125, 139, 144, 138, 138, 153, +141, 133, 90, 121, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +251, 240, 197, 176, 184, 177, 114, 89, 194, 165, 153, 161, 158, 136, + 92, 95, 123, 171, 160, 140, 148, 136, 129, 139, 145, 136, 143, 134, +138, 124, 92, 154, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +247, 220, 201, 110, 194, 176, 147, 59, 175, 171, 156, 157, 152, 146, +115, 114, 88, 151, 164, 141, 153, 135, 141, 131, 146, 139, 140, 145, +138, 137, 112, 184, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +238, 179, 203, 63, 194, 173, 149, 71, 139, 169, 154, 159, 150, 146, +117, 143, 78, 122, 152, 137, 149, 138, 138, 133, 134, 142, 142, 142, +148, 128, 118, 199, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +227, 127, 200, 44, 192, 170, 148, 100, 102, 161, 156, 153, 148, 149, +124, 160, 88, 101, 134, 132, 149, 145, 134, 134, 136, 141, 138, 142, +144, 137, 116, 208, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +214, 86, 195, 44, 187, 163, 148, 126, 81, 147, 156, 152, 150, 144, +121, 172, 96, 95, 117, 122, 145, 152, 136, 133, 135, 135, 131, 142, +141, 135, 114, 217, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +198, 56, 191, 54, 171, 162, 147, 144, 74, 128, 152, 149, 150, 142, +119, 177, 101, 100, 106, 111, 135, 154, 136, 137, 136, 132, 133, 142, +144, 130, 117, 222, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +176, 40, 189, 73, 147, 159, 148, 152, 79, 106, 147, 149, 151, 139, +123, 188, 108, 110, 106, 97, 125, 151, 137, 138, 135, 135, 134, 136, +140, 131, 116, 221, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +148, 33, 185, 88, 117, 158, 145, 163, 95, 91, 137, 146, 150, 140, +120, 197, 115, 116, 114, 92, 114, 144, 130, 133, 132, 133, 129, 140, +138, 130, 111, 224, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` +117, 31, 180, 104, 93, 150, 143, 166, 99, 85, 124, 139, 148, 142, +118, 201, 105, 120, 120, 90, 107, 135, 127, 130, 131, 131, 132, 140, +142, 133, 114, 229, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 87, 35, 170, 110, 78, 141, 144, 176, 106, 90, 112, 132, 143, 138, +119, 204, 111, 121, 125, 90, 105, 131, 124, 122, 129, 128, 129, 137, +138, 133, 114, 227, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 63, 42, 159, 123, 73, 127, 142, 191, 105, 91, 105, 123, 139, 137, +120, 209, 117, 110, 122, 98, 110, 125, 115, 123, 122, 126, 128, 134, +141, 129, 113, 229, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 45, 53, 146, 135, 71, 114, 138, 193, 100, 98, 98, 113, 133, 135, +118, 222, 113, 111, 139, 103, 107, 126, 111, 119, 121, 122, 127, 135, +141, 128, 114, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 33, 60, 132, 138, 75, 100, 134, 203, 112, 99, 98, 105, 126, 131, +115, 229, 107, 93, 121, 106, 108, 122, 106, 109, 114, 116, 127, 133, +143, 128, 110, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 24, 70, 118, 134, 76, 87, 130, 201, 110, 96, 99, 97, 119, 130, +111, 229, 97, 104, 125, 102, 112, 125, 101, 109, 113, 114, 125, 129, +142, 127, 112, 241, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 17, 65, 100, 121, 80, 75, 124, 174, 117, 100, 94, 93, 114, 128, +110, 216, 103, 94, 113, 122, 118, 126, 113, 108, 105, 108, 122, 128, +141, 125, 113, 238, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 12, 70, 82, 132, 78, 65, 118, 155, 136, 103, 97, 89, 106, 124, +111, 215, 115, 123, 129, 99, 104, 127, 110, 108, 101, 109, 118, 126, +136, 123, 110, 233, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 8, 66, 61, 117, 91, 59, 108, 195, 101, 112, 99, 99, 99, 116, +106, 230, 127, 99, 144, 101, 118, 137, 117, 111, 106, 104, 116, 121, +134, 122, 110, 223, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 6, 78, 42, 146, 101, 54, 94, 201, 116, 102, 110, 94, 92, 108, +103, 214, 108, 111, 127, 102, 121, 132, 120, 121, 95, 98, 110, 121, +129, 117, 107, 235, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 5, 93, 29, 145, 102, 52, 77, 216, 108, 115, 108, 102, 89, 97, + 94, 229, 89, 103, 139, 120, 103, 151, 102, 100, 97, 96, 99, 111, +125, 116, 104, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 4, 105, 21, 145, 100, 54, 64, 217, 100, 122, 128, 87, 88, 91, + 87, 230, 112, 80, 148, 95, 146, 123, 96, 140, 90, 91, 98, 106, +122, 111, 100, 249, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 4, 130, 14, 142, 104, 56, 51, 208, 116, 135, 100, 89, 82, 84, + 75, 239, 85, 85, 122, 125, 94, 144, 151, 136, 92, 97, 104, 109, +113, 110, 91, 246, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 3, 126, 9, 172, 105, 57, 39, 219, 95, 120, 118, 96, 93, 75, + 66, 241, 102, 134, 96, 156, 146, 162, 130, 112, 82, 89, 97, 101, +116, 103, 82, 254, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 3, 149, 7, 182, 122, 54, 29, 224, 103, 100, 113, 96, 90, 74, + 55, 250, 127, 94, 118, 93, 135, 160, 113, 130, 95, 117, 106, 96, +111, 97, 77, 242, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 3, 150, 4, 170, 138, 59, 20, 229, 91, 150, 107, 98, 92, 68, + 48, 245, 113, 64, 114, 111, 134, 127, 102, 104, 85, 118, 103, 107, +102, 91, 72, 245, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 3, 171, 3, 165, 137, 62, 14, 211, 96, 127, 132, 121, 95, 62, + 37, 248, 102, 57, 144, 85, 127, 191, 102, 97, 127, 104, 91, 102, +107, 81, 64, 254, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 2, 166, 2, 196, 122, 65, 10, 243, 102, 93, 117, 92, 96, 63, + 29, 251, 169, 159, 149, 96, 91, 139, 157, 40, 100, 89, 120, 92, +109, 79, 58, 247, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 2, 176, 2, 189, 118, 48, 7, 219, 68, 43, 109, 96, 129, 75, + 19, 254, 2, 3, 185, 6, 102, 127, 127, 127, 1, 131, 83, 99, +107, 80, 45, 254, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 1, 205, 2, 208, 64, 89, 4, 223, 29, 169, 29, 123, 118, 76, + 11, 240, 202, 243, 65, 6, 12, 243, 96, 55, 102, 102, 114, 102, +107, 74, 31, 247, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 1, 216, 1, 214, 127, 94, 2, 234, 145, 3, 127, 106, 155, 80, + 4, 247, 4, 65, 86, 127, 127, 127, 127, 102, 127, 143, 143, 108, +113, 80, 16, 216, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + +``` + 2, 199, 1, 222, 93, 94, 1, 232, 2, 65, 74, 139, 201, 48, + 2, 254, 169, 127, 52, 243, 251, 249, 102, 86, 202, 153, 65, 65, +146, 69, 8, 238, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, +128, 128, 128, 128, 128, 128, 128 +``` + diff --git a/media/libjxl/src/doc/tables/quant.md b/media/libjxl/src/doc/tables/quant.md new file mode 100644 index 000000000..1fb80d754 --- /dev/null +++ b/media/libjxl/src/doc/tables/quant.md @@ -0,0 +1,19 @@ +#### Table M.13 – template quant tables + +`is_luma == true`: +``` + 16, 11, 10, 16, 24, 40, 51, 61, 12, 12, 14, 19, 26, 58, 60, + 55, 14, 13, 16, 24, 40, 57, 69, 56, 14, 17, 22, 29, 51, 87, + 80, 62, 18, 22, 37, 56, 68, 109, 103, 77, 24, 35, 55, 64, 81, +104, 113, 92, 49, 64, 78, 87, 103, 121, 120, 101, 72, 92, 95, 98, +112, 100, 103, 99 +``` + +`is_luma == false`: +``` +17, 18, 24, 47, 99, 99, 99, 99, 18, 21, 26, 66, 99, 99, 99, 99, 24, 26, +56, 99, 99, 99, 99, 99, 47, 66, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, +99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, +99, 99, 99, 99, 99, 99, 99, 99, 99, 99 +``` + diff --git a/media/libjxl/src/doc/tables/stock_counts.md b/media/libjxl/src/doc/tables/stock_counts.md new file mode 100644 index 000000000..6e8e4458c --- /dev/null +++ b/media/libjxl/src/doc/tables/stock_counts.md @@ -0,0 +1,22 @@ +#### Table M.9 – stock counts arrays + +`is_ac == 0`, `stock_index == 0`: +``` +0, 0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0 +``` + +`is_ac == 0`, `stock_index == 1`: +``` +0, 0, 1, 5, 1, 1, 1, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0 +``` + +`is_ac == 1`, `stock_index == 0`: +``` +0, 0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 126 +``` + +`is_ac == 1`, `stock_index == 1`: +``` +0, 0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 120 +``` + diff --git a/media/libjxl/src/doc/tables/stock_quant.md b/media/libjxl/src/doc/tables/stock_quant.md new file mode 100644 index 000000000..b32fd3c7d --- /dev/null +++ b/media/libjxl/src/doc/tables/stock_quant.md @@ -0,0 +1,130 @@ +#### Table M.12 – stock quant tables + +`is_luma == true`, `stock_index == 0`: +``` + 3, 2, 2, 3, 5, 8, 10, 12, 2, 2, 3, 4, 5, 12, 12, 11, 3, 3, + 3, 5, 8, 11, 14, 11, 3, 3, 4, 6, 10, 17, 16, 12, 4, 4, 7, 11, +14, 22, 21, 15, 5, 7, 11, 13, 16, 21, 23, 18, 10, 13, 16, 17, 21, 24, +24, 20, 14, 18, 19, 20, 22, 20, 21, 20 +``` + +`is_luma == true`, `stock_index == 1`: +``` + 8, 6, 5, 8, 12, 20, 26, 31, 6, 6, 7, 10, 13, 29, 30, 28, 7, 7, + 8, 12, 20, 29, 35, 28, 7, 9, 11, 15, 26, 44, 40, 31, 9, 11, 19, 28, +34, 55, 52, 39, 12, 18, 28, 32, 41, 52, 57, 46, 25, 32, 39, 44, 52, 61, +60, 51, 36, 46, 48, 49, 56, 50, 52, 50 +``` + +`is_luma == true`, `stock_index == 2`: +``` + 6, 4, 4, 6, 10, 16, 20, 24, 5, 5, 6, 8, 10, 23, 24, 22, 6, 5, + 6, 10, 16, 23, 28, 22, 6, 7, 9, 12, 20, 35, 32, 25, 7, 9, 15, 22, +27, 44, 41, 31, 10, 14, 22, 26, 32, 42, 45, 37, 20, 26, 31, 35, 41, 48, +48, 40, 29, 37, 38, 39, 45, 40, 41, 40 +``` + +`is_luma == true`, `stock_index == 3`: +``` + 5, 3, 3, 5, 7, 12, 15, 18, 4, 4, 4, 6, 8, 17, 18, 17, 4, 4, + 5, 7, 12, 17, 21, 17, 4, 5, 7, 9, 15, 26, 24, 19, 5, 7, 11, 17, +20, 33, 31, 23, 7, 11, 17, 19, 24, 31, 34, 28, 15, 19, 23, 26, 31, 36, +36, 30, 22, 28, 29, 29, 34, 30, 31, 30 +``` + +`is_luma == true`, `stock_index == 4`: +``` + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 +``` + +`is_luma == true`, `stock_index == 5`: +``` + 2, 1, 1, 2, 2, 4, 5, 6, 1, 1, 1, 2, 3, 6, 6, 6, 1, 1, + 2, 2, 4, 6, 7, 6, 1, 2, 2, 3, 5, 9, 8, 6, 2, 2, 4, 6, + 7, 11, 10, 8, 2, 4, 6, 6, 8, 10, 11, 9, 5, 6, 8, 9, 10, 12, +12, 10, 7, 9, 10, 10, 11, 10, 10, 10 +``` + +`is_luma == true`, `stock_index == 6`: +``` + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, + 1, 2, 2, 3, 1, 1, 1, 1, 2, 2, 3, 3, 1, 1, 1, 2, 2, 3, + 3, 3, 1, 1, 2, 2, 3, 3, 3, 3 +``` + +`is_luma == true`, `stock_index == 7`: +``` +10, 7, 6, 10, 14, 24, 31, 37, 7, 7, 8, 11, 16, 35, 36, 33, 8, 8, +10, 14, 24, 34, 41, 34, 8, 10, 13, 17, 31, 52, 48, 37, 11, 13, 22, 34, +41, 65, 62, 46, 14, 21, 33, 38, 49, 62, 68, 55, 29, 38, 47, 52, 62, 73, +72, 61, 43, 55, 57, 59, 67, 60, 62, 59 +``` + +`is_luma == false`, `stock_index == 0`: +``` + 9, 9, 9, 12, 11, 12, 24, 13, 13, 24, 50, 33, 28, 33, 50, 50, 50, 50, +50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, +50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, +50, 50, 50, 50, 50, 50, 50, 50, 50, 50 +``` + +`is_luma == false`, `stock_index == 1`: +``` + 3, 4, 5, 9, 20, 20, 20, 20, 4, 4, 5, 13, 20, 20, 20, 20, 5, 5, +11, 20, 20, 20, 20, 20, 9, 13, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, +20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, +20, 20, 20, 20, 20, 20, 20, 20, 20, 20 +``` + +`is_luma == false`, `stock_index == 2`: +``` + 9, 9, 12, 24, 50, 50, 50, 50, 9, 11, 13, 33, 50, 50, 50, 50, 12, 13, +28, 50, 50, 50, 50, 50, 24, 33, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, +50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, +50, 50, 50, 50, 50, 50, 50, 50, 50, 50 +``` + +`is_luma == false`, `stock_index == 3`: +``` + 5, 5, 7, 14, 30, 30, 30, 30, 5, 6, 8, 20, 30, 30, 30, 30, 7, 8, +17, 30, 30, 30, 30, 30, 14, 20, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, +30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, +30, 30, 30, 30, 30, 30, 30, 30, 30, 30 +``` + +`is_luma == false`, `stock_index == 4`: +``` + 7, 7, 10, 19, 40, 40, 40, 40, 7, 8, 10, 26, 40, 40, 40, 40, 10, 10, +22, 40, 40, 40, 40, 40, 19, 26, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, +40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, +40, 40, 40, 40, 40, 40, 40, 40, 40, 40 +``` + +`is_luma == false`, `stock_index == 5`: +``` + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 +``` + +`is_luma == false`, `stock_index == 6`: +``` + 2, 2, 2, 5, 10, 10, 10, 10, 2, 2, 3, 7, 10, 10, 10, 10, 2, 3, + 6, 10, 10, 10, 10, 10, 5, 7, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, +10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, +10, 10, 10, 10, 10, 10, 10, 10, 10, 10 +``` + +`is_luma == false`, `stock_index == 7`: +``` +10, 11, 14, 28, 59, 59, 59, 59, 11, 13, 16, 40, 59, 59, 59, 59, 14, 16, +34, 59, 59, 59, 59, 59, 28, 40, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, +59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, +59, 59, 59, 59, 59, 59, 59, 59, 59, 59 +``` + diff --git a/media/libjxl/src/doc/tables/stock_values.md b/media/libjxl/src/doc/tables/stock_values.md new file mode 100644 index 000000000..8e67cff1e --- /dev/null +++ b/media/libjxl/src/doc/tables/stock_values.md @@ -0,0 +1,44 @@ +#### Table M.10 – stock values arrays + +`is_ac == 0`, `stock_index == 0`: +``` +0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 256 +``` + +`is_ac == 0`, `stock_index == 1`: +``` +0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 256 +``` + +`is_ac == 1`, `stock_index == 0`: +``` + 1, 2, 3, 0, 4, 17, 5, 18, 33, 49, 65, 6, 19, 81, + 97, 7, 34, 113, 20, 50, 129, 145, 161, 8, 35, 66, 177, 193, + 21, 82, 209, 240, 36, 51, 98, 114, 130, 9, 10, 22, 23, 24, + 25, 26, 37, 38, 39, 40, 41, 42, 52, 53, 54, 55, 56, 57, + 58, 67, 68, 69, 70, 71, 72, 73, 74, 83, 84, 85, 86, 87, + 88, 89, 90, 99, 100, 101, 102, 103, 104, 105, 106, 115, 116, 117, +118, 119, 120, 121, 122, 131, 132, 133, 134, 135, 136, 137, 138, 146, +147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, 166, 167, +168, 169, 170, 178, 179, 180, 181, 182, 183, 184, 185, 186, 194, 195, +196, 197, 198, 199, 200, 201, 202, 210, 211, 212, 213, 214, 215, 216, +217, 218, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 241, 242, +243, 244, 245, 246, 247, 248, 249, 250, 256 +``` + +`is_ac == 1`, `stock_index == 1`: +``` + 0, 1, 2, 3, 17, 4, 5, 33, 49, 6, 18, 65, 81, 7, + 97, 113, 19, 34, 50, 129, 8, 20, 66, 145, 161, 177, 193, 9, + 35, 51, 82, 240, 21, 98, 114, 209, 10, 22, 36, 52, 225, 37, +241, 23, 24, 25, 26, 38, 39, 40, 41, 42, 53, 54, 55, 56, + 57, 58, 67, 68, 69, 70, 71, 72, 73, 74, 83, 84, 85, 86, + 87, 88, 89, 90, 99, 100, 101, 102, 103, 104, 105, 106, 115, 116, +117, 118, 119, 120, 121, 122, 130, 131, 132, 133, 134, 135, 136, 137, +138, 146, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, +166, 167, 168, 169, 170, 178, 179, 180, 181, 182, 183, 184, 185, 186, +194, 195, 196, 197, 198, 199, 200, 201, 202, 210, 211, 212, 213, 214, +215, 216, 217, 218, 226, 227, 228, 229, 230, 231, 232, 233, 234, 242, +243, 244, 245, 246, 247, 248, 249, 250, 256 +``` + diff --git a/media/libjxl/src/doc/tables/symbol_order.md b/media/libjxl/src/doc/tables/symbol_order.md new file mode 100644 index 000000000..a196c0f52 --- /dev/null +++ b/media/libjxl/src/doc/tables/symbol_order.md @@ -0,0 +1,30 @@ +#### Table M.11 – predefined symbol order + +`is_ac == 0`: +``` +0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +``` + +`is_ac == 1`: +``` + 1, 0, 2, 3, 17, 4, 5, 33, 18, 49, 65, 6, 81, 19, + 97, 7, 34, 113, 50, 129, 20, 145, 161, 8, 35, 66, 177, 193, + 21, 82, 209, 240, 36, 51, 98, 114, 9, 130, 10, 22, 52, 225, + 23, 37, 241, 24, 25, 26, 38, 39, 40, 41, 42, 53, 54, 55, + 56, 57, 58, 67, 68, 69, 70, 71, 72, 73, 74, 83, 84, 85, + 86, 87, 88, 89, 90, 99, 100, 101, 102, 103, 104, 105, 106, 115, +116, 117, 118, 119, 120, 121, 122, 131, 132, 133, 134, 135, 136, 137, +138, 146, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, +166, 167, 168, 169, 170, 178, 179, 180, 181, 182, 183, 184, 185, 186, +194, 195, 196, 197, 198, 199, 200, 201, 202, 210, 211, 212, 213, 214, +215, 216, 217, 218, 226, 227, 228, 229, 230, 231, 232, 233, 234, 242, +243, 244, 245, 246, 247, 248, 249, 250, 16, 32, 48, 64, 80, 96, +112, 128, 144, 160, 176, 192, 208, 11, 12, 13, 14, 15, 27, 28, + 29, 30, 31, 43, 44, 45, 46, 47, 59, 60, 61, 62, 63, 75, + 76, 77, 78, 79, 91, 92, 93, 94, 95, 107, 108, 109, 110, 111, +123, 124, 125, 126, 127, 139, 140, 141, 142, 143, 155, 156, 157, 158, +159, 171, 172, 173, 174, 175, 187, 188, 189, 190, 191, 203, 204, 205, +206, 207, 219, 220, 221, 222, 223, 224, 235, 236, 237, 238, 239, 251, +252, 253, 254, 255 +``` + diff --git a/media/libjxl/src/doc/vuln_playbook.md b/media/libjxl/src/doc/vuln_playbook.md new file mode 100644 index 000000000..1326d70a9 --- /dev/null +++ b/media/libjxl/src/doc/vuln_playbook.md @@ -0,0 +1,245 @@ +# Security Vulnerabilities Playbook + +## Reporting security bugs + +Report security bugs by emailing libjxl-security@google.com. + +Don't open a GitHub issue, don't discuss it public forums like Discord and don't +send a Pull Request if you think you have found a security bug. + +## Overview + +This document outlines the guidelines followed by the project when handling +security bugs, their fixes, disclosure and coordination with security +researchers. For more context about this guide, read the [coordinated +vulnerability disclosure +guidelines](https://github.com/google/oss-vulnerability-guide/blob/main/guide.md) +from Google Open Source Programs Office. + +The main target audience of this guide is the coordinator from the libjxl +Vulnerability Management Team (VMT) handling the requests, however it is useful +for other people to understand what to expect from this process. + +Members of the VMT monitor the reports received by email and will coordinate +for these to be addressed. This doesn't mean that said member would fix the bug, +but their responsibility is to make sure it is handled properly according to +this guide. + +## Life of security bug + +The Coordinator from VMT will make sure that the following steps are taken. + +1. Acknowledge the bug report. + +Our policy mandates a maximum of **3 business days** to respond to bug reports +in the given email, but you should respond as soon as possible and keep a fluid +communication with the reporter, who has spent some time looking at the issue. + +2. Determine if the bug is a security bug covered by our policy. + +Not all bugs are security bugs, and not all security bugs are covered by this +vulnerability disclosure policy. See the [What's a Security bug] section below. + +3. Determine the affected versions. + +Often new bugs on stable projects are found on new features or because of those +new features, so only the most recent versions are affected. It is important to +determine both what older versions are affected, so users running those older +versions can patch or update the software, and also what older versions are +*not* affected. It is possible that stable distributions ship older versions +that didn't contain the bug and therefore don't need to patch the code. Often +maintainers of package distributions need to patch older versions instead of +updating due to incompatibilities with newer ones and they need to understand +what's the vulnerable code. + +Security bugs that have already been fixed in `main` or in already released code +but not disclosed as a vulnerability, for example if fixed as a result of a +refactor, should be treated like any other security bug in this policy and +disclosed indicating the range of older affected versions (expect for versions +before 0.5, see below). In such case a new release would likely not be needed if +one already exists, but stable distributions may be still using those version +and need to be aware of the issue and fix. + +If no released version is affected by the bug, for example because it was only +introduced in the `main` branch but not yet released, then no vulnerability +disclosure is needed. + +Note: Versions before 0.5 are not covered by the security policy. Those versions +have multiple security issues and should not be used anyway. + +4. Communicate with the reporter + +Communicate the decision to the reporter. + +If the bug was not considered a security bug or not covered by this policy, +explain why and direct the reporter to open a public [issue in +GitHub](https://github.com/libjxl/libjxl/issues) or open one on their behalf. +You don't need to follow the rest of the guide in this case. + +If the bug *is* a covered security bug then follow the rest of this guide. + +Ask the reporter how they want to be credited in the disclosure: name and +company affiliation if any. Security researchers often value this recognition +and helps them dedicate their time to finding security bugs in our project. + +There's no bug bounty (monetary compensation for security bugs) available for +libjxl. + +5. Create a Security Advisory draft in GitHub + +At this point it was established that the bug is a security issue that requires +a vulnerability disclosure. Start by creating a Security Advisory draft in the +[Security Advisories](https://github.com/libjxl/libjxl/security/advisories) page +in GitHub. + +Add a short description of the bug explaining what's the issue and what's the +impact of the issue. Being 'hard' or 'complex' to exploit is not a reason to +discard the potential impact. You can update this description later, save it as +a draft in GitHub. + +Add the reporter to the security advisory draft if they have a GitHub account, +and add the project members that will be working on a fix for the bug. + +Establish the severity of the issue according to the impact and tag the +appropriate Common Weakness Enumeration (CWE) values. This helps classify the +security issues according to their nature. + +6. Work on a fix in a private branch + +Coordinators can work on the fix themselves, use a proposed fix from the +reporter if there is one, or work with other project members to create one. + +Work on a fix for the bug in *private*. Don't publish a Pull Request with the +fix like you normally do, and don't upload the fix to your libjxl fork. If you +ask another project member to work on it, explain them that they should follow +this guide. + +7. Request a CVE number + +The Common Vulnerabilities and Exposures (CVE) is the system used to disclose +vulnerabilities in software. A CVE number, like CVE-2021-NNNNNN, is a unique +identifier for a given vulnerability. These numbers are assigned by a CVE +Numbering Authority (CNA) with scope on the given project that has the +vulnerability. For libjxl, we use Google's Generic CNA. + +For VMT coordinators at Google, file a bug at +[go/cve-request](https://goto.google.com/cve-request) to request a CVE. See +go/vcp-cna for context. + +When requesting the CVE include: + + * A description of the problem (example: bug when parsing this field) + * A description of the impact of the bug (example: OOB read, remote code + execution, etc) + * The proposed CWE id(s) determined earlier. + * List of affected versions. + * Reporter of the bug and their preferred name/company to include in the + disclosure. + * Links to the issues/fixes (if already public), these can be added later, even + after the CVE is public. + * The CPE prefix of the affected project (`cpe:2.3:a:libjxl_project:libjxl`) + +When in doubt, you can discuss these with the security team while requesting it. + +8. File a Security bug in Chromium (if affected). + +libjxl project is in charge of updating and maintaining Chromium's libjxl +integration code, this includes updating the libjxl library when needed. While +the regular CVE disclosure process will eventually create a bug to update +Chromium, filing one at this stage speeds up the process. + +[go/crbug](https://goto.google.com/crbug), select the "Security Bug" template +and complete the details. This bug will be used to keep track of what versions +of Chromium need backporting. The new bug in Chromium will not be public +initially, but will be made public some time after the issue is fixed. + +9. Test the fixes on the intended releases + +When disclosing a vulnerability normally two ways to fix it are offered: + + * A patch or set of patches that fix the issue on `main` branch, and + * A new release that contains the security fix for the user to update to. + +New releases that fix the vulnerability should be PATCH releases, that is, a +previous release (like 1.2.3) plus the patches that fix the vulnerability, +becoming a new version (like 1.2.4). See the [release process](release.md) for +details. At least the latest MINOR release branch should have a PATCH release +with the fix, however it might make sense to also backport the fix to older +minor branch releases, depending on long-term support schedule for certain +releases. For example, if many users are still using a particular older version +of the library and updating to a new version requires significant changes (due +to a redesigned API or new unavailable dependencies) it is helpful to provide a +PATCH release there too. + +In either case, make sure that you test the fix in all the branches that you +intend to release it to. + +The Continuous Integration pipelines don't work on the private forks created by +the Security Advisory, so manual testing of the fix is needed there before +making it public. Don't upload it to your public fork for testing. + +10. Coordinate a date for release of the vulnerability disclosure. + +Agree with the reporter and security folks from the CNA on a release date. There +is a maximum of 90 day disclosure timeline from the day the bug was reported. + +On the disclosure date publish the fixes and tag the new PATCH release with the +fix. You can prepare private drafts of the release for review beforehand to +reduce the workload. + +Update Chromium to the new release version (if affected) and work with Chrome +engineers on the required backports. + +## What's a Security bug + +A security bug is a bug that can potentially be exploited to let an attacker +gain unauthorized access or privileges. For example, gaining code execution in +libjxl decoder by decoding a malicious .jxl file is a security but hitting a +`JXL_ASSERT()` is not necessarily one. + +The supported use cases to consider in the context of security bugs that require +a vulnerability disclosure are "release" builds. The disclosure is intended for +users of the project, to let them know that there is a security issue and that +they should update or patch it. + +Unreleased versions are not relevant in this context. A bug introduced in the +`main` branch that is not yet in any release is not covered by this guide even +if the bug allows a remote code execution. CVEs should have a non-empty list of +affected released versions. + +"Developer only" code is also not covered by this policy. In particular, tools +that are not installed by the build, or not installed when packaging `libjxl` +are not covered. For example, a bug in `tone_map` would not affect users since +is a developer-only tool. The rationale behind this is that users of the +released software will not have the developer code. This developer code is in +the same libjxl repository for convenience. + +When considering the impact of a bug, "release" mode should be assumed. In +release mode `JXL_ASSERT()` and `JXL_CHECK()` are enabled, but `JXL_DASSERT()` +are not. This means that if a `JXL_DASSERT()` protects an out-of-bounds (OOB) +write, then the impact of a bug hitting the `JXL_DASSERT()` is at least an +OOB write. On the other hand, if a bug ends up hitting a `JXL_CHECK()` instead +of continuing, the only impact is the process abort instead of whatever else is +possible after the `JXL_CHECK()`. + +Asserts in `libjxl` *tools* cause the tool process to abort, but don't affect +the caller. Either crashing or returning an error (non-zero exit code) would +have the same effect, so `JXL_ASSERT()` failures in the tools have no security +or functional impact. + +Asserts in `libjxl` libraries, meant to be linked into other processes, cause +the caller process to abort, potentially causing a Denial of Service, however, +Denial of Service issues are *not* considered security bugs by this policy. +These are still issues and should be fixed, but they are not security issues. + +Out-of-bounds (OOB) reads in process memory are considered security +vulnerabilities. OOB reads may allow an attacker to read other buffers from the +same process that it shouldn't have access to, even a small OOB read can +allow the attacker to read an address in the stack or in the heap, defeating +address space randomization techniques. In combination with other bugs these +can enable or simplify attacks to the process using libjxl. OOB reads don't need +to require a segmentation fault to be a problem, leaking process information in +decoded RGB pixels could be used as part of an exploit in some scenarios. + +OOB writes and remote code execution (RCE) are security bugs of at least high +security impact. diff --git a/media/libjxl/src/doc/xl_overview.md b/media/libjxl/src/doc/xl_overview.md new file mode 100644 index 000000000..cfcfcb5f8 --- /dev/null +++ b/media/libjxl/src/doc/xl_overview.md @@ -0,0 +1,181 @@ +# XL Overview + +## Requirements + +JPEG XL was designed for two main requirements: + +* high quality: visually lossless at reasonable bitrates; +* decoding speed: multithreaded decoding should be able to reach around + 400 Megapixel/s on large images. + +These goals apply to various types of images, including HDR content, whose +support is made possible by full-precision (float32) computations and extensive +support of color spaces and transfer functions. + +High performance is achieved by designing the format with careful consideration +of memory bandwidth usage and ease of SIMD/GPU implementation. + +The full requirements for JPEG XL are listed in document wg1m82079. + +## General architecture + +The architecture follows the traditional block transform model with improvements +in the individual components. For a quick overview, we sketch a "block diagram" +of the lossy format decoder in the form of module names in **bold** followed by +a brief description. Note that post-processing modules in [brackets] are +optional - they are unnecessary or even counterproductive at very high quality +settings. + +**Header**: decode metadata (e.g. image dimensions) from compressed fields +(smaller than Exp-Golomb thanks to per-field encodings). The compression and +small number of required fields enables very compact headers - much smaller than +JFIF and HEVC. The container supports multiple images (e.g. animations/bursts) +and passes (progressive). + +**Bitstream**: decode transform coefficient residuals using rANS-encoded +<#bits,bits> symbols + +**Dequantize**: from adaptive quant map side information, plus chroma from luma + +**DC prediction**: expand DC residuals using adaptive (history-based) predictors + +**Chroma from luma**: restore predicted X from B and Y from B + +**IDCT:** 2x2..32x32, floating-point + +**[Gaborish]**: additional deblocking convolution with 3x3 kernel + +**[Edge preserving filter]**: nonlinear adaptive smoothing controlled by side +information + +**[Noise injection]**: add perceptually pleasing noise according to a per-image +noise model + +**Color space conversion**: from perceptual opsin XYB to linear RGB + +**[Converting to other color spaces via ICC]** + +The encoder is basically the reverse: + +**Color space conversion**: from linear RGB to perceptual opsin XYB + +**[Noise estimation]**: compute a noise model for the image + +**[Gaborish]**: sharpening to counteract the blurring on the decoder side + +**DCT**: transform sizes communicated via per-block side information + +**Chroma from luma**: find the best multipliers of Y for X and B channels of +entire image + +**Adaptive quantization**: iterative search for quant map that yields the best +perceived restoration + +**Quantize**: store 16-bit prediction residuals + +**DC prediction**: store residuals (prediction happens in quantized space) + +**Entropy coding**: rANS and context modeling with clustering + + +# File Structure + +A codestream begins with a `FileHeader` followed by one or more "passes" +(= scans: e.g. DC or AC_LF) which are then added together (summing the +respective color components in Opsin space) to form the final image. There is no +limit to the number of passes, so an encoder could choose to send salient parts +first, followed by arbitrary decompositions of the final image (in terms of +resolution, bit depth, quality or spatial location). + +Each pass contains groups of AC and DC data. A group is a subset of pixels that +can be decoded in parallel. DC groups contain 256x256 DCs (from 2048x2048 input +pixels), AC groups cover 256x256 input pixels. + +Each pass starts with a table of contents (sizes of each of their DC+AC +groups), which enables parallel decoding and/or the decoding of a subset. +However, there is no higher-level TOC of passes, as that would prevent +appending additional images and could be too constraining for the encoder. + + +## Lossless + +JPEG XL supports tools for lossless coding designed by Alexander Rhatushnyak and +Jon Sneyers. They are about 60-75% of size of PNG, and smaller than WebP +lossless for photos. + +An adaptive predictor computes 4 from the NW, N, NE and W pixels and combines +them with weights based on previous errors. The error value is encoded in a +bucket chosen based on a heuristic max error. The result is entropy-coded using +the ANS encoder. + +## Current Reference Implementation + +### Conventions + +The software is written in C++ and built using CMake 3.6 or later. + +Error handling is done by having functions return values of type `jxl::Status` +(a thin wrapper around bool which checks that it is not ignored). A convenience +macro named `JXL_RETURN_IF_ERROR` makes this more convenient by automatically +forwarding errors, and another macro named `JXL_FAILURE` exits with an error +message if reached, with no effect in optimized builds. + +To diagnose the cause of encoder/decoder failures (which often only result in a +generic "decode failed" message), build using the following command: + +```bash +CMAKE_FLAGS="-DJXL_CRASH_ON_ERROR" ./ci.sh opt +``` + +In such builds, the first JXL_FAILURE will print a message identifying where the +problem is and the program will exit immediately afterwards. + +### Architecture + +Getting back to the earlier block diagram: + +**Header** handling is implemented in `headers.h` and `field*`. + +**Bitstream**: `entropy_coder.h`, `dec_ans_*`. + +**(De)quantize**: `quantizer.h`. + +**DC prediction**: `predictor.h`. + +**Chroma from luma**: `chroma_from_luma.h` + +**(I)DCT**: `dct*.h`. Instead of operating directly on blocks of memory, the +functions operate on thin wrappers which can handle blocks spread across +multiple image lines. + +**DCT size selection**: `ac_strategy.cc` + +**[Gaborish]**: `gaborish.h`. + +**[Edge preserving filter]**: `epf.h` + +**[Noise injection]**: `noise*` (currently disabled) + +**Color space conversion**: `color_*`, `dec_xyb.h`. + +## Decoder overview + +After decoding headers, the decoder begins processing frames (`dec_frame.cc`). + +For each pass, it will read the DC group table of contents (TOC) and start +decoding, dequantizing and restoring color correlation of each DC group +(covering 2048x2048 pixels in the input image) in parallel +(`compressed_dc.cc`). The DC is split into parts corresponding to each AC group +(with 1px of extra border); the AC group TOC is read and each AC group (256x256 +pixels) is processed in parallel (`dec_group.cc`). + +In each AC group, the decoder reads per-block side information indicating the +kind of DCT transform; this is followed by the quantization field. Then, AC +coefficients are read, dequantized and have color correlation restored on a +tile per tile basis for better locality. + +After all the groups are read, postprocessing is applied: Gaborish smoothing +and edge preserving filter, to reduce blocking and other artifacts. + +Finally, the image is converted back from the XYB color space +(`dec_xyb.cc`) and saved to the output image (`codec_*.cc`). diff --git a/media/libjxl/src/docker/Dockerfile.jpegxl-builder b/media/libjxl/src/docker/Dockerfile.jpegxl-builder new file mode 100644 index 000000000..16e0077ee --- /dev/null +++ b/media/libjxl/src/docker/Dockerfile.jpegxl-builder @@ -0,0 +1,21 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Build an Ubuntu-based docker image with the installed software needed to +# develop and test JPEG XL. + +FROM ubuntu:bionic + +# Set a prompt for when using it locally. +ENV PS1="\[\033[01;33m\]\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$ " + +COPY scripts/99_norecommends /etc/apt/apt.conf.d/99_norecommends + +COPY scripts /jpegxl_scripts + +ARG DEBIAN_FRONTEND=noninteractive + +RUN /jpegxl_scripts/jpegxl_builder.sh && \ + rm -rf /jpegxl_scripts diff --git a/media/libjxl/src/docker/Dockerfile.jpegxl-builder-run-aarch64 b/media/libjxl/src/docker/Dockerfile.jpegxl-builder-run-aarch64 new file mode 100644 index 000000000..a9f38a401 --- /dev/null +++ b/media/libjxl/src/docker/Dockerfile.jpegxl-builder-run-aarch64 @@ -0,0 +1,37 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Build an Ubuntu-based docker image for aarch64 with the installed software +# needed to run JPEG XL. This is only useful when running on actual aarch64 +# hardware. + +FROM arm64v8/ubuntu:bionic + +COPY scripts/99_norecommends /etc/apt/apt.conf.d/99_norecommends + +# Set a prompt for when using it locally. +ENV PS1="\[\033[01;33m\]\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$ " + +ARG DEBIAN_FRONTEND=noninteractive + +RUN set -ex; \ + apt-get update -y; \ + apt-get install -y \ + bsdmainutils \ + cmake \ + curl \ + ca-certificates \ + extra-cmake-modules \ + git \ + imagemagick \ + libjpeg8 \ + libgif7 \ + libgoogle-perftools4 \ + libopenexr22 \ + libpng16-16 \ + libqt5x11extras5 \ + libsdl2-2.0-0 \ + parallel; \ + rm -rf /var/lib/apt/lists/*; diff --git a/media/libjxl/src/docker/README.md b/media/libjxl/src/docker/README.md new file mode 100644 index 000000000..874df1cb8 --- /dev/null +++ b/media/libjxl/src/docker/README.md @@ -0,0 +1,7 @@ +### Docker container infrastructure for JPEG XL + +This directory contains the requirements to build a docker image for the +JPEG XL project builder. + +Docker images need to be created and upload manually. See ./build.sh for +details. diff --git a/media/libjxl/src/docker/build.sh b/media/libjxl/src/docker/build.sh new file mode 100644 index 000000000..3d4727f6a --- /dev/null +++ b/media/libjxl/src/docker/build.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -eu + +MYDIR=$(dirname $(realpath "$0")) + +declare -a TARGETS + +load_targets() { + # Built-in OSX "find" does not support "-m". + FIND=$(which "gfind" || which "find") + for f in $(${FIND} -maxdepth 1 -name 'Dockerfile.*' | sort); do + local target="${f#*Dockerfile.}" + TARGETS+=("${target}") + done +} + +usage() { + cat >&2 <&2 + done +} + +build_target() { + local target="$1" + + local dockerfile="${MYDIR}/Dockerfile.${target}" + # JPEG XL builder images are stored in the gcr.io/jpegxl project. + local tag="gcr.io/jpegxl/${target}" + + echo "Building ${target}" + if ! sudo docker build --no-cache -t "${tag}" -f "${dockerfile}" "${MYDIR}" \ + >"${target}.log" 2>&1; then + echo "${target} failed. See ${target}.log" >&2 + else + echo "Done, to upload image run:" >&2 + echo " sudo docker push ${tag}" + if [[ "${JPEGXL_PUSH:-}" == "1" ]]; then + echo "sudo docker push ${tag}" >&2 + sudo docker push "${tag}" + # The RepoDigest is only created after it is pushed. + local fulltag=$(sudo docker inspect --format="{{.RepoDigests}}" "${tag}") + fulltag="${fulltag#[}" + fulltag="${fulltag%]}" + echo "Updating .gitlab-ci.yml to ${fulltag}" >&2 + sed -E "s;${tag}@sha256:[0-9a-f]+;${fulltag};" \ + -i "${MYDIR}/../.gitlab-ci.yml" + fi + fi +} + +main() { + cd "${MYDIR}" + local target="${1:-}" + + load_targets + if [[ -z "${target}" ]]; then + usage $0 + exit 1 + fi + + if [[ "${target}" == "all" ]]; then + for target in "${TARGETS[@]}"; do + build_target "${target}" + done + else + for target in "$@"; do + build_target "${target}" + done + fi +} + +main "$@" diff --git a/media/libjxl/src/docker/scripts/99_norecommends b/media/libjxl/src/docker/scripts/99_norecommends new file mode 100644 index 000000000..96d672811 --- /dev/null +++ b/media/libjxl/src/docker/scripts/99_norecommends @@ -0,0 +1 @@ +APT::Install-Recommends "false"; diff --git a/media/libjxl/src/docker/scripts/binutils_align_fix.patch b/media/libjxl/src/docker/scripts/binutils_align_fix.patch new file mode 100644 index 000000000..6066252db --- /dev/null +++ b/media/libjxl/src/docker/scripts/binutils_align_fix.patch @@ -0,0 +1,28 @@ +Description: fix lack of alignment in relocations (crashes on mingw) +See https://sourceware.org/git/?p=binutils-gdb.git;a=patch;h=73af69e74974eaa155eec89867e3ccc77ab39f6d +From: Marc +Date: Fri, 9 Nov 2018 11:13:50 +0000 +Subject: [PATCH] Allow for compilers that do not produce aligned .rdat + sections in PE format files. + +--- a/upstream/ld/scripttempl/pe.sc 2020-05-12 18:45:12.000000000 +0200 ++++ b/upstream/ld/scripttempl/pe.sc 2020-05-12 18:47:12.000000000 +0200 +@@ -143,6 +143,7 @@ + .rdata ${RELOCATING+BLOCK(__section_alignment__)} : + { + ${R_RDATA} ++ . = ALIGN(4); + ${RELOCATING+__rt_psrelocs_start = .;} + ${RELOCATING+KEEP(*(.rdata_runtime_pseudo_reloc))} + ${RELOCATING+__rt_psrelocs_end = .;} +--- a/upstream/ld/scripttempl/pep.sc 2020-05-12 18:45:19.000000000 +0200 ++++ b/upstream/ld/scripttempl/pep.sc 2020-05-12 18:47:18.000000000 +0200 +@@ -143,6 +143,7 @@ + .rdata ${RELOCATING+BLOCK(__section_alignment__)} : + { + ${R_RDATA} ++ . = ALIGN(4); + ${RELOCATING+__rt_psrelocs_start = .;} + ${RELOCATING+KEEP(*(.rdata_runtime_pseudo_reloc))} + ${RELOCATING+__rt_psrelocs_end = .;} + diff --git a/media/libjxl/src/docker/scripts/emsdk_install.sh b/media/libjxl/src/docker/scripts/emsdk_install.sh new file mode 100644 index 000000000..6cf225a9d --- /dev/null +++ b/media/libjxl/src/docker/scripts/emsdk_install.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +EMSDK_URL="https://github.com/emscripten-core/emsdk/archive/main.tar.gz" +EMSDK_DIR="/opt/emsdk" + +EMSDK_RELEASE="2.0.23" + +set -eu -x + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +main() { + local workdir=$(mktemp -d --suffix=emsdk) + CLEANUP_FILES+=("${workdir}") + + local emsdktar="${workdir}/emsdk.tar.gz" + curl --output "${emsdktar}" "${EMSDK_URL}" --location + mkdir -p "${EMSDK_DIR}" + tar -zxf "${emsdktar}" -C "${EMSDK_DIR}" --strip-components=1 + + cd "${EMSDK_DIR}" + ./emsdk install --shallow "${EMSDK_RELEASE}" + ./emsdk activate --embedded "${EMSDK_RELEASE}" +} + +main "$@" diff --git a/media/libjxl/src/docker/scripts/jpegxl_builder.sh b/media/libjxl/src/docker/scripts/jpegxl_builder.sh new file mode 100644 index 000000000..bf9f19d4e --- /dev/null +++ b/media/libjxl/src/docker/scripts/jpegxl_builder.sh @@ -0,0 +1,518 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Main entry point for all the Dockerfile for jpegxl-builder. This centralized +# file helps sharing code and configuration between Dockerfiles. + +set -eux + +MYDIR=$(dirname $(realpath "$0")) + +# libjpeg-turbo. +JPEG_TURBO_RELEASE="2.0.4" +JPEG_TURBO_URL="https://github.com/libjpeg-turbo/libjpeg-turbo/archive/${JPEG_TURBO_RELEASE}.tar.gz" +JPEG_TURBO_SHA256="7777c3c19762940cff42b3ba4d7cd5c52d1671b39a79532050c85efb99079064" + +# zlib (dependency of libpng) +ZLIB_RELEASE="1.2.11" +ZLIB_URL="https://www.zlib.net/zlib-${ZLIB_RELEASE}.tar.gz" +ZLIB_SHA256="c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1" +# The name in the .pc and the .dll generated don't match in zlib for Windows +# because they use different .dll names in Windows. We avoid that by defining +# UNIX=1. We also install all the .dll files to ${prefix}/lib instead of the +# default ${prefix}/bin. +ZLIB_FLAGS='-DUNIX=1 -DINSTALL_PKGCONFIG_DIR=/${CMAKE_INSTALL_PREFIX}/lib/pkgconfig -DINSTALL_BIN_DIR=/${CMAKE_INSTALL_PREFIX}/lib' + +# libpng +LIBPNG_RELEASE="1.6.37" +LIBPNG_URL="https://github.com/glennrp/libpng/archive/v${LIBPNG_RELEASE}.tar.gz" +LIBPNG_SHA256="ca74a0dace179a8422187671aee97dd3892b53e168627145271cad5b5ac81307" + +# giflib +GIFLIB_RELEASE="5.2.1" +GIFLIB_URL="https://netcologne.dl.sourceforge.net/project/giflib/giflib-${GIFLIB_RELEASE}.tar.gz" +GIFLIB_SHA256="31da5562f44c5f15d63340a09a4fd62b48c45620cd302f77a6d9acf0077879bd" + +# A patch needed to compile GIFLIB in mingw. +GIFLIB_PATCH_URL="https://github.com/msys2/MINGW-packages/raw/3afde38fcee7b3ba2cafd97d76cca8f06934504f/mingw-w64-giflib/001-mingw-build.patch" +GIFLIB_PATCH_SHA256="2b2262ddea87fc07be82e10aeb39eb699239f883c899aa18a16e4d4e40af8ec8" + +# webp +WEBP_RELEASE="1.0.2" +WEBP_URL="https://codeload.github.com/webmproject/libwebp/tar.gz/v${WEBP_RELEASE}" +WEBP_SHA256="347cf85ddc3497832b5fa9eee62164a37b249c83adae0ba583093e039bf4881f" + +# Google benchmark +BENCHMARK_RELEASE="1.5.2" +BENCHMARK_URL="https://github.com/google/benchmark/archive/v${BENCHMARK_RELEASE}.tar.gz" +BENCHMARK_SHA256="dccbdab796baa1043f04982147e67bb6e118fe610da2c65f88912d73987e700c" +BENCHMARK_FLAGS="-DGOOGLETEST_PATH=${MYDIR}/../../third_party/googletest" +# attribute(format(__MINGW_PRINTF_FORMAT, ...)) doesn't work in our +# environment, so we disable the warning. +BENCHMARK_FLAGS="-DCMAKE_BUILD_TYPE=Release -DBENCHMARK_ENABLE_TESTING=OFF \ + -DCMAKE_CXX_FLAGS=-Wno-ignored-attributes \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON" + +# V8 +V8_VERSION="9.3.22" + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +# List of Ubuntu arch names supported by the builder (such as "i386"). +LIST_ARCHS=( + amd64 + i386 + arm64 + armhf +) + +# List of target triplets supported by the builder. +LIST_TARGETS=( + x86_64-linux-gnu + i686-linux-gnu + arm-linux-gnueabihf + aarch64-linux-gnu +) +LIST_MINGW_TARGETS=( + i686-w64-mingw32 + x86_64-w64-mingw32 +) +LIST_WASM_TARGETS=( + wasm32 +) + +# Setup the apt repositories and supported architectures. +setup_apt() { + apt-get update -y + apt-get install -y curl gnupg ca-certificates + + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 1E9377A2BA9EF27F + + # node sources. + cat >/etc/apt/sources.list.d/nodesource.list <>"${newlist}" + fi + + main_list=$(echo "${main_list[@]}" | tr ' ' ,) + grep -v -E '^#' "${bkplist}" | + sed -E "s;^deb (http[^ ]+) (.*)\$;deb [arch=${main_list}] \\1 \\2\ndeb-src [arch=${main_list}] \\1 \\2;" \ + >>"${newlist}" + mv "${newlist}" /etc/apt/sources.list +} + +install_pkgs() { + packages=( + # Native compilers (minimum for SIMD is clang-7) + clang-7 clang-format-7 clang-tidy-7 + + # TODO: Consider adding clang-8 to every builder: + # clang-8 clang-format-8 clang-tidy-8 + + # For cross-compiling to Windows with mingw. + mingw-w64 + wine64 + wine-binfmt + + # Native tools. + bsdmainutils + cmake + extra-cmake-modules + git + llvm + nasm + ninja-build + parallel + pkg-config + + # For compiling / testing JNI wrapper. JDK8 is almost 2x smaller than JDK11 + # openjdk-8-jdk-headless would be 50MB smaller, unfortunately, CMake + # does mistakenly thinks it does not contain JNI feature. + openjdk-8-jdk + + # These are used by the ./ci.sh lint in the native builder. + clang-format-7 + clang-format-8 + + # For coverage builds + gcovr + + # For compiling giflib documentation. + xmlto + + # Common libraries. + libstdc++-8-dev + + # We don't use tcmalloc on archs other than amd64. This installs + # libgoogle-perftools4:amd64. + google-perftools + + # NodeJS for running WASM tests + nodejs + + # To generate API documentation. + doxygen + + # Freezes version that builds (passes tests). Newer version + # (2.30-21ubuntu1~18.04.4) claims to fix "On Intel Skylake + # (-march=native) generated avx512 instruction can be wrong", + # but newly added tests does not pass. Perhaps the problem is + # that mingw package is not updated. + binutils-source=2.30-15ubuntu1 + ) + + # Install packages that are arch-dependent. + local ubarch + for ubarch in "${LIST_ARCHS[@]}"; do + packages+=( + # Library dependencies. These normally depend on the target architecture + # we are compiling for and can't usually be installed for multiple + # architectures at the same time. + libgif7:"${ubarch}" + libjpeg-dev:"${ubarch}" + libpng-dev:"${ubarch}" + libqt5x11extras5-dev:"${ubarch}" + + libstdc++-8-dev:"${ubarch}" + qtbase5-dev:"${ubarch}" + + # For OpenEXR: + libilmbase12:"${ubarch}" + libopenexr22:"${ubarch}" + + # TCMalloc dependency + libunwind-dev:"${ubarch}" + + # Cross-compiling tools per arch. + libc6-dev-"${ubarch}"-cross + libstdc++-8-dev-"${ubarch}"-cross + ) + done + + local target + for target in "${LIST_TARGETS[@]}"; do + # Per target cross-compiling tools. + if [[ "${target}" != "x86_64-linux-gnu" ]]; then + packages+=( + binutils-"${target}" + gcc-"${target}" + ) + fi + done + + # Install all the manual packages via "apt install" for the main arch. These + # will be installed for other archs via manual download and unpack. + apt install -y "${packages[@]}" "${UNPACK_PKGS[@]}" +} + +# binutils <2.32 need a patch. +install_binutils() { + local workdir=$(mktemp -d --suffix=_install) + CLEANUP_FILES+=("${workdir}") + pushd "${workdir}" + apt source binutils-mingw-w64 + apt -y build-dep binutils-mingw-w64 + cd binutils-mingw-w64-8ubuntu1 + cp "${MYDIR}/binutils_align_fix.patch" debian/patches + echo binutils_align_fix.patch >> debian/patches/series + dpkg-buildpackage -b + cd .. + dpkg -i *deb + popd +} + +# Install a library from the source code for multiple targets. +# Usage: install_from_source [] +install_from_source() { + local package="$1" + shift + + local url + eval "url=\${${package}_URL}" + local sha256 + eval "sha256=\${${package}_SHA256}" + # Optional package flags + local pkgflags + eval "pkgflags=\${${package}_FLAGS:-}" + + local workdir=$(mktemp -d --suffix=_install) + CLEANUP_FILES+=("${workdir}") + + local tarfile="${workdir}"/$(basename "${url}") + curl -L --output "${tarfile}" "${url}" + if ! echo "${sha256} ${tarfile}" | sha256sum -c --status -; then + echo "SHA256 mismatch for ${url}: expected ${sha256} but found:" + sha256sum "${tarfile}" + exit 1 + fi + + local target + for target in "$@"; do + echo "Installing ${package} for target ${target} from ${url}" + + local srcdir="${workdir}/source-${target}" + mkdir -p "${srcdir}" + tar -zxf "${tarfile}" -C "${srcdir}" --strip-components=1 + + local prefix="/usr" + if [[ "${target}" != "x86_64-linux-gnu" ]]; then + prefix="/usr/${target}" + fi + + # Apply patches to buildfiles. + if [[ "${package}" == "GIFLIB" && "${target}" == *mingw32 ]]; then + # GIFLIB Makefile has several problems so we need to fix them here. We are + # using a patch from MSYS2 that already fixes the compilation for mingw. + local make_patch="${srcdir}/libgif.patch" + curl -L "${GIFLIB_PATCH_URL}" -o "${make_patch}" + echo "${GIFLIB_PATCH_SHA256} ${make_patch}" | sha256sum -c --status - + patch "${srcdir}/Makefile" < "${make_patch}" + elif [[ "${package}" == "LIBPNG" && "${target}" == wasm* ]]; then + # Cut the dependency to libm; there is pull request to fix it, so this + # might not be needed in the future. + sed -i 's/APPLE/EMSCRIPTEN/g' "${srcdir}/CMakeLists.txt" + fi + + local cmake_args=() + local export_args=("CC=clang-7" "CXX=clang++-7") + local cmake="cmake" + local make="make" + local system_name="Linux" + if [[ "${target}" == *mingw32 ]]; then + system_name="Windows" + # When compiling with clang, CMake doesn't detect that we are using mingw. + cmake_args+=( + -DMINGW=1 + # Googletest needs this when cross-compiling to windows + -DCMAKE_CROSSCOMPILING=1 + -DHAVE_STD_REGEX=0 + -DHAVE_POSIX_REGEX=0 + -DHAVE_GNU_POSIX_REGEX=0 + ) + local windres=$(which ${target}-windres || true) + if [[ -n "${windres}" ]]; then + cmake_args+=(-DCMAKE_RC_COMPILER="${windres}") + fi + fi + if [[ "${target}" == wasm* ]]; then + system_name="WASM" + cmake="emcmake cmake" + make="emmake make" + export_args=() + cmake_args+=( + -DCMAKE_FIND_ROOT_PATH="${prefix}" + -DCMAKE_PREFIX_PATH="${prefix}" + ) + # Static and shared library link to the same file -> race condition. + nproc=1 + else + nproc=`nproc --all` + fi + cmake_args+=(-DCMAKE_SYSTEM_NAME="${system_name}") + + if [[ "${target}" != "x86_64-linux-gnu" ]]; then + # Cross-compiling. + cmake_args+=( + -DCMAKE_C_COMPILER_TARGET="${target}" + -DCMAKE_CXX_COMPILER_TARGET="${target}" + -DCMAKE_SYSTEM_PROCESSOR="${target%%-*}" + ) + fi + + if [[ -e "${srcdir}/CMakeLists.txt" ]]; then + # Most packages use cmake for building which is easier to configure for + # cross-compiling. + if [[ "${package}" == "JPEG_TURBO" && "${target}" == wasm* ]]; then + # JT erroneously detects WASM CPU as i386 and tries to use asm. + # Wasm/Emscripten support for dynamic linking is incomplete; disable + # to avoid CMake warning. + cmake_args+=(-DWITH_SIMD=0 -DENABLE_SHARED=OFF) + fi + ( + cd "${srcdir}" + export ${export_args[@]} + ${cmake} \ + -DCMAKE_INSTALL_PREFIX="${prefix}" \ + "${cmake_args[@]}" ${pkgflags} + ${make} -j${nproc} + ${make} install + ) + elif [[ "${package}" == "GIFLIB" ]]; then + # GIFLIB doesn't yet have a cmake build system. There is a pull + # request in giflib for adding CMakeLists.txt so this might not be + # needed in the future. + ( + cd "${srcdir}" + local giflib_make_flags=( + CFLAGS="-O2 --target=${target} -std=gnu99" + PREFIX="${prefix}" + ) + if [[ "${target}" != wasm* ]]; then + giflib_make_flags+=(CC=clang-7) + fi + # giflib make dependencies are not properly set up so parallel building + # doesn't work for everything. + ${make} -j${nproc} libgif.a "${giflib_make_flags[@]}" + ${make} -j${nproc} all "${giflib_make_flags[@]}" + ${make} install "${giflib_make_flags[@]}" + ) + else + echo "Don't know how to install ${package}" + exit 1 + fi + + # CMake mistakenly uses ".so" libraries and EMCC fails to link properly. + if [[ "${target}" == wasm* ]]; then + rm -f "${prefix}/lib"/*.so* + fi + done +} + +# Packages that are manually unpacked for each architecture. +UNPACK_PKGS=( + libgif-dev + libclang-common-7-dev + + # For OpenEXR: + libilmbase-dev + libopenexr-dev + + # TCMalloc + libgoogle-perftools-dev + libtcmalloc-minimal4 + libgoogle-perftools4 +) + +# Main script entry point. +main() { + cd "${MYDIR}" + + # Configure the repositories with the sources for multi-arch cross + # compilation. + setup_apt + apt-get update -y + apt-get dist-upgrade -y + + install_pkgs + install_binutils + apt clean + + # Remove prebuilt Java classes cache. + rm /usr/lib/jvm/java-8-openjdk-amd64/jre/lib/amd64/server/classes.jsa + + # Manually extract packages for the target arch that can't install it directly + # at the same time as the native ones. + local ubarch + for ubarch in "${LIST_ARCHS[@]}"; do + if [[ "${ubarch}" != "amd64" ]]; then + local pkg + for pkg in "${UNPACK_PKGS[@]}"; do + apt download "${pkg}":"${ubarch}" + dpkg -x "${pkg}"_*_"${ubarch}".deb / + done + fi + done + # TODO: Add clang from the llvm repos. This is problematic since we are + # installing libclang-common-7-dev:"${ubarch}" from the ubuntu ports repos + # which is not available in the llvm repos so it might have a different + # version than the ubuntu ones. + + # Remove the win32 libgcc version. The gcc-mingw-w64-x86-64 (and i686) + # packages install two libgcc versions: + # /usr/lib/gcc/x86_64-w64-mingw32/7.3-posix + # /usr/lib/gcc/x86_64-w64-mingw32/7.3-win32 + # (exact libgcc version number depends on the package version). + # + # Clang will pick the best libgcc, sorting by version, but it doesn't + # seem to be a way to specify one or the other one, except by passing + # -nostdlib and setting all the include paths from the command line. + # To check which one is being used you can run: + # clang++-7 --target=x86_64-w64-mingw32 -v -print-libgcc-file-name + # We need to use the "posix" versions for thread support, so here we + # just remove the other one. + local target + for target in "${LIST_MINGW_TARGETS[@]}"; do + update-alternatives --set "${target}-gcc" $(which "${target}-gcc-posix") + local gcc_win32_path=$("${target}-cpp-win32" -print-libgcc-file-name) + rm -rf $(dirname "${gcc_win32_path}") + done + + # TODO: Add msan for the target when cross-compiling. This only installs it + # for amd64. + ./msan_install.sh + + # Build and install qemu user-linux targets. + ./qemu_install.sh + + # Install emscripten SDK. + ./emsdk_install.sh + + # Setup environment for building WASM libraries from sources. + source /opt/emsdk/emsdk_env.sh + + # Install some dependency libraries manually for the different targets. + + install_from_source JPEG_TURBO "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + install_from_source ZLIB "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + install_from_source LIBPNG "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + install_from_source GIFLIB "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + # webp in Ubuntu is relatively old so we install it from source for everybody. + install_from_source WEBP "${LIST_TARGETS[@]}" "${LIST_MINGW_TARGETS[@]}" + + install_from_source BENCHMARK "${LIST_TARGETS[@]}" "${LIST_MINGW_TARGETS[@]}" + + # Install v8. v8 has better WASM SIMD support than NodeJS 14 (LTS). + # First we need the installer to install v8. + npm install jsvu -g + # install specific version; + HOME=/opt jsvu --os=linux64 "v8@${V8_VERSION}" + ln -s "/opt/.jsvu/v8-${V8_VERSION}" "/opt/.jsvu/v8" + + # Cleanup. + find /var/lib/apt/lists/ -mindepth 1 -delete +} + +main "$@" diff --git a/media/libjxl/src/docker/scripts/msan_install.sh b/media/libjxl/src/docker/scripts/msan_install.sh new file mode 100644 index 000000000..0216f62b0 --- /dev/null +++ b/media/libjxl/src/docker/scripts/msan_install.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -eu + +MYDIR=$(dirname $(realpath "$0")) + +# Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS +CMAKE_FLAGS=${CMAKE_FLAGS:-} +CMAKE_C_FLAGS=${CMAKE_C_FLAGS:-${CMAKE_FLAGS}} +CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS:-${CMAKE_FLAGS}} +CMAKE_EXE_LINKER_FLAGS=${CMAKE_EXE_LINKER_FLAGS:-} + +CLANG_VERSION="${CLANG_VERSION:-}" +# Detect the clang version suffix and store it in CLANG_VERSION. For example, +# "6.0" for clang 6 or "7" for clang 7. +detect_clang_version() { + if [[ -n "${CLANG_VERSION}" ]]; then + return 0 + fi + local clang_version=$("${CC:-clang}" --version | head -n1) + local llvm_tag + case "${clang_version}" in + "clang version 6."*) + CLANG_VERSION="6.0" + ;; + "clang version 7."*) + CLANG_VERSION="7" + ;; + "clang version 8."*) + CLANG_VERSION="8" + ;; + "clang version 9."*) + CLANG_VERSION="9" + ;; + *) + echo "Unknown clang version: ${clang_version}" >&2 + return 1 + esac +} + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +# Install libc++ libraries compiled with msan in the msan_prefix for the current +# compiler version. +cmd_msan_install() { + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + # Detect the llvm to install: + export CC="${CC:-clang}" + export CXX="${CXX:-clang++}" + detect_clang_version + local llvm_tag + case "${CLANG_VERSION}" in + "6.0") + llvm_tag="llvmorg-6.0.1" + ;; + "7") + llvm_tag="llvmorg-7.0.1" + ;; + "8") + llvm_tag="llvmorg-8.0.0" + ;; + *) + echo "Unknown clang version: ${clang_version}" >&2 + return 1 + esac + local llvm_targz="${tmpdir}/${llvm_tag}.tar.gz" + curl -L --show-error -o "${llvm_targz}" \ + "https://github.com/llvm/llvm-project/archive/${llvm_tag}.tar.gz" + tar -C "${tmpdir}" -zxf "${llvm_targz}" + local llvm_root="${tmpdir}/llvm-project-${llvm_tag}" + + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + rm -rf "${msan_prefix}" + + declare -A CMAKE_EXTRAS + CMAKE_EXTRAS[libcxx]="\ + -DLIBCXX_CXX_ABI=libstdc++ \ + -DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=ON" + + for project in libcxx; do + local proj_build="${tmpdir}/build-${project}" + local proj_dir="${llvm_root}/${project}" + mkdir -p "${proj_build}" + cmake -B"${proj_build}" -H"${proj_dir}" \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_USE_SANITIZER=Memory \ + -DLLVM_PATH="${llvm_root}/llvm" \ + -DLLVM_CONFIG_PATH="$(which llvm-config llvm-config-7 llvm-config-6.0 | \ + head -n1)" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" \ + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \ + -DCMAKE_INSTALL_PREFIX="${msan_prefix}" \ + ${CMAKE_EXTRAS[${project}]} + cmake --build "${proj_build}" + ninja -C "${proj_build}" install + done +} + +main() { + set -x + for version in 6.0 7 8; do + if ! which "clang-${version}" >/dev/null; then + echo "Skipping msan install for clang version ${version}" + continue + fi + ( + trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + export CLANG_VERSION=${version} + export CC=clang-${version} + export CXX=clang++-${version} + cmd_msan_install + ) & + done + wait +} + +main "$@" diff --git a/media/libjxl/src/docker/scripts/qemu_install.sh b/media/libjxl/src/docker/scripts/qemu_install.sh new file mode 100644 index 000000000..8106c4471 --- /dev/null +++ b/media/libjxl/src/docker/scripts/qemu_install.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +QEMU_RELEASE="4.1.0" +QEMU_URL="https://download.qemu.org/qemu-${QEMU_RELEASE}.tar.xz" +QEMU_ARCHS=( + aarch64 + arm + i386 + # TODO: Consider adding these: + # aarch64_be + # mips64el + # mips64 + # mips + # ppc64 + # ppc +) + +# Ubuntu packages not installed that are needed to build qemu. +QEMU_BUILD_DEPS=( + libglib2.0-dev + libpixman-1-dev + flex + bison +) + +set -eu -x + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +main() { + local workdir=$(mktemp -d --suffix=qemu) + CLEANUP_FILES+=("${workdir}") + + apt install -y "${QEMU_BUILD_DEPS[@]}" + + local qemutar="${workdir}/qemu.tar.gz" + curl --output "${qemutar}" "${QEMU_URL}" + tar -Jxf "${qemutar}" -C "${workdir}" + local srcdir="${workdir}/qemu-${QEMU_RELEASE}" + + local builddir="${workdir}/build" + local prefixdir="${workdir}/prefix" + mkdir -p "${builddir}" + + # List of targets to build. + local targets="" + local make_targets=() + local target + for target in "${QEMU_ARCHS[@]}"; do + targets="${targets} ${target}-linux-user" + # Build just the linux-user targets. + make_targets+=("${target}-linux-user/all") + done + + cd "${builddir}" + "${srcdir}/configure" \ + --prefix="${prefixdir}" \ + --static --disable-system --enable-linux-user \ + --target-list="${targets}" + + make -j $(nproc --all || echo 1) "${make_targets[@]}" + + # Manually install these into the non-standard location. This script runs as + # root anyway. + for target in "${QEMU_ARCHS[@]}"; do + cp "${target}-linux-user/qemu-${target}" "/usr/bin/qemu-${target}-static" + done + + apt autoremove -y --purge "${QEMU_BUILD_DEPS[@]}" +} + +main "$@" diff --git a/media/libjxl/src/examples/CMakeLists.txt b/media/libjxl/src/examples/CMakeLists.txt new file mode 100644 index 000000000..88dc27c49 --- /dev/null +++ b/media/libjxl/src/examples/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Example project using libjxl. + +cmake_minimum_required(VERSION 3.10) + +project(SAMPLE_LIBJXL LANGUAGES C CXX) + +# Use pkg-config to find libjxl. +find_package(PkgConfig) +pkg_check_modules(Jxl REQUIRED IMPORTED_TARGET libjxl) +pkg_check_modules(JxlThreads REQUIRED IMPORTED_TARGET libjxl_threads) + +# Build the example encoder/decoder binaries using the default shared libraries +# installed. +add_executable(decode_oneshot decode_oneshot.cc) +target_link_libraries(decode_oneshot PkgConfig::Jxl PkgConfig::JxlThreads) + +add_executable(decode_progressive decode_progressive.cc) +target_link_libraries(decode_progressive PkgConfig::Jxl PkgConfig::JxlThreads) + +add_executable(encode_oneshot encode_oneshot.cc) +target_link_libraries(encode_oneshot PkgConfig::Jxl PkgConfig::JxlThreads) + + +# Building a static binary with the static libjxl dependencies. How to load +# static library configs from pkg-config and how to build static binaries +# depends on the platform, and building static binaries in general has problems. +# If you don't need static binaries you can remove this section. +add_library(StaticJxl INTERFACE IMPORTED GLOBAL) +set_target_properties(StaticJxl PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${Jxl_STATIC_INCLUDE_DIR}" + INTERFACE_COMPILE_OPTIONS "${Jxl_STATIC_CFLAGS_OTHER}" + INTERFACE_LINK_LIBRARIES "${Jxl_STATIC_LDFLAGS}" +) +add_library(StaticJxlThreads INTERFACE IMPORTED GLOBAL) +set_target_properties(StaticJxlThreads PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${JxlThreads_STATIC_INCLUDE_DIR}" + INTERFACE_COMPILE_OPTIONS "${JxlThreads_STATIC_CFLAGS_OTHER}" + # libgcc uses weak symbols for pthread which means that -lpthread is not + # linked when compiling a static binary. This is a platform-specific fix for + # that. + INTERFACE_LINK_LIBRARIES + "${JxlThreads_STATIC_LDFLAGS} -Wl,--whole-archive -lpthread -Wl,--no-whole-archive" +) + +add_executable(decode_oneshot_static decode_oneshot.cc) +target_link_libraries(decode_oneshot_static + -static StaticJxl StaticJxlThreads) + +add_executable(encode_oneshot_static encode_oneshot.cc) +target_link_libraries(encode_oneshot_static + -static StaticJxl StaticJxlThreads) diff --git a/media/libjxl/src/examples/decode_exif_metadata.cc b/media/libjxl/src/examples/decode_exif_metadata.cc new file mode 100644 index 000000000..adfe5f842 --- /dev/null +++ b/media/libjxl/src/examples/decode_exif_metadata.cc @@ -0,0 +1,173 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This C++ example decodes a JPEG XL image in one shot (all input bytes +// available at once). The example outputs the pixels and color information to a +// floating point image and an ICC profile on disk. + +#include +#include +#include +#include + +#include + +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" + +bool DecodeJpegXlExif(const uint8_t* jxl, size_t size, + std::vector* exif) { + auto dec = JxlDecoderMake(nullptr); + + // We're only interested in the Exif boxes in this example, so don't + // subscribe to events related to pixel data. + if (JXL_DEC_SUCCESS != JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BOX)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + return false; + } + bool support_decompression = true; + if (JXL_DEC_SUCCESS != JxlDecoderSetDecompressBoxes(dec.get(), JXL_TRUE)) { + fprintf(stderr, + "NOTE: decompressing brob boxes not supported with the currently " + "used jxl library.\n"); + support_decompression = false; + } + + JxlDecoderSetInput(dec.get(), jxl, size); + JxlDecoderCloseInput(dec.get()); + + const constexpr size_t kChunkSize = 65536; + size_t output_pos = 0; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Decoder error\n"); + return false; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + fprintf(stderr, "Error, already provided all input\n"); + return false; + } else if (status == JXL_DEC_BOX) { + if (!exif->empty()) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec.get()); + exif->resize(exif->size() - remaining); + // No need to wait for JXL_DEC_SUCCESS or decode other boxes. + return true; + } + JxlBoxType type; + if (JXL_DEC_SUCCESS != + JxlDecoderGetBoxType(dec.get(), type, support_decompression)) { + fprintf(stderr, "Error, failed to get box type\n"); + return false; + } + if (!memcmp(type, "Exif", 4)) { + exif->resize(kChunkSize); + JxlDecoderSetBoxBuffer(dec.get(), exif->data(), exif->size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec.get()); + output_pos += kChunkSize - remaining; + exif->resize(exif->size() + kChunkSize); + JxlDecoderSetBoxBuffer(dec.get(), exif->data() + output_pos, + exif->size() - output_pos); + } else if (status == JXL_DEC_SUCCESS) { + if (!exif->empty()) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec.get()); + exif->resize(exif->size() - remaining); + return true; + } + return true; + } else { + fprintf(stderr, "Unknown decoder status\n"); + return false; + } + } +} + +bool LoadFile(const char* filename, std::vector* out) { + FILE* file = fopen(filename, "rb"); + if (!file) { + return false; + } + + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + + long size = ftell(file); + // Avoid invalid file or directory. + if (size >= LONG_MAX || size < 0) { + fclose(file); + return false; + } + + if (fseek(file, 0, SEEK_SET) != 0) { + fclose(file); + return false; + } + + out->resize(size); + size_t readsize = fread(out->data(), 1, size, file); + if (fclose(file) != 0) { + return false; + } + + return readsize == static_cast(size); +} + +bool WriteFile(const char* filename, const uint8_t* data, size_t size) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing", filename); + return false; + } + fwrite(data, 1, size, file); + if (fclose(file) != 0) { + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, + "Usage: %s \n" + "Where:\n" + " jxl = input JPEG XL image filename\n" + " exif = output exif filename\n" + "Output files will be overwritten.\n", + argv[0]); + return 1; + } + + const char* jxl_filename = argv[1]; + const char* exif_filename = argv[2]; + + std::vector jxl; + if (!LoadFile(jxl_filename, &jxl)) { + fprintf(stderr, "couldn't load %s\n", jxl_filename); + return 1; + } + + std::vector exif; + if (!DecodeJpegXlExif(jxl.data(), jxl.size(), &exif)) { + fprintf(stderr, "Error while decoding the jxl file\n"); + return 1; + } + if (exif.empty()) { + printf("No exif data present in this image\n"); + } else { + // TODO(lode): the exif box data contains the 4-byte TIFF header at the + // beginning, check whether this is desired to be part of the output, or + // should be removed. + if (!WriteFile(exif_filename, exif.data(), exif.size())) { + fprintf(stderr, "Error while writing the exif file\n"); + return 1; + } + printf("Successfully wrote %s\n", exif_filename); + } + return 0; +} diff --git a/media/libjxl/src/examples/decode_oneshot.cc b/media/libjxl/src/examples/decode_oneshot.cc new file mode 100644 index 000000000..932193fd1 --- /dev/null +++ b/media/libjxl/src/examples/decode_oneshot.cc @@ -0,0 +1,248 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This C++ example decodes a JPEG XL image in one shot (all input bytes +// available at once). The example outputs the pixels and color information to a +// floating point image and an ICC profile on disk. + +#include +#include +#include +#include +#include + +#include + +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/resizable_parallel_runner.h" +#include "jxl/resizable_parallel_runner_cxx.h" + +/** Decodes JPEG XL image to floating point pixels and ICC Profile. Pixel are + * stored as floating point, as interleaved RGBA (4 floating point values per + * pixel), line per line from top to bottom. Pixel values have nominal range + * 0..1 but may go beyond this range for HDR or wide gamut. The ICC profile + * describes the color format of the pixel data. + */ +bool DecodeJpegXlOneShot(const uint8_t* jxl, size_t size, + std::vector* pixels, size_t* xsize, + size_t* ysize, std::vector* icc_profile) { + // Multi-threaded parallel runner. + auto runner = JxlResizableParallelRunnerMake(nullptr); + + auto dec = JxlDecoderMake(nullptr); + if (JXL_DEC_SUCCESS != + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + return false; + } + + if (JXL_DEC_SUCCESS != JxlDecoderSetParallelRunner(dec.get(), + JxlResizableParallelRunner, + runner.get())) { + fprintf(stderr, "JxlDecoderSetParallelRunner failed\n"); + return false; + } + + JxlBasicInfo info; + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + + JxlDecoderSetInput(dec.get(), jxl, size); + JxlDecoderCloseInput(dec.get()); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Decoder error\n"); + return false; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + fprintf(stderr, "Error, already provided all input\n"); + return false; + } else if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec.get(), &info)) { + fprintf(stderr, "JxlDecoderGetBasicInfo failed\n"); + return false; + } + *xsize = info.xsize; + *ysize = info.ysize; + JxlResizableParallelRunnerSetThreads( + runner.get(), + JxlResizableParallelRunnerSuggestThreads(info.xsize, info.ysize)); + } else if (status == JXL_DEC_COLOR_ENCODING) { + // Get the ICC color profile of the pixel data + size_t icc_size; + if (JXL_DEC_SUCCESS != + JxlDecoderGetICCProfileSize( + dec.get(), &format, JXL_COLOR_PROFILE_TARGET_DATA, &icc_size)) { + fprintf(stderr, "JxlDecoderGetICCProfileSize failed\n"); + return false; + } + icc_profile->resize(icc_size); + if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsICCProfile( + dec.get(), &format, + JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile->data(), icc_profile->size())) { + fprintf(stderr, "JxlDecoderGetColorAsICCProfile failed\n"); + return false; + } + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + size_t buffer_size; + if (JXL_DEC_SUCCESS != + JxlDecoderImageOutBufferSize(dec.get(), &format, &buffer_size)) { + fprintf(stderr, "JxlDecoderImageOutBufferSize failed\n"); + return false; + } + if (buffer_size != *xsize * *ysize * 16) { + fprintf(stderr, "Invalid out buffer size %" PRIu64 " %" PRIu64 "\n", + static_cast(buffer_size), + static_cast(*xsize * *ysize * 16)); + return false; + } + pixels->resize(*xsize * *ysize * 4); + void* pixels_buffer = (void*)pixels->data(); + size_t pixels_buffer_size = pixels->size() * sizeof(float); + if (JXL_DEC_SUCCESS != JxlDecoderSetImageOutBuffer(dec.get(), &format, + pixels_buffer, + pixels_buffer_size)) { + fprintf(stderr, "JxlDecoderSetImageOutBuffer failed\n"); + return false; + } + } else if (status == JXL_DEC_FULL_IMAGE) { + // Nothing to do. Do not yet return. If the image is an animation, more + // full frames may be decoded. This example only keeps the last one. + } else if (status == JXL_DEC_SUCCESS) { + // All decoding successfully finished. + // It's not required to call JxlDecoderReleaseInput(dec.get()) here since + // the decoder will be destroyed. + return true; + } else { + fprintf(stderr, "Unknown decoder status\n"); + return false; + } + } +} + +/** Writes to .pfm file (Portable FloatMap). Gimp, tev viewer and ImageMagick + * support viewing this format. + * The input pixels are given as 32-bit floating point with 4-channel RGBA. + * The alpha channel will not be written since .pfm does not support it. + */ +bool WritePFM(const char* filename, const float* pixels, size_t xsize, + size_t ysize) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing", filename); + return false; + } + uint32_t endian_test = 1; + uint8_t little_endian[4]; + memcpy(little_endian, &endian_test, 4); + + fprintf(file, "PF\n%d %d\n%s\n", (int)xsize, (int)ysize, + little_endian[0] ? "-1.0" : "1.0"); + for (int y = ysize - 1; y >= 0; y--) { + for (size_t x = 0; x < xsize; x++) { + for (size_t c = 0; c < 3; c++) { + const float* f = &pixels[(y * xsize + x) * 4 + c]; + fwrite(f, 4, 1, file); + } + } + } + if (fclose(file) != 0) { + return false; + } + return true; +} + +bool LoadFile(const char* filename, std::vector* out) { + FILE* file = fopen(filename, "rb"); + if (!file) { + return false; + } + + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + + long size = ftell(file); + // Avoid invalid file or directory. + if (size >= LONG_MAX || size < 0) { + fclose(file); + return false; + } + + if (fseek(file, 0, SEEK_SET) != 0) { + fclose(file); + return false; + } + + out->resize(size); + size_t readsize = fread(out->data(), 1, size, file); + if (fclose(file) != 0) { + return false; + } + + return readsize == static_cast(size); +} + +bool WriteFile(const char* filename, const uint8_t* data, size_t size) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing", filename); + return false; + } + fwrite(data, 1, size, file); + if (fclose(file) != 0) { + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + if (argc != 4) { + fprintf(stderr, + "Usage: %s \n" + "Where:\n" + " jxl = input JPEG XL image filename\n" + " pfm = output Portable FloatMap image filename\n" + " icc = output ICC color profile filename\n" + "Output files will be overwritten.\n", + argv[0]); + return 1; + } + + const char* jxl_filename = argv[1]; + const char* pfm_filename = argv[2]; + const char* icc_filename = argv[3]; + + std::vector jxl; + if (!LoadFile(jxl_filename, &jxl)) { + fprintf(stderr, "couldn't load %s\n", jxl_filename); + return 1; + } + + std::vector pixels; + std::vector icc_profile; + size_t xsize = 0, ysize = 0; + if (!DecodeJpegXlOneShot(jxl.data(), jxl.size(), &pixels, &xsize, &ysize, + &icc_profile)) { + fprintf(stderr, "Error while decoding the jxl file\n"); + return 1; + } + if (!WritePFM(pfm_filename, pixels.data(), xsize, ysize)) { + fprintf(stderr, "Error while writing the PFM image file\n"); + return 1; + } + if (!WriteFile(icc_filename, icc_profile.data(), icc_profile.size())) { + fprintf(stderr, "Error while writing the ICC profile file\n"); + return 1; + } + printf("Successfully wrote %s and %s\n", pfm_filename, icc_filename); + return 0; +} diff --git a/media/libjxl/src/examples/decode_progressive.cc b/media/libjxl/src/examples/decode_progressive.cc new file mode 100644 index 000000000..77d2a0f5c --- /dev/null +++ b/media/libjxl/src/examples/decode_progressive.cc @@ -0,0 +1,238 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This C++ example decodes a JPEG XL image progressively (input bytes are +// passed in chunks). The example outputs the intermediate steps to PAM files. + +#include +#include +#include +#include +#include + +#include + +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/resizable_parallel_runner.h" +#include "jxl/resizable_parallel_runner_cxx.h" + +bool WritePAM(const char* filename, const uint8_t* buffer, size_t w, size_t h) { + FILE* fp = fopen(filename, "wb"); + if (!fp) { + fprintf(stderr, "Could not open %s for writing", filename); + return false; + } + fprintf(fp, + "P7\nWIDTH %" PRIu64 "\nHEIGHT %" PRIu64 + "\nDEPTH 4\nMAXVAL 255\nTUPLTYPE " + "RGB_ALPHA\nENDHDR\n", + static_cast(w), static_cast(h)); + fwrite(buffer, 1, w * h * 4, fp); + if (fclose(fp) != 0) { + return false; + } + return true; +} + +/** Decodes JPEG XL image to 8-bit integer RGBA pixels and an ICC Profile, in a + * progressive way, saving the intermediate steps. + */ +bool DecodeJpegXlProgressive(const uint8_t* jxl, size_t size, + const char* filename, size_t chunksize) { + std::vector pixels; + std::vector icc_profile; + size_t xsize = 0, ysize = 0; + + // Multi-threaded parallel runner. + auto runner = JxlResizableParallelRunnerMake(nullptr); + + auto dec = JxlDecoderMake(nullptr); + if (JXL_DEC_SUCCESS != + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + return false; + } + + if (JXL_DEC_SUCCESS != JxlDecoderSetParallelRunner(dec.get(), + JxlResizableParallelRunner, + runner.get())) { + fprintf(stderr, "JxlDecoderSetParallelRunner failed\n"); + return false; + } + + JxlBasicInfo info; + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + + size_t seen = 0; + JxlDecoderSetInput(dec.get(), jxl, chunksize); + size_t remaining = chunksize; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Decoder error\n"); + return false; + } else if (status == JXL_DEC_NEED_MORE_INPUT || status == JXL_DEC_SUCCESS || + status == JXL_DEC_FULL_IMAGE) { + seen += remaining - JxlDecoderReleaseInput(dec.get()); + printf("Flushing after %" PRIu64 " bytes\n", static_cast(seen)); + if (status == JXL_DEC_NEED_MORE_INPUT && + JXL_DEC_SUCCESS != JxlDecoderFlushImage(dec.get())) { + printf("flush error (no preview yet)\n"); + } else { + char fname[1024]; + if (snprintf(fname, 1024, "%s-%" PRIu64 ".pam", filename, + static_cast(seen)) >= 1024) { + fprintf(stderr, "Filename too long\n"); + return false; + }; + if (!WritePAM(fname, pixels.data(), xsize, ysize)) { + fprintf(stderr, "Error writing progressive output\n"); + } + } + remaining = size - seen; + if (remaining > chunksize) remaining = chunksize; + if (remaining == 0) { + if (status == JXL_DEC_NEED_MORE_INPUT) { + fprintf(stderr, "Error, already provided all input\n"); + return false; + } else { + return true; + } + } + JxlDecoderSetInput(dec.get(), jxl + seen, remaining); + } else if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec.get(), &info)) { + fprintf(stderr, "JxlDecoderGetBasicInfo failed\n"); + return false; + } + xsize = info.xsize; + ysize = info.ysize; + JxlResizableParallelRunnerSetThreads( + runner.get(), + JxlResizableParallelRunnerSuggestThreads(info.xsize, info.ysize)); + } else if (status == JXL_DEC_COLOR_ENCODING) { + // Get the ICC color profile of the pixel data + size_t icc_size; + if (JXL_DEC_SUCCESS != + JxlDecoderGetICCProfileSize(dec.get(), &format, + JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &icc_size)) { + fprintf(stderr, "JxlDecoderGetICCProfileSize failed\n"); + return false; + } + icc_profile.resize(icc_size); + if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsICCProfile( + dec.get(), &format, + JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile.data(), icc_profile.size())) { + fprintf(stderr, "JxlDecoderGetColorAsICCProfile failed\n"); + return false; + } + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + size_t buffer_size; + if (JXL_DEC_SUCCESS != + JxlDecoderImageOutBufferSize(dec.get(), &format, &buffer_size)) { + fprintf(stderr, "JxlDecoderImageOutBufferSize failed\n"); + return false; + } + if (buffer_size != xsize * ysize * 4) { + fprintf(stderr, "Invalid out buffer size %" PRIu64 " != %" PRIu64 "\n", + static_cast(buffer_size), + static_cast(xsize * ysize * 4)); + return false; + } + pixels.resize(xsize * ysize * 4); + void* pixels_buffer = (void*)pixels.data(); + size_t pixels_buffer_size = pixels.size() * sizeof(float); + if (JXL_DEC_SUCCESS != JxlDecoderSetImageOutBuffer(dec.get(), &format, + pixels_buffer, + pixels_buffer_size)) { + fprintf(stderr, "JxlDecoderSetImageOutBuffer failed\n"); + return false; + } + } else { + fprintf(stderr, "Unknown decoder status\n"); + return false; + } + } +} + +bool LoadFile(const char* filename, std::vector* out) { + FILE* file = fopen(filename, "rb"); + if (!file) { + return false; + } + + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + + long size = ftell(file); + // Avoid invalid file or directory. + if (size >= LONG_MAX || size < 0) { + fclose(file); + return false; + } + + if (fseek(file, 0, SEEK_SET) != 0) { + fclose(file); + return false; + } + + out->resize(size); + size_t readsize = fread(out->data(), 1, size, file); + if (fclose(file) != 0) { + return false; + } + + return readsize == static_cast(size); +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + fprintf( + stderr, + "Usage: %s [chunksize]\n" + "Where:\n" + " jxl = input JPEG XL image filename\n" + " basename = prefix of output filenames\n" + " chunksize = loads chunksize bytes at a time and writes\n" + " intermediate results to basename-[bytes loaded].pam\n" + "Output files will be overwritten.\n", + argv[0]); + return 1; + } + + const char* jxl_filename = argv[1]; + const char* png_filename = argv[2]; + + std::vector jxl; + if (!LoadFile(jxl_filename, &jxl)) { + fprintf(stderr, "couldn't load %s\n", jxl_filename); + return 1; + } + size_t chunksize = jxl.size(); + if (argc > 3) { + long cs = atol(argv[3]); + if (cs < 100) { + fprintf(stderr, "Chunk size is too low, try at least 100 bytes\n"); + return 1; + } + chunksize = cs; + } + + if (!DecodeJpegXlProgressive(jxl.data(), jxl.size(), png_filename, + chunksize)) { + fprintf(stderr, "Error while decoding the jxl file\n"); + return 1; + } + return 0; +} diff --git a/media/libjxl/src/examples/encode_oneshot.cc b/media/libjxl/src/examples/encode_oneshot.cc new file mode 100644 index 000000000..f1cd9ab00 --- /dev/null +++ b/media/libjxl/src/examples/encode_oneshot.cc @@ -0,0 +1,276 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This example encodes a file containing a floating point image to another +// file containing JPEG XL image with a single frame. + +#include +#include + +#include +#include +#include + +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "jxl/thread_parallel_runner.h" +#include "jxl/thread_parallel_runner_cxx.h" + +/** + * Reads from .pfm file (Portable FloatMap) + * + * @param filename name of the file to read + * @param pixels vector to fill with loaded pixels as 32-bit floating point with + * 3-channel RGB + * @param xsize set to width of loaded image + * @param ysize set to height of loaded image + */ +bool ReadPFM(const char* filename, std::vector* pixels, uint32_t* xsize, + uint32_t* ysize) { + FILE* file = fopen(filename, "rb"); + if (!file) { + fprintf(stderr, "Could not open %s for reading.\n", filename); + return false; + } + uint32_t endian_test = 1; + uint8_t little_endian[4]; + memcpy(little_endian, &endian_test, 4); + + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + + long size = ftell(file); + // Avoid invalid file or directory. + if (size >= LONG_MAX || size < 0) { + fclose(file); + return false; + } + + if (fseek(file, 0, SEEK_SET) != 0) { + fclose(file); + return false; + } + + std::vector data; + data.resize(size); + + size_t readsize = fread(data.data(), 1, size, file); + if ((long)readsize != size) { + return false; + } + if (fclose(file) != 0) { + return false; + } + + std::stringstream datastream; + std::string datastream_content(data.data(), data.size()); + datastream.str(datastream_content); + + std::string pf_token; + getline(datastream, pf_token, '\n'); + if (pf_token != "PF") { + fprintf(stderr, + "%s doesn't seem to be a 3 channel Portable FloatMap file (missing " + "'PF\\n' " + "bytes).\n", + filename); + return false; + } + + std::string xsize_token; + getline(datastream, xsize_token, ' '); + *xsize = std::stoi(xsize_token); + + std::string ysize_token; + getline(datastream, ysize_token, '\n'); + *ysize = std::stoi(ysize_token); + + std::string endianness_token; + getline(datastream, endianness_token, '\n'); + bool input_little_endian; + if (endianness_token == "1.0") { + input_little_endian = false; + } else if (endianness_token == "-1.0") { + input_little_endian = true; + } else { + fprintf(stderr, + "%s doesn't seem to be a Portable FloatMap file (endianness token " + "isn't '1.0' or '-1.0').\n", + filename); + return false; + } + + size_t offset = pf_token.size() + 1 + xsize_token.size() + 1 + + ysize_token.size() + 1 + endianness_token.size() + 1; + + if (data.size() != *ysize * *xsize * 3 * 4 + offset) { + fprintf(stderr, + "%s doesn't seem to be a Portable FloatMap file (pixel data bytes " + "are %d, but expected %d * %d * 3 * 4 + %d (%d).\n", + filename, (int)data.size(), (int)*ysize, (int)*xsize, (int)offset, + (int)(*ysize * *xsize * 3 * 4 + offset)); + return false; + } + + if (!!little_endian[0] != input_little_endian) { + fprintf(stderr, + "%s has a different endianness than we do, conversion is not " + "supported.\n", + filename); + return false; + } + + pixels->resize(*ysize * *xsize * 3); + + for (int y = *ysize - 1; y >= 0; y--) { + for (int x = 0; x < (int)*xsize; x++) { + for (int c = 0; c < 3; c++) { + memcpy(pixels->data() + (y * *xsize + x) * 3 + c, data.data() + offset, + sizeof(float)); + offset += sizeof(float); + } + } + } + + return true; +} + +/** + * Compresses the provided pixels. + * + * @param pixels input pixels + * @param xsize width of the input image + * @param ysize height of the input image + * @param compressed will be populated with the compressed bytes + */ +bool EncodeJxlOneshot(const std::vector& pixels, const uint32_t xsize, + const uint32_t ysize, std::vector* compressed) { + auto enc = JxlEncoderMake(/*memory_manager=*/nullptr); + auto runner = JxlThreadParallelRunnerMake( + /*memory_manager=*/nullptr, + JxlThreadParallelRunnerDefaultNumWorkerThreads()); + if (JXL_ENC_SUCCESS != JxlEncoderSetParallelRunner(enc.get(), + JxlThreadParallelRunner, + runner.get())) { + fprintf(stderr, "JxlEncoderSetParallelRunner failed\n"); + return false; + } + + JxlPixelFormat pixel_format = {3, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + + JxlBasicInfo basic_info; + JxlEncoderInitBasicInfo(&basic_info); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.bits_per_sample = 32; + basic_info.exponent_bits_per_sample = 8; + basic_info.uses_original_profile = JXL_FALSE; + if (JXL_ENC_SUCCESS != JxlEncoderSetBasicInfo(enc.get(), &basic_info)) { + fprintf(stderr, "JxlEncoderSetBasicInfo failed\n"); + return false; + } + + JxlColorEncoding color_encoding = {}; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + if (JXL_ENC_SUCCESS != + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)) { + fprintf(stderr, "JxlEncoderSetColorEncoding failed\n"); + return false; + } + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), nullptr); + + if (JXL_ENC_SUCCESS != + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + (void*)pixels.data(), + sizeof(float) * pixels.size())) { + fprintf(stderr, "JxlEncoderAddImageFrame failed\n"); + return false; + } + JxlEncoderCloseInput(enc.get()); + + compressed->resize(64); + uint8_t* next_out = compressed->data(); + size_t avail_out = compressed->size() - (next_out - compressed->data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed->data(); + compressed->resize(compressed->size() * 2); + next_out = compressed->data() + offset; + avail_out = compressed->size() - offset; + } + } + compressed->resize(next_out - compressed->data()); + if (JXL_ENC_SUCCESS != process_result) { + fprintf(stderr, "JxlEncoderProcessOutput failed\n"); + return false; + } + + return true; +} + +/** + * Writes bytes to file. + */ +bool WriteFile(const std::vector& bytes, const char* filename) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing\n", filename); + return false; + } + if (fwrite(bytes.data(), sizeof(uint8_t), bytes.size(), file) != + bytes.size()) { + fprintf(stderr, "Could not write bytes to %s\n", filename); + return false; + } + if (fclose(file) != 0) { + fprintf(stderr, "Could not close %s\n", filename); + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, + "Usage: %s \n" + "Where:\n" + " pfm = input Portable FloatMap image filename\n" + " jxl = output JPEG XL image filename\n" + "Output files will be overwritten.\n", + argv[0]); + return 1; + } + + const char* pfm_filename = argv[1]; + const char* jxl_filename = argv[2]; + + std::vector pixels; + uint32_t xsize; + uint32_t ysize; + if (!ReadPFM(pfm_filename, &pixels, &xsize, &ysize)) { + fprintf(stderr, "Couldn't load %s\n", pfm_filename); + return 2; + } + + std::vector compressed; + if (!EncodeJxlOneshot(pixels, xsize, ysize, &compressed)) { + fprintf(stderr, "Couldn't encode jxl\n"); + return 3; + } + + if (!WriteFile(compressed, jxl_filename)) { + fprintf(stderr, "Couldn't write jxl file\n"); + return 4; + } + + return 0; +} diff --git a/media/libjxl/src/examples/examples.cmake b/media/libjxl/src/examples/examples.cmake new file mode 100644 index 000000000..fd159578b --- /dev/null +++ b/media/libjxl/src/examples/examples.cmake @@ -0,0 +1,11 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +add_executable(decode_oneshot ${CMAKE_CURRENT_LIST_DIR}/decode_oneshot.cc) +target_link_libraries(decode_oneshot jxl_dec jxl_threads) +add_executable(decode_progressive ${CMAKE_CURRENT_LIST_DIR}/decode_progressive.cc) +target_link_libraries(decode_progressive jxl_dec jxl_threads) +add_executable(encode_oneshot ${CMAKE_CURRENT_LIST_DIR}/encode_oneshot.cc) +target_link_libraries(encode_oneshot jxl jxl_threads) diff --git a/media/libjxl/src/experimental/fast_lossless/.gitignore b/media/libjxl/src/experimental/fast_lossless/.gitignore new file mode 100644 index 000000000..567609b12 --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/media/libjxl/src/experimental/fast_lossless/build-android.sh b/media/libjxl/src/experimental/fast_lossless/build-android.sh new file mode 100644 index 000000000..41452cdf9 --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/build-android.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +set -e + +DIR=$(realpath "$(dirname "$0")") + +mkdir -p /tmp/build-android +cd /tmp/build-android + +CXX="$ANDROID_NDK"/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android30-clang++ +if ! command -v "$CXX" >/dev/null ; then + printf >&2 '%s: Android C++ compiler not found, is ANDROID_NDK set properly?\n' "${0##*/}" + exit 1 +fi + +[ -f lodepng.cpp ] || curl -o lodepng.cpp --url 'https://raw.githubusercontent.com/lvandeve/lodepng/8c6a9e30576f07bf470ad6f09458a2dcd7a6a84a/lodepng.cpp' +[ -f lodepng.h ] || curl -o lodepng.h --url 'https://raw.githubusercontent.com/lvandeve/lodepng/8c6a9e30576f07bf470ad6f09458a2dcd7a6a84a/lodepng.h' +[ -f lodepng.o ] || "$CXX" lodepng.cpp -O3 -o lodepng.o -c + +"$CXX" -O3 -DFASTLL_ENABLE_NEON_INTRINSICS -fopenmp \ + -I. lodepng.o \ + "${DIR}"/fast_lossless.cc "${DIR}"/fast_lossless_main.cc \ + -o fast_lossless diff --git a/media/libjxl/src/experimental/fast_lossless/build.sh b/media/libjxl/src/experimental/fast_lossless/build.sh new file mode 100644 index 000000000..b2564c6a5 --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/build.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +set -e + +DIR=$(realpath "$(dirname "$0")") +mkdir -p "$DIR"/build +cd "$DIR"/build + +# set CXX to clang++ if not set in the environment +CXX="${CXX-clang++}" +if ! command -v "$CXX" >/dev/null ; then + printf >&2 '%s: C++ compiler not found\n' "${0##*/}" + exit 1 +fi + +[ -f lodepng.cpp ] || curl -o lodepng.cpp --url 'https://raw.githubusercontent.com/lvandeve/lodepng/8c6a9e30576f07bf470ad6f09458a2dcd7a6a84a/lodepng.cpp' +[ -f lodepng.h ] || curl -o lodepng.h --url 'https://raw.githubusercontent.com/lvandeve/lodepng/8c6a9e30576f07bf470ad6f09458a2dcd7a6a84a/lodepng.h' +[ -f lodepng.o ] || "$CXX" lodepng.cpp -O3 -mavx2 -o lodepng.o -c + +"$CXX" -O3 -mavx2 -DFASTLL_ENABLE_AVX2_INTRINSICS -fopenmp \ + -I. lodepng.o \ + "$DIR"/fast_lossless.cc "$DIR"/fast_lossless_main.cc \ + -o fast_lossless diff --git a/media/libjxl/src/experimental/fast_lossless/fast_lossless.cc b/media/libjxl/src/experimental/fast_lossless/fast_lossless.cc new file mode 100644 index 000000000..9b442aac1 --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/fast_lossless.cc @@ -0,0 +1,1362 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "fast_lossless.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#if (!defined(__BYTE_ORDER__) || (__BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__)) +#error "system not known to be little endian" +#endif + +struct BitWriter { + void Allocate(size_t maximum_bit_size) { + assert(data == nullptr); + // Leave some padding. + data.reset((uint8_t*)malloc(maximum_bit_size / 8 + 32)); + } + + void Write(uint32_t count, uint64_t bits) { + buffer |= bits << bits_in_buffer; + bits_in_buffer += count; + memcpy(data.get() + bytes_written, &buffer, 8); + size_t bytes_in_buffer = bits_in_buffer / 8; + bits_in_buffer -= bytes_in_buffer * 8; + buffer >>= bytes_in_buffer * 8; + bytes_written += bytes_in_buffer; + } + + void ZeroPadToByte() { + if (bits_in_buffer != 0) { + Write(8 - bits_in_buffer, 0); + } + } + + std::unique_ptr data = {nullptr, free}; + size_t bytes_written = 0; + size_t bits_in_buffer = 0; + uint64_t buffer = 0; +}; + +constexpr size_t kLZ77Offset = 224; +constexpr size_t kLZ77MinLength = 16; + +struct PrefixCode { + static constexpr size_t kNumLZ77 = 17; + static constexpr size_t kNumRaw = 15; + + alignas(32) uint8_t raw_nbits[16] = {}; + alignas(32) uint8_t raw_bits[16] = {}; + uint8_t lz77_nbits[kNumLZ77] = {}; + + uint16_t lz77_bits[kNumLZ77] = {}; + + static uint16_t BitReverse(size_t nbits, uint16_t bits) { + constexpr uint16_t kNibbleLookup[16] = { + 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110, + 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111, + }; + uint16_t rev16 = (kNibbleLookup[bits & 0xF] << 12) | + (kNibbleLookup[(bits >> 4) & 0xF] << 8) | + (kNibbleLookup[(bits >> 8) & 0xF] << 4) | + (kNibbleLookup[bits >> 12]); + return rev16 >> (16 - nbits); + } + + // Create the prefix codes given the code lengths. + // Supports the code lengths being split into two halves. + static void ComputeCanonicalCode(const uint8_t* first_chunk_nbits, + uint8_t* first_chunk_bits, + size_t first_chunk_size, + const uint8_t* second_chunk_nbits, + uint16_t* second_chunk_bits, + size_t second_chunk_size) { + uint8_t code_length_counts[16] = {}; + for (size_t i = 0; i < first_chunk_size; i++) { + code_length_counts[first_chunk_nbits[i]]++; + assert(first_chunk_nbits[i] <= 7); + assert(first_chunk_nbits[i] > 0); + } + for (size_t i = 0; i < second_chunk_size; i++) { + code_length_counts[second_chunk_nbits[i]]++; + } + + uint16_t next_code[16] = {}; + + uint16_t code = 0; + for (size_t i = 1; i < 16; i++) { + code = (code + code_length_counts[i - 1]) << 1; + next_code[i] = code; + } + + for (size_t i = 0; i < first_chunk_size; i++) { + first_chunk_bits[i] = + BitReverse(first_chunk_nbits[i], next_code[first_chunk_nbits[i]]++); + } + for (size_t i = 0; i < second_chunk_size; i++) { + second_chunk_bits[i] = + BitReverse(second_chunk_nbits[i], next_code[second_chunk_nbits[i]]++); + } + } + + PrefixCode(uint64_t* raw_counts, uint64_t* lz77_counts) { + // "merge" together all the lz77 counts in a single symbol for the level 1 + // table (containing just the raw symbols, up to length 7). + uint64_t level1_counts[kNumRaw + 1]; + memcpy(level1_counts, raw_counts, kNumRaw * sizeof(uint64_t)); + size_t numraw = kNumRaw; + while (numraw > 0 && level1_counts[numraw - 1] == 0) numraw--; + + level1_counts[numraw] = 0; + for (size_t i = 0; i < kNumLZ77; i++) { + level1_counts[numraw] += lz77_counts[i]; + } + uint8_t level1_nbits[kNumRaw + 1] = {}; + ComputeCodeLengths(level1_counts, numraw + 1, 7, level1_nbits); + + uint8_t level2_nbits[kNumLZ77] = {}; + ComputeCodeLengths(lz77_counts, kNumLZ77, 15 - level1_nbits[numraw], + level2_nbits); + for (size_t i = 0; i < numraw; i++) { + raw_nbits[i] = level1_nbits[i]; + } + for (size_t i = 0; i < kNumLZ77; i++) { + lz77_nbits[i] = + level2_nbits[i] ? level1_nbits[numraw] + level2_nbits[i] : 0; + } + + ComputeCanonicalCode(raw_nbits, raw_bits, numraw, lz77_nbits, lz77_bits, + kNumLZ77); + } + + static void ComputeCodeLengths(uint64_t* freqs, size_t n, size_t limit, + uint8_t* nbits) { + if (n <= 1) return; + assert(n <= (1 << limit)); + assert(n <= 32); + int parent[64] = {}; + int height[64] = {}; + using QElem = std::pair; + std::priority_queue, std::greater> q; + // Standard Huffman code construction. On failure (i.e. if going beyond the + // length limit), try again with halved frequencies. + while (true) { + size_t num_nodes = 0; + for (size_t i = 0; i < n; i++) { + if (freqs[i] == 0) continue; + q.emplace(freqs[i], num_nodes++); + } + if (num_nodes <= 1) return; + while (q.size() > 1) { + QElem n1 = q.top(); + q.pop(); + QElem n2 = q.top(); + q.pop(); + size_t next = num_nodes++; + parent[n1.second] = next; + parent[n2.second] = next; + q.emplace(n1.first + n2.first, next); + } + assert(q.size() == 1); + q.pop(); + bool is_ok = true; + for (size_t i = num_nodes - 1; i-- > 0;) { + height[i] = height[parent[i]] + 1; + is_ok &= height[i] <= limit; + } + if (is_ok) { + num_nodes = 0; + for (size_t i = 0; i < n; i++) { + if (freqs[i] == 0) continue; + nbits[i] = height[num_nodes++]; + } + break; + } else { + for (size_t i = 0; i < n; i++) { + freqs[i] = (freqs[i] + 1) >> 1; + } + } + } + } + + void WriteTo(BitWriter* writer) const { + uint64_t code_length_counts[18] = {}; + code_length_counts[17] = 3 + 2 * (kNumLZ77 - 1); + for (size_t i = 0; i < kNumRaw; i++) { + code_length_counts[raw_nbits[i]]++; + } + for (size_t i = 0; i < kNumLZ77; i++) { + code_length_counts[lz77_nbits[i]]++; + } + uint8_t code_length_nbits[18] = {}; + ComputeCodeLengths(code_length_counts, 18, 5, code_length_nbits); + writer->Write(2, 0b00); // HSKIP = 0, i.e. don't skip code lengths. + + // As per Brotli RFC. + uint8_t code_length_order[18] = {1, 2, 3, 4, 0, 5, 17, 6, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15}; + uint8_t code_length_length_nbits[] = {2, 4, 3, 2, 2, 4}; + uint8_t code_length_length_bits[] = {0, 7, 3, 2, 1, 15}; + + // Encode lengths of code lengths. + size_t num_code_lengths = 18; + while (code_length_nbits[code_length_order[num_code_lengths - 1]] == 0) { + num_code_lengths--; + } + for (size_t i = 0; i < num_code_lengths; i++) { + int symbol = code_length_nbits[code_length_order[i]]; + writer->Write(code_length_length_nbits[symbol], + code_length_length_bits[symbol]); + } + + // Compute the canonical codes for the codes that represent the lengths of + // the actual codes for data. + uint16_t code_length_bits[18] = {}; + ComputeCanonicalCode(nullptr, nullptr, 0, code_length_nbits, + code_length_bits, 18); + // Encode raw bit code lengths. + for (size_t i = 0; i < kNumRaw; i++) { + writer->Write(code_length_nbits[raw_nbits[i]], + code_length_bits[raw_nbits[i]]); + } + size_t num_lz77 = kNumLZ77; + while (lz77_nbits[num_lz77 - 1] == 0) { + num_lz77--; + } + // Encode 0s until 224 (start of LZ77 symbols). This is in total 224-15 = + // 209. + static_assert(kLZ77Offset == 224, ""); + static_assert(kNumRaw == 15, ""); + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b010); // 5 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b000); // (5-2)*8 + 3 = 27 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b110); // (27-2)*8 + 9 = 209 + // Encode LZ77 symbols, with values 224+i*16. + for (size_t i = 0; i < num_lz77; i++) { + writer->Write(code_length_nbits[lz77_nbits[i]], + code_length_bits[lz77_nbits[i]]); + if (i != num_lz77 - 1) { + // Encode gap between LZ77 symbols: 15 zeros. + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b000); // 3 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b100); // (3-2)*8+7 = 15 + } + } + } +}; + +constexpr size_t kChunkSize = 16; + +void EncodeHybridUint000(uint32_t value, uint32_t* token, uint32_t* nbits, + uint32_t* bits) { + uint32_t n = 31 - __builtin_clz(value); + *token = value ? n + 1 : 0; + *nbits = value ? n : 0; + *bits = value ? value - (1 << n) : 0; +} + +void AppendWriter(BitWriter* dest, const BitWriter* src) { + if (dest->bits_in_buffer == 0) { + memcpy(dest->data.get() + dest->bytes_written, src->data.get(), + src->bytes_written); + dest->bytes_written += src->bytes_written; + } else { + size_t i = 0; + uint64_t buf = dest->buffer; + uint64_t bits_in_buffer = dest->bits_in_buffer; + uint8_t* dest_buf = dest->data.get() + dest->bytes_written; + // Copy 8 bytes at a time until we reach the border. + for (; i + 8 < src->bytes_written; i += 8) { + uint64_t chunk; + memcpy(&chunk, src->data.get() + i, 8); + uint64_t out = buf | (chunk << bits_in_buffer); + memcpy(dest_buf + i, &out, 8); + buf = chunk >> (64 - bits_in_buffer); + } + dest->buffer = buf; + dest->bytes_written += i; + for (; i < src->bytes_written; i++) { + dest->Write(8, src->data[i]); + } + } + dest->Write(src->bits_in_buffer, src->buffer); +} + +void AssembleFrame(size_t width, size_t height, size_t nb_chans, + size_t bitdepth, + const std::vector>& group_data, + BitWriter* output) { + size_t total_size_groups = 0; + std::vector group_sizes(group_data.size()); + for (size_t i = 0; i < group_data.size(); i++) { + size_t sz = 0; + for (size_t j = 0; j < nb_chans; j++) { + const auto& writer = group_data[i][j]; + sz += writer.bytes_written * 8 + writer.bits_in_buffer; + } + sz = (sz + 7) / 8; + group_sizes[i] = sz; + total_size_groups += sz * 8; + } + output->Allocate(1000 + group_data.size() * 32 + total_size_groups); + + // Signature + output->Write(16, 0x0AFF); + + // Size header, hand-crafted. + // Not small + output->Write(1, 0); + + auto wsz = [output](size_t size) { + if (size - 1 < (1 << 9)) { + output->Write(2, 0b00); + output->Write(9, size - 1); + } else if (size - 1 < (1 << 13)) { + output->Write(2, 0b01); + output->Write(13, size - 1); + } else if (size - 1 < (1 << 18)) { + output->Write(2, 0b10); + output->Write(18, size - 1); + } else { + output->Write(2, 0b11); + output->Write(30, size - 1); + } + }; + + wsz(height); + + // No special ratio. + output->Write(3, 0); + + wsz(width); + + // Hand-crafted ImageMetadata. + output->Write(1, 0); // all_default + output->Write(1, 0); // extra_fields + output->Write(1, 0); // bit_depth.floating_point_sample + if (bitdepth == 8) { + output->Write(2, 0b00); // bit_depth.bits_per_sample = 8 + } else if (bitdepth == 10) { + output->Write(2, 0b01); // bit_depth.bits_per_sample = 10 + } else if (bitdepth == 12) { + output->Write(2, 0b10); // bit_depth.bits_per_sample = 12 + } else { + output->Write(2, 0b11); // 1 + u(6) + output->Write(6, bitdepth - 1); + } + output->Write(1, 1); // 16-bit-buffer sufficient + bool have_alpha = (nb_chans == 2 || nb_chans == 4); + if (have_alpha) { + output->Write(2, 0b01); // One extra channel + output->Write(1, 1); // ... all_default (ie. 8-bit alpha) + } else { + output->Write(2, 0b00); // No extra channel + } + output->Write(1, 0); // Not XYB + if (nb_chans > 1) { + output->Write(1, 1); // color_encoding.all_default (sRGB) + } else { + output->Write(1, 0); // color_encoding.all_default false + output->Write(1, 0); // color_encoding.want_icc false + output->Write(2, 1); // grayscale + output->Write(2, 1); // D65 + output->Write(1, 0); // no gamma transfer function + output->Write(2, 0b10); // tf: 2 + u(4) + output->Write(4, 11); // tf of sRGB + output->Write(2, 1); // relative rendering intent + } + output->Write(2, 0b00); // No extensions. + + output->Write(1, 1); // all_default transform data + + // No ICC, no preview. Frame should start at byte boundery. + output->ZeroPadToByte(); + + auto wsz_fh = [output](size_t size) { + if (size < (1 << 8)) { + output->Write(2, 0b00); + output->Write(8, size); + } else if (size - 256 < (1 << 11)) { + output->Write(2, 0b01); + output->Write(11, size - 256); + } else if (size - 2304 < (1 << 14)) { + output->Write(2, 0b10); + output->Write(14, size - 2304); + } else { + output->Write(2, 0b11); + output->Write(30, size - 18688); + } + }; + + // Handcrafted frame header. + output->Write(1, 0); // all_default + output->Write(2, 0b00); // regular frame + output->Write(1, 1); // modular + output->Write(2, 0b00); // default flags + output->Write(1, 0); // not YCbCr + output->Write(2, 0b00); // no upsampling + if (have_alpha) { + output->Write(2, 0b00); // no alpha upsampling + } + output->Write(2, 0b01); // default group size + output->Write(2, 0b00); // exactly one pass + if (width % kChunkSize == 0) { + output->Write(1, 0); // no custom size or origin + } else { + output->Write(1, 1); // custom size + wsz_fh(0); // x0 = 0 + wsz_fh(0); // y0 = 0 + wsz_fh((width + kChunkSize - 1) / kChunkSize * + kChunkSize); // xsize rounded up to chunk size + wsz_fh(height); // ysize same + } + output->Write(2, 0b00); // kReplace blending mode + if (have_alpha) { + output->Write(2, 0b00); // kReplace blending mode for alpha channel + } + output->Write(1, 1); // is_last + output->Write(2, 0b00); // a frame has no name + output->Write(1, 0); // loop filter is not all_default + output->Write(1, 0); // no gaborish + output->Write(2, 0); // 0 EPF iters + output->Write(2, 0b00); // No LF extensions + output->Write(2, 0b00); // No FH extensions + + output->Write(1, 0); // No TOC permutation + output->ZeroPadToByte(); // TOC is byte-aligned. + for (size_t i = 0; i < group_data.size(); i++) { + size_t sz = group_sizes[i]; + if (sz < (1 << 10)) { + output->Write(2, 0b00); + output->Write(10, sz); + } else if (sz - 1024 < (1 << 14)) { + output->Write(2, 0b01); + output->Write(14, sz - 1024); + } else if (sz - 17408 < (1 << 22)) { + output->Write(2, 0b10); + output->Write(22, sz - 17408); + } else { + output->Write(2, 0b11); + output->Write(30, sz - 4211712); + } + } + output->ZeroPadToByte(); // Groups are byte-aligned. + + for (size_t i = 0; i < group_data.size(); i++) { + for (size_t j = 0; j < nb_chans; j++) { + AppendWriter(output, &group_data[i][j]); + } + output->ZeroPadToByte(); + } +} + +void PrepareDCGlobalCommon(bool is_single_group, size_t width, size_t height, + const PrefixCode& code, BitWriter* output) { + output->Allocate(100000 + (is_single_group ? width * height * 16 : 0)); + // No patches, spline or noise. + output->Write(1, 1); // default DC dequantization factors (?) + output->Write(1, 1); // use global tree / histograms + output->Write(1, 0); // no lz77 for the tree + + output->Write(1, 1); // simple code for the tree's context map + output->Write(2, 0); // all contexts clustered together + output->Write(1, 1); // use prefix code for tree + output->Write(4, 15); // don't do hybriduint for tree - 2 symbols anyway + output->Write(7, 0b0100101); // Alphabet size is 6: we need 0 and 5 (var16) + output->Write(2, 1); // simple prefix code + output->Write(2, 1); // with two symbols + output->Write(3, 0); // 0 + output->Write(3, 5); // 5 + output->Write(5, 0b00010); // tree repr: predictor is 5, all else 0 + + output->Write(1, 1); // Enable lz77 for the main bitstream + output->Write(2, 0b00); // lz77 offset 224 + static_assert(kLZ77Offset == 224, ""); + output->Write(10, 0b0000011111); // lz77 min length 16 + static_assert(kLZ77MinLength == 16, ""); + output->Write(4, 4); // 404 hybrid uint config for lz77: 4 + output->Write(3, 0); // 0 + output->Write(3, 4); // 4 + output->Write(1, 1); // simple code for the context map + output->Write(2, 1); // two clusters + output->Write(1, 1); // raw/lz77 length histogram last + output->Write(1, 0); // distance histogram first + output->Write(1, 1); // use prefix codes + output->Write(4, 0); // 000 hybrid uint config for distances (only need 0) + output->Write(4, 0); // 000 hybrid uint config for symbols (only <= 10) + // Distance alphabet size: + output->Write(5, 0b00001); // 2: just need 1 for RLE (i.e. distance 1) + // Symbol + LZ77 alphabet size: + output->Write(1, 1); // > 1 + output->Write(4, 8); // <= 512 + output->Write(8, 255); // == 512 + + // Distance histogram: + output->Write(2, 1); // simple prefix code + output->Write(2, 0); // with one symbol + output->Write(1, 1); // 1 + + // Symbol + lz77 histogram: + code.WriteTo(output); + + // Group header for global modular image. + output->Write(1, 1); // Global tree + output->Write(1, 1); // All default wp +} + +void PrepareDCGlobal(bool is_single_group, size_t width, size_t height, + size_t nb_chans, size_t bitdepth, const PrefixCode& code, + BitWriter* output) { + PrepareDCGlobalCommon(is_single_group, width, height, code, output); + if (nb_chans > 2) { + output->Write(2, 0b01); // 1 transform + output->Write(2, 0b00); // RCT + output->Write(5, 0b00000); // Starting from ch 0 + output->Write(2, 0b00); // YCoCg + } else { + output->Write(2, 0b00); // no transforms + } + if (!is_single_group) { + output->ZeroPadToByte(); + } +} + +void EncodeHybridUint404_Mul16(uint32_t value, uint32_t* token_div16, + uint32_t* nbits, uint32_t* bits) { + // NOTE: token in libjxl is actually << 4. + uint32_t n = 31 - __builtin_clz(value); + *token_div16 = value < 16 ? 0 : n - 3; + *nbits = value < 16 ? 0 : n - 4; + *bits = value < 16 ? 0 : (value >> 4) - (1 << *nbits); +} + +#ifdef FASTLL_ENABLE_AVX2_INTRINSICS +#include +void EncodeChunk(const uint16_t* residuals, const PrefixCode& prefix_code, + BitWriter& output) { + static_assert(kChunkSize == 16, "Chunk size must be 16"); + auto value = _mm256_load_si256((__m256i*)residuals); + + // we know that residuals[i] has at most 12 bits, so we just need 3 nibbles + // and don't need to mask the third. However we do need to set the high + // byte to 0xFF, which will make table lookups return 0. + auto lo_nibble = + _mm256_or_si256(_mm256_and_si256(value, _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto mi_nibble = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi16(value, 4), _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto hi_nibble = + _mm256_or_si256(_mm256_srli_epi16(value, 8), _mm256_set1_epi16(0xFF00)); + + auto lo_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4)); + auto mi_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8)); + auto hi_lut = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12)); + + auto lo_token = _mm256_shuffle_epi8(lo_lut, lo_nibble); + auto mi_token = _mm256_shuffle_epi8(mi_lut, mi_nibble); + auto hi_token = _mm256_shuffle_epi8(hi_lut, hi_nibble); + + auto token = _mm256_max_epi16(lo_token, _mm256_max_epi16(mi_token, hi_token)); + auto nbits = _mm256_subs_epu16(token, _mm256_set1_epi16(1)); + + // Compute 1< 64 bit lanes. + auto nbits_hi32 = _mm256_srli_epi64(nbits, 32); + auto nbits_lo32 = _mm256_and_si256(nbits, _mm256_set1_epi64x(0xFFFFFFFF)); + auto bits_hi32 = _mm256_srli_epi64(bits, 32); + auto bits_lo32 = _mm256_and_si256(bits, _mm256_set1_epi64x(0xFFFFFFFF)); + + nbits = _mm256_add_epi64(nbits_hi32, nbits_lo32); + bits = _mm256_or_si256(_mm256_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32); + + alignas(32) uint64_t nbits_simd[4] = {}; + alignas(32) uint64_t bits_simd[4] = {}; + + _mm256_store_si256((__m256i*)nbits_simd, nbits); + _mm256_store_si256((__m256i*)bits_simd, bits); + + // Manually merge the buffer bits with the SIMD bits. + // Necessary because Write() is only guaranteed to work with <=56 bits. + // Trying to SIMD-fy this code results in slower speed (and definitely less + // clarity). + { + for (size_t i = 0; i < 4; i++) { + output.buffer |= bits_simd[i] << output.bits_in_buffer; + memcpy(output.data.get() + output.bytes_written, &output.buffer, 8); + // If >> 64, next_buffer is unused. + uint64_t next_buffer = bits_simd[i] >> (64 - output.bits_in_buffer); + output.bits_in_buffer += nbits_simd[i]; + // This `if` seems to be faster than using ternaries. + if (output.bits_in_buffer >= 64) { + output.buffer = next_buffer; + output.bits_in_buffer -= 64; + output.bytes_written += 8; + } + } + memcpy(output.data.get() + output.bytes_written, &output.buffer, 8); + size_t bytes_in_buffer = output.bits_in_buffer / 8; + output.bits_in_buffer -= bytes_in_buffer * 8; + output.buffer >>= bytes_in_buffer * 8; + output.bytes_written += bytes_in_buffer; + } +} +#endif + +#ifdef FASTLL_ENABLE_NEON_INTRINSICS +#include + +void EncodeChunk(const uint16_t* residuals, const PrefixCode& code, + BitWriter& output) { + uint16x8_t res = vld1q_u16(residuals); + uint16x8_t token = vsubq_u16(vdupq_n_u16(16), vclzq_u16(res)); + uint16x8_t nbits = vqsubq_u16(token, vdupq_n_u16(1)); + uint16x8_t bits = vqsubq_u16(res, vshlq_s16(vdupq_n_s16(1), nbits)); + uint16x8_t huff_bits = + vandq_u16(vdupq_n_u16(0xFF), vqtbl1q_u8(vld1q_u8(code.raw_bits), token)); + uint16x8_t huff_nbits = + vandq_u16(vdupq_n_u16(0xFF), vqtbl1q_u8(vld1q_u8(code.raw_nbits), token)); + bits = vorrq_u16(vshlq_u16(bits, huff_nbits), huff_bits); + nbits = vaddq_u16(nbits, huff_nbits); + + // Merge nbits and bits from 16-bit to 32-bit lanes. + uint32x4_t nbits_lo16 = vandq_u32(nbits, vdupq_n_u32(0xFFFF)); + uint32x4_t bits_hi16 = vshlq_u32(vshrq_n_u32(bits, 16), nbits_lo16); + uint32x4_t bits_lo16 = vandq_u32(bits, vdupq_n_u32(0xFFFF)); + + uint32x4_t nbits32 = vsraq_n_u32(nbits_lo16, nbits, 16); + uint32x4_t bits32 = vorrq_u32(bits_hi16, bits_lo16); + + // Merging up to 64 bits is not faster. + + // Manually merge the buffer bits with the SIMD bits. + // A bit faster. + for (size_t i = 0; i < 4; i++) { + output.buffer |= bits32[i] << output.bits_in_buffer; + memcpy(output.data.get() + output.bytes_written, &output.buffer, 8); + output.bits_in_buffer += nbits32[i]; + size_t bytes_in_buffer = output.bits_in_buffer / 8; + output.bits_in_buffer -= bytes_in_buffer * 8; + output.buffer >>= bytes_in_buffer * 8; + output.bytes_written += bytes_in_buffer; + } +} +#endif + +template +struct ChunkEncoder { + static void EncodeRle(size_t count, const PrefixCode& code, + BitWriter& output) { + if (count == 0) return; + count -= kLZ77MinLength; + unsigned token_div16, nbits, bits; + EncodeHybridUint404_Mul16(count, &token_div16, &nbits, &bits); + output.Write( + code.lz77_nbits[token_div16] + nbits, + (bits << code.lz77_nbits[token_div16]) | code.lz77_bits[token_div16]); + } + + inline void Chunk(size_t run, uint16_t* residuals) { + EncodeRle(run, *code, *output); +#if defined(FASTLL_ENABLE_AVX2_INTRINSICS) && FASTLL_ENABLE_AVX2_INTRINSICS + if (bytedepth == 1) { + EncodeChunk(residuals, *code, *output); + return; + } +#elif defined(FASTLL_ENABLE_NEON_INTRINSICS) && FASTLL_ENABLE_NEON_INTRINSICS + if (bytedepth == 1) { + EncodeChunk(residuals, *code, *output); + if (kChunkSize > 8) { + EncodeChunk(residuals + 8, *code, *output); + } + return; + } +#endif + for (size_t ix = 0; ix < kChunkSize; ix++) { + unsigned token, nbits, bits; + EncodeHybridUint000(residuals[ix], &token, &nbits, &bits); + output->Write(code->raw_nbits[token] + nbits, + code->raw_bits[token] | bits << code->raw_nbits[token]); + } + } + + inline void Finalize(size_t run) { EncodeRle(run, *code, *output); } + + const PrefixCode* code; + BitWriter* output; +}; + +struct ChunkSampleCollector { + void Rle(size_t count, uint64_t* lz77_counts) { + if (count == 0) return; + count -= kLZ77MinLength; + unsigned token_div16, nbits, bits; + EncodeHybridUint404_Mul16(count, &token_div16, &nbits, &bits); + lz77_counts[token_div16]++; + } + + inline void Chunk(size_t run, uint16_t* residuals) { + // Run is broken. Encode the run and encode the individual vector. + Rle(run, lz77_counts); + for (size_t ix = 0; ix < kChunkSize; ix++) { + unsigned token, nbits, bits; + EncodeHybridUint000(residuals[ix], &token, &nbits, &bits); + raw_counts[token]++; + } + } + + // don't count final run since we don't know how long it really is + void Finalize(size_t run) {} + + uint64_t* raw_counts; + uint64_t* lz77_counts; +}; + +constexpr uint16_t PackSigned(int16_t value) { + return (static_cast(value) << 1) ^ + ((static_cast(~value) >> 15) - 1); +} + +template +struct ChannelRowProcessor { + T* t; + inline void ProcessChunk(const int16_t* row, const int16_t* row_left, + const int16_t* row_top, const int16_t* row_topleft) { + bool continue_rle = true; + alignas(32) uint16_t residuals[kChunkSize] = {}; + for (size_t ix = 0; ix < kChunkSize; ix++) { + int16_t px = row[ix]; + int16_t left = row_left[ix]; + int16_t top = row_top[ix]; + int16_t topleft = row_topleft[ix]; + int16_t ac = left - topleft; + int16_t ab = left - top; + int16_t bc = top - topleft; + int16_t grad = static_cast(static_cast(ac) + + static_cast(top)); + int16_t d = ab ^ bc; + int16_t clamp = d < 0 ? top : left; + int16_t s = ac ^ bc; + int16_t pred = s < 0 ? grad : clamp; + residuals[ix] = PackSigned(px - pred); + continue_rle &= residuals[ix] == last; + } + // Run continues, nothing to do. + if (continue_rle) { + run += kChunkSize; + } else { + // Run is broken. Encode the run and encode the individual vector. + t->Chunk(run, residuals); + run = 0; + } + last = residuals[kChunkSize - 1]; + } + void ProcessRow(const int16_t* row, const int16_t* row_left, + const int16_t* row_top, const int16_t* row_topleft, + size_t xs) { + for (size_t x = 0; x + kChunkSize <= xs; x += kChunkSize) { + ProcessChunk(row + x, row_left + x, row_top + x, row_topleft + x); + } + } + + void Finalize() { t->Finalize(run); } + size_t run = 0; + uint16_t last = 0xFFFF; // Can never appear +}; + +template +void ProcessImageArea(const unsigned char* rgba, size_t x0, size_t y0, + size_t oxs, size_t xs, size_t yskip, size_t ys, + size_t row_stride, Processor* processors) { + constexpr size_t kPadding = 16; + + int16_t group_data[nb_chans][2][256 + kPadding * 2] = {}; + int16_t allzero[nb_chans] = {}; + int16_t allone[nb_chans]; + auto get_pixel = [&](size_t x, size_t y, size_t channel) { + int16_t p = rgba[row_stride * (y0 + y) + (x0 + x) * nb_chans * bytedepth + + channel * bytedepth]; + if (bytedepth == 2) { + p <<= 8; + p |= rgba[row_stride * (y0 + y) + (x0 + x) * nb_chans * 2 + channel * 2 + + 1]; + } + return p; + }; + + for (size_t i = 0; i < nb_chans; i++) allone[i] = 0xffff; + for (size_t y = 0; y < ys; y++) { + // Pre-fill rows with YCoCg converted pixels. + for (size_t x = 0; x < oxs; x++) { + if (nb_chans < 3) { + int16_t luma = get_pixel(x, y, 0); + group_data[0][y & 1][x + kPadding] = luma; + if (nb_chans == 2) { + int16_t a = get_pixel(x, y, 1); + group_data[1][y & 1][x + kPadding] = a; + } + } else { + int16_t r = get_pixel(x, y, 0); + int16_t g = get_pixel(x, y, 1); + int16_t b = get_pixel(x, y, 2); + if (nb_chans == 4) { + int16_t a = get_pixel(x, y, 3); + group_data[3][y & 1][x + kPadding] = a; + group_data[1][y & 1][x + kPadding] = a ? r - b : 0; + int16_t tmp = b + (group_data[1][y & 1][x + kPadding] >> 1); + group_data[2][y & 1][x + kPadding] = a ? g - tmp : 0; + group_data[0][y & 1][x + kPadding] = + a ? tmp + (group_data[2][y & 1][x + kPadding] >> 1) : 0; + } else { + group_data[1][y & 1][x + kPadding] = r - b; + int16_t tmp = b + (group_data[1][y & 1][x + kPadding] >> 1); + group_data[2][y & 1][x + kPadding] = g - tmp; + group_data[0][y & 1][x + kPadding] = + tmp + (group_data[2][y & 1][x + kPadding] >> 1); + } + } + for (size_t c = 0; c < nb_chans; c++) { + allzero[c] |= group_data[c][y & 1][x + kPadding]; + allone[c] &= group_data[c][y & 1][x + kPadding]; + } + } + // Deal with x == 0. + for (size_t c = 0; c < nb_chans; c++) { + group_data[c][y & 1][kPadding - 1] = + y > 0 ? group_data[c][(y - 1) & 1][kPadding] : 0; + // Fix topleft. + group_data[c][(y - 1) & 1][kPadding - 1] = + y > 0 ? group_data[c][(y - 1) & 1][kPadding] : 0; + } + // Fill in padding. + for (size_t c = 0; c < nb_chans; c++) { + for (size_t x = oxs; x < xs; x++) { + group_data[c][y & 1][kPadding + x] = + group_data[c][y & 1][kPadding + oxs - 1]; + } + } + if (y < yskip) continue; + for (size_t c = 0; c < nb_chans; c++) { + if (y > 0 && (allzero[c] == 0 || (allone[c] == 0xff && bytedepth == 1))) { + processors[c].run += xs; + continue; + } + + // Get pointers to px/left/top/topleft data to speedup loop. + const int16_t* row = &group_data[c][y & 1][kPadding]; + const int16_t* row_left = &group_data[c][y & 1][kPadding - 1]; + const int16_t* row_top = + y == 0 ? row_left : &group_data[c][(y - 1) & 1][kPadding]; + const int16_t* row_topleft = + y == 0 ? row_left : &group_data[c][(y - 1) & 1][kPadding - 1]; + + processors[c].ProcessRow(row, row_left, row_top, row_topleft, xs); + } + } + for (size_t c = 0; c < nb_chans; c++) { + processors[c].Finalize(); + } +} + +template +void WriteACSection(const unsigned char* rgba, size_t x0, size_t y0, size_t oxs, + size_t ys, size_t row_stride, bool is_single_group, + const PrefixCode& code, std::array& output) { + size_t xs = (oxs + kChunkSize - 1) / kChunkSize * kChunkSize; + for (size_t i = 0; i < nb_chans; i++) { + if (is_single_group && i == 0) continue; + output[i].Allocate(16 * xs * ys * bytedepth + 4); + } + if (!is_single_group) { + // Group header for modular image. + // When the image is single-group, the global modular image is the one that + // contains the pixel data, and there is no group header. + output[0].Write(1, 1); // Global tree + output[0].Write(1, 1); // All default wp + output[0].Write(2, 0b00); // 0 transforms + } + + ChunkEncoder encoders[nb_chans]; + ChannelRowProcessor> row_encoders[nb_chans]; + for (size_t c = 0; c < nb_chans; c++) { + row_encoders[c].t = &encoders[c]; + encoders[c].output = &output[c]; + encoders[c].code = &code; + } + ProcessImageArea>, nb_chans, + bytedepth>(rgba, x0, y0, oxs, xs, 0, ys, row_stride, + row_encoders); +} + +constexpr int kHashExp = 16; +constexpr uint32_t kHashSize = 1 << kHashExp; +constexpr uint32_t kHashMultiplier = 2654435761; +constexpr int kMaxColors = 512; + +// can be any function that returns a value in 0 .. kHashSize-1 +// has to map 0 to 0 +inline uint32_t pixel_hash(uint32_t p) { + return (p * kHashMultiplier) >> (32 - kHashExp); +} + +template +void ProcessImageAreaPalette(const unsigned char* rgba, size_t x0, size_t y0, + size_t oxs, size_t xs, size_t yskip, size_t ys, + size_t row_stride, const int16_t* lookup, + Processor* processors) { + constexpr size_t kPadding = 16; + + int16_t group_data[2][256 + kPadding * 2] = {}; + Processor& row_encoder = processors[0]; + + for (size_t y = 0; y < ys; y++) { + // Pre-fill rows with palette converted pixels. + const unsigned char* inrow = rgba + row_stride * (y0 + y) + x0 * nb_chans; + for (size_t x = 0; x < oxs; x++) { + uint32_t p = 0; + memcpy(&p, inrow + x * nb_chans, nb_chans); + group_data[y & 1][x + kPadding] = lookup[pixel_hash(p)]; + } + // Deal with x == 0. + group_data[y & 1][kPadding - 1] = + y > 0 ? group_data[(y - 1) & 1][kPadding] : 0; + // Fix topleft. + group_data[(y - 1) & 1][kPadding - 1] = + y > 0 ? group_data[(y - 1) & 1][kPadding] : 0; + // Fill in padding. + for (size_t x = oxs; x < xs; x++) { + group_data[y & 1][kPadding + x] = group_data[y & 1][kPadding + oxs - 1]; + } + // Get pointers to px/left/top/topleft data to speedup loop. + const int16_t* row = &group_data[y & 1][kPadding]; + const int16_t* row_left = &group_data[y & 1][kPadding - 1]; + const int16_t* row_top = + y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding]; + const int16_t* row_topleft = + y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding - 1]; + + row_encoder.ProcessRow(row, row_left, row_top, row_topleft, xs); + } + row_encoder.Finalize(); +} + +template +void WriteACSectionPalette(const unsigned char* rgba, size_t x0, size_t y0, + size_t oxs, size_t ys, size_t row_stride, + bool is_single_group, const PrefixCode& code, + const int16_t* lookup, BitWriter& output) { + size_t xs = (oxs + kChunkSize - 1) / kChunkSize * kChunkSize; + + if (!is_single_group) { + output.Allocate(16 * xs * ys + 4); + // Group header for modular image. + // When the image is single-group, the global modular image is the one that + // contains the pixel data, and there is no group header. + output.Write(1, 1); // Global tree + output.Write(1, 1); // All default wp + output.Write(2, 0b00); // 0 transforms + } + + ChunkEncoder<1> encoder; + ChannelRowProcessor> row_encoder; + + row_encoder.t = &encoder; + encoder.output = &output; + encoder.code = &code; + ProcessImageAreaPalette>, nb_chans>( + rgba, x0, y0, oxs, xs, 0, ys, row_stride, lookup, &row_encoder); +} + +template +void CollectSamples(const unsigned char* rgba, size_t x0, size_t y0, size_t xs, + size_t row_stride, size_t row_count, uint64_t* raw_counts, + uint64_t* lz77_counts, bool palette, + const int16_t* lookup) { + ChunkSampleCollector sample_collectors[nb_chans]; + ChannelRowProcessor row_sample_collectors[nb_chans]; + for (size_t c = 0; c < nb_chans; c++) { + row_sample_collectors[c].t = &sample_collectors[c]; + sample_collectors[c].raw_counts = raw_counts; + sample_collectors[c].lz77_counts = lz77_counts; + } + if (palette) { + assert(bytedepth == 1); + ProcessImageAreaPalette, + nb_chans>(rgba, x0, y0, xs, xs, 1, 1 + row_count, + row_stride, lookup, + row_sample_collectors); + } else { + ProcessImageArea, nb_chans, + bytedepth>(rgba, x0, y0, xs, xs, 1, 1 + row_count, + row_stride, row_sample_collectors); + } +} + +void PrepareDCGlobalPalette(bool is_single_group, size_t width, size_t height, + const PrefixCode& code, + const std::vector& palette, + size_t pcolors_real, BitWriter* output) { + PrepareDCGlobalCommon(is_single_group, width, height, code, output); + output->Write(2, 0b01); // 1 transform + output->Write(2, 0b01); // Palette + output->Write(5, 0b00000); // Starting from ch 0 + output->Write(2, 0b10); // 4-channel palette (RGBA) + size_t pcolors = (pcolors_real + kChunkSize - 1) / kChunkSize * kChunkSize; + // pcolors <= kMaxColors + kChunkSize - 1 + static_assert(kMaxColors + kChunkSize < 1281, + "add code to signal larger palette sizes"); + if (pcolors < 256) { + output->Write(2, 0b00); + output->Write(8, pcolors); + } else { + output->Write(2, 0b01); + output->Write(10, pcolors - 256); + } + + output->Write(2, 0b00); // nb_deltas == 0 + output->Write(4, 0); // Zero predictor for delta palette + // Encode palette + ChunkEncoder<1> encoder; + ChannelRowProcessor> row_encoder; + row_encoder.t = &encoder; + encoder.output = output; + encoder.code = &code; + int16_t p[4][32 + 1024] = {}; + uint8_t prgba[4]; + int i = 0; + int have_zero = 0; + if (palette[pcolors_real - 1] == 0) have_zero = 1; + for (; i < pcolors; i++) { + if (i < pcolors_real) { + memcpy(prgba, &palette[i], 4); + } + p[0][16 + i + have_zero] = prgba[0]; + p[1][16 + i + have_zero] = prgba[1]; + p[2][16 + i + have_zero] = prgba[2]; + p[3][16 + i + have_zero] = prgba[3]; + } + p[0][15] = 0; + row_encoder.ProcessRow(p[0] + 16, p[0] + 15, p[0] + 15, p[0] + 15, pcolors); + p[1][15] = p[0][16]; + p[0][15] = p[0][16]; + row_encoder.ProcessRow(p[1] + 16, p[1] + 15, p[0] + 16, p[0] + 15, pcolors); + p[2][15] = p[1][16]; + p[1][15] = p[1][16]; + row_encoder.ProcessRow(p[2] + 16, p[2] + 15, p[1] + 16, p[1] + 15, pcolors); + p[3][15] = p[2][16]; + p[2][15] = p[2][16]; + row_encoder.ProcessRow(p[3] + 16, p[3] + 15, p[2] + 16, p[2] + 15, pcolors); + row_encoder.Finalize(); + + if (!is_single_group) { + output->ZeroPadToByte(); + } +} + +template +size_t LLEnc(const unsigned char* rgba, size_t width, size_t stride, + size_t height, size_t bitdepth, int effort, + unsigned char** output) { + size_t bytes_per_sample = (bitdepth > 8 ? 2 : 1); + assert(bytedepth == bytes_per_sample); + assert(width != 0); + assert(height != 0); + assert(stride >= nb_chans * bytes_per_sample * width); + (void)bytes_per_sample; + + // Count colors to try palette + std::vector palette(kHashSize); + palette[0] = 1; + int16_t lookup[kHashSize]; + lookup[0] = 0; + int pcolors = 0; + bool collided = + effort < 2 || bitdepth != 8 || nb_chans < 4; // todo: also do rgb palette + for (size_t y = 0; y < height && !collided; y++) { + const unsigned char* r = rgba + stride * y; + size_t x = 0; + if (nb_chans == 4) { + // this is just an unrolling of the next loop + for (; x + 7 < width; x += 8) { + uint32_t p[8], index[8]; + memcpy(p, r + x * 4, 32); + for (int i = 0; i < 8; i++) index[i] = pixel_hash(p[i]); + for (int i = 0; i < 8; i++) { + uint32_t init_entry = index[i] ? 0 : 1; + if (init_entry != palette[index[i]] && p[i] != palette[index[i]]) { + collided = true; + } + } + for (int i = 0; i < 8; i++) palette[index[i]] = p[i]; + } + for (; x < width; x++) { + uint32_t p; + memcpy(&p, r + x * 4, 4); + uint32_t index = pixel_hash(p); + uint32_t init_entry = index ? 0 : 1; + if (init_entry != palette[index] && p != palette[index]) { + collided = true; + } + palette[index] = p; + } + } else { + for (; x < width; x++) { + uint32_t p = 0; + memcpy(&p, r + x * nb_chans, nb_chans); + uint32_t index = pixel_hash(p); + uint32_t init_entry = index ? 0 : 1; + if (init_entry != palette[index] && p != palette[index]) { + collided = true; + } + palette[index] = p; + } + } + } + + int nb_entries = 0; + if (!collided) { + if (palette[0] == 0) pcolors = 1; + if (palette[0] == 1) palette[0] = 0; + bool have_color = false; + uint8_t minG = 255, maxG = 0; + for (int k = 0; k < kHashSize; k++) { + if (palette[k] == 0) continue; + uint8_t p[4]; + memcpy(p, &palette[k], 4); + // move entries to front so sort has less work + palette[nb_entries] = palette[k]; + if (p[0] != p[1] || p[0] != p[2]) have_color = true; + if (p[1] < minG) minG = p[1]; + if (p[1] > maxG) maxG = p[1]; + nb_entries++; + // don't do palette if too many colors are needed + if (nb_entries + pcolors > kMaxColors) { + collided = true; + break; + } + } + if (!have_color) { + // don't do palette if it's just grayscale without many holes + if (maxG - minG < nb_entries * 1.4f) collided = true; + } + } + if (!collided) { + std::sort( + palette.begin(), palette.begin() + nb_entries, + [](uint32_t ap, uint32_t bp) { + if (ap == 0) return false; + if (bp == 0) return true; + uint8_t a[4], b[4]; + memcpy(a, &ap, 4); + memcpy(b, &bp, 4); + float ay, by; + ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f) * a[3]; + by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f) * b[3]; + return ay < by; // sort on alpha*luma + }); + for (int k = 0; k < nb_entries; k++) { + if (palette[k] == 0) break; + lookup[pixel_hash(palette[k])] = pcolors++; + } + } + + // Width gets padded to kChunkSize, but this computation doesn't change + // because of that. + size_t num_groups_x = (width + 255) / 256; + size_t num_groups_y = (height + 255) / 256; + size_t num_dc_groups_x = (width + 2047) / 2048; + size_t num_dc_groups_y = (height + 2047) / 2048; + + uint64_t raw_counts[16] = {}; + uint64_t lz77_counts[17] = {}; + + // sample the middle (effort * 2) rows of every group + for (size_t g = 0; g < num_groups_y * num_groups_x; g++) { + size_t xg = g % num_groups_x; + size_t yg = g / num_groups_x; + int y_offset = yg * 256; + int y_max = std::min(height - yg * 256, 256); + int y_begin = y_offset + std::max(0, y_max - 2 * effort) / 2; + int y_count = + std::min(2 * effort * y_max / 256, y_offset + y_max - y_begin - 1); + int x_max = + std::min(width - xg * 256, 256) / kChunkSize * kChunkSize; + CollectSamples(rgba, xg * 256, y_begin, x_max, stride, + y_count, raw_counts, lz77_counts, + !collided, lookup); + } + + uint64_t base_raw_counts[16] = {3843, 852, 1270, 1214, 1014, 727, 481, 300, + 159, 51, 5, 1, 1, 1, 1, 1}; + + bool doing_ycocg = nb_chans > 2 && collided; + for (size_t i = bitdepth + 2 + (doing_ycocg ? 1 : 0); i < 16; i++) { + base_raw_counts[i] = 0; + } + uint64_t base_lz77_counts[17] = { + // short runs will be sampled, but long ones won't. + // near full-group run is quite common (e.g. all-opaque alpha) + 18, 12, 9, 11, 15, 2, 2, 1, 1, 1, 1, 2, 300, 0, 0, 0, 0}; + + for (size_t i = 0; i < 16; i++) { + raw_counts[i] = (raw_counts[i] << 8) + base_raw_counts[i]; + } + if (!collided) { + unsigned token, nbits, bits; + EncodeHybridUint000(PackSigned(pcolors - 1), &token, &nbits, &bits); + // ensure all palette indices can actually be encoded + for (size_t i = 0; i < token + 1; i++) + raw_counts[i] = std::max(raw_counts[i], 1); + // these tokens are only used for the palette itself so they can get a bad + // code + for (size_t i = token + 1; i < 10; i++) raw_counts[i] = 1; + } + for (size_t i = 0; i < 17; i++) { + lz77_counts[i] = (lz77_counts[i] << 8) + base_lz77_counts[i]; + } + alignas(32) PrefixCode hcode(raw_counts, lz77_counts); + + BitWriter writer; + + bool onegroup = num_groups_x == 1 && num_groups_y == 1; + + size_t num_groups = onegroup ? 1 + : (2 + num_dc_groups_x * num_dc_groups_y + + num_groups_x * num_groups_y); + + std::vector> group_data(num_groups); + if (collided) { + PrepareDCGlobal(onegroup, width, height, nb_chans, bitdepth, hcode, + &group_data[0][0]); + } else { + PrepareDCGlobalPalette(onegroup, width, height, hcode, palette, pcolors, + &group_data[0][0]); + } +#pragma omp parallel for + for (size_t g = 0; g < num_groups_y * num_groups_x; g++) { + size_t xg = g % num_groups_x; + size_t yg = g / num_groups_x; + size_t group_id = + onegroup ? 0 : (2 + num_dc_groups_x * num_dc_groups_y + g); + size_t xs = std::min(width - xg * 256, 256); + size_t ys = std::min(height - yg * 256, 256); + size_t x0 = xg * 256; + size_t y0 = yg * 256; + auto& gd = group_data[group_id]; + if (collided) { + WriteACSection(rgba, x0, y0, xs, ys, stride, + onegroup, hcode, gd); + + } else { + WriteACSectionPalette(rgba, x0, y0, xs, ys, stride, onegroup, + hcode, lookup, gd[0]); + } + } + + AssembleFrame(width, height, nb_chans, bitdepth, group_data, &writer); + + *output = writer.data.release(); + return writer.bytes_written; +} + +size_t FastLosslessEncode(const unsigned char* rgba, size_t width, + size_t stride, size_t height, size_t nb_chans, + size_t bitdepth, int effort, unsigned char** output) { + assert(bitdepth <= 12); + assert(bitdepth > 0); + assert(nb_chans <= 4); + assert(nb_chans != 0); + if (bitdepth <= 8) { + if (nb_chans == 1) { + return LLEnc<1, 1>(rgba, width, stride, height, bitdepth, effort, output); + } + if (nb_chans == 2) { + return LLEnc<2, 1>(rgba, width, stride, height, bitdepth, effort, output); + } + if (nb_chans == 3) { + return LLEnc<3, 1>(rgba, width, stride, height, bitdepth, effort, output); + } + if (nb_chans == 4) { + return LLEnc<4, 1>(rgba, width, stride, height, bitdepth, effort, output); + } + } else { + if (nb_chans == 1) { + return LLEnc<1, 2>(rgba, width, stride, height, bitdepth, effort, output); + } + if (nb_chans == 2) { + return LLEnc<2, 2>(rgba, width, stride, height, bitdepth, effort, output); + } + if (nb_chans == 3) { + return LLEnc<3, 2>(rgba, width, stride, height, bitdepth, effort, output); + } + if (nb_chans == 4) { + return LLEnc<4, 2>(rgba, width, stride, height, bitdepth, effort, output); + } + } + return 0; +} diff --git a/media/libjxl/src/experimental/fast_lossless/fast_lossless.h b/media/libjxl/src/experimental/fast_lossless/fast_lossless.h new file mode 100644 index 000000000..f7940e53f --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/fast_lossless.h @@ -0,0 +1,14 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef FAST_LOSSLESS_H +#define FAST_LOSSLESS_H +#include + +size_t FastLosslessEncode(const unsigned char* rgba, size_t width, + size_t row_stride, size_t height, size_t nb_chans, + size_t bitdepth, int effort, unsigned char** output); + +#endif diff --git a/media/libjxl/src/experimental/fast_lossless/fast_lossless_main.cc b/media/libjxl/src/experimental/fast_lossless/fast_lossless_main.cc new file mode 100644 index 000000000..4db1ec49e --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/fast_lossless_main.cc @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include + +#include +#include + +#include "fast_lossless.h" +#include "lodepng.h" +#include "pam-input.h" + +int main(int argc, char** argv) { + if (argc < 3) { + fprintf(stderr, "Usage: %s in.png out.jxl [effort] [num_reps]\n", argv[0]); + return 1; + } + + const char* in = argv[1]; + const char* out = argv[2]; + int effort = argc >= 4 ? atoi(argv[3]) : 2; + size_t num_reps = argc >= 5 ? atoi(argv[4]) : 1; + + if (effort < 0 || effort > 127) { + fprintf( + stderr, + "Effort should be between 0 and 127 (default is 2, more is slower)\n"); + return 1; + } + + unsigned char* png; + unsigned w, h; + size_t nb_chans = 4, bitdepth = 8; + + unsigned error = lodepng_decode32_file(&png, &w, &h, in); + + size_t width = w, height = h; + if (error && !DecodePAM(in, &png, &width, &height, &nb_chans, &bitdepth)) { + fprintf(stderr, "lodepng error %u: %s\n", error, lodepng_error_text(error)); + return 1; + } + + size_t encoded_size = 0; + unsigned char* encoded = nullptr; + size_t stride = width * nb_chans * (bitdepth > 8 ? 2 : 1); + + auto start = std::chrono::high_resolution_clock::now(); + for (size_t _ = 0; _ < num_reps; _++) { + free(encoded); + encoded_size = FastLosslessEncode(png, width, stride, height, nb_chans, + bitdepth, effort, &encoded); + } + auto stop = std::chrono::high_resolution_clock::now(); + if (num_reps > 1) { + float us = + std::chrono::duration_cast(stop - start) + .count(); + size_t pixels = size_t{width} * size_t{height} * num_reps; + float mps = pixels / us; + fprintf(stderr, "%10.3f MP/s\n", mps); + fprintf(stderr, "%10.3f bits/pixel\n", + encoded_size * 8.0 / float(width) / float(height)); + } + + FILE* o = fopen(out, "wb"); + if (!o) { + fprintf(stderr, "error opening %s: %s\n", out, strerror(errno)); + return 1; + } + if (fwrite(encoded, 1, encoded_size, o) != encoded_size) { + fprintf(stderr, "error writing to %s: %s\n", out, strerror(errno)); + } + fclose(o); +} diff --git a/media/libjxl/src/experimental/fast_lossless/pam-input.h b/media/libjxl/src/experimental/fast_lossless/pam-input.h new file mode 100644 index 000000000..8bc41ef28 --- /dev/null +++ b/media/libjxl/src/experimental/fast_lossless/pam-input.h @@ -0,0 +1,289 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include + +bool error_msg(const char* message) { + fprintf(stderr, "%s\n", message); + return false; +} +#define return_on_error(X) \ + if (!X) return false; + +size_t Log2(uint32_t value) { return 31 - __builtin_clz(value); } + +struct HeaderPNM { + size_t xsize; + size_t ysize; + bool is_gray; // PGM + bool has_alpha; // PAM + size_t bits_per_sample; +}; + +class Parser { + public: + explicit Parser(uint8_t* data, size_t length) + : pos_(data), end_(data + length) {} + + // Sets "pos" to the first non-header byte/pixel on success. + bool ParseHeader(HeaderPNM* header, const uint8_t** pos) { + // codec.cc ensures we have at least two bytes => no range check here. + if (pos_[0] != 'P') return false; + const uint8_t type = pos_[1]; + pos_ += 2; + + switch (type) { + case '5': + header->is_gray = true; + return ParseHeaderPNM(header, pos); + + case '6': + header->is_gray = false; + return ParseHeaderPNM(header, pos); + + case '7': + return ParseHeaderPAM(header, pos); + } + return false; + } + + // Exposed for testing + bool ParseUnsigned(size_t* number) { + if (pos_ == end_) return error_msg("PNM: reached end before number"); + if (!IsDigit(*pos_)) return error_msg("PNM: expected unsigned number"); + + *number = 0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + return true; + } + + bool ParseSigned(double* number) { + if (pos_ == end_) return error_msg("PNM: reached end before signed"); + + if (*pos_ != '-' && *pos_ != '+' && !IsDigit(*pos_)) { + return error_msg("PNM: expected signed number"); + } + + // Skip sign + const bool is_neg = *pos_ == '-'; + if (is_neg || *pos_ == '+') { + ++pos_; + if (pos_ == end_) return error_msg("PNM: reached end before digits"); + } + + // Leading digits + *number = 0.0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + // Decimal places? + if (pos_ < end_ && *pos_ == '.') { + ++pos_; + double place = 0.1; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number += (*pos_ - '0') * place; + place *= 0.1; + ++pos_; + } + } + + if (is_neg) *number = -*number; + return true; + } + + private: + static bool IsDigit(const uint8_t c) { return '0' <= c && c <= '9'; } + static bool IsLineBreak(const uint8_t c) { return c == '\r' || c == '\n'; } + static bool IsWhitespace(const uint8_t c) { + return IsLineBreak(c) || c == '\t' || c == ' '; + } + + bool SkipBlank() { + if (pos_ == end_) return error_msg("PNM: reached end before blank"); + const uint8_t c = *pos_; + if (c != ' ' && c != '\n') return error_msg("PNM: expected blank"); + ++pos_; + return true; + } + + bool SkipSingleWhitespace() { + if (pos_ == end_) return error_msg("PNM: reached end before whitespace"); + if (!IsWhitespace(*pos_)) return error_msg("PNM: expected whitespace"); + ++pos_; + return true; + } + + bool SkipWhitespace() { + if (pos_ == end_) return error_msg("PNM: reached end before whitespace"); + if (!IsWhitespace(*pos_) && *pos_ != '#') { + return error_msg("PNM: expected whitespace/comment"); + } + + while (pos_ < end_ && IsWhitespace(*pos_)) { + ++pos_; + } + + // Comment(s) + while (pos_ != end_ && *pos_ == '#') { + while (pos_ != end_ && !IsLineBreak(*pos_)) { + ++pos_; + } + // Newline(s) + while (pos_ != end_ && IsLineBreak(*pos_)) pos_++; + } + + while (pos_ < end_ && IsWhitespace(*pos_)) { + ++pos_; + } + return true; + } + + bool MatchString(const char* keyword) { + const uint8_t* ppos = pos_; + while (*keyword) { + if (ppos >= end_) return error_msg("PAM: unexpected end of input"); + if (*keyword != *ppos) return false; + ppos++; + keyword++; + } + pos_ = ppos; + return_on_error(SkipWhitespace()); + return true; + } + + bool ParseHeaderPAM(HeaderPNM* header, const uint8_t** pos) { + size_t num_channels = 3; + size_t max_val = 255; + while (!MatchString("ENDHDR")) { + return_on_error(SkipWhitespace()); + if (MatchString("WIDTH")) { + return_on_error(ParseUnsigned(&header->xsize)); + } else if (MatchString("HEIGHT")) { + return_on_error(ParseUnsigned(&header->ysize)); + } else if (MatchString("DEPTH")) { + return_on_error(ParseUnsigned(&num_channels)); + } else if (MatchString("MAXVAL")) { + return_on_error(ParseUnsigned(&max_val)); + } else if (MatchString("TUPLTYPE")) { + if (MatchString("RGB_ALPHA")) { + header->has_alpha = true; + } else if (MatchString("RGB")) { + } else if (MatchString("GRAYSCALE_ALPHA")) { + header->has_alpha = true; + header->is_gray = true; + } else if (MatchString("GRAYSCALE")) { + header->is_gray = true; + } else if (MatchString("BLACKANDWHITE_ALPHA")) { + header->has_alpha = true; + header->is_gray = true; + max_val = 1; + } else if (MatchString("BLACKANDWHITE")) { + header->is_gray = true; + max_val = 1; + } else { + return error_msg("PAM: unknown TUPLTYPE"); + } + } else { + return error_msg("PAM: unknown header keyword"); + } + } + if (num_channels != + (header->has_alpha ? 1 : 0) + (header->is_gray ? 1 : 3)) { + return error_msg("PAM: bad DEPTH"); + } + if (max_val == 0 || max_val >= 65536) { + return error_msg("PAM: bad MAXVAL"); + } + header->bits_per_sample = Log2(max_val + 1); + + *pos = pos_; + return true; + } + + bool ParseHeaderPNM(HeaderPNM* header, const uint8_t** pos) { + return_on_error(SkipWhitespace()); + return_on_error(ParseUnsigned(&header->xsize)); + + return_on_error(SkipWhitespace()); + return_on_error(ParseUnsigned(&header->ysize)); + + return_on_error(SkipWhitespace()); + size_t max_val; + return_on_error(ParseUnsigned(&max_val)); + if (max_val == 0 || max_val >= 65536) { + return error_msg("PNM: bad MaxVal"); + } + header->bits_per_sample = Log2(max_val + 1); + + return_on_error(SkipSingleWhitespace()); + + *pos = pos_; + return true; + } + + const uint8_t* pos_; + const uint8_t* const end_; +}; + +bool load_file(unsigned char** out, size_t* outsize, const char* filename) { + FILE* file; + file = fopen(filename, "rb"); + if (!file) return false; + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + *outsize = ftell(file); + if (*outsize == LONG_MAX || *outsize < 9 || fseek(file, 0, SEEK_SET)) { + fclose(file); + return false; + } + *out = (unsigned char*)malloc(*outsize); + if (!(*out)) return false; + size_t readsize; + readsize = fread(*out, 1, *outsize, file); + fclose(file); + if (readsize != *outsize) return false; + return true; +} + +bool DecodePAM(const char* filename, uint8_t** buffer, size_t* w, size_t* h, + size_t* nb_chans, size_t* bitdepth) { + unsigned char* in_file; + size_t in_size; + if (!load_file(&in_file, &in_size, filename)) + return error_msg("Could not read input file"); + Parser parser(in_file, in_size); + HeaderPNM header = {}; + const uint8_t* pos = nullptr; + if (!parser.ParseHeader(&header, &pos)) return false; + + if (header.bits_per_sample == 0 || header.bits_per_sample > 12) { + return error_msg("PNM: bits_per_sample invalid (can do at most 12-bit)"); + } + *w = header.xsize; + *h = header.ysize; + *bitdepth = header.bits_per_sample; + *nb_chans = (header.is_gray ? 1 : 3) + (header.has_alpha ? 1 : 0); + + size_t pnm_remaining_size = in_file + in_size - pos; + size_t buffer_size = *w * *h * *nb_chans * (*bitdepth > 8 ? 2 : 1); + if (pnm_remaining_size < buffer_size) { + return error_msg("PNM file too small"); + } + *buffer = (uint8_t*)malloc(buffer_size); + memcpy(*buffer, pos, buffer_size); + return true; +} diff --git a/media/libjxl/src/lib/CMakeLists.txt b/media/libjxl/src/lib/CMakeLists.txt new file mode 100644 index 000000000..5c8e0bada --- /dev/null +++ b/media/libjxl/src/lib/CMakeLists.txt @@ -0,0 +1,166 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set(JPEGXL_MAJOR_VERSION 0) +set(JPEGXL_MINOR_VERSION 7) +set(JPEGXL_PATCH_VERSION 0) +set(JPEGXL_LIBRARY_VERSION + "${JPEGXL_MAJOR_VERSION}.${JPEGXL_MINOR_VERSION}.${JPEGXL_PATCH_VERSION}") + +# This is the library API/ABI compatibility version. Changing this value makes +# the shared library incompatible with previous version. A program linked +# against this shared library SOVERSION will not run with an older SOVERSION. +# It is important to update this value when making incompatible API/ABI changes +# so that programs that depend on libjxl can update their dependencies. Semantic +# versioning allows 0.y.z to have incompatible changes in minor versions. +set(JPEGXL_SO_MINOR_VERSION 7) +if (JPEGXL_MAJOR_VERSION EQUAL 0) +set(JPEGXL_LIBRARY_SOVERSION + "${JPEGXL_MAJOR_VERSION}.${JPEGXL_SO_MINOR_VERSION}") +else() +set(JPEGXL_LIBRARY_SOVERSION "${JPEGXL_MAJOR_VERSION}") +endif() + + +# List of warning and feature flags for our library and tests. +if (MSVC) +set(JPEGXL_INTERNAL_FLAGS + # TODO(janwas): add flags +) +else () +set(JPEGXL_INTERNAL_FLAGS + # F_FLAGS + -fmerge-all-constants + -fno-builtin-fwrite + -fno-builtin-fread + + # WARN_FLAGS + -Wall + -Wextra + -Wc++11-compat + -Warray-bounds + -Wformat-security + -Wimplicit-fallthrough + -Wno-register # Needed by public headers in lcms + -Wno-unused-function + -Wno-unused-parameter + -Wnon-virtual-dtor + -Woverloaded-virtual + -Wvla +) + +# Warning flags supported by clang. +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND JPEGXL_INTERNAL_FLAGS + -Wdeprecated-increment-bool + # TODO(deymo): Add -Wextra-semi once we update third_party/highway. + # -Wextra-semi + -Wfloat-overflow-conversion + -Wfloat-zero-conversion + -Wfor-loop-analysis + -Wgnu-redeclared-enum + -Winfinite-recursion + -Wliteral-conversion + -Wno-c++98-compat + -Wno-unused-command-line-argument + -Wprivate-header + -Wself-assign + -Wstring-conversion + -Wtautological-overlap-compare + -Wthread-safety-analysis + -Wundefined-func-template + -Wunreachable-code + -Wunused-comparison + ) + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + list(APPEND HWY_FLAGS -Wc++2a-extensions) + endif() +endif() # Clang + +if (WIN32) + list(APPEND JPEGXL_INTERNAL_FLAGS + -Wno-cast-align + -Wno-double-promotion + -Wno-float-equal + -Wno-format-nonliteral + -Wno-shadow + -Wno-sign-conversion + -Wno-zero-as-null-pointer-constant + ) + + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND JPEGXL_INTERNAL_FLAGS + -Wno-used-but-marked-unused + -Wno-unused-template + -Wno-unused-member-function + -Wno-shadow-field-in-constructor + -Wno-language-extension-token + -Wno-global-constructors + -Wno-c++98-compat-pedantic + ) + endif() # Clang +else() # WIN32 + list(APPEND JPEGXL_INTERNAL_FLAGS + -fsized-deallocation + -fno-exceptions + + # Language flags + -fmath-errno + ) + + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND JPEGXL_INTERNAL_FLAGS + -fnew-alignment=8 + -fno-cxx-exceptions + -fno-slp-vectorize + -fno-vectorize + + -disable-free + -disable-llvm-verifier + ) + endif() # Clang +endif() # WIN32 + +# Internal flags for coverage builds: +if(JPEGXL_ENABLE_COVERAGE) +set(JPEGXL_COVERAGE_FLAGS + -g -O0 -fprofile-arcs -ftest-coverage + -DJXL_ENABLE_ASSERT=0 -DJXL_ENABLE_CHECK=0 +) +endif() # JPEGXL_ENABLE_COVERAGE +endif() #!MSVC + +# The jxl library definition. +include(jxl.cmake) + +# Other libraries outside the core jxl library. +if(JPEGXL_ENABLE_TOOLS) + include(jxl_extras.cmake) +endif() +include(jxl_threads.cmake) + +# Install all the library headers from the source and the generated ones. There +# is no distinction on which libraries use which header since it is expected +# that all developer libraries are available together at build time. +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/jxl + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/include/jxl + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + +# Profiler for libjxl +include(jxl_profiler.cmake) + +if(BUILD_TESTING) +# Unittests +cmake_policy(SET CMP0057 NEW) # https://gitlab.kitware.com/cmake/cmake/issues/18198 +include(GoogleTest) + +# Tests for the jxl library. +include(jxl_tests.cmake) + +# Google benchmark for the jxl library +include(jxl_benchmark.cmake) + +endif() # BUILD_TESTING diff --git a/media/libjxl/src/lib/extras/LICENSE.apngdis b/media/libjxl/src/lib/extras/LICENSE.apngdis new file mode 100644 index 000000000..eb0ba7c07 --- /dev/null +++ b/media/libjxl/src/lib/extras/LICENSE.apngdis @@ -0,0 +1,27 @@ +APNG Disassembler 2.8 + +Deconstructs APNG files into individual frames. + +http://apngdis.sourceforge.net + +Copyright (c) 2010-2015 Max Stepin +maxst at users.sourceforge.net + +zlib license +------------ + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. +2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. +3. This notice may not be removed or altered from any source distribution. diff --git a/media/libjxl/src/lib/extras/README.md b/media/libjxl/src/lib/extras/README.md new file mode 100644 index 000000000..06a9b5ea0 --- /dev/null +++ b/media/libjxl/src/lib/extras/README.md @@ -0,0 +1,5 @@ +## JPEG XL "extras" + +The files in this directory do not form part of the library or codec and are +only used by tests or specific internal tools that have access to the internals +of the library. diff --git a/media/libjxl/src/lib/extras/codec.cc b/media/libjxl/src/lib/extras/codec.cc new file mode 100644 index 000000000..774b4ccb6 --- /dev/null +++ b/media/libjxl/src/lib/extras/codec.cc @@ -0,0 +1,189 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/codec.h" + +#include "jxl/decode.h" +#include "jxl/types.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + +#if JPEGXL_ENABLE_APNG +#include "lib/extras/enc/apng.h" +#endif +#if JPEGXL_ENABLE_JPEG +#include "lib/extras/enc/jpg.h" +#endif +#if JPEGXL_ENABLE_EXR +#include "lib/extras/enc/exr.h" +#endif + +#include "lib/extras/dec/decode.h" +#include "lib/extras/enc/pgx.h" +#include "lib/extras/enc/pnm.h" +#include "lib/extras/packed_image_convert.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { +namespace { + +// Any valid encoding is larger (ensures codecs can read the first few bytes) +constexpr size_t kMinBytes = 9; + +} // namespace + +Status SetFromBytes(const Span bytes, + const extras::ColorHints& color_hints, CodecInOut* io, + ThreadPool* pool, extras::Codec* orig_codec) { + if (bytes.size() < kMinBytes) return JXL_FAILURE("Too few bytes"); + + extras::PackedPixelFile ppf; + if (extras::DecodeBytes(bytes, color_hints, io->constraints, &ppf, + orig_codec)) { + return ConvertPackedPixelFileToCodecInOut(ppf, pool, io); + } + return JXL_FAILURE("Codecs failed to decode"); +} + +Status SetFromFile(const std::string& pathname, + const extras::ColorHints& color_hints, CodecInOut* io, + ThreadPool* pool, extras::Codec* orig_codec) { + std::vector encoded; + JXL_RETURN_IF_ERROR(ReadFile(pathname, &encoded)); + JXL_RETURN_IF_ERROR(SetFromBytes(Span(encoded), color_hints, + io, pool, orig_codec)); + return true; +} + +Status Encode(const CodecInOut& io, const extras::Codec codec, + const ColorEncoding& c_desired, size_t bits_per_sample, + std::vector* bytes, ThreadPool* pool) { + JXL_CHECK(!io.Main().c_current().ICC().empty()); + JXL_CHECK(!c_desired.ICC().empty()); + io.CheckMetadata(); + if (io.Main().IsJPEG()) { + JXL_WARNING("Writing JPEG data as pixels"); + } + JxlPixelFormat format = { + 0, // num_channels is ignored by the converter + bits_per_sample <= 8 ? JXL_TYPE_UINT8 : JXL_TYPE_UINT16, JXL_BIG_ENDIAN, + 0}; + const bool floating_point = bits_per_sample > 16; + std::unique_ptr encoder; + std::ostringstream os; + switch (codec) { + case extras::Codec::kPNG: +#if JPEGXL_ENABLE_APNG + encoder = extras::GetAPNGEncoder(); + break; +#else + return JXL_FAILURE("JPEG XL was built without (A)PNG support"); +#endif + case extras::Codec::kJPG: +#if JPEGXL_ENABLE_JPEG + format.data_type = JXL_TYPE_UINT8; + encoder = extras::GetJPEGEncoder(); + os << io.jpeg_quality; + encoder->SetOption("q", os.str()); + break; +#else + return JXL_FAILURE("JPEG XL was built without JPEG support"); +#endif + case extras::Codec::kPNM: + if (io.Main().HasAlpha()) { + encoder = extras::GetPAMEncoder(); + } else if (io.Main().IsGray()) { + encoder = extras::GetPGMEncoder(); + } else if (!floating_point) { + encoder = extras::GetPPMEncoder(); + } else { + format.data_type = JXL_TYPE_FLOAT; + format.endianness = JXL_NATIVE_ENDIAN; + encoder = extras::GetPFMEncoder(); + } + if (!c_desired.IsSRGB()) { + JXL_WARNING( + "PNM encoder cannot store custom ICC profile; decoder " + "will need hint key=color_space to get the same values"); + } + break; + case extras::Codec::kPGX: + encoder = extras::GetPGXEncoder(); + break; + case extras::Codec::kGIF: + return JXL_FAILURE("Encoding to GIF is not implemented"); + case extras::Codec::kEXR: +#if JPEGXL_ENABLE_EXR + format.data_type = JXL_TYPE_FLOAT; + encoder = extras::GetEXREncoder(); + break; +#else + return JXL_FAILURE("JPEG XL was built without OpenEXR support"); +#endif + case extras::Codec::kUnknown: + return JXL_FAILURE("Cannot encode using Codec::kUnknown"); + } + + if (!encoder) { + return JXL_FAILURE("Invalid codec."); + } + + extras::PackedPixelFile ppf; + JXL_RETURN_IF_ERROR( + ConvertCodecInOutToPackedPixelFile(io, format, c_desired, pool, &ppf)); + extras::EncodedImage encoded_image; + JXL_RETURN_IF_ERROR(encoder->Encode(ppf, &encoded_image, pool)); + JXL_ASSERT(encoded_image.bitstreams.size() == 1); + *bytes = encoded_image.bitstreams[0]; + + return true; +} + +Status EncodeToFile(const CodecInOut& io, const ColorEncoding& c_desired, + size_t bits_per_sample, const std::string& pathname, + ThreadPool* pool) { + const std::string extension = Extension(pathname); + const extras::Codec codec = + extras::CodecFromExtension(extension, &bits_per_sample); + + // Warn about incorrect usage of PGM/PGX/PPM - only the latter supports + // color, but CodecFromExtension lumps them all together. + if (codec == extras::Codec::kPNM && extension != ".pfm") { + if (io.Main().HasAlpha() && extension != ".pam") { + JXL_WARNING( + "For images with alpha, the filename should end with .pam.\n"); + } else if (!io.Main().IsGray() && extension == ".pgm") { + JXL_WARNING("For color images, the filename should end with .ppm.\n"); + } else if (io.Main().IsGray() && extension == ".ppm") { + JXL_WARNING( + "For grayscale images, the filename should not end with .ppm.\n"); + } + if (bits_per_sample > 16) { + JXL_WARNING("PPM only supports up to 16 bits per sample"); + bits_per_sample = 16; + } + } else if (codec == extras::Codec::kPGX && !io.Main().IsGray()) { + JXL_WARNING("Storing color image to PGX - use .ppm extension instead.\n"); + } + if (bits_per_sample > 16 && codec == extras::Codec::kPNG) { + JXL_WARNING("PNG only supports up to 16 bits per sample"); + bits_per_sample = 16; + } + + std::vector encoded; + return Encode(io, codec, c_desired, bits_per_sample, &encoded, pool) && + WriteFile(encoded, pathname); +} + +Status EncodeToFile(const CodecInOut& io, const std::string& pathname, + ThreadPool* pool) { + // TODO(lode): need to take the floating_point_sample field into account + return EncodeToFile(io, io.metadata.m.color_encoding, + io.metadata.m.bit_depth.bits_per_sample, pathname, pool); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/codec.h b/media/libjxl/src/lib/extras/codec.h new file mode 100644 index 000000000..73fdc80be --- /dev/null +++ b/media/libjxl/src/lib/extras/codec.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_CODEC_H_ +#define LIB_EXTRAS_CODEC_H_ + +// Facade for image encoders/decoders (PNG, PNM, ...). + +#include +#include + +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/dec/decode.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/field_encodings.h" // MakeBit + +namespace jxl { + +// Decodes "bytes" and sets io->metadata.m. +// color_space_hint may specify the color space, otherwise, defaults to sRGB. +Status SetFromBytes(Span bytes, + const extras::ColorHints& color_hints, CodecInOut* io, + ThreadPool* pool = nullptr, + extras::Codec* orig_codec = nullptr); +// Helper function to use no color_space_hint. +JXL_INLINE Status SetFromBytes(const Span bytes, CodecInOut* io, + ThreadPool* pool = nullptr, + extras::Codec* orig_codec = nullptr) { + return SetFromBytes(bytes, extras::ColorHints(), io, pool, orig_codec); +} + +// Reads from file and calls SetFromBytes. +Status SetFromFile(const std::string& pathname, + const extras::ColorHints& color_hints, CodecInOut* io, + ThreadPool* pool = nullptr, + extras::Codec* orig_codec = nullptr); + +// Replaces "bytes" with an encoding of pixels transformed from c_current +// color space to c_desired. +Status Encode(const CodecInOut& io, extras::Codec codec, + const ColorEncoding& c_desired, size_t bits_per_sample, + std::vector* bytes, ThreadPool* pool = nullptr); + +// Deduces codec, calls Encode and writes to file. +Status EncodeToFile(const CodecInOut& io, const ColorEncoding& c_desired, + size_t bits_per_sample, const std::string& pathname, + ThreadPool* pool = nullptr); +// Same, but defaults to metadata.original color_encoding and bits_per_sample. +Status EncodeToFile(const CodecInOut& io, const std::string& pathname, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_H_ diff --git a/media/libjxl/src/lib/extras/codec_test.cc b/media/libjxl/src/lib/extras/codec_test.cc new file mode 100644 index 000000000..19cac3997 --- /dev/null +++ b/media/libjxl/src/lib/extras/codec_test.cc @@ -0,0 +1,556 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/codec.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "lib/extras/dec/pgx.h" +#include "lib/extras/dec/pnm.h" +#include "lib/extras/enc/encode.h" +#include "lib/extras/packed_image_convert.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace extras { +namespace { + +using ::testing::AllOf; +using ::testing::Contains; +using ::testing::Field; +using ::testing::IsEmpty; +using ::testing::NotNull; +using ::testing::SizeIs; + +std::string ExtensionFromCodec(Codec codec, const bool is_gray, + const bool has_alpha, + const size_t bits_per_sample) { + switch (codec) { + case Codec::kJPG: + return ".jpg"; + case Codec::kPGX: + return ".pgx"; + case Codec::kPNG: + return ".png"; + case Codec::kPNM: + if (has_alpha) return ".pam"; + if (is_gray) return ".pgm"; + return (bits_per_sample == 32) ? ".pfm" : ".ppm"; + case Codec::kGIF: + return ".gif"; + case Codec::kEXR: + return ".exr"; + case Codec::kUnknown: + return std::string(); + } + JXL_UNREACHABLE; + return std::string(); +} + +void VerifySameImage(const PackedImage& im0, size_t bits_per_sample0, + const PackedImage& im1, size_t bits_per_sample1, + bool lossless = true) { + ASSERT_EQ(im0.xsize, im1.xsize); + ASSERT_EQ(im0.ysize, im1.ysize); + ASSERT_EQ(im0.format.num_channels, im1.format.num_channels); + auto get_factor = [](JxlPixelFormat f, size_t bits) -> double { + return 1.0 / ((1u << std::min(test::GetPrecision(f.data_type), bits)) - 1); + }; + double factor0 = get_factor(im0.format, bits_per_sample0); + double factor1 = get_factor(im1.format, bits_per_sample1); + auto pixels0 = static_cast(im0.pixels()); + auto pixels1 = static_cast(im1.pixels()); + auto rgba0 = + test::ConvertToRGBA32(pixels0, im0.xsize, im0.ysize, im0.format, factor0); + auto rgba1 = + test::ConvertToRGBA32(pixels1, im1.xsize, im1.ysize, im1.format, factor1); + double tolerance = + lossless ? 0.5 * std::min(factor0, factor1) : 3.0f / 255.0f; + if (bits_per_sample0 == 32 || bits_per_sample1 == 32) { + tolerance = 0.5 * std::max(factor0, factor1); + } + for (size_t y = 0; y < im0.ysize; ++y) { + for (size_t x = 0; x < im0.xsize; ++x) { + for (size_t c = 0; c < im0.format.num_channels; ++c) { + size_t ix = (y * im0.xsize + x) * 4 + c; + double val0 = rgba0[ix]; + double val1 = rgba1[ix]; + ASSERT_NEAR(val1, val0, tolerance) + << "y = " << y << " x = " << x << " c = " << c; + } + } + } +} + +JxlColorEncoding CreateTestColorEncoding(bool is_gray) { + JxlColorEncoding c; + c.color_space = is_gray ? JXL_COLOR_SPACE_GRAY : JXL_COLOR_SPACE_RGB; + c.white_point = JXL_WHITE_POINT_D65; + c.primaries = JXL_PRIMARIES_P3; + c.rendering_intent = JXL_RENDERING_INTENT_RELATIVE; + c.transfer_function = JXL_TRANSFER_FUNCTION_LINEAR; + // Roundtrip through internal color encoding to fill in primaries and white + // point CIE xy coordinates. + ColorEncoding c_internal; + JXL_CHECK(ConvertExternalToInternalColorEncoding(c, &c_internal)); + ConvertInternalToExternalColorEncoding(c_internal, &c); + return c; +} + +std::vector GenerateICC(JxlColorEncoding color_encoding) { + ColorEncoding c; + JXL_CHECK(ConvertExternalToInternalColorEncoding(color_encoding, &c)); + JXL_CHECK(c.CreateICC()); + PaddedBytes icc = c.ICC(); + return std::vector(icc.begin(), icc.end()); +} + +void StoreRandomValue(uint8_t* out, Rng* rng, JxlPixelFormat format, + size_t bits_per_sample) { + uint64_t max_val = (1ull << bits_per_sample) - 1; + if (format.data_type == JXL_TYPE_UINT8) { + *out = rng->UniformU(0, max_val); + } else if (format.data_type == JXL_TYPE_UINT16) { + uint32_t val = rng->UniformU(0, max_val); + if (format.endianness == JXL_BIG_ENDIAN) { + StoreBE16(val, out); + } else { + StoreLE16(val, out); + } + } else { + ASSERT_EQ(format.data_type, JXL_TYPE_FLOAT); + float val = rng->UniformF(0.0, 1.0); + uint32_t uval; + memcpy(&uval, &val, 4); + if (format.endianness == JXL_BIG_ENDIAN) { + StoreBE32(uval, out); + } else { + StoreLE32(uval, out); + } + } +} + +void FillPackedImage(size_t bits_per_sample, PackedImage* image) { + JxlPixelFormat format = image->format; + size_t bytes_per_channel = PackedImage::BitsPerChannel(format.data_type) / 8; + uint8_t* out = static_cast(image->pixels()); + size_t stride = image->xsize * format.num_channels * bytes_per_channel; + ASSERT_EQ(image->pixels_size, image->ysize * stride); + Rng rng(129); + for (size_t y = 0; y < image->ysize; ++y) { + for (size_t x = 0; x < image->xsize; ++x) { + for (size_t c = 0; c < format.num_channels; ++c) { + StoreRandomValue(out, &rng, format, bits_per_sample); + out += bytes_per_channel; + } + } + } +} + +struct TestImageParams { + Codec codec; + size_t xsize; + size_t ysize; + size_t bits_per_sample; + bool is_gray; + bool add_alpha; + bool big_endian; + + bool ShouldTestRoundtrip() const { + if (codec == Codec::kPNG) { + return true; + } else if (codec == Codec::kPNM) { + // TODO(szabadka) Make PNM encoder endianness-aware. + return ((bits_per_sample <= 16 && big_endian) || + (bits_per_sample == 32 && !add_alpha && !big_endian)); + } else if (codec == Codec::kPGX) { + return ((bits_per_sample == 8 || bits_per_sample == 16) && is_gray && + !add_alpha); + } else if (codec == Codec::kEXR) { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + // OpenEXR 2.3 has a memory leak in IlmThread_2_3::ThreadPool + return false; +#else + return bits_per_sample == 32 && !is_gray; +#endif + } else if (codec == Codec::kJPG) { + return bits_per_sample == 8 && !add_alpha; + } else { + return false; + } + } + + JxlPixelFormat PixelFormat() const { + JxlPixelFormat format; + format.num_channels = (is_gray ? 1 : 3) + (add_alpha ? 1 : 0); + format.data_type = (bits_per_sample == 32 ? JXL_TYPE_FLOAT + : bits_per_sample > 8 ? JXL_TYPE_UINT16 + : JXL_TYPE_UINT8); + format.endianness = big_endian ? JXL_BIG_ENDIAN : JXL_LITTLE_ENDIAN; + format.align = 0; + return format; + } + + std::string DebugString() const { + std::ostringstream os; + os << "bps:" << bits_per_sample << " gr:" << is_gray << " al:" << add_alpha + << " be: " << big_endian; + return os.str(); + } +}; + +void CreateTestImage(const TestImageParams& params, PackedPixelFile* ppf) { + ppf->info.xsize = params.xsize; + ppf->info.ysize = params.ysize; + ppf->info.bits_per_sample = params.bits_per_sample; + ppf->info.exponent_bits_per_sample = params.bits_per_sample == 32 ? 8 : 0; + ppf->info.num_color_channels = params.is_gray ? 1 : 3; + ppf->info.alpha_bits = params.add_alpha ? params.bits_per_sample : 0; + ppf->info.alpha_premultiplied = (params.codec == Codec::kEXR); + + JxlColorEncoding color_encoding = CreateTestColorEncoding(params.is_gray); + ppf->icc = GenerateICC(color_encoding); + ppf->color_encoding = color_encoding; + + PackedFrame frame(params.xsize, params.ysize, params.PixelFormat()); + FillPackedImage(params.bits_per_sample, &frame.color); + ppf->frames.emplace_back(std::move(frame)); +} + +// Ensures reading a newly written file leads to the same image pixels. +void TestRoundTrip(const TestImageParams& params, ThreadPool* pool) { + if (!params.ShouldTestRoundtrip()) return; + + std::string extension = ExtensionFromCodec( + params.codec, params.is_gray, params.add_alpha, params.bits_per_sample); + printf("Codec %s %s\n", extension.c_str(), params.DebugString().c_str()); + + PackedPixelFile ppf_in; + CreateTestImage(params, &ppf_in); + + EncodedImage encoded; + auto encoder = Encoder::FromExtension(extension); + ASSERT_TRUE(encoder.get()); + ASSERT_TRUE(encoder->Encode(ppf_in, &encoded, pool)); + ASSERT_EQ(encoded.bitstreams.size(), 1); + + PackedPixelFile ppf_out; + ASSERT_TRUE(DecodeBytes(Span(encoded.bitstreams[0]), + ColorHints(), SizeConstraints(), &ppf_out)); + + if (params.codec != Codec::kPNM && params.codec != Codec::kPGX && + params.codec != Codec::kEXR) { + EXPECT_EQ(ppf_in.icc, ppf_out.icc); + } + + ASSERT_EQ(ppf_out.frames.size(), 1); + VerifySameImage(ppf_in.frames[0].color, ppf_in.info.bits_per_sample, + ppf_out.frames[0].color, ppf_out.info.bits_per_sample, + /*lossless=*/params.codec != Codec::kJPG); +} + +TEST(CodecTest, TestRoundTrip) { + ThreadPoolInternal pool(12); + + TestImageParams params; + params.xsize = 7; + params.ysize = 4; + + for (Codec codec : AvailableCodecs()) { + for (int bits_per_sample : {4, 8, 10, 12, 16, 32}) { + for (bool is_gray : {false, true}) { + for (bool add_alpha : {false, true}) { + for (bool big_endian : {false, true}) { + params.codec = codec; + params.bits_per_sample = static_cast(bits_per_sample); + params.is_gray = is_gray; + params.add_alpha = add_alpha; + params.big_endian = big_endian; + TestRoundTrip(params, &pool); + } + } + } + } + } +} + +CodecInOut DecodeRoundtrip(const std::string& pathname, ThreadPool* pool, + const ColorHints& color_hints = ColorHints()) { + CodecInOut io; + const PaddedBytes orig = ReadTestData(pathname); + JXL_CHECK( + SetFromBytes(Span(orig), color_hints, &io, pool, nullptr)); + const ImageBundle& ib1 = io.Main(); + + // Encode/Decode again to make sure Encode carries through all metadata. + std::vector encoded; + JXL_CHECK(Encode(io, Codec::kPNG, io.metadata.m.color_encoding, + io.metadata.m.bit_depth.bits_per_sample, &encoded, pool)); + + CodecInOut io2; + JXL_CHECK(SetFromBytes(Span(encoded), color_hints, &io2, pool, + nullptr)); + const ImageBundle& ib2 = io2.Main(); + EXPECT_EQ(Description(ib1.metadata()->color_encoding), + Description(ib2.metadata()->color_encoding)); + EXPECT_EQ(Description(ib1.c_current()), Description(ib2.c_current())); + + size_t bits_per_sample = io2.metadata.m.bit_depth.bits_per_sample; + + // "Same" pixels? + double max_l1 = bits_per_sample <= 12 ? 1.3 : 2E-3; + double max_rel = bits_per_sample <= 12 ? 6E-3 : 1E-4; + if (ib1.metadata()->color_encoding.IsGray()) { + max_rel *= 2.0; + } else if (ib1.metadata()->color_encoding.primaries != Primaries::kSRGB) { + // Need more tolerance for large gamuts (anything but sRGB) + max_l1 *= 1.5; + max_rel *= 3.0; + } + VerifyRelativeError(ib1.color(), ib2.color(), max_l1, max_rel); + + // Simulate the encoder removing profile and decoder restoring it. + if (!ib2.metadata()->color_encoding.WantICC()) { + io2.metadata.m.color_encoding.InternalRemoveICC(); + EXPECT_TRUE(io2.metadata.m.color_encoding.CreateICC()); + } + + return io2; +} + +#if 0 +TEST(CodecTest, TestMetadataSRGB) { + ThreadPoolInternal pool(12); + + const char* paths[] = {"external/raw.pixls/DJI-FC6310-16bit_srgb8_v4_krita.png", + "external/raw.pixls/Google-Pixel2XL-16bit_srgb8_v4_krita.png", + "external/raw.pixls/HUAWEI-EVA-L09-16bit_srgb8_dt.png", + "external/raw.pixls/Nikon-D300-12bit_srgb8_dt.png", + "external/raw.pixls/Sony-DSC-RX1RM2-14bit_srgb8_v4_krita.png"}; + for (const char* relative_pathname : paths) { + const CodecInOut io = + DecodeRoundtrip(relative_pathname, Codec::kPNG, &pool); + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + EXPECT_EQ(64, io.xsize()); + EXPECT_EQ(64, io.ysize()); + EXPECT_FALSE(io.metadata.m.HasAlpha()); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(WhitePoint::kD65, c_original.white_point); + EXPECT_EQ(Primaries::kSRGB, c_original.primaries); + EXPECT_TRUE(c_original.tf.IsSRGB()); + } +} + +TEST(CodecTest, TestMetadataLinear) { + ThreadPoolInternal pool(12); + + const char* paths[3] = { + "external/raw.pixls/Google-Pixel2XL-16bit_acescg_g1_v4_krita.png", + "external/raw.pixls/HUAWEI-EVA-L09-16bit_709_g1_dt.png", + "external/raw.pixls/Nikon-D300-12bit_2020_g1_dt.png", + }; + const WhitePoint white_points[3] = {WhitePoint::kCustom, WhitePoint::kD65, + WhitePoint::kD65}; + const Primaries primaries[3] = {Primaries::kCustom, Primaries::kSRGB, + Primaries::k2100}; + + for (size_t i = 0; i < 3; ++i) { + const CodecInOut io = DecodeRoundtrip(paths[i], Codec::kPNG, &pool); + EXPECT_EQ(16, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + EXPECT_EQ(64, io.xsize()); + EXPECT_EQ(64, io.ysize()); + EXPECT_FALSE(io.metadata.m.HasAlpha()); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(white_points[i], c_original.white_point); + EXPECT_EQ(primaries[i], c_original.primaries); + EXPECT_TRUE(c_original.tf.IsLinear()); + } +} + +TEST(CodecTest, TestMetadataICC) { + ThreadPoolInternal pool(12); + + const char* paths[] = { + "external/raw.pixls/DJI-FC6310-16bit_709_v4_krita.png", + "external/raw.pixls/Sony-DSC-RX1RM2-14bit_709_v4_krita.png", + }; + for (const char* relative_pathname : paths) { + const CodecInOut io = + DecodeRoundtrip(relative_pathname, Codec::kPNG, &pool); + EXPECT_GE(16, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_LE(14, io.metadata.m.bit_depth.bits_per_sample); + + EXPECT_EQ(64, io.xsize()); + EXPECT_EQ(64, io.ysize()); + EXPECT_FALSE(io.metadata.m.HasAlpha()); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(RenderingIntent::kPerceptual, c_original.rendering_intent); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(WhitePoint::kD65, c_original.white_point); + EXPECT_EQ(Primaries::kSRGB, c_original.primaries); + EXPECT_EQ(TransferFunction::k709, c_original.tf.GetTransferFunction()); + } +} + +TEST(CodecTest, Testexternal/pngsuite) { + ThreadPoolInternal pool(12); + + // Ensure we can load PNG with text, japanese UTF-8, compressed text. + (void)DecodeRoundtrip("external/pngsuite/ct1n0g04.png", Codec::kPNG, &pool); + (void)DecodeRoundtrip("external/pngsuite/ctjn0g04.png", Codec::kPNG, &pool); + (void)DecodeRoundtrip("external/pngsuite/ctzn0g04.png", Codec::kPNG, &pool); + + // Extract gAMA + const CodecInOut b1 = + DecodeRoundtrip("external/pngsuite/g10n3p04.png", Codec::kPNG, &pool); + EXPECT_TRUE(b1.metadata.color_encoding.tf.IsLinear()); + + // Extract cHRM + const CodecInOut b_p = + DecodeRoundtrip("external/pngsuite/ccwn2c08.png", Codec::kPNG, &pool); + EXPECT_EQ(Primaries::kSRGB, b_p.metadata.color_encoding.primaries); + EXPECT_EQ(WhitePoint::kD65, b_p.metadata.color_encoding.white_point); + + // Extract EXIF from (new-style) dedicated chunk + const CodecInOut b_exif = + DecodeRoundtrip("external/pngsuite/exif2c08.png", Codec::kPNG, &pool); + EXPECT_EQ(978, b_exif.blobs.exif.size()); +} +#endif + +void VerifyWideGamutMetadata(const std::string& relative_pathname, + const Primaries primaries, ThreadPool* pool) { + const CodecInOut io = DecodeRoundtrip(relative_pathname, pool); + + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(RenderingIntent::kAbsolute, c_original.rendering_intent); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(WhitePoint::kD65, c_original.white_point); + EXPECT_EQ(primaries, c_original.primaries); +} + +TEST(CodecTest, TestWideGamut) { + ThreadPoolInternal pool(12); + // VerifyWideGamutMetadata("external/wide-gamut-tests/P3-sRGB-color-bars.png", + // Primaries::kP3, &pool); + VerifyWideGamutMetadata("external/wide-gamut-tests/P3-sRGB-color-ring.png", + Primaries::kP3, &pool); + // VerifyWideGamutMetadata("external/wide-gamut-tests/R2020-sRGB-color-bars.png", + // Primaries::k2100, &pool); + // VerifyWideGamutMetadata("external/wide-gamut-tests/R2020-sRGB-color-ring.png", + // Primaries::k2100, &pool); +} + +TEST(CodecTest, TestPNM) { TestCodecPNM(); } + +TEST(CodecTest, FormatNegotiation) { + const std::vector accepted_formats = { + {/*num_channels=*/4, + /*data_type=*/JXL_TYPE_UINT16, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0}, + {/*num_channels=*/3, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0}, + {/*num_channels=*/3, + /*data_type=*/JXL_TYPE_UINT16, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0}, + {/*num_channels=*/1, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0}, + }; + + JxlBasicInfo info; + JxlEncoderInitBasicInfo(&info); + info.bits_per_sample = 12; + info.num_color_channels = 2; + + JxlPixelFormat format; + EXPECT_FALSE(SelectFormat(accepted_formats, info, &format)); + + info.num_color_channels = 3; + ASSERT_TRUE(SelectFormat(accepted_formats, info, &format)); + EXPECT_EQ(format.num_channels, info.num_color_channels); + // 16 is the smallest accepted format that can accommodate the 12-bit data. + EXPECT_EQ(format.data_type, JXL_TYPE_UINT16); +} + +TEST(CodecTest, EncodeToPNG) { + ThreadPool* const pool = nullptr; + + std::unique_ptr png_encoder = Encoder::FromExtension(".png"); + ASSERT_THAT(png_encoder, NotNull()); + + const PaddedBytes original_png = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + PackedPixelFile ppf; + ASSERT_TRUE(extras::DecodeBytes(Span(original_png), + ColorHints(), SizeConstraints(), &ppf)); + + const JxlPixelFormat& format = ppf.frames.front().color.format; + ASSERT_THAT( + png_encoder->AcceptedFormats(), + Contains(AllOf(Field(&JxlPixelFormat::num_channels, format.num_channels), + Field(&JxlPixelFormat::data_type, format.data_type), + Field(&JxlPixelFormat::endianness, format.endianness)))); + EncodedImage encoded_png; + ASSERT_TRUE(png_encoder->Encode(ppf, &encoded_png, pool)); + EXPECT_THAT(encoded_png.icc, IsEmpty()); + ASSERT_THAT(encoded_png.bitstreams, SizeIs(1)); + + PackedPixelFile decoded_ppf; + ASSERT_TRUE( + extras::DecodeBytes(Span(encoded_png.bitstreams.front()), + ColorHints(), SizeConstraints(), &decoded_ppf)); + + ASSERT_EQ(decoded_ppf.info.bits_per_sample, ppf.info.bits_per_sample); + ASSERT_EQ(decoded_ppf.frames.size(), 1); + VerifySameImage(ppf.frames[0].color, ppf.info.bits_per_sample, + decoded_ppf.frames[0].color, + decoded_ppf.info.bits_per_sample); +} + +} // namespace +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/apng.cc b/media/libjxl/src/lib/extras/dec/apng.cc new file mode 100644 index 000000000..566746665 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/apng.cc @@ -0,0 +1,797 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/apng.h" + +// Parts of this code are taken from apngdis, which has the following license: +/* APNG Disassembler 2.8 + * + * Deconstructs APNG files into individual frames. + * + * http://apngdis.sourceforge.net + * + * Copyright (c) 2010-2015 Max Stepin + * maxst at users.sourceforge.net + * + * zlib license + * ------------ + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + * + */ + +#include +#include + +#include +#include +#include + +#include "jxl/codestream_header.h" +#include "jxl/encode.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/scope_guard.h" +#include "lib/jxl/common.h" +#include "lib/jxl/sanitizers.h" +#include "png.h" /* original (unpatched) libpng is ok */ + +namespace jxl { +namespace extras { + +namespace { + +/* hIST chunk tail is not proccesed properly; skip this chunk completely; + see https://github.com/glennrp/libpng/pull/413 */ +const png_byte kIgnoredPngChunks[] = { + 104, 73, 83, 84, '\0' /* hIST */ +}; + +// Returns floating-point value from the PNG encoding (times 10^5). +static double F64FromU32(const uint32_t x) { + return static_cast(x) * 1E-5; +} + +Status DecodeSRGB(const unsigned char* payload, const size_t payload_size, + JxlColorEncoding* color_encoding) { + if (payload_size != 1) return JXL_FAILURE("Wrong sRGB size"); + // (PNG uses the same values as ICC.) + if (payload[0] >= 4) return JXL_FAILURE("Invalid Rendering Intent"); + color_encoding->rendering_intent = + static_cast(payload[0]); + return true; +} + +Status DecodeGAMA(const unsigned char* payload, const size_t payload_size, + JxlColorEncoding* color_encoding) { + if (payload_size != 4) return JXL_FAILURE("Wrong gAMA size"); + color_encoding->transfer_function = JXL_TRANSFER_FUNCTION_GAMMA; + color_encoding->gamma = F64FromU32(LoadBE32(payload)); + return true; +} + +Status DecodeCHRM(const unsigned char* payload, const size_t payload_size, + JxlColorEncoding* color_encoding) { + if (payload_size != 32) return JXL_FAILURE("Wrong cHRM size"); + + color_encoding->white_point = JXL_WHITE_POINT_CUSTOM; + color_encoding->white_point_xy[0] = F64FromU32(LoadBE32(payload + 0)); + color_encoding->white_point_xy[1] = F64FromU32(LoadBE32(payload + 4)); + + color_encoding->primaries = JXL_PRIMARIES_CUSTOM; + color_encoding->primaries_red_xy[0] = F64FromU32(LoadBE32(payload + 8)); + color_encoding->primaries_red_xy[1] = F64FromU32(LoadBE32(payload + 12)); + color_encoding->primaries_green_xy[0] = F64FromU32(LoadBE32(payload + 16)); + color_encoding->primaries_green_xy[1] = F64FromU32(LoadBE32(payload + 20)); + color_encoding->primaries_blue_xy[0] = F64FromU32(LoadBE32(payload + 24)); + color_encoding->primaries_blue_xy[1] = F64FromU32(LoadBE32(payload + 28)); + return true; +} + +// Retrieves XMP and EXIF/IPTC from itext and text. +class BlobsReaderPNG { + public: + static Status Decode(const png_text_struct& info, PackedMetadata* metadata) { + // We trust these are properly null-terminated by libpng. + const char* key = info.key; + const char* value = info.text; + if (strstr(key, "XML:com.adobe.xmp")) { + metadata->xmp.resize(strlen(value)); // safe, see above + memcpy(metadata->xmp.data(), value, metadata->xmp.size()); + } + + std::string type; + std::vector bytes; + + // Handle text chunks annotated with key "Raw profile type ####", with + // #### a type, which may contain metadata. + const char* kKey = "Raw profile type "; + if (strncmp(key, kKey, strlen(kKey)) != 0) return false; + + if (!MaybeDecodeBase16(key, value, &type, &bytes)) { + JXL_WARNING("Couldn't parse 'Raw format type' text chunk"); + return false; + } + if (type == "exif") { + if (!metadata->exif.empty()) { + JXL_WARNING("overwriting EXIF (%" PRIuS " bytes) with base16 (%" PRIuS + " bytes)", + metadata->exif.size(), bytes.size()); + } + metadata->exif = std::move(bytes); + } else if (type == "iptc") { + // TODO (jon): Deal with IPTC in some way + } else if (type == "8bim") { + // TODO (jon): Deal with 8bim in some way + } else if (type == "xmp") { + if (!metadata->xmp.empty()) { + JXL_WARNING("overwriting XMP (%" PRIuS " bytes) with base16 (%" PRIuS + " bytes)", + metadata->xmp.size(), bytes.size()); + } + metadata->xmp = std::move(bytes); + } else { + JXL_WARNING("Unknown type in 'Raw format type' text chunk: %s: %" PRIuS + " bytes", + type.c_str(), bytes.size()); + } + return true; + } + + private: + // Returns false if invalid. + static JXL_INLINE Status DecodeNibble(const char c, + uint32_t* JXL_RESTRICT nibble) { + if ('a' <= c && c <= 'f') { + *nibble = 10 + c - 'a'; + } else if ('0' <= c && c <= '9') { + *nibble = c - '0'; + } else { + *nibble = 0; + return JXL_FAILURE("Invalid metadata nibble"); + } + JXL_ASSERT(*nibble < 16); + return true; + } + + // Returns false if invalid. + static JXL_INLINE Status DecodeDecimal(const char** pos, const char* end, + uint32_t* JXL_RESTRICT value) { + size_t len = 0; + *value = 0; + while (*pos < end) { + char next = **pos; + if (next >= '0' && next <= '9') { + *value = (*value * 10) + static_cast(next - '0'); + len++; + if (len > 8) { + break; + } + } else { + // Do not consume terminator (non-decimal digit). + break; + } + (*pos)++; + } + if (len == 0 || len > 8) { + return JXL_FAILURE("Failed to parse decimal"); + } + return true; + } + + // Parses a PNG text chunk with key of the form "Raw profile type ####", with + // #### a type. + // Returns whether it could successfully parse the content. + // We trust key and encoded are null-terminated because they come from + // libpng. + static Status MaybeDecodeBase16(const char* key, const char* encoded, + std::string* type, + std::vector* bytes) { + const char* encoded_end = encoded + strlen(encoded); + + const char* kKey = "Raw profile type "; + if (strncmp(key, kKey, strlen(kKey)) != 0) return false; + *type = key + strlen(kKey); + const size_t kMaxTypeLen = 20; + if (type->length() > kMaxTypeLen) return false; // Type too long + + // Header: freeform string and number of bytes + // Expected format is: + // \n + // profile name/description\n + // 40\n (the number of bytes after hex-decoding) + // 01234566789abcdef....\n (72 bytes per line max). + // 012345667\n (last line) + const char* pos = encoded; + + if (*(pos++) != '\n') return false; + while (pos < encoded_end && *pos != '\n') { + pos++; + } + if (pos == encoded_end) return false; + // We parsed so far a \n, some number of non \n characters and are now + // pointing at a \n. + if (*(pos++) != '\n') return false; + uint32_t bytes_to_decode = 0; + JXL_RETURN_IF_ERROR(DecodeDecimal(&pos, encoded_end, &bytes_to_decode)); + + // We need 2*bytes for the hex values plus 1 byte every 36 values, + // plus terminal \n for length. + const unsigned long needed_bytes = + bytes_to_decode * 2 + 1 + DivCeil(bytes_to_decode, 36); + if (needed_bytes != static_cast(encoded_end - pos)) { + return JXL_FAILURE("Not enough bytes to parse %d bytes in hex", + bytes_to_decode); + } + JXL_ASSERT(bytes->empty()); + bytes->reserve(bytes_to_decode); + + // Encoding: base16 with newline after 72 chars. + // pos points to the \n before the first line of hex values. + for (size_t i = 0; i < bytes_to_decode; ++i) { + if (i % 36 == 0) { + if (pos + 1 >= encoded_end) return false; // Truncated base16 1 + if (*pos != '\n') return false; // Expected newline + ++pos; + } + + if (pos + 2 >= encoded_end) return false; // Truncated base16 2; + uint32_t nibble0, nibble1; + JXL_RETURN_IF_ERROR(DecodeNibble(pos[0], &nibble0)); + JXL_RETURN_IF_ERROR(DecodeNibble(pos[1], &nibble1)); + bytes->push_back(static_cast((nibble0 << 4) + nibble1)); + pos += 2; + } + if (pos + 1 != encoded_end) return false; // Too many encoded bytes + if (pos[0] != '\n') return false; // Incorrect metadata terminator + return true; + } +}; + +constexpr bool isAbc(char c) { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); +} + +constexpr uint32_t kId_IHDR = 0x52444849; +constexpr uint32_t kId_acTL = 0x4C546361; +constexpr uint32_t kId_fcTL = 0x4C546366; +constexpr uint32_t kId_IDAT = 0x54414449; +constexpr uint32_t kId_fdAT = 0x54416466; +constexpr uint32_t kId_IEND = 0x444E4549; +constexpr uint32_t kId_iCCP = 0x50434369; +constexpr uint32_t kId_sRGB = 0x42475273; +constexpr uint32_t kId_gAMA = 0x414D4167; +constexpr uint32_t kId_cHRM = 0x4D524863; +constexpr uint32_t kId_eXIf = 0x66495865; + +struct APNGFrame { + std::vector pixels; + std::vector rows; + unsigned int w, h, delay_num, delay_den; +}; + +struct Reader { + const uint8_t* next; + const uint8_t* last; + bool Read(void* data, size_t len) { + size_t cap = last - next; + size_t to_copy = std::min(cap, len); + memcpy(data, next, to_copy); + next += to_copy; + return (len == to_copy); + } + bool Eof() { return next == last; } +}; + +const unsigned long cMaxPNGSize = 1000000UL; +const size_t kMaxPNGChunkSize = 1lu << 30; // 1 GB + +void info_fn(png_structp png_ptr, png_infop info_ptr) { + png_set_expand(png_ptr); + png_set_palette_to_rgb(png_ptr); + png_set_tRNS_to_alpha(png_ptr); + (void)png_set_interlace_handling(png_ptr); + png_read_update_info(png_ptr, info_ptr); +} + +void row_fn(png_structp png_ptr, png_bytep new_row, png_uint_32 row_num, + int pass) { + APNGFrame* frame = (APNGFrame*)png_get_progressive_ptr(png_ptr); + JXL_CHECK(frame); + JXL_CHECK(row_num < frame->rows.size()); + JXL_CHECK(frame->rows[row_num] < frame->pixels.data() + frame->pixels.size()); + png_progressive_combine_row(png_ptr, frame->rows[row_num], new_row); +} + +inline unsigned int read_chunk(Reader* r, std::vector* pChunk) { + unsigned char len[4]; + if (r->Read(&len, 4)) { + const auto size = png_get_uint_32(len); + // Check first, to avoid overflow. + if (size > kMaxPNGChunkSize) { + JXL_WARNING("APNG chunk size is too big"); + return 0; + } + pChunk->resize(size + 12); + memcpy(pChunk->data(), len, 4); + if (r->Read(pChunk->data() + 4, pChunk->size() - 4)) { + return LoadLE32(pChunk->data() + 4); + } + } + return 0; +} + +int processing_start(png_structp& png_ptr, png_infop& info_ptr, void* frame_ptr, + bool hasInfo, std::vector& chunkIHDR, + std::vector>& chunksInfo) { + unsigned char header[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + + png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL); + info_ptr = png_create_info_struct(png_ptr); + if (!png_ptr || !info_ptr) return 1; + + if (setjmp(png_jmpbuf(png_ptr))) { + return 1; + } + + png_set_keep_unknown_chunks(png_ptr, 1, kIgnoredPngChunks, + (int)sizeof(kIgnoredPngChunks) / 5); + + png_set_crc_action(png_ptr, PNG_CRC_QUIET_USE, PNG_CRC_QUIET_USE); + png_set_progressive_read_fn(png_ptr, frame_ptr, info_fn, row_fn, NULL); + + png_process_data(png_ptr, info_ptr, header, 8); + png_process_data(png_ptr, info_ptr, chunkIHDR.data(), chunkIHDR.size()); + + if (hasInfo) { + for (unsigned int i = 0; i < chunksInfo.size(); i++) { + png_process_data(png_ptr, info_ptr, chunksInfo[i].data(), + chunksInfo[i].size()); + } + } + return 0; +} + +int processing_data(png_structp png_ptr, png_infop info_ptr, unsigned char* p, + unsigned int size) { + if (!png_ptr || !info_ptr) return 1; + + if (setjmp(png_jmpbuf(png_ptr))) { + return 1; + } + + png_process_data(png_ptr, info_ptr, p, size); + return 0; +} + +int processing_finish(png_structp png_ptr, png_infop info_ptr, + PackedMetadata* metadata) { + unsigned char footer[12] = {0, 0, 0, 0, 73, 69, 78, 68, 174, 66, 96, 130}; + + if (!png_ptr || !info_ptr) return 1; + + if (setjmp(png_jmpbuf(png_ptr))) { + return 1; + } + + png_process_data(png_ptr, info_ptr, footer, 12); + // before destroying: check if we encountered any metadata chunks + png_textp text_ptr; + int num_text; + png_get_text(png_ptr, info_ptr, &text_ptr, &num_text); + for (int i = 0; i < num_text; i++) { + (void)BlobsReaderPNG::Decode(text_ptr[i], metadata); + } + + return 0; +} + +} // namespace + +Status DecodeImageAPNG(const Span bytes, + const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf) { + Reader r; + unsigned int id, j, w, h, w0, h0, x0, y0; + unsigned int delay_num, delay_den, dop, bop, rowbytes, imagesize; + unsigned char sig[8]; + png_structp png_ptr = nullptr; + png_infop info_ptr = nullptr; + std::vector chunk; + std::vector chunkIHDR; + std::vector> chunksInfo; + bool isAnimated = false; + bool hasInfo = false; + APNGFrame frameRaw = {}; + uint32_t num_channels; + JxlPixelFormat format; + unsigned int bytes_per_pixel = 0; + + struct FrameInfo { + PackedImage data; + uint32_t duration; + size_t x0, xsize; + size_t y0, ysize; + uint32_t dispose_op; + uint32_t blend_op; + }; + + std::vector frames; + + // Make sure png memory is released in any case. + auto scope_guard = MakeScopeGuard([&]() { + png_destroy_read_struct(&png_ptr, &info_ptr, 0); + // Just in case. Not all versions on libpng wipe-out the pointers. + png_ptr = nullptr; + info_ptr = nullptr; + }); + + r = {bytes.data(), bytes.data() + bytes.size()}; + // Not a PNG => not an error + unsigned char png_signature[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + if (!r.Read(sig, 8) || memcmp(sig, png_signature, 8) != 0) { + return false; + } + id = read_chunk(&r, &chunkIHDR); + + ppf->info.exponent_bits_per_sample = 0; + ppf->info.alpha_exponent_bits = 0; + ppf->info.orientation = JXL_ORIENT_IDENTITY; + + ppf->frames.clear(); + + bool have_color = false, have_srgb = false; + bool errorstate = true; + if (id == kId_IHDR && chunkIHDR.size() == 25) { + x0 = 0; + y0 = 0; + delay_num = 1; + delay_den = 10; + dop = 0; + bop = 0; + + w0 = w = png_get_uint_32(chunkIHDR.data() + 8); + h0 = h = png_get_uint_32(chunkIHDR.data() + 12); + if (w > cMaxPNGSize || h > cMaxPNGSize) { + return false; + } + + // default settings in case e.g. only gAMA is given + ppf->color_encoding.color_space = JXL_COLOR_SPACE_RGB; + ppf->color_encoding.white_point = JXL_WHITE_POINT_D65; + ppf->color_encoding.primaries = JXL_PRIMARIES_SRGB; + ppf->color_encoding.transfer_function = JXL_TRANSFER_FUNCTION_SRGB; + + if (!processing_start(png_ptr, info_ptr, (void*)&frameRaw, hasInfo, + chunkIHDR, chunksInfo)) { + while (!r.Eof()) { + id = read_chunk(&r, &chunk); + if (!id) break; + + if (id == kId_acTL && !hasInfo && !isAnimated) { + isAnimated = true; + ppf->info.have_animation = true; + ppf->info.animation.tps_numerator = 1000; + ppf->info.animation.tps_denominator = 1; + } else if (id == kId_IEND || + (id == kId_fcTL && (!hasInfo || isAnimated))) { + if (hasInfo) { + if (!processing_finish(png_ptr, info_ptr, &ppf->metadata)) { + // Allocates the frame buffer. + uint32_t duration = delay_num * 1000 / delay_den; + frames.push_back(FrameInfo{PackedImage(w0, h0, format), duration, + x0, w0, y0, h0, dop, bop}); + auto& frame = frames.back().data; + for (size_t y = 0; y < h0; ++y) { + memcpy(static_cast(frame.pixels()) + frame.stride * y, + frameRaw.rows[y], bytes_per_pixel * w0); + } + } else { + break; + } + } + + if (id == kId_IEND) { + errorstate = false; + break; + } + if (chunk.size() < 34) { + return JXL_FAILURE("Received a chunk that is too small (%" PRIuS + "B)", + chunk.size()); + } + // At this point the old frame is done. Let's start a new one. + w0 = png_get_uint_32(chunk.data() + 12); + h0 = png_get_uint_32(chunk.data() + 16); + x0 = png_get_uint_32(chunk.data() + 20); + y0 = png_get_uint_32(chunk.data() + 24); + delay_num = png_get_uint_16(chunk.data() + 28); + delay_den = png_get_uint_16(chunk.data() + 30); + dop = chunk[32]; + bop = chunk[33]; + + if (!delay_den) delay_den = 100; + + if (w0 > cMaxPNGSize || h0 > cMaxPNGSize || x0 > cMaxPNGSize || + y0 > cMaxPNGSize || x0 + w0 > w || y0 + h0 > h || dop > 2 || + bop > 1) { + break; + } + + if (hasInfo) { + memcpy(chunkIHDR.data() + 8, chunk.data() + 12, 8); + if (processing_start(png_ptr, info_ptr, (void*)&frameRaw, hasInfo, + chunkIHDR, chunksInfo)) { + break; + } + } + } else if (id == kId_IDAT) { + // First IDAT chunk means we now have all header info + hasInfo = true; + JXL_CHECK(w == png_get_image_width(png_ptr, info_ptr)); + JXL_CHECK(h == png_get_image_height(png_ptr, info_ptr)); + int colortype = png_get_color_type(png_ptr, info_ptr); + ppf->info.bits_per_sample = png_get_bit_depth(png_ptr, info_ptr); + png_color_8p sigbits = NULL; + png_get_sBIT(png_ptr, info_ptr, &sigbits); + if (colortype & 1) { + // palette will actually be 8-bit regardless of the index bitdepth + ppf->info.bits_per_sample = 8; + } + if (colortype & 2) { + ppf->info.num_color_channels = 3; + ppf->color_encoding.color_space = JXL_COLOR_SPACE_RGB; + if (sigbits && sigbits->red == sigbits->green && + sigbits->green == sigbits->blue) + ppf->info.bits_per_sample = sigbits->red; + } else { + ppf->info.num_color_channels = 1; + ppf->color_encoding.color_space = JXL_COLOR_SPACE_GRAY; + if (sigbits) ppf->info.bits_per_sample = sigbits->gray; + } + if (colortype & 4 || + png_get_valid(png_ptr, info_ptr, PNG_INFO_tRNS)) { + ppf->info.alpha_bits = ppf->info.bits_per_sample; + if (sigbits) { + if (sigbits->alpha && + sigbits->alpha != ppf->info.bits_per_sample) { + return JXL_FAILURE("Unsupported alpha bit-depth"); + } + ppf->info.alpha_bits = sigbits->alpha; + } + } else { + ppf->info.alpha_bits = 0; + } + ppf->color_encoding.color_space = + (ppf->info.num_color_channels == 1 ? JXL_COLOR_SPACE_GRAY + : JXL_COLOR_SPACE_RGB); + ppf->info.xsize = w; + ppf->info.ysize = h; + JXL_RETURN_IF_ERROR(VerifyDimensions(&constraints, w, h)); + num_channels = + ppf->info.num_color_channels + (ppf->info.alpha_bits ? 1 : 0); + format = { + /*num_channels=*/num_channels, + /*data_type=*/ppf->info.bits_per_sample > 8 ? JXL_TYPE_UINT16 + : JXL_TYPE_UINT8, + /*endianness=*/JXL_BIG_ENDIAN, + /*align=*/0, + }; + bytes_per_pixel = + num_channels * (format.data_type == JXL_TYPE_UINT16 ? 2 : 1); + rowbytes = w * bytes_per_pixel; + imagesize = h * rowbytes; + frameRaw.pixels.resize(imagesize); + frameRaw.rows.resize(h); + for (j = 0; j < h; j++) + frameRaw.rows[j] = frameRaw.pixels.data() + j * rowbytes; + + if (processing_data(png_ptr, info_ptr, chunk.data(), chunk.size())) { + break; + } + } else if (id == kId_fdAT && isAnimated) { + png_save_uint_32(chunk.data() + 4, chunk.size() - 16); + memcpy(chunk.data() + 8, "IDAT", 4); + if (processing_data(png_ptr, info_ptr, chunk.data() + 4, + chunk.size() - 4)) { + break; + } + } else if (id == kId_iCCP) { + if (processing_data(png_ptr, info_ptr, chunk.data(), chunk.size())) { + JXL_WARNING("Corrupt iCCP chunk"); + break; + } + + // TODO(jon): catch special case of PQ and synthesize color encoding + // in that case + int compression_type; + png_bytep profile; + png_charp name; + png_uint_32 proflen = 0; + auto ok = png_get_iCCP(png_ptr, info_ptr, &name, &compression_type, + &profile, &proflen); + if (ok && proflen) { + ppf->icc.assign(profile, profile + proflen); + have_color = true; + } else { + // TODO(eustas): JXL_WARNING? + } + } else if (id == kId_sRGB) { + JXL_RETURN_IF_ERROR(DecodeSRGB(chunk.data() + 8, chunk.size() - 12, + &ppf->color_encoding)); + have_srgb = true; + have_color = true; + } else if (id == kId_gAMA) { + JXL_RETURN_IF_ERROR(DecodeGAMA(chunk.data() + 8, chunk.size() - 12, + &ppf->color_encoding)); + have_color = true; + } else if (id == kId_cHRM) { + JXL_RETURN_IF_ERROR(DecodeCHRM(chunk.data() + 8, chunk.size() - 12, + &ppf->color_encoding)); + have_color = true; + } else if (id == kId_eXIf) { + ppf->metadata.exif.resize(chunk.size() - 12); + memcpy(ppf->metadata.exif.data(), chunk.data() + 8, + chunk.size() - 12); + } else if (!isAbc(chunk[4]) || !isAbc(chunk[5]) || !isAbc(chunk[6]) || + !isAbc(chunk[7])) { + break; + } else { + if (processing_data(png_ptr, info_ptr, chunk.data(), chunk.size())) { + break; + } + if (!hasInfo) { + chunksInfo.push_back(chunk); + continue; + } + } + } + } + + if (have_srgb) { + ppf->color_encoding.white_point = JXL_WHITE_POINT_D65; + ppf->color_encoding.primaries = JXL_PRIMARIES_SRGB; + ppf->color_encoding.transfer_function = JXL_TRANSFER_FUNCTION_SRGB; + ppf->color_encoding.rendering_intent = JXL_RENDERING_INTENT_PERCEPTUAL; + } + JXL_RETURN_IF_ERROR(ApplyColorHints( + color_hints, have_color, ppf->info.num_color_channels == 1, ppf)); + } + + if (errorstate) return false; + + bool has_nontrivial_background = false; + bool previous_frame_should_be_cleared = false; + enum { + DISPOSE_OP_NONE = 0, + DISPOSE_OP_BACKGROUND = 1, + DISPOSE_OP_PREVIOUS = 2, + }; + enum { + BLEND_OP_SOURCE = 0, + BLEND_OP_OVER = 1, + }; + for (size_t i = 0; i < frames.size(); i++) { + auto& frame = frames[i]; + JXL_ASSERT(frame.data.xsize == frame.xsize); + JXL_ASSERT(frame.data.ysize == frame.ysize); + + // Before encountering a DISPOSE_OP_NONE frame, the canvas is filled with 0, + // so DISPOSE_OP_BACKGROUND and DISPOSE_OP_PREVIOUS are equivalent. + if (frame.dispose_op == DISPOSE_OP_NONE) { + has_nontrivial_background = true; + } + bool should_blend = frame.blend_op == BLEND_OP_OVER; + bool use_for_next_frame = + has_nontrivial_background && frame.dispose_op != DISPOSE_OP_PREVIOUS; + size_t x0 = frame.x0; + size_t y0 = frame.y0; + size_t xsize = frame.data.xsize; + size_t ysize = frame.data.ysize; + if (previous_frame_should_be_cleared) { + size_t xs = frame.data.xsize; + size_t ys = frame.data.ysize; + size_t px0 = frames[i - 1].x0; + size_t py0 = frames[i - 1].y0; + size_t pxs = frames[i - 1].xsize; + size_t pys = frames[i - 1].ysize; + if (px0 >= x0 && py0 >= y0 && px0 + pxs <= x0 + xs && + py0 + pys <= y0 + ys && frame.blend_op == BLEND_OP_SOURCE && + use_for_next_frame) { + // If the previous frame is entirely contained in the current frame and + // we are using BLEND_OP_SOURCE, nothing special needs to be done. + ppf->frames.emplace_back(std::move(frame.data)); + } else if (px0 == x0 && py0 == y0 && px0 + pxs == x0 + xs && + py0 + pys == y0 + ys && use_for_next_frame) { + // If the new frame has the same size as the old one, but we are + // blending, we can instead just not blend. + should_blend = false; + ppf->frames.emplace_back(std::move(frame.data)); + } else if (px0 <= x0 && py0 <= y0 && px0 + pxs >= x0 + xs && + py0 + pys >= y0 + ys && use_for_next_frame) { + // If the new frame is contained within the old frame, we can pad the + // new frame with zeros and not blend. + PackedImage new_data(pxs, pys, frame.data.format); + memset(new_data.pixels(), 0, new_data.pixels_size); + for (size_t y = 0; y < ys; y++) { + size_t bytes_per_pixel = + PackedImage::BitsPerChannel(new_data.format.data_type) * + new_data.format.num_channels / 8; + memcpy(static_cast(new_data.pixels()) + + new_data.stride * (y + y0 - py0) + + bytes_per_pixel * (x0 - px0), + static_cast(frame.data.pixels()) + + frame.data.stride * y, + xs * bytes_per_pixel); + } + + x0 = px0; + y0 = py0; + xsize = pxs; + ysize = pys; + should_blend = false; + ppf->frames.emplace_back(std::move(new_data)); + } else { + // If all else fails, insert a dummy blank frame with kReplace. + PackedImage blank(pxs, pys, frame.data.format); + memset(blank.pixels(), 0, blank.pixels_size); + ppf->frames.emplace_back(std::move(blank)); + auto& pframe = ppf->frames.back(); + pframe.frame_info.layer_info.crop_x0 = px0; + pframe.frame_info.layer_info.crop_y0 = py0; + pframe.frame_info.layer_info.xsize = frame.xsize; + pframe.frame_info.layer_info.ysize = frame.ysize; + pframe.frame_info.duration = 0; + pframe.frame_info.layer_info.have_crop = 0; + pframe.frame_info.layer_info.blend_info.blendmode = JXL_BLEND_REPLACE; + pframe.frame_info.layer_info.blend_info.source = 0; + pframe.frame_info.layer_info.save_as_reference = 1; + ppf->frames.emplace_back(std::move(frame.data)); + } + } else { + ppf->frames.emplace_back(std::move(frame.data)); + } + + auto& pframe = ppf->frames.back(); + pframe.frame_info.layer_info.crop_x0 = x0; + pframe.frame_info.layer_info.crop_y0 = y0; + pframe.frame_info.layer_info.xsize = xsize; + pframe.frame_info.layer_info.ysize = ysize; + pframe.frame_info.duration = frame.duration; + pframe.frame_info.layer_info.blend_info.blendmode = + should_blend ? JXL_BLEND_BLEND : JXL_BLEND_REPLACE; + bool is_full_size = x0 == 0 && y0 == 0 && xsize == ppf->info.xsize && + ysize == ppf->info.ysize; + pframe.frame_info.layer_info.have_crop = is_full_size ? 0 : 1; + pframe.frame_info.layer_info.blend_info.source = should_blend ? 1 : 0; + pframe.frame_info.layer_info.blend_info.alpha = 0; + pframe.frame_info.layer_info.save_as_reference = use_for_next_frame ? 1 : 0; + + previous_frame_should_be_cleared = + has_nontrivial_background && frame.dispose_op == DISPOSE_OP_BACKGROUND; + } + if (ppf->frames.empty()) return JXL_FAILURE("No frames decoded"); + ppf->frames.back().frame_info.is_last = true; + + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/apng.h b/media/libjxl/src/lib/extras/dec/apng.h new file mode 100644 index 000000000..a68f6f8ec --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/apng.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_APNG_H_ +#define LIB_EXTRAS_DEC_APNG_H_ + +// Decodes APNG images in memory. + +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Decodes `bytes` into `ppf`. +Status DecodeImageAPNG(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_APNG_H_ diff --git a/media/libjxl/src/lib/extras/dec/color_description.cc b/media/libjxl/src/lib/extras/dec/color_description.cc new file mode 100644 index 000000000..2325b50f3 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/color_description.cc @@ -0,0 +1,218 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/color_description.h" + +#include + +#include + +namespace jxl { + +namespace { + +template +struct EnumName { + const char* name; + T value; +}; + +const EnumName kJxlColorSpaceNames[] = { + {"RGB", JXL_COLOR_SPACE_RGB}, + {"Gra", JXL_COLOR_SPACE_GRAY}, + {"XYB", JXL_COLOR_SPACE_XYB}, + {"CS?", JXL_COLOR_SPACE_UNKNOWN}, +}; + +const EnumName kJxlWhitePointNames[] = { + {"D65", JXL_WHITE_POINT_D65}, + {"Cst", JXL_WHITE_POINT_CUSTOM}, + {"EER", JXL_WHITE_POINT_E}, + {"DCI", JXL_WHITE_POINT_DCI}, +}; + +const EnumName kJxlPrimariesNames[] = { + {"SRG", JXL_PRIMARIES_SRGB}, + {"Cst", JXL_PRIMARIES_CUSTOM}, + {"202", JXL_PRIMARIES_2100}, + {"DCI", JXL_PRIMARIES_P3}, +}; + +const EnumName kJxlTransferFunctionNames[] = { + {"709", JXL_TRANSFER_FUNCTION_709}, + {"TF?", JXL_TRANSFER_FUNCTION_UNKNOWN}, + {"Lin", JXL_TRANSFER_FUNCTION_LINEAR}, + {"SRG", JXL_TRANSFER_FUNCTION_SRGB}, + {"PeQ", JXL_TRANSFER_FUNCTION_PQ}, + {"DCI", JXL_TRANSFER_FUNCTION_DCI}, + {"HLG", JXL_TRANSFER_FUNCTION_HLG}, + {"", JXL_TRANSFER_FUNCTION_GAMMA}, +}; + +const EnumName kJxlRenderingIntentNames[] = { + {"Per", JXL_RENDERING_INTENT_PERCEPTUAL}, + {"Rel", JXL_RENDERING_INTENT_RELATIVE}, + {"Sat", JXL_RENDERING_INTENT_SATURATION}, + {"Abs", JXL_RENDERING_INTENT_ABSOLUTE}, +}; + +template +Status ParseEnum(const std::string& token, const EnumName* enum_values, + size_t enum_len, T* value) { + for (size_t i = 0; i < enum_len; i++) { + if (enum_values[i].name == token) { + *value = enum_values[i].value; + return true; + } + } + return false; +} +#define ARRAYSIZE(X) (sizeof(X) / sizeof((X)[0])) +#define PARSE_ENUM(type, token, value) \ + ParseEnum(token, k##type##Names, ARRAYSIZE(k##type##Names), value) + +class Tokenizer { + public: + Tokenizer(const std::string* input, char separator) + : input_(input), separator_(separator) {} + + Status Next(std::string* next) { + const size_t end = input_->find(separator_, start_); + if (end == std::string::npos) { + *next = input_->substr(start_); // rest of string + } else { + *next = input_->substr(start_, end - start_); + } + if (next->empty()) return JXL_FAILURE("Missing token"); + start_ = end + 1; + return true; + } + + private: + const std::string* const input_; // not owned + const char separator_; + size_t start_ = 0; // of next token +}; + +Status ParseDouble(const std::string& num, double* d) { + char* end; + errno = 0; + *d = strtod(num.c_str(), &end); + if (*d == 0.0 && end == num.c_str()) { + return JXL_FAILURE("Invalid double: %s", num.c_str()); + } + if (std::isnan(*d)) { + return JXL_FAILURE("Invalid double: %s", num.c_str()); + } + if (errno == ERANGE) { + return JXL_FAILURE("Double out of range: %s", num.c_str()); + } + return true; +} + +Status ParseDouble(Tokenizer* tokenizer, double* d) { + std::string num; + JXL_RETURN_IF_ERROR(tokenizer->Next(&num)); + return ParseDouble(num, d); +} + +Status ParseColorSpace(Tokenizer* tokenizer, JxlColorEncoding* c) { + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + JxlColorSpace cs; + if (PARSE_ENUM(JxlColorSpace, str, &cs)) { + c->color_space = cs; + return true; + } + + return JXL_FAILURE("Unknown ColorSpace %s", str.c_str()); +} + +Status ParseWhitePoint(Tokenizer* tokenizer, JxlColorEncoding* c) { + if (c->color_space == JXL_COLOR_SPACE_XYB) { + // Implicit white point. + c->white_point = JXL_WHITE_POINT_D65; + return true; + } + + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (PARSE_ENUM(JxlWhitePoint, str, &c->white_point)) return true; + + Tokenizer xy_tokenizer(&str, ';'); + c->white_point = JXL_WHITE_POINT_CUSTOM; + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->white_point_xy + 0)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->white_point_xy + 1)); + return true; +} + +Status ParsePrimaries(Tokenizer* tokenizer, JxlColorEncoding* c) { + if (c->color_space == JXL_COLOR_SPACE_GRAY || + c->color_space == JXL_COLOR_SPACE_XYB) { + // No primaries case. + return true; + } + + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (PARSE_ENUM(JxlPrimaries, str, &c->primaries)) return true; + + Tokenizer xy_tokenizer(&str, ';'); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->primaries_red_xy + 0)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->primaries_red_xy + 1)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->primaries_green_xy + 0)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->primaries_green_xy + 1)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->primaries_blue_xy + 0)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, c->primaries_blue_xy + 1)); + c->primaries = JXL_PRIMARIES_CUSTOM; + + return JXL_FAILURE("Invalid primaries %s", str.c_str()); +} + +Status ParseRenderingIntent(Tokenizer* tokenizer, JxlColorEncoding* c) { + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (PARSE_ENUM(JxlRenderingIntent, str, &c->rendering_intent)) return true; + + return JXL_FAILURE("Invalid RenderingIntent %s\n", str.c_str()); +} + +Status ParseTransferFunction(Tokenizer* tokenizer, JxlColorEncoding* c) { + if (c->color_space == JXL_COLOR_SPACE_XYB) { + // Implicit TF. + c->transfer_function = JXL_TRANSFER_FUNCTION_GAMMA; + c->gamma = 1 / 3.; + return true; + } + + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (PARSE_ENUM(JxlTransferFunction, str, &c->transfer_function)) { + return true; + } + + if (str[0] == 'g') { + JXL_RETURN_IF_ERROR(ParseDouble(str.substr(1), &c->gamma)); + c->transfer_function = JXL_TRANSFER_FUNCTION_GAMMA; + return true; + } + + return JXL_FAILURE("Invalid gamma %s", str.c_str()); +} + +} // namespace + +Status ParseDescription(const std::string& description, JxlColorEncoding* c) { + *c = {}; + Tokenizer tokenizer(&description, '_'); + JXL_RETURN_IF_ERROR(ParseColorSpace(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParseWhitePoint(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParsePrimaries(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParseRenderingIntent(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParseTransferFunction(&tokenizer, c)); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/color_description.h b/media/libjxl/src/lib/extras/dec/color_description.h new file mode 100644 index 000000000..989d5910c --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/color_description.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_COLOR_DESCRIPTION_H_ +#define LIB_EXTRAS_COLOR_DESCRIPTION_H_ + +#include + +#include "jxl/color_encoding.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Parse the color description into a JxlColorEncoding "RGB_D65_SRG_Rel_Lin". +Status ParseDescription(const std::string& description, + JxlColorEncoding* JXL_RESTRICT c); + +} // namespace jxl + +#endif // LIB_EXTRAS_COLOR_DESCRIPTION_H_ diff --git a/media/libjxl/src/lib/extras/dec/color_description_test.cc b/media/libjxl/src/lib/extras/dec/color_description_test.cc new file mode 100644 index 000000000..8ae9e5dce --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/color_description_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/color_description.h" + +#include "gtest/gtest.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { + +// Verify ParseDescription(Description) yields the same ColorEncoding +TEST(ColorDescriptionTest, RoundTripAll) { + for (const auto& cdesc : test::AllEncodings()) { + const ColorEncoding c_original = test::ColorEncodingFromDescriptor(cdesc); + const std::string description = Description(c_original); + printf("%s\n", description.c_str()); + + JxlColorEncoding c_external = {}; + EXPECT_TRUE(ParseDescription(description, &c_external)); + ColorEncoding c_internal; + EXPECT_TRUE( + ConvertExternalToInternalColorEncoding(c_external, &c_internal)); + EXPECT_TRUE(c_original.SameColorEncoding(c_internal)) + << "Where c_original=" << c_original + << " and c_internal=" << c_internal; + } +} + +TEST(ColorDescriptionTest, NanGamma) { + const std::string description = "Gra_2_Per_gnan"; + JxlColorEncoding c; + EXPECT_FALSE(ParseDescription(description, &c)); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/color_hints.cc b/media/libjxl/src/lib/extras/dec/color_hints.cc new file mode 100644 index 000000000..cf7d3e31f --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/color_hints.cc @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/color_hints.h" + +#include "jxl/encode.h" +#include "lib/extras/dec/color_description.h" +#include "lib/jxl/base/file_io.h" + +namespace jxl { +namespace extras { + +Status ApplyColorHints(const ColorHints& color_hints, + const bool color_already_set, const bool is_gray, + PackedPixelFile* ppf) { + if (color_already_set) { + return color_hints.Foreach( + [](const std::string& key, const std::string& /*value*/) { + JXL_WARNING("Decoder ignoring %s hint", key.c_str()); + return true; + }); + } + + bool got_color_space = false; + + JXL_RETURN_IF_ERROR(color_hints.Foreach( + [is_gray, ppf, &got_color_space](const std::string& key, + const std::string& value) -> Status { + if (key == "color_space") { + JxlColorEncoding c_original_external; + if (!ParseDescription(value, &c_original_external)) { + return JXL_FAILURE("Failed to apply color_space"); + } + ppf->color_encoding = c_original_external; + + if (is_gray != + (ppf->color_encoding.color_space == JXL_COLOR_SPACE_GRAY)) { + return JXL_FAILURE("mismatch between file and color_space hint"); + } + + got_color_space = true; + } else if (key == "icc_pathname") { + JXL_RETURN_IF_ERROR(ReadFile(value, &ppf->icc)); + got_color_space = true; + } else { + JXL_WARNING("Ignoring %s hint", key.c_str()); + } + return true; + })); + + if (!got_color_space) { + JXL_WARNING("No color_space/icc_pathname given, assuming sRGB"); + ppf->color_encoding.color_space = + is_gray ? JXL_COLOR_SPACE_GRAY : JXL_COLOR_SPACE_RGB; + ppf->color_encoding.white_point = JXL_WHITE_POINT_D65; + ppf->color_encoding.primaries = JXL_PRIMARIES_SRGB; + ppf->color_encoding.transfer_function = JXL_TRANSFER_FUNCTION_SRGB; + } + + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/color_hints.h b/media/libjxl/src/lib/extras/dec/color_hints.h new file mode 100644 index 000000000..9c7de884f --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/color_hints.h @@ -0,0 +1,72 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_COLOR_HINTS_H_ +#define LIB_EXTRAS_COLOR_HINTS_H_ + +// Not all the formats implemented in the extras lib support bundling color +// information into the file, and those that support it may not have it. +// To allow attaching color information to those file formats the caller can +// define these color hints. + +#include +#include + +#include +#include + +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace extras { + +class ColorHints { + public: + // key=color_space, value=Description(c/pp): specify the ColorEncoding of + // the pixels for decoding. Otherwise, if the codec did not obtain an ICC + // profile from the image, assume sRGB. + // + // Strings are taken from the command line, so avoid spaces for convenience. + void Add(const std::string& key, const std::string& value) { + kv_.emplace_back(key, value); + } + + // Calls `func(key, value)` for each key/value in the order they were added, + // returning false immediately if `func` returns false. + template + Status Foreach(const Func& func) const { + for (const KeyValue& kv : kv_) { + Status ok = func(kv.key, kv.value); + if (!ok) { + return JXL_FAILURE("ColorHints::Foreach returned false"); + } + } + return true; + } + + private: + // Splitting into key/value avoids parsing in each codec. + struct KeyValue { + KeyValue(std::string key, std::string value) + : key(std::move(key)), value(std::move(value)) {} + + std::string key; + std::string value; + }; + + std::vector kv_; +}; + +// Apply the color hints to the decoded image in PackedPixelFile if any. +// color_already_set tells whether the color encoding was already set, in which +// case the hints are ignored if any hint is passed. +Status ApplyColorHints(const ColorHints& color_hints, bool color_already_set, + bool is_gray, PackedPixelFile* ppf); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_COLOR_HINTS_H_ diff --git a/media/libjxl/src/lib/extras/dec/decode.cc b/media/libjxl/src/lib/extras/dec/decode.cc new file mode 100644 index 000000000..8712e03aa --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/decode.cc @@ -0,0 +1,128 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/decode.h" + +#include + +#if JPEGXL_ENABLE_APNG +#include "lib/extras/dec/apng.h" +#endif +#if JPEGXL_ENABLE_EXR +#include "lib/extras/dec/exr.h" +#endif +#if JPEGXL_ENABLE_GIF +#include "lib/extras/dec/gif.h" +#endif +#if JPEGXL_ENABLE_JPEG +#include "lib/extras/dec/jpg.h" +#endif +#include "lib/extras/dec/pgx.h" +#include "lib/extras/dec/pnm.h" + +namespace jxl { +namespace extras { +namespace { + +// Any valid encoding is larger (ensures codecs can read the first few bytes) +constexpr size_t kMinBytes = 9; + +} // namespace + +std::vector AvailableCodecs() { + std::vector out; +#if JPEGXL_ENABLE_APNG + out.push_back(Codec::kPNG); +#endif +#if JPEGXL_ENABLE_EXR + out.push_back(Codec::kEXR); +#endif +#if JPEGXL_ENABLE_GIF + out.push_back(Codec::kGIF); +#endif +#if JPEGXL_ENABLE_JPEG + out.push_back(Codec::kJPG); +#endif + out.push_back(Codec::kPGX); + out.push_back(Codec::kPNM); + return out; +} + +Codec CodecFromExtension(std::string extension, + size_t* JXL_RESTRICT bits_per_sample) { + std::transform( + extension.begin(), extension.end(), extension.begin(), + [](char c) { return std::tolower(c, std::locale::classic()); }); + if (extension == ".png") return Codec::kPNG; + + if (extension == ".jpg") return Codec::kJPG; + if (extension == ".jpeg") return Codec::kJPG; + + if (extension == ".pgx") return Codec::kPGX; + + if (extension == ".pam") return Codec::kPNM; + if (extension == ".pnm") return Codec::kPNM; + if (extension == ".pgm") return Codec::kPNM; + if (extension == ".ppm") return Codec::kPNM; + if (extension == ".pfm") { + if (bits_per_sample != nullptr) *bits_per_sample = 32; + return Codec::kPNM; + } + + if (extension == ".gif") return Codec::kGIF; + + if (extension == ".exr") return Codec::kEXR; + + return Codec::kUnknown; +} + +Status DecodeBytes(const Span bytes, + const ColorHints& color_hints, + const SizeConstraints& constraints, + extras::PackedPixelFile* ppf, Codec* orig_codec) { + if (bytes.size() < kMinBytes) return JXL_FAILURE("Too few bytes"); + + *ppf = extras::PackedPixelFile(); + + // Default values when not set by decoders. + ppf->info.uses_original_profile = true; + ppf->info.orientation = JXL_ORIENT_IDENTITY; + + Codec codec; +#if JPEGXL_ENABLE_APNG + if (DecodeImageAPNG(bytes, color_hints, constraints, ppf)) { + codec = Codec::kPNG; + } else +#endif + if (DecodeImagePGX(bytes, color_hints, constraints, ppf)) { + codec = Codec::kPGX; + } else if (DecodeImagePNM(bytes, color_hints, constraints, ppf)) { + codec = Codec::kPNM; + } +#if JPEGXL_ENABLE_GIF + else if (DecodeImageGIF(bytes, color_hints, constraints, ppf)) { + codec = Codec::kGIF; + } +#endif +#if JPEGXL_ENABLE_JPEG + else if (DecodeImageJPG(bytes, color_hints, constraints, ppf)) { + codec = Codec::kJPG; + } +#endif +#if JPEGXL_ENABLE_EXR + else if (DecodeImageEXR(bytes, color_hints, constraints, ppf)) { + codec = Codec::kEXR; + } +#endif + else { + return JXL_FAILURE("Codecs failed to decode"); + } + if (orig_codec) *orig_codec = codec; + + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/decode.h b/media/libjxl/src/lib/extras/dec/decode.h new file mode 100644 index 000000000..7f0ff70aa --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/decode.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_DECODE_H_ +#define LIB_EXTRAS_DEC_DECODE_H_ + +// Facade for image decoders (PNG, PNM, ...). + +#include +#include + +#include +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Codecs supported by CodecInOut::Encode. +enum class Codec : uint32_t { + kUnknown, // for CodecFromExtension + kPNG, + kPNM, + kPGX, + kJPG, + kGIF, + kEXR +}; + +std::vector AvailableCodecs(); + +// If and only if extension is ".pfm", *bits_per_sample is updated to 32 so +// that Encode() would encode to PFM instead of PPM. +Codec CodecFromExtension(std::string extension, + size_t* JXL_RESTRICT bits_per_sample = nullptr); + +// Decodes "bytes" info *ppf. +// color_space_hint may specify the color space, otherwise, defaults to sRGB. +Status DecodeBytes(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, + extras::PackedPixelFile* ppf, Codec* orig_codec = nullptr); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_DECODE_H_ diff --git a/media/libjxl/src/lib/extras/dec/exr.cc b/media/libjxl/src/lib/extras/dec/exr.cc new file mode 100644 index 000000000..ddb6d534e --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/exr.cc @@ -0,0 +1,184 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/exr.h" + +#include +#include +#include +#include + +#include + +namespace jxl { +namespace extras { + +namespace { + +namespace OpenEXR = OPENEXR_IMF_NAMESPACE; +namespace Imath = IMATH_NAMESPACE; + +// OpenEXR::Int64 is deprecated in favor of using uint64_t directly, but using +// uint64_t as recommended causes build failures with previous OpenEXR versions +// on macOS, where the definition for OpenEXR::Int64 was actually not equivalent +// to uint64_t. This alternative should work in all cases. +using ExrInt64 = decltype(std::declval().tellg()); + +constexpr int kExrBitsPerSample = 16; +constexpr int kExrAlphaBits = 16; + +class InMemoryIStream : public OpenEXR::IStream { + public: + // The data pointed to by `bytes` must outlive the InMemoryIStream. + explicit InMemoryIStream(const Span bytes) + : IStream(/*fileName=*/""), bytes_(bytes) {} + + bool isMemoryMapped() const override { return true; } + char* readMemoryMapped(const int n) override { + JXL_ASSERT(pos_ + n <= bytes_.size()); + char* const result = + const_cast(reinterpret_cast(bytes_.data() + pos_)); + pos_ += n; + return result; + } + bool read(char c[], const int n) override { + std::copy_n(readMemoryMapped(n), n, c); + return pos_ < bytes_.size(); + } + + ExrInt64 tellg() override { return pos_; } + void seekg(const ExrInt64 pos) override { + JXL_ASSERT(pos + 1 <= bytes_.size()); + pos_ = pos; + } + + private: + const Span bytes_; + size_t pos_ = 0; +}; + +} // namespace + +Status DecodeImageEXR(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf) { + InMemoryIStream is(bytes); + +#ifdef __EXCEPTIONS + std::unique_ptr input_ptr; + try { + input_ptr.reset(new OpenEXR::RgbaInputFile(is)); + } catch (...) { + return JXL_FAILURE("OpenEXR failed to parse input"); + } + OpenEXR::RgbaInputFile& input = *input_ptr; +#else + OpenEXR::RgbaInputFile input(is); +#endif + + if ((input.channels() & OpenEXR::RgbaChannels::WRITE_RGB) != + OpenEXR::RgbaChannels::WRITE_RGB) { + return JXL_FAILURE("only RGB OpenEXR files are supported"); + } + const bool has_alpha = (input.channels() & OpenEXR::RgbaChannels::WRITE_A) == + OpenEXR::RgbaChannels::WRITE_A; + + const float intensity_target = OpenEXR::hasWhiteLuminance(input.header()) + ? OpenEXR::whiteLuminance(input.header()) + : kDefaultIntensityTarget; + + auto image_size = input.displayWindow().size(); + // Size is computed as max - min, but both bounds are inclusive. + ++image_size.x; + ++image_size.y; + + ppf->info.xsize = image_size.x; + ppf->info.ysize = image_size.y; + ppf->info.num_color_channels = 3; + + const JxlDataType data_type = + kExrBitsPerSample == 16 ? JXL_TYPE_FLOAT16 : JXL_TYPE_FLOAT; + const JxlPixelFormat format{ + /*num_channels=*/3u + (has_alpha ? 1u : 0u), + /*data_type=*/data_type, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + ppf->frames.clear(); + // Allocates the frame buffer. + ppf->frames.emplace_back(image_size.x, image_size.y, format); + const auto& frame = ppf->frames.back(); + + const int row_size = input.dataWindow().size().x + 1; + // Number of rows to read at a time. + // https://www.openexr.com/documentation/ReadingAndWritingImageFiles.pdf + // recommends reading the whole file at once. + const int y_chunk_size = input.displayWindow().size().y + 1; + std::vector input_rows(row_size * y_chunk_size); + for (int start_y = + std::max(input.dataWindow().min.y, input.displayWindow().min.y); + start_y <= + std::min(input.dataWindow().max.y, input.displayWindow().max.y); + start_y += y_chunk_size) { + // Inclusive. + const int end_y = std::min( + start_y + y_chunk_size - 1, + std::min(input.dataWindow().max.y, input.displayWindow().max.y)); + input.setFrameBuffer( + input_rows.data() - input.dataWindow().min.x - start_y * row_size, + /*xStride=*/1, /*yStride=*/row_size); + input.readPixels(start_y, end_y); + for (int exr_y = start_y; exr_y <= end_y; ++exr_y) { + const int image_y = exr_y - input.displayWindow().min.y; + const OpenEXR::Rgba* const JXL_RESTRICT input_row = + &input_rows[(exr_y - start_y) * row_size]; + uint8_t* row = static_cast(frame.color.pixels()) + + frame.color.stride * image_y; + const uint32_t pixel_size = + (3 + (has_alpha ? 1 : 0)) * kExrBitsPerSample / 8; + for (int exr_x = + std::max(input.dataWindow().min.x, input.displayWindow().min.x); + exr_x <= + std::min(input.dataWindow().max.x, input.displayWindow().max.x); + ++exr_x) { + const int image_x = exr_x - input.displayWindow().min.x; + memcpy(row + image_x * pixel_size, + input_row + (exr_x - input.dataWindow().min.x), pixel_size); + } + } + } + + ppf->color_encoding.transfer_function = JXL_TRANSFER_FUNCTION_LINEAR; + ppf->color_encoding.color_space = JXL_COLOR_SPACE_RGB; + ppf->color_encoding.primaries = JXL_PRIMARIES_SRGB; + ppf->color_encoding.white_point = JXL_WHITE_POINT_D65; + if (OpenEXR::hasChromaticities(input.header())) { + ppf->color_encoding.primaries = JXL_PRIMARIES_CUSTOM; + ppf->color_encoding.white_point = JXL_WHITE_POINT_CUSTOM; + const auto& chromaticities = OpenEXR::chromaticities(input.header()); + ppf->color_encoding.primaries_red_xy[0] = chromaticities.red.x; + ppf->color_encoding.primaries_red_xy[1] = chromaticities.red.y; + ppf->color_encoding.primaries_green_xy[0] = chromaticities.green.x; + ppf->color_encoding.primaries_green_xy[1] = chromaticities.green.y; + ppf->color_encoding.primaries_blue_xy[0] = chromaticities.blue.x; + ppf->color_encoding.primaries_blue_xy[1] = chromaticities.blue.y; + ppf->color_encoding.white_point_xy[0] = chromaticities.white.x; + ppf->color_encoding.white_point_xy[1] = chromaticities.white.y; + } + + // EXR uses binary16 or binary32 floating point format. + ppf->info.bits_per_sample = kExrBitsPerSample; + ppf->info.exponent_bits_per_sample = kExrBitsPerSample == 16 ? 5 : 8; + if (has_alpha) { + ppf->info.alpha_bits = kExrAlphaBits; + ppf->info.alpha_exponent_bits = ppf->info.exponent_bits_per_sample; + ppf->info.alpha_premultiplied = true; + } + ppf->info.intensity_target = intensity_target; + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/exr.h b/media/libjxl/src/lib/extras/dec/exr.h new file mode 100644 index 000000000..6af4e6bec --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/exr.h @@ -0,0 +1,29 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_EXR_H_ +#define LIB_EXTRAS_DEC_EXR_H_ + +// Decodes OpenEXR images in memory. + +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Decodes `bytes` into `ppf`. color_hints are ignored. +Status DecodeImageEXR(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, PackedPixelFile* ppf); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_EXR_H_ diff --git a/media/libjxl/src/lib/extras/dec/gif.cc b/media/libjxl/src/lib/extras/dec/gif.cc new file mode 100644 index 000000000..5167bf5fa --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/gif.cc @@ -0,0 +1,414 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/gif.h" + +#include +#include + +#include +#include +#include + +#include "jxl/codestream_header.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { +namespace extras { + +namespace { + +struct ReadState { + Span bytes; +}; + +struct DGifCloser { + void operator()(GifFileType* const ptr) const { DGifCloseFile(ptr, nullptr); } +}; +using GifUniquePtr = std::unique_ptr; + +struct PackedRgba { + uint8_t r, g, b, a; +}; + +struct PackedRgb { + uint8_t r, g, b; +}; + +// Gif does not support partial transparency, so this considers any nonzero +// alpha channel value as opaque. +bool AllOpaque(const PackedImage& color) { + for (size_t y = 0; y < color.ysize; ++y) { + const PackedRgba* const JXL_RESTRICT row = + static_cast(color.pixels()) + y * color.xsize; + for (size_t x = 0; x < color.xsize; ++x) { + if (row[x].a == 0) { + return false; + } + } + } + return true; +} + +void ensure_have_alpha(PackedFrame* frame) { + if (!frame->extra_channels.empty()) return; + const JxlPixelFormat alpha_format{ + /*num_channels=*/1u, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + frame->extra_channels.emplace_back(frame->color.xsize, frame->color.ysize, + alpha_format); + // We need to set opaque-by-default. + std::fill_n(static_cast(frame->extra_channels[0].pixels()), + frame->color.xsize * frame->color.ysize, 255u); +} + +} // namespace + +Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf) { + int error = GIF_OK; + ReadState state = {bytes}; + const auto ReadFromSpan = [](GifFileType* const gif, GifByteType* const bytes, + int n) { + ReadState* const state = reinterpret_cast(gif->UserData); + // giflib API requires the input size `n` to be signed int. + if (static_cast(n) > state->bytes.size()) { + n = state->bytes.size(); + } + memcpy(bytes, state->bytes.data(), n); + state->bytes.remove_prefix(n); + return n; + }; + GifUniquePtr gif(DGifOpen(&state, ReadFromSpan, &error)); + if (gif == nullptr) { + if (error == D_GIF_ERR_NOT_GIF_FILE) { + // Not an error. + return false; + } else { + return JXL_FAILURE("Failed to read GIF: %s", GifErrorString(error)); + } + } + error = DGifSlurp(gif.get()); + if (error != GIF_OK) { + return JXL_FAILURE("Failed to read GIF: %s", GifErrorString(gif->Error)); + } + + msan::UnpoisonMemory(gif.get(), sizeof(*gif)); + if (gif->SColorMap) { + msan::UnpoisonMemory(gif->SColorMap, sizeof(*gif->SColorMap)); + msan::UnpoisonMemory( + gif->SColorMap->Colors, + sizeof(*gif->SColorMap->Colors) * gif->SColorMap->ColorCount); + } + msan::UnpoisonMemory(gif->SavedImages, + sizeof(*gif->SavedImages) * gif->ImageCount); + + JXL_RETURN_IF_ERROR( + VerifyDimensions(&constraints, gif->SWidth, gif->SHeight)); + uint64_t total_pixel_count = + static_cast(gif->SWidth) * gif->SHeight; + for (int i = 0; i < gif->ImageCount; ++i) { + const SavedImage& image = gif->SavedImages[i]; + uint32_t w = image.ImageDesc.Width; + uint32_t h = image.ImageDesc.Height; + JXL_RETURN_IF_ERROR(VerifyDimensions(&constraints, w, h)); + uint64_t pixel_count = static_cast(w) * h; + if (total_pixel_count + pixel_count < total_pixel_count) { + return JXL_FAILURE("Image too big"); + } + total_pixel_count += pixel_count; + if (total_pixel_count > constraints.dec_max_pixels) { + return JXL_FAILURE("Image too big"); + } + } + + if (!gif->SColorMap) { + for (int i = 0; i < gif->ImageCount; ++i) { + if (!gif->SavedImages[i].ImageDesc.ColorMap) { + return JXL_FAILURE("Missing GIF color map"); + } + } + } + + if (gif->ImageCount > 1) { + ppf->info.have_animation = true; + // Delays in GIF are specified in 100ths of a second. + ppf->info.animation.tps_numerator = 100; + ppf->info.animation.tps_denominator = 1; + } + + ppf->frames.clear(); + ppf->frames.reserve(gif->ImageCount); + + ppf->info.xsize = gif->SWidth; + ppf->info.ysize = gif->SHeight; + ppf->info.bits_per_sample = 8; + ppf->info.exponent_bits_per_sample = 0; + // alpha_bits is later set to 8 if we find a frame with transparent pixels. + ppf->info.alpha_bits = 0; + ppf->info.alpha_exponent_bits = 0; + JXL_RETURN_IF_ERROR(ApplyColorHints(color_hints, /*color_already_set=*/false, + /*is_gray=*/false, ppf)); + + ppf->info.num_color_channels = 3; + + // Pixel format for the 'canvas' onto which we paint + // the (potentially individually cropped) GIF frames + // of an animation. + const JxlPixelFormat canvas_format{ + /*num_channels=*/4u, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + + // Pixel format for the JXL PackedFrame that goes into the + // PackedPixelFile. Here, we use 3 color channels, and provide + // the alpha channel as an extra_channel wherever it is used. + const JxlPixelFormat packed_frame_format{ + /*num_channels=*/3u, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + + GifColorType background_color; + if (gif->SColorMap == nullptr || + gif->SBackGroundColor >= gif->SColorMap->ColorCount) { + background_color = {0, 0, 0}; + } else { + background_color = gif->SColorMap->Colors[gif->SBackGroundColor]; + } + const PackedRgba background_rgba{background_color.Red, background_color.Green, + background_color.Blue, 0}; + PackedFrame canvas(gif->SWidth, gif->SHeight, canvas_format); + std::fill_n(static_cast(canvas.color.pixels()), + canvas.color.xsize * canvas.color.ysize, background_rgba); + Rect canvas_rect{0, 0, canvas.color.xsize, canvas.color.ysize}; + + Rect previous_rect_if_restore_to_background; + + bool replace = true; + bool last_base_was_none = true; + for (int i = 0; i < gif->ImageCount; ++i) { + const SavedImage& image = gif->SavedImages[i]; + msan::UnpoisonMemory(image.RasterBits, sizeof(*image.RasterBits) * + image.ImageDesc.Width * + image.ImageDesc.Height); + const Rect image_rect(image.ImageDesc.Left, image.ImageDesc.Top, + image.ImageDesc.Width, image.ImageDesc.Height); + + Rect total_rect; + if (previous_rect_if_restore_to_background.xsize() != 0 || + previous_rect_if_restore_to_background.ysize() != 0) { + const size_t xbegin = std::min( + image_rect.x0(), previous_rect_if_restore_to_background.x0()); + const size_t ybegin = std::min( + image_rect.y0(), previous_rect_if_restore_to_background.y0()); + const size_t xend = + std::max(image_rect.x0() + image_rect.xsize(), + previous_rect_if_restore_to_background.x0() + + previous_rect_if_restore_to_background.xsize()); + const size_t yend = + std::max(image_rect.y0() + image_rect.ysize(), + previous_rect_if_restore_to_background.y0() + + previous_rect_if_restore_to_background.ysize()); + total_rect = Rect(xbegin, ybegin, xend - xbegin, yend - ybegin); + previous_rect_if_restore_to_background = Rect(); + replace = true; + } else { + total_rect = image_rect; + replace = false; + } + if (!image_rect.IsInside(canvas_rect)) { + return JXL_FAILURE("GIF frame extends outside of the canvas"); + } + + // Allocates the frame buffer. + ppf->frames.emplace_back(total_rect.xsize(), total_rect.ysize(), + packed_frame_format); + PackedFrame* frame = &ppf->frames.back(); + + // We cannot tell right from the start whether there will be a + // need for an alpha channel. This is discovered only as soon as + // we see a transparent pixel. We hence initialize alpha lazily. + auto set_pixel_alpha = [&frame](size_t x, size_t y, uint8_t a) { + // If we do not have an alpha-channel and a==255 (fully opaque), + // we can skip setting this pixel-value and rely on + // "no alpha channel = no transparency". + if (a == 255 && !frame->extra_channels.empty()) return; + ensure_have_alpha(frame); + static_cast( + frame->extra_channels[0].pixels())[y * frame->color.xsize + x] = a; + }; + + const ColorMapObject* const color_map = + image.ImageDesc.ColorMap ? image.ImageDesc.ColorMap : gif->SColorMap; + JXL_CHECK(color_map); + msan::UnpoisonMemory(color_map, sizeof(*color_map)); + msan::UnpoisonMemory(color_map->Colors, + sizeof(*color_map->Colors) * color_map->ColorCount); + GraphicsControlBlock gcb; + DGifSavedExtensionToGCB(gif.get(), i, &gcb); + msan::UnpoisonMemory(&gcb, sizeof(gcb)); + bool is_full_size = total_rect.x0() == 0 && total_rect.y0() == 0 && + total_rect.xsize() == canvas.color.xsize && + total_rect.ysize() == canvas.color.ysize; + if (ppf->info.have_animation) { + frame->frame_info.duration = gcb.DelayTime; + frame->frame_info.layer_info.have_crop = static_cast(!is_full_size); + frame->frame_info.layer_info.crop_x0 = total_rect.x0(); + frame->frame_info.layer_info.crop_y0 = total_rect.y0(); + frame->frame_info.layer_info.xsize = frame->color.xsize; + frame->frame_info.layer_info.ysize = frame->color.ysize; + if (last_base_was_none) { + replace = true; + } + frame->frame_info.layer_info.blend_info.blendmode = + replace ? JXL_BLEND_REPLACE : JXL_BLEND_BLEND; + // We always only reference at most the last frame + frame->frame_info.layer_info.blend_info.source = + last_base_was_none ? 0u : 1u; + frame->frame_info.layer_info.blend_info.clamp = 1; + frame->frame_info.layer_info.blend_info.alpha = 0; + // TODO(veluca): this could in principle be implemented. + if (last_base_was_none && + (total_rect.x0() != 0 || total_rect.y0() != 0 || + total_rect.xsize() != canvas.color.xsize || + total_rect.ysize() != canvas.color.ysize || !replace)) { + return JXL_FAILURE( + "GIF with dispose-to-0 is not supported for non-full or " + "blended frames"); + } + switch (gcb.DisposalMode) { + case DISPOSE_DO_NOT: + case DISPOSE_BACKGROUND: + frame->frame_info.layer_info.save_as_reference = 1u; + last_base_was_none = false; + break; + case DISPOSE_PREVIOUS: + frame->frame_info.layer_info.save_as_reference = 0u; + break; + default: + frame->frame_info.layer_info.save_as_reference = 0u; + last_base_was_none = true; + } + } + + // Update the canvas by creating a copy first. + PackedImage new_canvas_image(canvas.color.xsize, canvas.color.ysize, + canvas.color.format); + memcpy(new_canvas_image.pixels(), canvas.color.pixels(), + new_canvas_image.pixels_size); + for (size_t y = 0, byte_index = 0; y < image_rect.ysize(); ++y) { + // Assumes format.align == 0. row points to the beginning of the y row in + // the image_rect. + PackedRgba* row = static_cast(new_canvas_image.pixels()) + + (y + image_rect.y0()) * new_canvas_image.xsize + + image_rect.x0(); + for (size_t x = 0; x < image_rect.xsize(); ++x, ++byte_index) { + const GifByteType byte = image.RasterBits[byte_index]; + if (byte >= color_map->ColorCount) { + return JXL_FAILURE("GIF color is out of bounds"); + } + + if (byte == gcb.TransparentColor) continue; + GifColorType color = color_map->Colors[byte]; + row[x].r = color.Red; + row[x].g = color.Green; + row[x].b = color.Blue; + row[x].a = 255; + } + } + const PackedImage& sub_frame_image = frame->color; + if (replace) { + // Copy from the new canvas image to the subframe + for (size_t y = 0; y < total_rect.ysize(); ++y) { + const PackedRgba* row_in = + static_cast(new_canvas_image.pixels()) + + (y + total_rect.y0()) * new_canvas_image.xsize + total_rect.x0(); + PackedRgb* row_out = static_cast(sub_frame_image.pixels()) + + y * sub_frame_image.xsize; + for (size_t x = 0; x < sub_frame_image.xsize; ++x) { + row_out[x].r = row_in[x].r; + row_out[x].g = row_in[x].g; + row_out[x].b = row_in[x].b; + set_pixel_alpha(x, y, row_in[x].a); + } + } + } else { + for (size_t y = 0, byte_index = 0; y < image_rect.ysize(); ++y) { + // Assumes format.align == 0 + PackedRgb* row = static_cast(sub_frame_image.pixels()) + + y * sub_frame_image.xsize; + for (size_t x = 0; x < image_rect.xsize(); ++x, ++byte_index) { + const GifByteType byte = image.RasterBits[byte_index]; + if (byte > color_map->ColorCount) { + return JXL_FAILURE("GIF color is out of bounds"); + } + if (byte == gcb.TransparentColor) { + row[x].r = 0; + row[x].g = 0; + row[x].b = 0; + set_pixel_alpha(x, y, 0); + continue; + } + GifColorType color = color_map->Colors[byte]; + row[x].r = color.Red; + row[x].g = color.Green; + row[x].b = color.Blue; + set_pixel_alpha(x, y, 255); + } + } + } + + if (!frame->extra_channels.empty()) { + ppf->info.alpha_bits = 8; + } + + switch (gcb.DisposalMode) { + case DISPOSE_DO_NOT: + canvas.color = std::move(new_canvas_image); + break; + + case DISPOSE_BACKGROUND: + std::fill_n(static_cast(canvas.color.pixels()), + canvas.color.xsize * canvas.color.ysize, background_rgba); + previous_rect_if_restore_to_background = image_rect; + break; + + case DISPOSE_PREVIOUS: + break; + + case DISPOSAL_UNSPECIFIED: + default: + std::fill_n(static_cast(canvas.color.pixels()), + canvas.color.xsize * canvas.color.ysize, background_rgba); + } + } + // Finally, if any frame has an alpha-channel, every frame will need + // to have an alpha-channel. + bool seen_alpha = false; + for (const PackedFrame& frame : ppf->frames) { + if (!frame.extra_channels.empty()) { + seen_alpha = true; + break; + } + } + if (seen_alpha) { + for (PackedFrame& frame : ppf->frames) { + ensure_have_alpha(&frame); + } + } + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/gif.h b/media/libjxl/src/lib/extras/dec/gif.h new file mode 100644 index 000000000..b35951728 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/gif.h @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_GIF_H_ +#define LIB_EXTRAS_DEC_GIF_H_ + +// Decodes GIF images in memory. + +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Decodes `bytes` into `ppf`. color_hints are ignored. +Status DecodeImageGIF(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, PackedPixelFile* ppf); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_GIF_H_ diff --git a/media/libjxl/src/lib/extras/dec/jpg.cc b/media/libjxl/src/lib/extras/dec/jpg.cc new file mode 100644 index 000000000..6b92f4a8a --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/jpg.cc @@ -0,0 +1,289 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/jpg.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { +namespace extras { + +namespace { + +constexpr unsigned char kICCSignature[12] = { + 0x49, 0x43, 0x43, 0x5F, 0x50, 0x52, 0x4F, 0x46, 0x49, 0x4C, 0x45, 0x00}; +constexpr int kICCMarker = JPEG_APP0 + 2; + +constexpr unsigned char kExifSignature[6] = {0x45, 0x78, 0x69, + 0x66, 0x00, 0x00}; +constexpr int kExifMarker = JPEG_APP0 + 1; + +static inline bool IsJPG(const Span bytes) { + if (bytes.size() < 2) return false; + if (bytes[0] != 0xFF || bytes[1] != 0xD8) return false; + return true; +} + +bool MarkerIsICC(const jpeg_saved_marker_ptr marker) { + return marker->marker == kICCMarker && + marker->data_length >= sizeof kICCSignature + 2 && + std::equal(std::begin(kICCSignature), std::end(kICCSignature), + marker->data); +} +bool MarkerIsExif(const jpeg_saved_marker_ptr marker) { + return marker->marker == kExifMarker && + marker->data_length >= sizeof kExifSignature + 2 && + std::equal(std::begin(kExifSignature), std::end(kExifSignature), + marker->data); +} + +Status ReadICCProfile(jpeg_decompress_struct* const cinfo, + std::vector* const icc) { + constexpr size_t kICCSignatureSize = sizeof kICCSignature; + // ICC signature + uint8_t index + uint8_t max_index. + constexpr size_t kICCHeadSize = kICCSignatureSize + 2; + // Markers are 1-indexed, and we keep them that way in this vector to get a + // convenient 0 at the front for when we compute the offsets later. + std::vector marker_lengths; + int num_markers = 0; + int seen_markers_count = 0; + bool has_num_markers = false; + for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; + marker = marker->next) { + // marker is initialized by libjpeg, which we are not instrumenting with + // msan. + msan::UnpoisonMemory(marker, sizeof(*marker)); + msan::UnpoisonMemory(marker->data, marker->data_length); + if (!MarkerIsICC(marker)) continue; + + const int current_marker = marker->data[kICCSignatureSize]; + if (current_marker == 0) { + return JXL_FAILURE("inconsistent JPEG ICC marker numbering"); + } + const int current_num_markers = marker->data[kICCSignatureSize + 1]; + if (current_marker > current_num_markers) { + return JXL_FAILURE("inconsistent JPEG ICC marker numbering"); + } + if (has_num_markers) { + if (current_num_markers != num_markers) { + return JXL_FAILURE("inconsistent numbers of JPEG ICC markers"); + } + } else { + num_markers = current_num_markers; + has_num_markers = true; + marker_lengths.resize(num_markers + 1); + } + + size_t marker_length = marker->data_length - kICCHeadSize; + + if (marker_length == 0) { + // NB: if we allow empty chunks, then the next check is incorrect. + return JXL_FAILURE("Empty ICC chunk"); + } + + if (marker_lengths[current_marker] != 0) { + return JXL_FAILURE("duplicate JPEG ICC marker number"); + } + marker_lengths[current_marker] = marker_length; + seen_markers_count++; + } + + if (marker_lengths.empty()) { + // Not an error. + return false; + } + + if (seen_markers_count != num_markers) { + JXL_DASSERT(has_num_markers); + return JXL_FAILURE("Incomplete set of ICC chunks"); + } + + std::vector offsets = std::move(marker_lengths); + std::partial_sum(offsets.begin(), offsets.end(), offsets.begin()); + icc->resize(offsets.back()); + + for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; + marker = marker->next) { + if (!MarkerIsICC(marker)) continue; + const uint8_t* first = marker->data + kICCHeadSize; + uint8_t current_marker = marker->data[kICCSignatureSize]; + size_t offset = offsets[current_marker - 1]; + size_t marker_length = offsets[current_marker] - offset; + std::copy_n(first, marker_length, icc->data() + offset); + } + + return true; +} + +void ReadExif(jpeg_decompress_struct* const cinfo, + std::vector* const exif) { + constexpr size_t kExifSignatureSize = sizeof kExifSignature; + for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; + marker = marker->next) { + // marker is initialized by libjpeg, which we are not instrumenting with + // msan. + msan::UnpoisonMemory(marker, sizeof(*marker)); + msan::UnpoisonMemory(marker->data, marker->data_length); + if (!MarkerIsExif(marker)) continue; + size_t marker_length = marker->data_length - kExifSignatureSize; + exif->resize(marker_length); + std::copy_n(marker->data + kExifSignatureSize, marker_length, exif->data()); + return; + } +} + +void MyErrorExit(j_common_ptr cinfo) { + jmp_buf* env = static_cast(cinfo->client_data); + (*cinfo->err->output_message)(cinfo); + jpeg_destroy_decompress(reinterpret_cast(cinfo)); + longjmp(*env, 1); +} + +void MyOutputMessage(j_common_ptr cinfo) { +#if JXL_DEBUG_WARNING == 1 + char buf[JMSG_LENGTH_MAX + 1]; + (*cinfo->err->format_message)(cinfo, buf); + buf[JMSG_LENGTH_MAX] = 0; + JXL_WARNING("%s", buf); +#endif +} + +} // namespace + +Status DecodeImageJPG(const Span bytes, + const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf) { + // Don't do anything for non-JPEG files (no need to report an error) + if (!IsJPG(bytes)) return false; + + // TODO(veluca): use JPEGData also for pixels? + + // We need to declare all the non-trivial destructor local variables before + // the call to setjmp(). + std::unique_ptr row; + + const auto try_catch_block = [&]() -> bool { + jpeg_decompress_struct cinfo; + // cinfo is initialized by libjpeg, which we are not instrumenting with + // msan, therefore we need to initialize cinfo here. + msan::UnpoisonMemory(&cinfo, sizeof(cinfo)); + // Setup error handling in jpeg library so we can deal with broken jpegs in + // the fuzzer. + jpeg_error_mgr jerr; + jmp_buf env; + cinfo.err = jpeg_std_error(&jerr); + jerr.error_exit = &MyErrorExit; + jerr.output_message = &MyOutputMessage; + if (setjmp(env)) { + return false; + } + cinfo.client_data = static_cast(&env); + + jpeg_create_decompress(&cinfo); + jpeg_mem_src(&cinfo, reinterpret_cast(bytes.data()), + bytes.size()); + jpeg_save_markers(&cinfo, kICCMarker, 0xFFFF); + jpeg_save_markers(&cinfo, kExifMarker, 0xFFFF); + const auto failure = [&cinfo](const char* str) -> Status { + jpeg_abort_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + return JXL_FAILURE("%s", str); + }; + int read_header_result = jpeg_read_header(&cinfo, TRUE); + // TODO(eustas): what about JPEG_HEADER_TABLES_ONLY? + if (read_header_result == JPEG_SUSPENDED) { + return failure("truncated JPEG input"); + } + if (!VerifyDimensions(&constraints, cinfo.image_width, + cinfo.image_height)) { + return failure("image too big"); + } + // Might cause CPU-zip bomb. + if (cinfo.arith_code) { + return failure("arithmetic code JPEGs are not supported"); + } + int nbcomp = cinfo.num_components; + if (nbcomp != 1 && nbcomp != 3) { + return failure("unsupported number of components in JPEG"); + } + if (!ReadICCProfile(&cinfo, &ppf->icc)) { + ppf->icc.clear(); + // Default to SRGB + // Actually, (cinfo.output_components == nbcomp) will be checked after + // `jpeg_start_decompress`. + ppf->color_encoding.color_space = + (nbcomp == 1) ? JXL_COLOR_SPACE_GRAY : JXL_COLOR_SPACE_RGB; + ppf->color_encoding.white_point = JXL_WHITE_POINT_D65; + ppf->color_encoding.primaries = JXL_PRIMARIES_SRGB; + ppf->color_encoding.transfer_function = JXL_TRANSFER_FUNCTION_SRGB; + ppf->color_encoding.rendering_intent = JXL_RENDERING_INTENT_PERCEPTUAL; + } + ReadExif(&cinfo, &ppf->metadata.exif); + if (!ApplyColorHints(color_hints, /*color_already_set=*/true, + /*is_gray=*/false, ppf)) { + return failure("ApplyColorHints failed"); + } + + ppf->info.xsize = cinfo.image_width; + ppf->info.ysize = cinfo.image_height; + // Original data is uint, so exponent_bits_per_sample = 0. + ppf->info.bits_per_sample = BITS_IN_JSAMPLE; + JXL_ASSERT(BITS_IN_JSAMPLE == 8 || BITS_IN_JSAMPLE == 16); + ppf->info.exponent_bits_per_sample = 0; + ppf->info.uses_original_profile = true; + + // No alpha in JPG + ppf->info.alpha_bits = 0; + ppf->info.alpha_exponent_bits = 0; + + ppf->info.num_color_channels = nbcomp; + ppf->info.orientation = JXL_ORIENT_IDENTITY; + + jpeg_start_decompress(&cinfo); + JXL_ASSERT(cinfo.output_components == nbcomp); + + const JxlPixelFormat format{ + /*num_channels=*/static_cast(nbcomp), + /*data_type=*/BITS_IN_JSAMPLE == 8 ? JXL_TYPE_UINT8 : JXL_TYPE_UINT16, + /*endianness=*/JXL_NATIVE_ENDIAN, + /*align=*/0, + }; + ppf->frames.clear(); + // Allocates the frame buffer. + ppf->frames.emplace_back(cinfo.image_width, cinfo.image_height, format); + const auto& frame = ppf->frames.back(); + JXL_ASSERT(sizeof(JSAMPLE) * cinfo.output_components * cinfo.image_width <= + frame.color.stride); + + for (size_t y = 0; y < cinfo.image_height; ++y) { + JSAMPROW rows[] = {reinterpret_cast( + static_cast(frame.color.pixels()) + + frame.color.stride * y)}; + jpeg_read_scanlines(&cinfo, rows, 1); + msan::UnpoisonMemory(rows[0], sizeof(JSAMPLE) * cinfo.output_components * + cinfo.image_width); + } + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + return true; + }; + + return try_catch_block(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/jpg.h b/media/libjxl/src/lib/extras/dec/jpg.h new file mode 100644 index 000000000..66b345288 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/jpg.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_JPG_H_ +#define LIB_EXTRAS_DEC_JPG_H_ + +// Decodes JPG pixels and metadata in memory. + +#include + +#include "lib/extras/codec.h" +#include "lib/extras/dec/color_hints.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Decodes `bytes` into `ppf`. color_hints are ignored. +// `elapsed_deinterleave`, if non-null, will be set to the time (in seconds) +// that it took to deinterleave the raw JSAMPLEs to planar floats. +Status DecodeImageJPG(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, PackedPixelFile* ppf); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_JPG_H_ diff --git a/media/libjxl/src/lib/extras/dec/jxl.cc b/media/libjxl/src/lib/extras/dec/jxl.cc new file mode 100644 index 000000000..0e1035646 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/jxl.cc @@ -0,0 +1,480 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/jxl.h" + +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/types.h" +#include "lib/extras/dec/color_description.h" +#include "lib/extras/enc/encode.h" +#include "lib/jxl/base/printf_macros.h" + +namespace jxl { +namespace extras { +namespace { + +struct BoxProcessor { + BoxProcessor(JxlDecoder* dec) : dec_(dec) { Reset(); } + + void InitializeOutput(std::vector* out) { + box_data_ = out; + AddMoreOutput(); + } + + bool AddMoreOutput() { + Flush(); + static const size_t kBoxOutputChunkSize = 1 << 16; + box_data_->resize(box_data_->size() + kBoxOutputChunkSize); + next_out_ = box_data_->data() + total_size_; + avail_out_ = box_data_->size() - total_size_; + if (JXL_DEC_SUCCESS != + JxlDecoderSetBoxBuffer(dec_, next_out_, avail_out_)) { + fprintf(stderr, "JxlDecoderSetBoxBuffer failed\n"); + return false; + } + return true; + } + + void FinalizeOutput() { + if (box_data_ == nullptr) return; + Flush(); + box_data_->resize(total_size_); + Reset(); + } + + private: + JxlDecoder* dec_; + std::vector* box_data_; + uint8_t* next_out_; + size_t avail_out_; + size_t total_size_; + + void Reset() { + box_data_ = nullptr; + next_out_ = nullptr; + avail_out_ = 0; + total_size_ = 0; + } + void Flush() { + if (box_data_ == nullptr) return; + size_t remaining = JxlDecoderReleaseBoxBuffer(dec_); + size_t bytes_written = avail_out_ - remaining; + next_out_ += bytes_written; + avail_out_ -= bytes_written; + total_size_ += bytes_written; + } +}; + +} // namespace + +bool DecodeImageJXL(const uint8_t* bytes, size_t bytes_size, + const JXLDecompressParams& dparams, size_t* decoded_bytes, + PackedPixelFile* ppf, std::vector* jpeg_bytes) { + auto decoder = JxlDecoderMake(/*memory_manager=*/nullptr); + JxlDecoder* dec = decoder.get(); + ppf->frames.clear(); + + if (dparams.runner_opaque != nullptr && + JXL_DEC_SUCCESS != JxlDecoderSetParallelRunner(dec, dparams.runner, + dparams.runner_opaque)) { + fprintf(stderr, "JxlEncoderSetParallelRunner failed\n"); + return false; + } + + JxlPixelFormat format; + std::vector accepted_formats = dparams.accepted_formats; + if (accepted_formats.empty()) { + for (const uint32_t num_channels : {1, 2, 3, 4}) { + accepted_formats.push_back( + {num_channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, /*align=*/0}); + } + } + JxlColorEncoding color_encoding; + size_t num_color_channels = 0; + if (!dparams.color_space.empty()) { + if (!jxl::ParseDescription(dparams.color_space, &color_encoding)) { + fprintf(stderr, "Failed to parse color space %s.\n", + dparams.color_space.c_str()); + return false; + } + num_color_channels = + color_encoding.color_space == JXL_COLOR_SPACE_GRAY ? 1 : 3; + } + + bool can_reconstruct_jpeg = false; + std::vector jpeg_data_chunk; + if (jpeg_bytes != nullptr) { + jpeg_data_chunk.resize(16384); + jpeg_bytes->resize(0); + } + + int events = (JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE); + + bool max_passes_defined = + (dparams.max_passes < std::numeric_limits::max()); + if (max_passes_defined || dparams.max_downsampling > 1) { + events |= JXL_DEC_FRAME_PROGRESSION; + if (max_passes_defined) { + JxlDecoderSetProgressiveDetail(dec, JxlProgressiveDetail::kPasses); + } else { + JxlDecoderSetProgressiveDetail(dec, JxlProgressiveDetail::kLastPasses); + } + } + if (jpeg_bytes != nullptr) { + events |= JXL_DEC_JPEG_RECONSTRUCTION; + } else { + events |= (JXL_DEC_COLOR_ENCODING | JXL_DEC_FRAME | JXL_DEC_PREVIEW_IMAGE | + JXL_DEC_BOX); + } + if (JXL_DEC_SUCCESS != JxlDecoderSubscribeEvents(dec, events)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + return false; + } + if (jpeg_bytes == nullptr) { + if (JXL_DEC_SUCCESS != + JxlDecoderSetRenderSpotcolors(dec, dparams.render_spotcolors)) { + fprintf(stderr, "JxlDecoderSetRenderSpotColors failed\n"); + return false; + } + if (JXL_DEC_SUCCESS != + JxlDecoderSetKeepOrientation(dec, dparams.keep_orientation)) { + fprintf(stderr, "JxlDecoderSetKeepOrientation failed\n"); + return false; + } + if (JXL_DEC_SUCCESS != + JxlDecoderSetUnpremultiplyAlpha(dec, dparams.unpremultiply_alpha)) { + fprintf(stderr, "JxlDecoderSetUnpremultiplyAlpha failed\n"); + return false; + } + if (dparams.display_nits > 0 && + JXL_DEC_SUCCESS != + JxlDecoderSetDesiredIntensityTarget(dec, dparams.display_nits)) { + fprintf(stderr, "Decoder failed to set desired intensity target\n"); + return false; + } + if (JXL_DEC_SUCCESS != JxlDecoderSetDecompressBoxes(dec, JXL_TRUE)) { + fprintf(stderr, "JxlDecoderSetDecompressBoxes failed\n"); + return false; + } + } + if (JXL_DEC_SUCCESS != JxlDecoderSetInput(dec, bytes, bytes_size)) { + fprintf(stderr, "Decoder failed to set input\n"); + return false; + } + uint32_t progression_index = 0; + bool codestream_done = false; + BoxProcessor boxes(dec); + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Failed to decode image\n"); + return false; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + if (codestream_done) { + break; + } + if (dparams.allow_partial_input) { + if (JXL_DEC_SUCCESS != JxlDecoderFlushImage(dec)) { + fprintf(stderr, + "Input file is truncated and there is no preview " + "available yet.\n"); + return false; + } + break; + } + fprintf(stderr, + "Input file is truncated and allow_partial_input was disabled."); + return false; + } else if (status == JXL_DEC_BOX) { + boxes.FinalizeOutput(); + JxlBoxType box_type; + if (JXL_DEC_SUCCESS != JxlDecoderGetBoxType(dec, box_type, JXL_TRUE)) { + fprintf(stderr, "JxlDecoderGetBoxType failed\n"); + return false; + } + std::vector* box_data = nullptr; + if (memcmp(box_type, "Exif", 4) == 0) { + box_data = &ppf->metadata.exif; + } else if (memcmp(box_type, "iptc", 4) == 0) { + box_data = &ppf->metadata.iptc; + } else if (memcmp(box_type, "jumb", 4) == 0) { + box_data = &ppf->metadata.jumbf; + } else if (memcmp(box_type, "xml ", 4) == 0) { + box_data = &ppf->metadata.xmp; + } + if (box_data) { + boxes.InitializeOutput(box_data); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + boxes.AddMoreOutput(); + } else if (status == JXL_DEC_JPEG_RECONSTRUCTION) { + can_reconstruct_jpeg = true; + // Decoding to JPEG. + if (JXL_DEC_SUCCESS != JxlDecoderSetJPEGBuffer(dec, + jpeg_data_chunk.data(), + jpeg_data_chunk.size())) { + fprintf(stderr, "Decoder failed to set JPEG Buffer\n"); + return false; + } + } else if (status == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + // Decoded a chunk to JPEG. + size_t used_jpeg_output = + jpeg_data_chunk.size() - JxlDecoderReleaseJPEGBuffer(dec); + jpeg_bytes->insert(jpeg_bytes->end(), jpeg_data_chunk.data(), + jpeg_data_chunk.data() + used_jpeg_output); + if (used_jpeg_output == 0) { + // Chunk is too small. + jpeg_data_chunk.resize(jpeg_data_chunk.size() * 2); + } + if (JXL_DEC_SUCCESS != JxlDecoderSetJPEGBuffer(dec, + jpeg_data_chunk.data(), + jpeg_data_chunk.size())) { + fprintf(stderr, "Decoder failed to set JPEG Buffer\n"); + return false; + } + } else if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec, &ppf->info)) { + fprintf(stderr, "JxlDecoderGetBasicInfo failed\n"); + return false; + } + if (num_color_channels != 0) { + // Mark the change in number of color channels due to the requested + // color space. + ppf->info.num_color_channels = num_color_channels; + } + // Select format according to accepted formats. + if (!jxl::extras::SelectFormat(accepted_formats, ppf->info, &format)) { + fprintf(stderr, "SelectFormat failed\n"); + return false; + } + bool have_alpha = (format.num_channels == 2 || format.num_channels == 4); + if (!have_alpha) { + // Mark in the basic info that alpha channel was dropped. + ppf->info.alpha_bits = 0; + } else if (dparams.unpremultiply_alpha) { + // Mark in the basic info that alpha was unpremultiplied. + ppf->info.alpha_premultiplied = false; + } + bool alpha_found = false; + for (uint32_t i = 0; i < ppf->info.num_extra_channels; ++i) { + JxlExtraChannelInfo eci; + if (JXL_DEC_SUCCESS != JxlDecoderGetExtraChannelInfo(dec, i, &eci)) { + fprintf(stderr, "JxlDecoderGetExtraChannelInfo failed\n"); + return false; + } + if (eci.type == JXL_CHANNEL_ALPHA && have_alpha && !alpha_found) { + // Skip the first alpha channels because it is already present in the + // interleaved image. + alpha_found = true; + continue; + } + std::string name(eci.name_length + 1, 0); + if (JXL_DEC_SUCCESS != + JxlDecoderGetExtraChannelName(dec, i, &name[0], name.size())) { + fprintf(stderr, "JxlDecoderGetExtraChannelName failed\n"); + return false; + } + name.resize(eci.name_length); + ppf->extra_channels_info.push_back({eci, i, name}); + } + } else if (status == JXL_DEC_COLOR_ENCODING) { + if (!dparams.color_space.empty()) { + if (ppf->info.uses_original_profile) { + fprintf(stderr, + "Warning: --color_space ignored because the image is " + "not XYB encoded.\n"); + } else { + if (JXL_DEC_SUCCESS != + JxlDecoderSetPreferredColorProfile(dec, &color_encoding)) { + fprintf(stderr, "Failed to set color space.\n"); + return false; + } + } + } + size_t icc_size = 0; + JxlColorProfileTarget target = JXL_COLOR_PROFILE_TARGET_DATA; + if (JXL_DEC_SUCCESS != + JxlDecoderGetICCProfileSize(dec, nullptr, target, &icc_size)) { + fprintf(stderr, "JxlDecoderGetICCProfileSize failed\n"); + } + if (icc_size != 0) { + ppf->icc.resize(icc_size); + if (JXL_DEC_SUCCESS != + JxlDecoderGetColorAsICCProfile(dec, nullptr, target, + ppf->icc.data(), icc_size)) { + fprintf(stderr, "JxlDecoderGetColorAsICCProfile failed\n"); + return false; + } + } + if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsEncodedProfile( + dec, nullptr, target, &ppf->color_encoding)) { + ppf->color_encoding.color_space = JXL_COLOR_SPACE_UNKNOWN; + } + icc_size = 0; + target = JXL_COLOR_PROFILE_TARGET_ORIGINAL; + if (JXL_DEC_SUCCESS != + JxlDecoderGetICCProfileSize(dec, nullptr, target, &icc_size)) { + fprintf(stderr, "JxlDecoderGetICCProfileSize failed\n"); + } + if (icc_size != 0) { + ppf->orig_icc.resize(icc_size); + if (JXL_DEC_SUCCESS != + JxlDecoderGetColorAsICCProfile(dec, nullptr, target, + ppf->orig_icc.data(), icc_size)) { + fprintf(stderr, "JxlDecoderGetColorAsICCProfile failed\n"); + return false; + } + } + } else if (status == JXL_DEC_FRAME) { + jxl::extras::PackedFrame frame(ppf->info.xsize, ppf->info.ysize, format); + if (JXL_DEC_SUCCESS != JxlDecoderGetFrameHeader(dec, &frame.frame_info)) { + fprintf(stderr, "JxlDecoderGetFrameHeader failed\n"); + return false; + } + frame.name.resize(frame.frame_info.name_length + 1, 0); + if (JXL_DEC_SUCCESS != + JxlDecoderGetFrameName(dec, &frame.name[0], frame.name.size())) { + fprintf(stderr, "JxlDecoderGetFrameName failed\n"); + return false; + } + frame.name.resize(frame.frame_info.name_length); + ppf->frames.emplace_back(std::move(frame)); + progression_index = 0; + } else if (status == JXL_DEC_FRAME_PROGRESSION) { + size_t downsampling = JxlDecoderGetIntendedDownsamplingRatio(dec); + if ((max_passes_defined && progression_index >= dparams.max_passes) || + (!max_passes_defined && downsampling <= dparams.max_downsampling)) { + if (JXL_DEC_SUCCESS != JxlDecoderFlushImage(dec)) { + fprintf(stderr, "JxlDecoderFlushImage failed\n"); + return false; + } + if (ppf->frames.back().frame_info.is_last) { + break; + } + if (JXL_DEC_SUCCESS != JxlDecoderSkipCurrentFrame(dec)) { + fprintf(stderr, "JxlDecoderSkipCurrentFrame failed\n"); + return false; + } + } + ++progression_index; + } else if (status == JXL_DEC_NEED_PREVIEW_OUT_BUFFER) { + size_t buffer_size; + if (JXL_DEC_SUCCESS != + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)) { + fprintf(stderr, "JxlDecoderPreviewOutBufferSize failed\n"); + return false; + } + ppf->preview_frame = std::unique_ptr( + new jxl::extras::PackedFrame(ppf->info.preview.xsize, + ppf->info.preview.ysize, format)); + if (buffer_size != ppf->preview_frame->color.pixels_size) { + fprintf(stderr, "Invalid out buffer size %" PRIuS " %" PRIuS "\n", + buffer_size, ppf->preview_frame->color.pixels_size); + return false; + } + if (JXL_DEC_SUCCESS != + JxlDecoderSetPreviewOutBuffer( + dec, &format, ppf->preview_frame->color.pixels(), buffer_size)) { + fprintf(stderr, "JxlDecoderSetPreviewOutBuffer failed\n"); + return false; + } + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + if (jpeg_bytes != nullptr) { + break; + } + size_t buffer_size; + if (JXL_DEC_SUCCESS != + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)) { + fprintf(stderr, "JxlDecoderImageOutBufferSize failed\n"); + return false; + } + jxl::extras::PackedFrame& frame = ppf->frames.back(); + if (buffer_size != frame.color.pixels_size) { + fprintf(stderr, "Invalid out buffer size %" PRIuS " %" PRIuS "\n", + buffer_size, frame.color.pixels_size); + return false; + } + + if (dparams.use_image_callback) { + auto callback = [](void* opaque, size_t x, size_t y, size_t num_pixels, + const void* pixels) { + auto* ppf = reinterpret_cast(opaque); + jxl::extras::PackedImage& color = ppf->frames.back().color; + uint8_t* pixels_buffer = reinterpret_cast(color.pixels()); + size_t sample_size = color.pixel_stride(); + memcpy(pixels_buffer + (color.stride * y + sample_size * x), pixels, + num_pixels * sample_size); + }; + if (JXL_DEC_SUCCESS != + JxlDecoderSetImageOutCallback(dec, &format, callback, ppf)) { + fprintf(stderr, "JxlDecoderSetImageOutCallback failed\n"); + return false; + } + } else { + if (JXL_DEC_SUCCESS != JxlDecoderSetImageOutBuffer(dec, &format, + frame.color.pixels(), + buffer_size)) { + fprintf(stderr, "JxlDecoderSetImageOutBuffer failed\n"); + return false; + } + } + JxlPixelFormat ec_format = format; + ec_format.num_channels = 1; + for (const auto& eci : ppf->extra_channels_info) { + frame.extra_channels.emplace_back(jxl::extras::PackedImage( + ppf->info.xsize, ppf->info.ysize, ec_format)); + auto& ec = frame.extra_channels.back(); + size_t buffer_size; + if (JXL_DEC_SUCCESS != JxlDecoderExtraChannelBufferSize( + dec, &ec_format, &buffer_size, eci.index)) { + fprintf(stderr, "JxlDecoderExtraChannelBufferSize failed\n"); + return false; + } + if (buffer_size != ec.pixels_size) { + fprintf(stderr, + "Invalid extra channel buffer size" + " %" PRIuS " %" PRIuS "\n", + buffer_size, ec.pixels_size); + return false; + } + if (JXL_DEC_SUCCESS != + JxlDecoderSetExtraChannelBuffer(dec, &ec_format, ec.pixels(), + buffer_size, eci.index)) { + fprintf(stderr, "JxlDecoderSetExtraChannelBuffer failed\n"); + return false; + } + } + } else if (status == JXL_DEC_SUCCESS) { + // Decoding finished successfully. + break; + } else if (status == JXL_DEC_PREVIEW_IMAGE) { + // Nothing to do. + } else if (status == JXL_DEC_FULL_IMAGE) { + if (jpeg_bytes != nullptr || ppf->frames.back().frame_info.is_last) { + codestream_done = true; + } + } else { + fprintf(stderr, "Error: unexpected status: %d\n", + static_cast(status)); + return false; + } + } + boxes.FinalizeOutput(); + if (jpeg_bytes != nullptr) { + if (!can_reconstruct_jpeg) return false; + size_t used_jpeg_output = + jpeg_data_chunk.size() - JxlDecoderReleaseJPEGBuffer(dec); + jpeg_bytes->insert(jpeg_bytes->end(), jpeg_data_chunk.data(), + jpeg_data_chunk.data() + used_jpeg_output); + } + if (decoded_bytes) { + *decoded_bytes = bytes_size - JxlDecoderReleaseInput(dec); + } + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/jxl.h b/media/libjxl/src/lib/extras/dec/jxl.h new file mode 100644 index 000000000..c462fa4b7 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/jxl.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_JXL_H_ +#define LIB_EXTRAS_DEC_JXL_H_ + +// Decodes JPEG XL images in memory. + +#include + +#include +#include +#include + +#include "jxl/parallel_runner.h" +#include "jxl/types.h" +#include "lib/extras/packed_image.h" + +namespace jxl { +namespace extras { + +struct JXLDecompressParams { + // If empty, little endian float formats will be accepted. + std::vector accepted_formats; + + // Requested output color space description. + std::string color_space; + // If set, performs tone mapping to this intensity target luminance. + float display_nits = 0.0; + // Whether spot colors are rendered on the image. + bool render_spotcolors = true; + // Whether to keep or undo the orientation given in the header. + bool keep_orientation = false; + + // If runner_opaque is set, the decoder uses this parallel runner. + JxlParallelRunner runner; + void* runner_opaque = nullptr; + + // Whether truncated input should be treated as an error. + bool allow_partial_input = false; + + // How many passes to decode at most. By default, decode everything. + uint32_t max_passes = std::numeric_limits::max(); + + // Alternatively, one can specify the maximum tolerable downscaling factor + // with respect to the full size of the image. By default, nothing less than + // the full size is requested. + size_t max_downsampling = 1; + + // Whether to use the image callback or the image buffer to get the output. + bool use_image_callback = true; + // Whether to unpremultiply colors for associated alpha channels. + bool unpremultiply_alpha = false; +}; + +bool DecodeImageJXL(const uint8_t* bytes, size_t bytes_size, + const JXLDecompressParams& dparams, size_t* decoded_bytes, + PackedPixelFile* ppf, + std::vector* jpeg_bytes = nullptr); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_JXL_H_ diff --git a/media/libjxl/src/lib/extras/dec/pgx.cc b/media/libjxl/src/lib/extras/dec/pgx.cc new file mode 100644 index 000000000..1417348c6 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/pgx.cc @@ -0,0 +1,202 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/pgx.h" + +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { +namespace extras { +namespace { + +struct HeaderPGX { + // NOTE: PGX is always grayscale + size_t xsize; + size_t ysize; + size_t bits_per_sample; + bool big_endian; + bool is_signed; +}; + +class Parser { + public: + explicit Parser(const Span bytes) + : pos_(bytes.data()), end_(pos_ + bytes.size()) {} + + // Sets "pos" to the first non-header byte/pixel on success. + Status ParseHeader(HeaderPGX* header, const uint8_t** pos) { + // codec.cc ensures we have at least two bytes => no range check here. + if (pos_[0] != 'P' || pos_[1] != 'G') return false; + pos_ += 2; + return ParseHeaderPGX(header, pos); + } + + // Exposed for testing + Status ParseUnsigned(size_t* number) { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before number"); + if (!IsDigit(*pos_)) return JXL_FAILURE("PGX: expected unsigned number"); + + *number = 0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + return true; + } + + private: + static bool IsDigit(const uint8_t c) { return '0' <= c && c <= '9'; } + static bool IsLineBreak(const uint8_t c) { return c == '\r' || c == '\n'; } + static bool IsWhitespace(const uint8_t c) { + return IsLineBreak(c) || c == '\t' || c == ' '; + } + + Status SkipSpace() { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before space"); + const uint8_t c = *pos_; + if (c != ' ') return JXL_FAILURE("PGX: expected space"); + ++pos_; + return true; + } + + Status SkipLineBreak() { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before line break"); + // Line break can be either "\n" (0a) or "\r\n" (0d 0a). + if (*pos_ == '\n') { + pos_++; + return true; + } else if (*pos_ == '\r' && pos_ + 1 != end_ && *(pos_ + 1) == '\n') { + pos_ += 2; + return true; + } + return JXL_FAILURE("PGX: expected line break"); + } + + Status SkipSingleWhitespace() { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before whitespace"); + if (!IsWhitespace(*pos_)) return JXL_FAILURE("PGX: expected whitespace"); + ++pos_; + return true; + } + + Status ParseHeaderPGX(HeaderPGX* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(SkipSpace()); + if (pos_ + 2 > end_) return JXL_FAILURE("PGX: header too small"); + if (*pos_ == 'M' && *(pos_ + 1) == 'L') { + header->big_endian = true; + } else if (*pos_ == 'L' && *(pos_ + 1) == 'M') { + header->big_endian = false; + } else { + return JXL_FAILURE("PGX: invalid endianness"); + } + pos_ += 2; + JXL_RETURN_IF_ERROR(SkipSpace()); + if (pos_ == end_) return JXL_FAILURE("PGX: header too small"); + if (*pos_ == '+') { + header->is_signed = false; + } else if (*pos_ == '-') { + header->is_signed = true; + } else { + return JXL_FAILURE("PGX: invalid signedness"); + } + pos_++; + // Skip optional space + if (pos_ < end_ && *pos_ == ' ') pos_++; + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->bits_per_sample)); + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + // 0xa, or 0xd 0xa. + JXL_RETURN_IF_ERROR(SkipLineBreak()); + + // TODO(jon): could do up to 24-bit by converting the values to + // JXL_TYPE_FLOAT. + if (header->bits_per_sample > 16) { + return JXL_FAILURE("PGX: >16 bits not yet supported"); + } + // TODO(lode): support signed integers. This may require changing the way + // external_image works. + if (header->is_signed) { + return JXL_FAILURE("PGX: signed not yet supported"); + } + + size_t numpixels = header->xsize * header->ysize; + size_t bytes_per_pixel = header->bits_per_sample <= 8 ? 1 : 2; + if (pos_ + numpixels * bytes_per_pixel > end_) { + return JXL_FAILURE("PGX: data too small"); + } + + *pos = pos_; + return true; + } + + const uint8_t* pos_; + const uint8_t* const end_; +}; + +} // namespace + +Status DecodeImagePGX(const Span bytes, + const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf) { + Parser parser(bytes); + HeaderPGX header = {}; + const uint8_t* pos; + if (!parser.ParseHeader(&header, &pos)) return false; + JXL_RETURN_IF_ERROR( + VerifyDimensions(&constraints, header.xsize, header.ysize)); + if (header.bits_per_sample == 0 || header.bits_per_sample > 32) { + return JXL_FAILURE("PGX: bits_per_sample invalid"); + } + + JXL_RETURN_IF_ERROR(ApplyColorHints(color_hints, /*color_already_set=*/false, + /*is_gray=*/true, ppf)); + ppf->info.xsize = header.xsize; + ppf->info.ysize = header.ysize; + // Original data is uint, so exponent_bits_per_sample = 0. + ppf->info.bits_per_sample = header.bits_per_sample; + ppf->info.exponent_bits_per_sample = 0; + ppf->info.uses_original_profile = true; + + // No alpha in PGX + ppf->info.alpha_bits = 0; + ppf->info.alpha_exponent_bits = 0; + ppf->info.num_color_channels = 1; // Always grayscale + ppf->info.orientation = JXL_ORIENT_IDENTITY; + + JxlDataType data_type; + if (header.bits_per_sample > 8) { + data_type = JXL_TYPE_UINT16; + } else { + data_type = JXL_TYPE_UINT8; + } + + const JxlPixelFormat format{ + /*num_channels=*/1, + /*data_type=*/data_type, + /*endianness=*/header.big_endian ? JXL_BIG_ENDIAN : JXL_LITTLE_ENDIAN, + /*align=*/0, + }; + ppf->frames.clear(); + // Allocates the frame buffer. + ppf->frames.emplace_back(header.xsize, header.ysize, format); + const auto& frame = ppf->frames.back(); + size_t pgx_remaining_size = bytes.data() + bytes.size() - pos; + if (pgx_remaining_size < frame.color.pixels_size) { + return JXL_FAILURE("PGX file too small"); + } + memcpy(frame.color.pixels(), pos, frame.color.pixels_size); + return true; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/pgx.h b/media/libjxl/src/lib/extras/dec/pgx.h new file mode 100644 index 000000000..38aedf51a --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/pgx.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_PGX_H_ +#define LIB_EXTRAS_DEC_PGX_H_ + +// Decodes PGX pixels in memory. + +#include +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Decodes `bytes` into `ppf`. +Status DecodeImagePGX(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, PackedPixelFile* ppf); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_PGX_H_ diff --git a/media/libjxl/src/lib/extras/dec/pgx_test.cc b/media/libjxl/src/lib/extras/dec/pgx_test.cc new file mode 100644 index 000000000..41e6bf810 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/pgx_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/pgx.h" + +#include "gtest/gtest.h" +#include "lib/extras/packed_image_convert.h" + +namespace jxl { +namespace extras { +namespace { + +Span MakeSpan(const char* str) { + return Span(reinterpret_cast(str), + strlen(str)); +} + +TEST(CodecPGXTest, Test8bits) { + std::string pgx = "PG ML + 8 2 3\npixels"; + + PackedPixelFile ppf; + ThreadPool* pool = nullptr; + + EXPECT_TRUE(DecodeImagePGX(MakeSpan(pgx.c_str()), ColorHints(), + SizeConstraints(), &ppf)); + CodecInOut io; + EXPECT_TRUE(ConvertPackedPixelFileToCodecInOut(ppf, pool, &io)); + + ScaleImage(255.f, io.Main().color()); + + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.IsGray()); + EXPECT_EQ(2u, io.xsize()); + EXPECT_EQ(3u, io.ysize()); + + float eps = 1e-5; + EXPECT_NEAR('p', io.Main().color()->Plane(0).Row(0)[0], eps); + EXPECT_NEAR('i', io.Main().color()->Plane(0).Row(0)[1], eps); + EXPECT_NEAR('x', io.Main().color()->Plane(0).Row(1)[0], eps); + EXPECT_NEAR('e', io.Main().color()->Plane(0).Row(1)[1], eps); + EXPECT_NEAR('l', io.Main().color()->Plane(0).Row(2)[0], eps); + EXPECT_NEAR('s', io.Main().color()->Plane(0).Row(2)[1], eps); +} + +TEST(CodecPGXTest, Test16bits) { + std::string pgx = "PG ML + 16 2 3\np_i_x_e_l_s_"; + + PackedPixelFile ppf; + ThreadPool* pool = nullptr; + + EXPECT_TRUE(DecodeImagePGX(MakeSpan(pgx.c_str()), ColorHints(), + SizeConstraints(), &ppf)); + CodecInOut io; + EXPECT_TRUE(ConvertPackedPixelFileToCodecInOut(ppf, pool, &io)); + + ScaleImage(255.f, io.Main().color()); + + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(16u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.IsGray()); + EXPECT_EQ(2u, io.xsize()); + EXPECT_EQ(3u, io.ysize()); + + // Comparing ~16-bit numbers in floats, only ~7 bits left. + float eps = 1e-3; + const auto& plane = io.Main().color()->Plane(0); + EXPECT_NEAR(256.0f * 'p' + '_', plane.Row(0)[0] * 257, eps); + EXPECT_NEAR(256.0f * 'i' + '_', plane.Row(0)[1] * 257, eps); + EXPECT_NEAR(256.0f * 'x' + '_', plane.Row(1)[0] * 257, eps); + EXPECT_NEAR(256.0f * 'e' + '_', plane.Row(1)[1] * 257, eps); + EXPECT_NEAR(256.0f * 'l' + '_', plane.Row(2)[0] * 257, eps); + EXPECT_NEAR(256.0f * 's' + '_', plane.Row(2)[1] * 257, eps); +} + +} // namespace +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/pnm.cc b/media/libjxl/src/lib/extras/dec/pnm.cc new file mode 100644 index 000000000..03aecef29 --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/pnm.cc @@ -0,0 +1,420 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/pnm.h" + +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace extras { +namespace { + +struct HeaderPNM { + size_t xsize; + size_t ysize; + bool is_gray; // PGM + bool has_alpha; // PAM + size_t bits_per_sample; + bool floating_point; + bool big_endian; +}; + +class Parser { + public: + explicit Parser(const Span bytes) + : pos_(bytes.data()), end_(pos_ + bytes.size()) {} + + // Sets "pos" to the first non-header byte/pixel on success. + Status ParseHeader(HeaderPNM* header, const uint8_t** pos) { + // codec.cc ensures we have at least two bytes => no range check here. + if (pos_[0] != 'P') return false; + const uint8_t type = pos_[1]; + pos_ += 2; + + switch (type) { + case '4': + return JXL_FAILURE("pbm not supported"); + + case '5': + header->is_gray = true; + return ParseHeaderPNM(header, pos); + + case '6': + header->is_gray = false; + return ParseHeaderPNM(header, pos); + + case '7': + return ParseHeaderPAM(header, pos); + + case 'F': + header->is_gray = false; + return ParseHeaderPFM(header, pos); + + case 'f': + header->is_gray = true; + return ParseHeaderPFM(header, pos); + } + return false; + } + + // Exposed for testing + Status ParseUnsigned(size_t* number) { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before number"); + if (!IsDigit(*pos_)) return JXL_FAILURE("PNM: expected unsigned number"); + + *number = 0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + return true; + } + + Status ParseSigned(double* number) { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before signed"); + + if (*pos_ != '-' && *pos_ != '+' && !IsDigit(*pos_)) { + return JXL_FAILURE("PNM: expected signed number"); + } + + // Skip sign + const bool is_neg = *pos_ == '-'; + if (is_neg || *pos_ == '+') { + ++pos_; + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before digits"); + } + + // Leading digits + *number = 0.0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + // Decimal places? + if (pos_ < end_ && *pos_ == '.') { + ++pos_; + double place = 0.1; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number += (*pos_ - '0') * place; + place *= 0.1; + ++pos_; + } + } + + if (is_neg) *number = -*number; + return true; + } + + private: + static bool IsDigit(const uint8_t c) { return '0' <= c && c <= '9'; } + static bool IsLineBreak(const uint8_t c) { return c == '\r' || c == '\n'; } + static bool IsWhitespace(const uint8_t c) { + return IsLineBreak(c) || c == '\t' || c == ' '; + } + + Status SkipBlank() { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before blank"); + const uint8_t c = *pos_; + if (c != ' ' && c != '\n') return JXL_FAILURE("PNM: expected blank"); + ++pos_; + return true; + } + + Status SkipSingleWhitespace() { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before whitespace"); + if (!IsWhitespace(*pos_)) return JXL_FAILURE("PNM: expected whitespace"); + ++pos_; + return true; + } + + Status SkipWhitespace() { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before whitespace"); + if (!IsWhitespace(*pos_) && *pos_ != '#') { + return JXL_FAILURE("PNM: expected whitespace/comment"); + } + + while (pos_ < end_ && IsWhitespace(*pos_)) { + ++pos_; + } + + // Comment(s) + while (pos_ != end_ && *pos_ == '#') { + while (pos_ != end_ && !IsLineBreak(*pos_)) { + ++pos_; + } + // Newline(s) + while (pos_ != end_ && IsLineBreak(*pos_)) pos_++; + } + + while (pos_ < end_ && IsWhitespace(*pos_)) { + ++pos_; + } + return true; + } + + Status MatchString(const char* keyword, bool skipws = true) { + const uint8_t* ppos = pos_; + while (*keyword) { + if (ppos >= end_) return JXL_FAILURE("PAM: unexpected end of input"); + if (*keyword != *ppos) return false; + ppos++; + keyword++; + } + pos_ = ppos; + if (skipws) { + JXL_RETURN_IF_ERROR(SkipWhitespace()); + } else { + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + } + return true; + } + + Status ParseHeaderPAM(HeaderPNM* header, const uint8_t** pos) { + size_t depth = 3; + size_t max_val = 255; + while (!MatchString("ENDHDR", /*skipws=*/false)) { + JXL_RETURN_IF_ERROR(SkipWhitespace()); + if (MatchString("WIDTH")) { + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + } else if (MatchString("HEIGHT")) { + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + } else if (MatchString("DEPTH")) { + JXL_RETURN_IF_ERROR(ParseUnsigned(&depth)); + } else if (MatchString("MAXVAL")) { + JXL_RETURN_IF_ERROR(ParseUnsigned(&max_val)); + } else if (MatchString("TUPLTYPE")) { + if (MatchString("RGB_ALPHA")) { + header->has_alpha = true; + } else if (MatchString("RGB")) { + } else if (MatchString("GRAYSCALE_ALPHA")) { + header->has_alpha = true; + header->is_gray = true; + } else if (MatchString("GRAYSCALE")) { + header->is_gray = true; + } else if (MatchString("BLACKANDWHITE_ALPHA")) { + header->has_alpha = true; + header->is_gray = true; + max_val = 1; + } else if (MatchString("BLACKANDWHITE")) { + header->is_gray = true; + max_val = 1; + } else { + return JXL_FAILURE("PAM: unknown TUPLTYPE"); + } + } else { + constexpr size_t kMaxHeaderLength = 20; + char unknown_header[kMaxHeaderLength + 1]; + size_t len = std::min(kMaxHeaderLength, end_ - pos_); + strncpy(unknown_header, reinterpret_cast(pos_), len); + unknown_header[len] = 0; + return JXL_FAILURE("PAM: unknown header keyword: %s", unknown_header); + } + } + size_t num_channels = header->is_gray ? 1 : 3; + if (header->has_alpha) num_channels++; + if (num_channels != depth) { + return JXL_FAILURE("PAM: bad DEPTH"); + } + if (max_val == 0 || max_val >= 65536) { + return JXL_FAILURE("PAM: bad MAXVAL"); + } + // e.g When `max_val` is 1 , we want 1 bit: + header->bits_per_sample = FloorLog2Nonzero(max_val) + 1; + if ((1u << header->bits_per_sample) - 1 != max_val) + return JXL_FAILURE("PNM: unsupported MaxVal (expected 2^n - 1)"); + // PAM does not pack bits as in PBM. + + header->floating_point = false; + header->big_endian = true; + *pos = pos_; + return true; + } + + Status ParseHeaderPNM(HeaderPNM* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(SkipWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + + JXL_RETURN_IF_ERROR(SkipWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + + JXL_RETURN_IF_ERROR(SkipWhitespace()); + size_t max_val; + JXL_RETURN_IF_ERROR(ParseUnsigned(&max_val)); + if (max_val == 0 || max_val >= 65536) { + return JXL_FAILURE("PNM: bad MaxVal"); + } + header->bits_per_sample = FloorLog2Nonzero(max_val) + 1; + if ((1u << header->bits_per_sample) - 1 != max_val) + return JXL_FAILURE("PNM: unsupported MaxVal (expected 2^n - 1)"); + header->floating_point = false; + header->big_endian = true; + + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + + *pos = pos_; + return true; + } + + Status ParseHeaderPFM(HeaderPNM* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + + JXL_RETURN_IF_ERROR(SkipBlank()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + // The scale has no meaning as multiplier, only its sign is used to + // indicate endianness. All software expects nominal range 0..1. + double scale; + JXL_RETURN_IF_ERROR(ParseSigned(&scale)); + if (scale == 0.0) { + return JXL_FAILURE("PFM: bad scale factor value."); + } else if (std::abs(scale) != 1.0) { + JXL_WARNING("PFM: Discarding non-unit scale factor"); + } + header->big_endian = scale > 0.0; + header->bits_per_sample = 32; + header->floating_point = true; + + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + + *pos = pos_; + return true; + } + + const uint8_t* pos_; + const uint8_t* const end_; +}; + +Span MakeSpan(const char* str) { + return Span(reinterpret_cast(str), + strlen(str)); +} + +} // namespace + +Status DecodeImagePNM(const Span bytes, + const ColorHints& color_hints, + const SizeConstraints& constraints, + PackedPixelFile* ppf) { + Parser parser(bytes); + HeaderPNM header = {}; + const uint8_t* pos = nullptr; + if (!parser.ParseHeader(&header, &pos)) return false; + JXL_RETURN_IF_ERROR( + VerifyDimensions(&constraints, header.xsize, header.ysize)); + + if (header.bits_per_sample == 0 || header.bits_per_sample > 32) { + return JXL_FAILURE("PNM: bits_per_sample invalid"); + } + + // PPM specify that in the raster, the sample values are "nonlinear" (BP.709, + // with gamma number of 2.2). Deviate from the specification and assume + // `sRGB` in our implementation. + JXL_RETURN_IF_ERROR(ApplyColorHints(color_hints, /*color_already_set=*/false, + header.is_gray, ppf)); + + ppf->info.xsize = header.xsize; + ppf->info.ysize = header.ysize; + if (header.floating_point) { + ppf->info.bits_per_sample = 32; + ppf->info.exponent_bits_per_sample = 8; + } else { + ppf->info.bits_per_sample = header.bits_per_sample; + ppf->info.exponent_bits_per_sample = 0; + } + + ppf->info.orientation = JXL_ORIENT_IDENTITY; + + // No alpha in PNM and PFM + ppf->info.alpha_bits = (header.has_alpha ? ppf->info.bits_per_sample : 0); + ppf->info.alpha_exponent_bits = 0; + ppf->info.num_color_channels = (header.is_gray ? 1 : 3); + ppf->info.num_extra_channels = (header.has_alpha ? 1 : 0); + + JxlDataType data_type; + if (header.floating_point) { + // There's no float16 pnm version. + data_type = JXL_TYPE_FLOAT; + } else { + if (header.bits_per_sample > 8) { + data_type = JXL_TYPE_UINT16; + } else { + data_type = JXL_TYPE_UINT8; + } + } + + const JxlPixelFormat format{ + /*num_channels=*/ppf->info.num_color_channels + + ppf->info.num_extra_channels, + /*data_type=*/data_type, + /*endianness=*/header.big_endian ? JXL_BIG_ENDIAN : JXL_LITTLE_ENDIAN, + /*align=*/0, + }; + ppf->frames.clear(); + ppf->frames.emplace_back(header.xsize, header.ysize, format); + auto* frame = &ppf->frames.back(); + + size_t pnm_remaining_size = bytes.data() + bytes.size() - pos; + if (pnm_remaining_size < frame->color.pixels_size) { + return JXL_FAILURE("PNM file too small"); + } + const bool flipped_y = header.bits_per_sample == 32; // PFMs are flipped + uint8_t* out = reinterpret_cast(frame->color.pixels()); + for (size_t y = 0; y < header.ysize; ++y) { + size_t y_in = flipped_y ? header.ysize - 1 - y : y; + const uint8_t* row_in = &pos[y_in * frame->color.stride]; + uint8_t* row_out = &out[y * frame->color.stride]; + memcpy(row_out, row_in, frame->color.stride); + } + return true; +} + +void TestCodecPNM() { + size_t u = 77777; // Initialized to wrong value. + double d = 77.77; +// Failing to parse invalid strings results in a crash if `JXL_CRASH_ON_ERROR` +// is defined and hence the tests fail. Therefore we only run these tests if +// `JXL_CRASH_ON_ERROR` is not defined. +#ifndef JXL_CRASH_ON_ERROR + JXL_CHECK(false == Parser(MakeSpan("")).ParseUnsigned(&u)); + JXL_CHECK(false == Parser(MakeSpan("+")).ParseUnsigned(&u)); + JXL_CHECK(false == Parser(MakeSpan("-")).ParseUnsigned(&u)); + JXL_CHECK(false == Parser(MakeSpan("A")).ParseUnsigned(&u)); + + JXL_CHECK(false == Parser(MakeSpan("")).ParseSigned(&d)); + JXL_CHECK(false == Parser(MakeSpan("+")).ParseSigned(&d)); + JXL_CHECK(false == Parser(MakeSpan("-")).ParseSigned(&d)); + JXL_CHECK(false == Parser(MakeSpan("A")).ParseSigned(&d)); +#endif + JXL_CHECK(true == Parser(MakeSpan("1")).ParseUnsigned(&u)); + JXL_CHECK(u == 1); + + JXL_CHECK(true == Parser(MakeSpan("32")).ParseUnsigned(&u)); + JXL_CHECK(u == 32); + + JXL_CHECK(true == Parser(MakeSpan("1")).ParseSigned(&d)); + JXL_CHECK(d == 1.0); + JXL_CHECK(true == Parser(MakeSpan("+2")).ParseSigned(&d)); + JXL_CHECK(d == 2.0); + JXL_CHECK(true == Parser(MakeSpan("-3")).ParseSigned(&d)); + JXL_CHECK(std::abs(d - -3.0) < 1E-15); + JXL_CHECK(true == Parser(MakeSpan("3.141592")).ParseSigned(&d)); + JXL_CHECK(std::abs(d - 3.141592) < 1E-15); + JXL_CHECK(true == Parser(MakeSpan("-3.141592")).ParseSigned(&d)); + JXL_CHECK(std::abs(d - -3.141592) < 1E-15); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/dec/pnm.h b/media/libjxl/src/lib/extras/dec/pnm.h new file mode 100644 index 000000000..f6374830c --- /dev/null +++ b/media/libjxl/src/lib/extras/dec/pnm.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_DEC_PNM_H_ +#define LIB_EXTRAS_DEC_PNM_H_ + +// Decodes PBM/PGM/PPM/PFM pixels in memory. + +#include +#include + +// TODO(janwas): workaround for incorrect Win64 codegen (cause unknown) +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Decodes `bytes` into `ppf`. color_hints may specify "color_space", which +// defaults to sRGB. +Status DecodeImagePNM(Span bytes, const ColorHints& color_hints, + const SizeConstraints& constraints, PackedPixelFile* ppf); + +void TestCodecPNM(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_DEC_PNM_H_ diff --git a/media/libjxl/src/lib/extras/enc/apng.cc b/media/libjxl/src/lib/extras/enc/apng.cc new file mode 100644 index 000000000..db6cf9ef4 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/apng.cc @@ -0,0 +1,369 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/apng.h" + +// Parts of this code are taken from apngdis, which has the following license: +/* APNG Disassembler 2.8 + * + * Deconstructs APNG files into individual frames. + * + * http://apngdis.sourceforge.net + * + * Copyright (c) 2010-2015 Max Stepin + * maxst at users.sourceforge.net + * + * zlib license + * ------------ + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + * + */ + +#include +#include + +#include +#include + +#include "lib/extras/exif.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/printf_macros.h" +#include "png.h" /* original (unpatched) libpng is ok */ + +namespace jxl { +namespace extras { + +namespace { + +class APNGEncoder : public Encoder { + public: + std::vector AcceptedFormats() const override { + std::vector formats; + for (const uint32_t num_channels : {1, 2, 3, 4}) { + for (const JxlDataType data_type : {JXL_TYPE_UINT8, JXL_TYPE_UINT16}) { + formats.push_back(JxlPixelFormat{num_channels, data_type, + JXL_BIG_ENDIAN, /*align=*/0}); + } + } + return formats; + } + Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool) const override { + JXL_RETURN_IF_ERROR(VerifyBasicInfo(ppf.info)); + encoded_image->icc.clear(); + encoded_image->bitstreams.resize(1); + return EncodePackedPixelFileToAPNG(ppf, pool, + &encoded_image->bitstreams.front()); + } + + private: + Status EncodePackedPixelFileToAPNG(const PackedPixelFile& ppf, + ThreadPool* pool, + std::vector* bytes) const; +}; + +static void PngWrite(png_structp png_ptr, png_bytep data, png_size_t length) { + std::vector* bytes = + static_cast*>(png_get_io_ptr(png_ptr)); + bytes->insert(bytes->end(), data, data + length); +} + +// Stores XMP and EXIF/IPTC into key/value strings for PNG +class BlobsWriterPNG { + public: + static Status Encode(const PackedMetadata& blobs, + std::vector* strings) { + if (!blobs.exif.empty()) { + // PNG viewers typically ignore Exif orientation but not all of them do + // (and e.g. cjxl doesn't), so we overwrite the Exif orientation to the + // identity to avoid repeated orientation. + std::vector exif = blobs.exif; + ResetExifOrientation(exif); + JXL_RETURN_IF_ERROR(EncodeBase16("exif", exif, strings)); + } + if (!blobs.iptc.empty()) { + JXL_RETURN_IF_ERROR(EncodeBase16("iptc", blobs.iptc, strings)); + } + if (!blobs.xmp.empty()) { + JXL_RETURN_IF_ERROR(EncodeBase16("xmp", blobs.xmp, strings)); + } + return true; + } + + private: + static JXL_INLINE char EncodeNibble(const uint8_t nibble) { + JXL_ASSERT(nibble < 16); + return (nibble < 10) ? '0' + nibble : 'a' + nibble - 10; + } + + static Status EncodeBase16(const std::string& type, + const std::vector& bytes, + std::vector* strings) { + // Encoding: base16 with newline after 72 chars. + const size_t base16_size = + 2 * bytes.size() + DivCeil(bytes.size(), size_t(36)) + 1; + std::string base16; + base16.reserve(base16_size); + for (size_t i = 0; i < bytes.size(); ++i) { + if (i % 36 == 0) base16.push_back('\n'); + base16.push_back(EncodeNibble(bytes[i] >> 4)); + base16.push_back(EncodeNibble(bytes[i] & 0x0F)); + } + base16.push_back('\n'); + JXL_ASSERT(base16.length() == base16_size); + + char key[30]; + snprintf(key, sizeof(key), "Raw profile type %s", type.c_str()); + + char header[30]; + snprintf(header, sizeof(header), "\n%s\n%8" PRIuS, type.c_str(), + bytes.size()); + + strings->push_back(std::string(key)); + strings->push_back(std::string(header) + base16); + return true; + } +}; + +void MaybeAddCICP(JxlColorEncoding c_enc, png_structp png_ptr, + png_infop info_ptr) { + png_byte cicp_data[4] = {}; + png_unknown_chunk cicp_chunk; + if (c_enc.color_space != JXL_COLOR_SPACE_RGB) { + return; + } + if (c_enc.primaries == JXL_PRIMARIES_P3) { + if (c_enc.white_point == JXL_WHITE_POINT_D65) { + cicp_data[0] = 12; + } else if (c_enc.white_point == JXL_WHITE_POINT_DCI) { + cicp_data[0] = 11; + } else { + return; + } + } else if (c_enc.primaries != JXL_PRIMARIES_CUSTOM && + c_enc.white_point == JXL_WHITE_POINT_D65) { + cicp_data[0] = static_cast(c_enc.primaries); + } else { + return; + } + if (c_enc.transfer_function == JXL_TRANSFER_FUNCTION_UNKNOWN || + c_enc.transfer_function == JXL_TRANSFER_FUNCTION_GAMMA) { + return; + } + cicp_data[1] = static_cast(c_enc.transfer_function); + cicp_data[2] = 0; + cicp_data[3] = 1; + cicp_chunk.data = cicp_data; + cicp_chunk.size = sizeof(cicp_data); + cicp_chunk.location = PNG_HAVE_PLTE; + memcpy(cicp_chunk.name, "cICP", 5); + png_set_keep_unknown_chunks(png_ptr, 3, + reinterpret_cast("cICP"), 1); + png_set_unknown_chunks(png_ptr, info_ptr, &cicp_chunk, 1); +} + +Status APNGEncoder::EncodePackedPixelFileToAPNG( + const PackedPixelFile& ppf, ThreadPool* pool, + std::vector* bytes) const { + size_t xsize = ppf.info.xsize; + size_t ysize = ppf.info.ysize; + bool has_alpha = ppf.info.alpha_bits != 0; + bool is_gray = ppf.info.num_color_channels == 1; + size_t color_channels = ppf.info.num_color_channels; + size_t num_channels = color_channels + (has_alpha ? 1 : 0); + size_t num_samples = num_channels * xsize * ysize; + + if (!ppf.info.have_animation && ppf.frames.size() != 1) { + return JXL_FAILURE("Invalid number of frames"); + } + + size_t count = 0; + size_t anim_chunks = 0; + + for (const auto& frame : ppf.frames) { + JXL_RETURN_IF_ERROR(VerifyPackedImage(frame.color, ppf.info)); + + const PackedImage& color = frame.color; + const JxlPixelFormat format = color.format; + const uint8_t* in = reinterpret_cast(color.pixels()); + size_t data_bits_per_sample = PackedImage::BitsPerChannel(format.data_type); + size_t bytes_per_sample = data_bits_per_sample / 8; + size_t out_bytes_per_sample = bytes_per_sample > 1 ? 2 : 1; + size_t out_stride = xsize * num_channels * out_bytes_per_sample; + size_t out_size = ysize * out_stride; + std::vector out(out_size); + + if (format.data_type == JXL_TYPE_UINT8) { + if (ppf.info.bits_per_sample < 8) { + float mul = 255.0 / ((1u << ppf.info.bits_per_sample) - 1); + for (size_t i = 0; i < num_samples; ++i) { + out[i] = static_cast(in[i] * mul + 0.5); + } + } else { + memcpy(&out[0], in, out_size); + } + } else if (format.data_type == JXL_TYPE_UINT16) { + if (ppf.info.bits_per_sample < 16 || + format.endianness != JXL_BIG_ENDIAN) { + float mul = 65535.0 / ((1u << ppf.info.bits_per_sample) - 1); + const uint8_t* p_in = in; + uint8_t* p_out = out.data(); + for (size_t i = 0; i < num_samples; ++i, p_in += 2, p_out += 2) { + uint32_t val = (format.endianness == JXL_BIG_ENDIAN ? LoadBE16(p_in) + : LoadLE16(p_in)); + StoreBE16(static_cast(val * mul + 0.5), p_out); + } + } else { + memcpy(&out[0], in, out_size); + } + } else if (format.data_type == JXL_TYPE_FLOAT) { + float mul = 65535.0; + const uint8_t* p_in = in; + uint8_t* p_out = out.data(); + for (size_t i = 0; i < num_samples; ++i, p_in += 4, p_out += 2) { + uint32_t val = (format.endianness == JXL_BIG_ENDIAN ? LoadBE32(p_in) + : LoadLE32(p_in)); + float fval; + memcpy(&fval, &val, 4); + StoreBE16(static_cast(fval * mul + 0.5), p_out); + } + } else { + return JXL_FAILURE("Unsupported pixel data type"); + } + + png_structp png_ptr; + png_infop info_ptr; + + png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL); + + if (!png_ptr) return JXL_FAILURE("Could not init png encoder"); + + info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) return JXL_FAILURE("Could not init png info struct"); + + png_set_write_fn(png_ptr, bytes, PngWrite, NULL); + png_set_flush(png_ptr, 0); + + int width = xsize; + int height = ysize; + + png_byte color_type = (is_gray ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB); + if (has_alpha) color_type |= PNG_COLOR_MASK_ALPHA; + png_byte bit_depth = out_bytes_per_sample * 8; + + png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, color_type, + PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_BASE, + PNG_FILTER_TYPE_BASE); + if (count == 0) { + MaybeAddCICP(ppf.color_encoding, png_ptr, info_ptr); + if (!ppf.icc.empty()) { + png_set_benign_errors(png_ptr, 1); + png_set_iCCP(png_ptr, info_ptr, "1", 0, ppf.icc.data(), ppf.icc.size()); + } + std::vector textstrings; + JXL_RETURN_IF_ERROR(BlobsWriterPNG::Encode(ppf.metadata, &textstrings)); + for (size_t kk = 0; kk + 1 < textstrings.size(); kk += 2) { + png_text text; + text.key = const_cast(textstrings[kk].c_str()); + text.text = const_cast(textstrings[kk + 1].c_str()); + text.compression = PNG_TEXT_COMPRESSION_zTXt; + png_set_text(png_ptr, info_ptr, &text, 1); + } + + png_write_info(png_ptr, info_ptr); + } else { + // fake writing a header, otherwise libpng gets confused + size_t pos = bytes->size(); + png_write_info(png_ptr, info_ptr); + bytes->resize(pos); + } + + if (ppf.info.have_animation) { + if (count == 0) { + png_byte adata[8]; + png_save_uint_32(adata, ppf.frames.size()); + png_save_uint_32(adata + 4, ppf.info.animation.num_loops); + png_byte actl[5] = "acTL"; + png_write_chunk(png_ptr, actl, adata, 8); + } + png_byte fdata[26]; + // TODO(jon): also make this work for the non-coalesced case + png_save_uint_32(fdata, anim_chunks++); + png_save_uint_32(fdata + 4, width); + png_save_uint_32(fdata + 8, height); + png_save_uint_32(fdata + 12, 0); + png_save_uint_32(fdata + 16, 0); + png_save_uint_16(fdata + 20, frame.frame_info.duration * + ppf.info.animation.tps_denominator); + png_save_uint_16(fdata + 22, ppf.info.animation.tps_numerator); + fdata[24] = 1; + fdata[25] = 0; + png_byte fctl[5] = "fcTL"; + png_write_chunk(png_ptr, fctl, fdata, 26); + } + + std::vector rows(height); + for (int y = 0; y < height; ++y) { + rows[y] = out.data() + y * out_stride; + } + + png_write_flush(png_ptr); + const size_t pos = bytes->size(); + png_write_image(png_ptr, &rows[0]); + png_write_flush(png_ptr); + if (count > 0) { + std::vector fdata(4); + png_save_uint_32(fdata.data(), anim_chunks++); + size_t p = pos; + while (p + 8 < bytes->size()) { + size_t len = png_get_uint_32(bytes->data() + p); + JXL_ASSERT(bytes->operator[](p + 4) == 'I'); + JXL_ASSERT(bytes->operator[](p + 5) == 'D'); + JXL_ASSERT(bytes->operator[](p + 6) == 'A'); + JXL_ASSERT(bytes->operator[](p + 7) == 'T'); + fdata.insert(fdata.end(), bytes->data() + p + 8, + bytes->data() + p + 8 + len); + p += len + 12; + } + bytes->resize(pos); + + png_byte fdat[5] = "fdAT"; + png_write_chunk(png_ptr, fdat, fdata.data(), fdata.size()); + } + + count++; + if (count == ppf.frames.size() || !ppf.info.have_animation) { + png_write_end(png_ptr, NULL); + } + + png_destroy_write_struct(&png_ptr, &info_ptr); + } + + return true; +} + +} // namespace + +std::unique_ptr GetAPNGEncoder() { + return jxl::make_unique(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/apng.h b/media/libjxl/src/lib/extras/enc/apng.h new file mode 100644 index 000000000..2a2139c8f --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/apng.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_APNG_H_ +#define LIB_EXTRAS_ENC_APNG_H_ + +// Encodes APNG images in memory. + +#include + +#include "lib/extras/enc/encode.h" + +namespace jxl { +namespace extras { + +std::unique_ptr GetAPNGEncoder(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_APNG_H_ diff --git a/media/libjxl/src/lib/extras/enc/encode.cc b/media/libjxl/src/lib/extras/enc/encode.cc new file mode 100644 index 000000000..dc593d290 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/encode.cc @@ -0,0 +1,136 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/encode.h" + +#include + +#if JPEGXL_ENABLE_APNG +#include "lib/extras/enc/apng.h" +#endif +#if JPEGXL_ENABLE_EXR +#include "lib/extras/enc/exr.h" +#endif +#if JPEGXL_ENABLE_JPEG +#include "lib/extras/enc/jpg.h" +#endif +#include "lib/extras/enc/npy.h" +#include "lib/extras/enc/pgx.h" +#include "lib/extras/enc/pnm.h" +#include "lib/jxl/base/printf_macros.h" + +namespace jxl { +namespace extras { + +Status Encoder::VerifyBasicInfo(const JxlBasicInfo& info) const { + if (info.xsize == 0 || info.ysize == 0) { + return JXL_FAILURE("Empty image"); + } + if (info.num_color_channels != 1 && info.num_color_channels != 3) { + return JXL_FAILURE("Invalid number of color channels"); + } + if (info.alpha_bits > 0 && info.alpha_bits != info.bits_per_sample) { + return JXL_FAILURE("Alpha bit depth does not match image bit depth"); + } + if (info.orientation != JXL_ORIENT_IDENTITY) { + return JXL_FAILURE("Orientation must be identity"); + } + return true; +} + +Status Encoder::VerifyPackedImage(const PackedImage& image, + const JxlBasicInfo& info) const { + if (image.pixels() == nullptr) { + return JXL_FAILURE("Invalid image."); + } + if (image.stride != image.xsize * image.pixel_stride()) { + return JXL_FAILURE("Invalid image stride."); + } + if (image.pixels_size != image.ysize * image.stride) { + return JXL_FAILURE("Invalid image size."); + } + size_t info_num_channels = + (info.num_color_channels + (info.alpha_bits > 0 ? 1 : 0)); + if (image.xsize != info.xsize || image.ysize != info.ysize || + image.format.num_channels != info_num_channels) { + return JXL_FAILURE("Frame size does not match image size"); + } + if (info.bits_per_sample > + PackedImage::BitsPerChannel(image.format.data_type)) { + return JXL_FAILURE("Bit depth does not fit pixel data type"); + } + return true; +} + +Status SelectFormat(const std::vector& accepted_formats, + const JxlBasicInfo& basic_info, JxlPixelFormat* format) { + const size_t original_bit_depth = basic_info.bits_per_sample; + size_t current_bit_depth = 0; + size_t num_alpha_channels = (basic_info.alpha_bits != 0 ? 1 : 0); + size_t num_channels = basic_info.num_color_channels + num_alpha_channels; + for (;;) { + for (const JxlPixelFormat& candidate : accepted_formats) { + if (candidate.num_channels != num_channels) continue; + const size_t candidate_bit_depth = + PackedImage::BitsPerChannel(candidate.data_type); + if ( + // Candidate bit depth is less than what we have and still enough + (original_bit_depth <= candidate_bit_depth && + candidate_bit_depth < current_bit_depth) || + // Or larger than the too-small bit depth we currently have + (current_bit_depth < candidate_bit_depth && + current_bit_depth < original_bit_depth)) { + *format = candidate; + current_bit_depth = candidate_bit_depth; + } + } + if (current_bit_depth == 0) { + if (num_channels > basic_info.num_color_channels) { + // Try dropping the alpha channel. + --num_channels; + continue; + } + return JXL_FAILURE("no appropriate format found"); + } + break; + } + if (current_bit_depth < original_bit_depth) { + JXL_WARNING("encoding %" PRIuS "-bit original to %" PRIuS " bits", + original_bit_depth, current_bit_depth); + } + return true; +} + +std::unique_ptr Encoder::FromExtension(std::string extension) { + std::transform( + extension.begin(), extension.end(), extension.begin(), + [](char c) { return std::tolower(c, std::locale::classic()); }); +#if JPEGXL_ENABLE_APNG + if (extension == ".png" || extension == ".apng") return GetAPNGEncoder(); +#endif + +#if JPEGXL_ENABLE_JPEG + if (extension == ".jpg") return GetJPEGEncoder(); + if (extension == ".jpeg") return GetJPEGEncoder(); +#endif + + if (extension == ".npy") return GetNumPyEncoder(); + + if (extension == ".pgx") return GetPGXEncoder(); + + if (extension == ".pam") return GetPAMEncoder(); + if (extension == ".pgm") return GetPGMEncoder(); + if (extension == ".ppm") return GetPPMEncoder(); + if (extension == ".pfm") return GetPFMEncoder(); + +#if JPEGXL_ENABLE_EXR + if (extension == ".exr") return GetEXREncoder(); +#endif + + return nullptr; +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/encode.h b/media/libjxl/src/lib/extras/enc/encode.h new file mode 100644 index 000000000..92eec50b6 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/encode.h @@ -0,0 +1,77 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_ENCODE_H_ +#define LIB_EXTRAS_ENC_ENCODE_H_ + +// Facade for image encoders. + +#include +#include + +#include "lib/extras/dec/decode.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace extras { + +struct EncodedImage { + // One (if the format supports animations or the image has only one frame) or + // more sequential bitstreams. + std::vector> bitstreams; + + // For each extra channel one or more sequential bitstreams. + std::vector>> extra_channel_bitstreams; + + std::vector preview_bitstream; + + // If the format does not support embedding color profiles into the bitstreams + // above, it will be present here, to be written as a separate file. If it + // does support them, this field will be empty. + std::vector icc; + + // Additional output for conformance testing, only filled in by NumPyEncoder. + std::vector metadata; +}; + +class Encoder { + public: + static std::unique_ptr FromExtension(std::string extension); + + virtual ~Encoder() = default; + + virtual std::vector AcceptedFormats() const = 0; + + // Any existing data in encoded_image is discarded. + virtual Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool = nullptr) const = 0; + + void SetOption(std::string name, std::string value) { + options_[std::move(name)] = std::move(value); + } + + protected: + const std::unordered_map& options() const { + return options_; + } + + Status VerifyBasicInfo(const JxlBasicInfo& info) const; + + Status VerifyPackedImage(const PackedImage& image, + const JxlBasicInfo& info) const; + + private: + std::unordered_map options_; +}; + +// TODO(sboukortt): consider exposing this as part of the C API. +Status SelectFormat(const std::vector& accepted_formats, + const JxlBasicInfo& basic_info, JxlPixelFormat* format); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_ENCODE_H_ diff --git a/media/libjxl/src/lib/extras/enc/exr.cc b/media/libjxl/src/lib/extras/enc/exr.cc new file mode 100644 index 000000000..05e05f96c --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/exr.cc @@ -0,0 +1,200 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/exr.h" + +#include +#include +#include +#include + +#include + +#include "jxl/codestream_header.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/byte_order.h" + +namespace jxl { +namespace extras { + +namespace { + +namespace OpenEXR = OPENEXR_IMF_NAMESPACE; +namespace Imath = IMATH_NAMESPACE; + +// OpenEXR::Int64 is deprecated in favor of using uint64_t directly, but using +// uint64_t as recommended causes build failures with previous OpenEXR versions +// on macOS, where the definition for OpenEXR::Int64 was actually not equivalent +// to uint64_t. This alternative should work in all cases. +using ExrInt64 = decltype(std::declval().tellg()); + +class InMemoryOStream : public OpenEXR::OStream { + public: + // `bytes` must outlive the InMemoryOStream. + explicit InMemoryOStream(std::vector* const bytes) + : OStream(/*fileName=*/""), bytes_(*bytes) {} + + void write(const char c[], const int n) override { + if (bytes_.size() < pos_ + n) { + bytes_.resize(pos_ + n); + } + std::copy_n(c, n, bytes_.begin() + pos_); + pos_ += n; + } + + ExrInt64 tellp() override { return pos_; } + void seekp(const ExrInt64 pos) override { + if (bytes_.size() + 1 < pos) { + bytes_.resize(pos - 1); + } + pos_ = pos; + } + + private: + std::vector& bytes_; + size_t pos_ = 0; +}; + +// Loads a Big-Endian float +float LoadBEFloat(const uint8_t* p) { + uint32_t u = LoadBE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +// Loads a Little-Endian float +float LoadLEFloat(const uint8_t* p) { + uint32_t u = LoadLE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +Status EncodeImageEXR(const PackedImage& image, const JxlBasicInfo& info, + const JxlColorEncoding& c_enc, ThreadPool* pool, + std::vector* bytes) { + OpenEXR::setGlobalThreadCount(0); + + const size_t xsize = info.xsize; + const size_t ysize = info.ysize; + const bool has_alpha = info.alpha_bits > 0; + const bool alpha_is_premultiplied = info.alpha_premultiplied; + + if (info.num_color_channels != 3 || + c_enc.color_space != JXL_COLOR_SPACE_RGB || + c_enc.transfer_function != JXL_TRANSFER_FUNCTION_LINEAR) { + return JXL_FAILURE("Unsupported color encoding for OpenEXR output."); + } + + const size_t num_channels = 3 + (has_alpha ? 1 : 0); + const JxlPixelFormat format = image.format; + + if (format.data_type != JXL_TYPE_FLOAT) { + return JXL_FAILURE("Unsupported pixel format for OpenEXR output"); + } + + const uint8_t* in = reinterpret_cast(image.pixels()); + size_t in_stride = num_channels * 4 * xsize; + + OpenEXR::Header header(xsize, ysize); + OpenEXR::Chromaticities chromaticities; + chromaticities.red = + Imath::V2f(c_enc.primaries_red_xy[0], c_enc.primaries_red_xy[1]); + chromaticities.green = + Imath::V2f(c_enc.primaries_green_xy[0], c_enc.primaries_green_xy[1]); + chromaticities.blue = + Imath::V2f(c_enc.primaries_blue_xy[0], c_enc.primaries_blue_xy[1]); + chromaticities.white = + Imath::V2f(c_enc.white_point_xy[0], c_enc.white_point_xy[1]); + OpenEXR::addChromaticities(header, chromaticities); + OpenEXR::addWhiteLuminance(header, 255.0f); + + auto loadFloat = + format.endianness == JXL_BIG_ENDIAN ? LoadBEFloat : LoadLEFloat; + auto loadAlpha = + has_alpha ? loadFloat : [](const uint8_t* p) -> float { return 1.0f; }; + + // Ensure that the destructor of RgbaOutputFile has run before we look at the + // size of `bytes`. + { + InMemoryOStream os(bytes); + OpenEXR::RgbaOutputFile output( + os, header, has_alpha ? OpenEXR::WRITE_RGBA : OpenEXR::WRITE_RGB); + // How many rows to write at once. Again, the OpenEXR documentation + // recommends writing the whole image in one call. + const int y_chunk_size = ysize; + std::vector output_rows(xsize * y_chunk_size); + + for (size_t start_y = 0; start_y < ysize; start_y += y_chunk_size) { + // Inclusive. + const size_t end_y = std::min(start_y + y_chunk_size - 1, ysize - 1); + output.setFrameBuffer(output_rows.data() - start_y * xsize, + /*xStride=*/1, /*yStride=*/xsize); + for (size_t y = start_y; y <= end_y; ++y) { + const uint8_t* in_row = &in[(y - start_y) * in_stride]; + OpenEXR::Rgba* const JXL_RESTRICT row_data = + &output_rows[(y - start_y) * xsize]; + for (size_t x = 0; x < xsize; ++x) { + const uint8_t* in_pixel = &in_row[4 * num_channels * x]; + float r = loadFloat(&in_pixel[0]); + float g = loadFloat(&in_pixel[4]); + float b = loadFloat(&in_pixel[8]); + const float alpha = loadAlpha(&in_pixel[12]); + if (!alpha_is_premultiplied) { + r *= alpha; + g *= alpha; + b *= alpha; + } + row_data[x] = OpenEXR::Rgba(r, g, b, alpha); + } + } + output.writePixels(/*numScanLines=*/end_y - start_y + 1); + } + } + + return true; +} + +class EXREncoder : public Encoder { + std::vector AcceptedFormats() const override { + std::vector formats; + for (const uint32_t num_channels : {1, 2, 3, 4}) { + for (const JxlDataType data_type : {JXL_TYPE_FLOAT, JXL_TYPE_FLOAT16}) { + for (JxlEndianness endianness : {JXL_BIG_ENDIAN, JXL_LITTLE_ENDIAN}) { + formats.push_back(JxlPixelFormat{/*num_channels=*/num_channels, + /*data_type=*/data_type, + /*endianness=*/endianness, + /*align=*/0}); + } + } + } + return formats; + } + Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool = nullptr) const override { + JXL_RETURN_IF_ERROR(VerifyBasicInfo(ppf.info)); + encoded_image->icc.clear(); + encoded_image->bitstreams.clear(); + encoded_image->bitstreams.reserve(ppf.frames.size()); + for (const auto& frame : ppf.frames) { + JXL_RETURN_IF_ERROR(VerifyPackedImage(frame.color, ppf.info)); + encoded_image->bitstreams.emplace_back(); + JXL_RETURN_IF_ERROR(EncodeImageEXR(frame.color, ppf.info, + ppf.color_encoding, pool, + &encoded_image->bitstreams.back())); + } + return true; + } +}; + +} // namespace + +std::unique_ptr GetEXREncoder() { + return jxl::make_unique(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/exr.h b/media/libjxl/src/lib/extras/enc/exr.h new file mode 100644 index 000000000..1baaa0272 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/exr.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_EXR_H_ +#define LIB_EXTRAS_ENC_EXR_H_ + +// Encodes OpenEXR images in memory. + +#include + +#include "lib/extras/enc/encode.h" + +namespace jxl { +namespace extras { + +std::unique_ptr GetEXREncoder(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_EXR_H_ diff --git a/media/libjxl/src/lib/extras/enc/jpg.cc b/media/libjxl/src/lib/extras/enc/jpg.cc new file mode 100644 index 000000000..93a39dd2e --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/jpg.cc @@ -0,0 +1,298 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/jpg.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/extras/exif.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/sanitizers.h" +#if JPEGXL_ENABLE_SJPEG +#include "sjpeg.h" +#endif + +namespace jxl { +namespace extras { + +namespace { + +constexpr unsigned char kICCSignature[12] = { + 0x49, 0x43, 0x43, 0x5F, 0x50, 0x52, 0x4F, 0x46, 0x49, 0x4C, 0x45, 0x00}; +constexpr int kICCMarker = JPEG_APP0 + 2; +constexpr size_t kMaxBytesInMarker = 65533; + +constexpr unsigned char kExifSignature[6] = {0x45, 0x78, 0x69, + 0x66, 0x00, 0x00}; +constexpr int kExifMarker = JPEG_APP0 + 1; + +enum class JpegEncoder { + kLibJpeg, + kSJpeg, +}; + +bool IsSRGBEncoding(const JxlColorEncoding& c) { + return ((c.color_space == JXL_COLOR_SPACE_RGB || + c.color_space == JXL_COLOR_SPACE_GRAY) && + c.primaries == JXL_PRIMARIES_SRGB && + c.white_point == JXL_WHITE_POINT_D65 && + c.transfer_function == JXL_TRANSFER_FUNCTION_SRGB); +} + +void WriteICCProfile(jpeg_compress_struct* const cinfo, + const std::vector& icc) { + constexpr size_t kMaxIccBytesInMarker = + kMaxBytesInMarker - sizeof kICCSignature - 2; + const int num_markers = + static_cast(DivCeil(icc.size(), kMaxIccBytesInMarker)); + size_t begin = 0; + for (int current_marker = 0; current_marker < num_markers; ++current_marker) { + const size_t length = std::min(kMaxIccBytesInMarker, icc.size() - begin); + jpeg_write_m_header( + cinfo, kICCMarker, + static_cast(length + sizeof kICCSignature + 2)); + for (const unsigned char c : kICCSignature) { + jpeg_write_m_byte(cinfo, c); + } + jpeg_write_m_byte(cinfo, current_marker + 1); + jpeg_write_m_byte(cinfo, num_markers); + for (size_t i = 0; i < length; ++i) { + jpeg_write_m_byte(cinfo, icc[begin]); + ++begin; + } + } +} +void WriteExif(jpeg_compress_struct* const cinfo, + const std::vector& exif) { + jpeg_write_m_header( + cinfo, kExifMarker, + static_cast(exif.size() + sizeof kExifSignature)); + for (const unsigned char c : kExifSignature) { + jpeg_write_m_byte(cinfo, c); + } + for (size_t i = 0; i < exif.size(); ++i) { + jpeg_write_m_byte(cinfo, exif[i]); + } +} + +Status SetChromaSubsampling(const std::string& subsampling, + jpeg_compress_struct* const cinfo) { + const std::pair, std::array>> + options[] = {{"444", {{{1, 1, 1}}, {{1, 1, 1}}}}, + {"420", {{{2, 1, 1}}, {{2, 1, 1}}}}, + {"422", {{{2, 1, 1}}, {{1, 1, 1}}}}, + {"440", {{{1, 1, 1}}, {{2, 1, 1}}}}}; + for (const auto& option : options) { + if (subsampling == option.first) { + for (size_t i = 0; i < 3; i++) { + cinfo->comp_info[i].h_samp_factor = option.second.first[i]; + cinfo->comp_info[i].v_samp_factor = option.second.second[i]; + } + return true; + } + } + return false; +} + +Status EncodeWithLibJpeg(const PackedImage& image, const JxlBasicInfo& info, + const std::vector& icc, + std::vector exif, size_t quality, + const std::string& chroma_subsampling, + std::vector* bytes) { + if (BITS_IN_JSAMPLE != 8 || sizeof(JSAMPLE) != 1) { + return JXL_FAILURE("Only 8 bit JSAMPLE is supported."); + } + jpeg_compress_struct cinfo; + // cinfo is initialized by libjpeg, which we are not instrumenting with + // msan. + msan::UnpoisonMemory(&cinfo, sizeof(cinfo)); + jpeg_error_mgr jerr; + cinfo.err = jpeg_std_error(&jerr); + jpeg_create_compress(&cinfo); + unsigned char* buffer = nullptr; + unsigned long size = 0; + jpeg_mem_dest(&cinfo, &buffer, &size); + cinfo.image_width = image.xsize; + cinfo.image_height = image.ysize; + cinfo.input_components = info.num_color_channels; + cinfo.in_color_space = info.num_color_channels == 1 ? JCS_GRAYSCALE : JCS_RGB; + jpeg_set_defaults(&cinfo); + cinfo.optimize_coding = TRUE; + if (cinfo.input_components == 3) { + JXL_RETURN_IF_ERROR(SetChromaSubsampling(chroma_subsampling, &cinfo)); + } + jpeg_set_quality(&cinfo, quality, TRUE); + jpeg_start_compress(&cinfo, TRUE); + if (!icc.empty()) { + WriteICCProfile(&cinfo, icc); + } + if (!exif.empty()) { + ResetExifOrientation(exif); + WriteExif(&cinfo, exif); + } + if (cinfo.input_components > 3 || cinfo.input_components < 0) + return JXL_FAILURE("invalid numbers of components"); + + std::vector raw_bytes(image.pixels_size); + memcpy(&raw_bytes[0], reinterpret_cast(image.pixels()), + image.pixels_size); + for (size_t y = 0; y < info.ysize; ++y) { + JSAMPROW row[] = {raw_bytes.data() + y * image.stride}; + + jpeg_write_scanlines(&cinfo, row, 1); + } + jpeg_finish_compress(&cinfo); + jpeg_destroy_compress(&cinfo); + bytes->resize(size); + // Compressed image data is initialized by libjpeg, which we are not + // instrumenting with msan. + msan::UnpoisonMemory(buffer, size); + std::copy_n(buffer, size, bytes->data()); + std::free(buffer); + return true; +} + +Status EncodeWithSJpeg(const PackedImage& image, const JxlBasicInfo& info, + const std::vector& icc, + std::vector exif, size_t quality, + const std::string& chroma_subsampling, + std::vector* bytes) { +#if !JPEGXL_ENABLE_SJPEG + return JXL_FAILURE("JPEG XL was built without sjpeg support"); +#else + sjpeg::EncoderParam param(quality); + if (!icc.empty()) { + param.iccp.assign(icc.begin(), icc.end()); + } + if (!exif.empty()) { + ResetExifOrientation(exif); + param.exif.assign(exif.begin(), exif.end()); + } + if (chroma_subsampling == "444") { + param.yuv_mode = SJPEG_YUV_444; + } else if (chroma_subsampling == "420") { + param.yuv_mode = SJPEG_YUV_SHARP; + } else { + return JXL_FAILURE("sjpeg does not support this chroma subsampling mode"); + } + size_t stride = info.xsize * 3; + const uint8_t* pixels = reinterpret_cast(image.pixels()); + std::string output; + JXL_RETURN_IF_ERROR( + sjpeg::Encode(pixels, image.xsize, image.ysize, stride, param, &output)); + bytes->assign( + reinterpret_cast(output.data()), + reinterpret_cast(output.data() + output.size())); + return true; +#endif +} + +Status EncodeImageJPG(const PackedImage& image, const JxlBasicInfo& info, + const std::vector& icc, + std::vector exif, JpegEncoder encoder, + size_t quality, const std::string& chroma_subsampling, + ThreadPool* pool, std::vector* bytes) { + if (image.format.data_type != JXL_TYPE_UINT8) { + return JXL_FAILURE("Unsupported pixel data type"); + } + if (info.alpha_bits > 0) { + return JXL_FAILURE("alpha is not supported"); + } + if (quality > 100) { + return JXL_FAILURE("please specify a 0-100 JPEG quality"); + } + + switch (encoder) { + case JpegEncoder::kLibJpeg: + JXL_RETURN_IF_ERROR(EncodeWithLibJpeg(image, info, icc, std::move(exif), + quality, chroma_subsampling, + bytes)); + break; + case JpegEncoder::kSJpeg: + JXL_RETURN_IF_ERROR(EncodeWithSJpeg(image, info, icc, std::move(exif), + quality, chroma_subsampling, bytes)); + break; + default: + return JXL_FAILURE("tried to use an unknown JPEG encoder"); + } + + return true; +} + +class JPEGEncoder : public Encoder { + std::vector AcceptedFormats() const override { + std::vector formats; + for (const uint32_t num_channels : {1, 3}) { + for (JxlEndianness endianness : {JXL_BIG_ENDIAN, JXL_LITTLE_ENDIAN}) { + formats.push_back(JxlPixelFormat{/*num_channels=*/num_channels, + /*data_type=*/JXL_TYPE_UINT8, + /*endianness=*/endianness, + /*align=*/0}); + } + } + return formats; + } + Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool = nullptr) const override { + JXL_RETURN_IF_ERROR(VerifyBasicInfo(ppf.info)); + const auto& options = this->options(); + int quality = 100; + auto it_quality = options.find("q"); + if (it_quality != options.end()) { + std::istringstream is(it_quality->second); + JXL_RETURN_IF_ERROR(static_cast(is >> quality)); + } + std::string chroma_subsampling = "444"; + auto it_chroma_subsampling = options.find("chroma_subsampling"); + if (it_chroma_subsampling != options.end()) { + chroma_subsampling = it_chroma_subsampling->second; + } + JpegEncoder jpeg_encoder = JpegEncoder::kLibJpeg; + auto it_encoder = options.find("jpeg_encoder"); + if (it_encoder != options.end()) { + if (it_encoder->second == "libjpeg") { + jpeg_encoder = JpegEncoder::kLibJpeg; + } else if (it_encoder->second == "sjpeg") { + jpeg_encoder = JpegEncoder::kSJpeg; + } else { + return JXL_FAILURE("unknown jpeg encoder \"%s\"", + it_encoder->second.c_str()); + } + } + std::vector icc; + if (!IsSRGBEncoding(ppf.color_encoding)) { + icc = ppf.icc; + } + encoded_image->bitstreams.clear(); + encoded_image->bitstreams.reserve(ppf.frames.size()); + for (const auto& frame : ppf.frames) { + JXL_RETURN_IF_ERROR(VerifyPackedImage(frame.color, ppf.info)); + encoded_image->bitstreams.emplace_back(); + JXL_RETURN_IF_ERROR(EncodeImageJPG( + frame.color, ppf.info, icc, ppf.metadata.exif, jpeg_encoder, quality, + chroma_subsampling, pool, &encoded_image->bitstreams.back())); + } + return true; + } +}; + +} // namespace + +std::unique_ptr GetJPEGEncoder() { + return jxl::make_unique(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/jpg.h b/media/libjxl/src/lib/extras/enc/jpg.h new file mode 100644 index 000000000..20b37cd16 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/jpg.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_JPG_H_ +#define LIB_EXTRAS_ENC_JPG_H_ + +// Encodes JPG pixels and metadata in memory. + +#include + +#include "lib/extras/enc/encode.h" + +namespace jxl { +namespace extras { + +std::unique_ptr GetJPEGEncoder(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_JPG_H_ diff --git a/media/libjxl/src/lib/extras/enc/npy.cc b/media/libjxl/src/lib/extras/enc/npy.cc new file mode 100644 index 000000000..1428e6427 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/npy.cc @@ -0,0 +1,322 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/npy.h" + +#include + +#include +#include +#include + +#include "jxl/types.h" +#include "lib/extras/packed_image.h" + +namespace jxl { +namespace extras { +namespace { + +// JSON value writing + +class JSONField { + public: + virtual ~JSONField() = default; + virtual void Write(std::ostream& o, uint32_t indent) const = 0; + + protected: + JSONField() = default; +}; + +class JSONValue : public JSONField { + public: + template + explicit JSONValue(const T& value) : value_(std::to_string(value)) {} + + explicit JSONValue(const std::string& value) : value_("\"" + value + "\"") {} + + explicit JSONValue(bool value) : value_(value ? "true" : "false") {} + + void Write(std::ostream& o, uint32_t indent) const override { o << value_; } + + private: + std::string value_; +}; + +class JSONDict : public JSONField { + public: + JSONDict() = default; + + template + T* AddEmpty(const std::string& key) { + static_assert(std::is_convertible::value, + "T must be a JSONField"); + T* ret = new T(); + values_.emplace_back( + key, std::unique_ptr(static_cast(ret))); + return ret; + } + + template + void Add(const std::string& key, const T& value) { + values_.emplace_back(key, std::unique_ptr(new JSONValue(value))); + } + + void Write(std::ostream& o, uint32_t indent) const override { + std::string indent_str(indent, ' '); + o << "{"; + bool is_first = true; + for (const auto& key_value : values_) { + if (!is_first) { + o << ","; + } + is_first = false; + o << std::endl << indent_str << " \"" << key_value.first << "\": "; + key_value.second->Write(o, indent + 2); + } + if (!values_.empty()) { + o << std::endl << indent_str; + } + o << "}"; + } + + private: + // Dictionary with order. + std::vector>> values_; +}; + +class JSONArray : public JSONField { + public: + JSONArray() = default; + + template + T* AddEmpty() { + static_assert(std::is_convertible::value, + "T must be a JSONField"); + T* ret = new T(); + values_.emplace_back(ret); + return ret; + } + + template + void Add(const T& value) { + values_.emplace_back(new JSONValue(value)); + } + + void Write(std::ostream& o, uint32_t indent) const override { + std::string indent_str(indent, ' '); + o << "["; + bool is_first = true; + for (const auto& value : values_) { + if (!is_first) { + o << ","; + } + is_first = false; + o << std::endl << indent_str << " "; + value->Write(o, indent + 2); + } + if (!values_.empty()) { + o << std::endl << indent_str; + } + o << "]"; + } + + private: + std::vector> values_; +}; + +void GenerateMetadata(const PackedPixelFile& ppf, std::vector* out) { + JSONDict meta; + // Same order as in 18181-3 CD. + + // Frames. + auto* meta_frames = meta.AddEmpty("frames"); + for (size_t i = 0; i < ppf.frames.size(); i++) { + auto* frame_i = meta_frames->AddEmpty(); + if (ppf.info.have_animation) { + frame_i->Add("duration", + JSONValue(ppf.frames[i].frame_info.duration * 1.0f * + ppf.info.animation.tps_denominator / + ppf.info.animation.tps_numerator)); + } + + frame_i->Add("name", JSONValue(ppf.frames[i].name)); + + if (ppf.info.animation.have_timecodes) { + frame_i->Add("timecode", JSONValue(ppf.frames[i].frame_info.timecode)); + } + } + +#define METADATA(FIELD) meta.Add(#FIELD, ppf.info.FIELD) + + METADATA(intensity_target); + METADATA(min_nits); + METADATA(relative_to_max_display); + METADATA(linear_below); + + if (ppf.info.have_preview) { + meta.AddEmpty("preview"); + // TODO(veluca): can we have duration/name/timecode here? + } + + { + auto ectype = meta.AddEmpty("extra_channel_type"); + auto bps = meta.AddEmpty("bits_per_sample"); + auto ebps = meta.AddEmpty("exp_bits_per_sample"); + bps->Add(ppf.info.bits_per_sample); + ebps->Add(ppf.info.exponent_bits_per_sample); + for (size_t i = 0; i < ppf.extra_channels_info.size(); i++) { + switch (ppf.extra_channels_info[i].ec_info.type) { + case JXL_CHANNEL_ALPHA: { + ectype->Add(std::string("Alpha")); + break; + } + case JXL_CHANNEL_DEPTH: { + ectype->Add(std::string("Depth")); + break; + } + case JXL_CHANNEL_SPOT_COLOR: { + ectype->Add(std::string("SpotColor")); + break; + } + case JXL_CHANNEL_SELECTION_MASK: { + ectype->Add(std::string("SelectionMask")); + break; + } + case JXL_CHANNEL_BLACK: { + ectype->Add(std::string("Black")); + break; + } + case JXL_CHANNEL_CFA: { + ectype->Add(std::string("CFA")); + break; + } + case JXL_CHANNEL_THERMAL: { + ectype->Add(std::string("Thermal")); + break; + } + default: { + ectype->Add(std::string("UNKNOWN")); + break; + } + } + bps->Add(ppf.extra_channels_info[i].ec_info.bits_per_sample); + ebps->Add(ppf.extra_channels_info[i].ec_info.exponent_bits_per_sample); + } + } + + std::ostringstream os; + meta.Write(os, 0); + out->resize(os.str().size()); + memcpy(out->data(), os.str().data(), os.str().size()); +} + +void Append(std::vector* out, const void* data, size_t size) { + size_t pos = out->size(); + out->resize(pos + size); + memcpy(out->data() + pos, data, size); +} + +void WriteNPYHeader(size_t xsize, size_t ysize, uint32_t num_channels, + size_t num_frames, std::vector* out) { + const uint8_t header[] = "\x93NUMPY\x01\x00"; + Append(out, header, 8); + std::stringstream ss; + ss << "{'descr': '(ss.str().size() % 256), + static_cast(ss.str().size() / 256)}; + Append(out, header_len, 2); + Append(out, ss.str().data(), ss.str().size()); +} + +bool WriteFrameToNPYArray(size_t xsize, size_t ysize, const PackedFrame& frame, + std::vector* out) { + const auto& color = frame.color; + if (color.xsize != xsize || color.ysize != ysize) { + return false; + } + for (const auto& ec : frame.extra_channels) { + if (ec.xsize != xsize || ec.ysize != ysize) { + return false; + } + } + // interleave the samples from color and extra channels + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + { + size_t sample_size = color.pixel_stride(); + size_t offset = y * color.stride + x * sample_size; + uint8_t* pixels = reinterpret_cast(color.pixels()); + JXL_ASSERT(offset + sample_size <= color.pixels_size); + Append(out, pixels + offset, sample_size); + } + for (const auto& ec : frame.extra_channels) { + size_t sample_size = ec.pixel_stride(); + size_t offset = y * ec.stride + x * sample_size; + uint8_t* pixels = reinterpret_cast(ec.pixels()); + JXL_ASSERT(offset + sample_size <= ec.pixels_size); + Append(out, pixels + offset, sample_size); + } + } + } + return true; +} + +// Writes a PackedPixelFile as a numpy 4D ndarray in binary format. +bool WriteNPYArray(const PackedPixelFile& ppf, std::vector* out) { + size_t xsize = ppf.info.xsize; + size_t ysize = ppf.info.ysize; + WriteNPYHeader(xsize, ysize, + ppf.info.num_color_channels + ppf.extra_channels_info.size(), + ppf.frames.size(), out); + for (const auto& frame : ppf.frames) { + if (!WriteFrameToNPYArray(xsize, ysize, frame, out)) { + return false; + } + } + return true; +} + +class NumPyEncoder : public Encoder { + public: + Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool = nullptr) const override { + JXL_RETURN_IF_ERROR(VerifyBasicInfo(ppf.info)); + GenerateMetadata(ppf, &encoded_image->metadata); + encoded_image->bitstreams.emplace_back(); + if (!WriteNPYArray(ppf, &encoded_image->bitstreams.back())) { + return false; + } + if (ppf.preview_frame) { + size_t xsize = ppf.info.preview.xsize; + size_t ysize = ppf.info.preview.ysize; + WriteNPYHeader(xsize, ysize, ppf.info.num_color_channels, 1, + &encoded_image->preview_bitstream); + if (!WriteFrameToNPYArray(xsize, ysize, *ppf.preview_frame, + &encoded_image->preview_bitstream)) { + return false; + } + } + return true; + } + std::vector AcceptedFormats() const override { + std::vector formats; + for (const uint32_t num_channels : {1, 3}) { + formats.push_back(JxlPixelFormat{num_channels, JXL_TYPE_FLOAT, + JXL_LITTLE_ENDIAN, /*align=*/0}); + } + return formats; + } +}; + +} // namespace + +std::unique_ptr GetNumPyEncoder() { + return jxl::make_unique(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/npy.h b/media/libjxl/src/lib/extras/enc/npy.h new file mode 100644 index 000000000..3ee6208ec --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/npy.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_NPY_H_ +#define LIB_EXTRAS_ENC_NPY_H_ + +// Encodes pixels to numpy array, used for conformance testing. + +#include + +#include "lib/extras/enc/encode.h" + +namespace jxl { +namespace extras { + +std::unique_ptr GetNumPyEncoder(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_NPY_H_ diff --git a/media/libjxl/src/lib/extras/enc/pgx.cc b/media/libjxl/src/lib/extras/enc/pgx.cc new file mode 100644 index 000000000..ef204ad1c --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/pgx.cc @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/pgx.h" + +#include +#include + +#include "jxl/codestream_header.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/printf_macros.h" + +namespace jxl { +namespace extras { +namespace { + +constexpr size_t kMaxHeaderSize = 200; + +Status EncodeHeader(const JxlBasicInfo& info, char* header, + int* chars_written) { + if (info.alpha_bits > 0) { + return JXL_FAILURE("PGX: can't store alpha"); + } + if (info.num_color_channels != 1) { + return JXL_FAILURE("PGX: must be grayscale"); + } + // TODO(lode): verify other bit depths: for other bit depths such as 1 or 4 + // bits, have a test case to verify it works correctly. For bits > 16, we may + // need to change the way external_image works. + if (info.bits_per_sample != 8 && info.bits_per_sample != 16) { + return JXL_FAILURE("PGX: bits other than 8 or 16 not yet supported"); + } + + // Use ML (Big Endian), LM may not be well supported by all decoders. + *chars_written = snprintf(header, kMaxHeaderSize, "PG ML + %u %u %u\n", + info.bits_per_sample, info.xsize, info.ysize); + JXL_RETURN_IF_ERROR(static_cast(*chars_written) < + kMaxHeaderSize); + return true; +} + +Status EncodeImagePGX(const PackedFrame& frame, const JxlBasicInfo& info, + std::vector* bytes) { + char header[kMaxHeaderSize]; + int header_size = 0; + JXL_RETURN_IF_ERROR(EncodeHeader(info, header, &header_size)); + + const PackedImage& color = frame.color; + const JxlPixelFormat format = color.format; + const uint8_t* in = reinterpret_cast(color.pixels()); + size_t data_bits_per_sample = PackedImage::BitsPerChannel(format.data_type); + size_t bytes_per_sample = data_bits_per_sample / kBitsPerByte; + size_t num_samples = info.xsize * info.ysize; + + if (info.bits_per_sample != data_bits_per_sample) { + return JXL_FAILURE("Bit depth does not match pixel data type"); + } + + std::vector pixels(num_samples * bytes_per_sample); + + if (format.data_type == JXL_TYPE_UINT8) { + memcpy(&pixels[0], in, num_samples * bytes_per_sample); + } else if (format.data_type == JXL_TYPE_UINT16) { + if (format.endianness != JXL_BIG_ENDIAN) { + const uint8_t* p_in = in; + uint8_t* p_out = pixels.data(); + for (size_t i = 0; i < num_samples; ++i, p_in += 2, p_out += 2) { + StoreBE16(LoadLE16(p_in), p_out); + } + } else { + memcpy(&pixels[0], in, num_samples * bytes_per_sample); + } + } else { + return JXL_FAILURE("Unsupported pixel data type"); + } + + bytes->resize(static_cast(header_size) + pixels.size()); + memcpy(bytes->data(), header, static_cast(header_size)); + memcpy(bytes->data() + header_size, pixels.data(), pixels.size()); + + return true; +} + +class PGXEncoder : public Encoder { + public: + std::vector AcceptedFormats() const override { + std::vector formats; + for (const JxlDataType data_type : {JXL_TYPE_UINT8, JXL_TYPE_UINT16}) { + for (JxlEndianness endianness : {JXL_BIG_ENDIAN, JXL_LITTLE_ENDIAN}) { + formats.push_back(JxlPixelFormat{/*num_channels=*/1, + /*data_type=*/data_type, + /*endianness=*/endianness, + /*align=*/0}); + } + } + return formats; + } + Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool) const override { + JXL_RETURN_IF_ERROR(VerifyBasicInfo(ppf.info)); + encoded_image->icc.assign(ppf.icc.begin(), ppf.icc.end()); + encoded_image->bitstreams.clear(); + encoded_image->bitstreams.reserve(ppf.frames.size()); + for (const auto& frame : ppf.frames) { + JXL_RETURN_IF_ERROR(VerifyPackedImage(frame.color, ppf.info)); + encoded_image->bitstreams.emplace_back(); + JXL_RETURN_IF_ERROR( + EncodeImagePGX(frame, ppf.info, &encoded_image->bitstreams.back())); + } + return true; + } +}; + +} // namespace + +std::unique_ptr GetPGXEncoder() { + return jxl::make_unique(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/pgx.h b/media/libjxl/src/lib/extras/enc/pgx.h new file mode 100644 index 000000000..f24e391b0 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/pgx.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_PGX_H_ +#define LIB_EXTRAS_ENC_PGX_H_ + +// Encodes PGX pixels in memory. + +#include +#include + +#include "lib/extras/enc/encode.h" + +namespace jxl { +namespace extras { + +std::unique_ptr GetPGXEncoder(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_PGX_H_ diff --git a/media/libjxl/src/lib/extras/enc/pnm.cc b/media/libjxl/src/lib/extras/enc/pnm.cc new file mode 100644 index 000000000..9b5f6cbc9 --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/pnm.cc @@ -0,0 +1,217 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/enc/pnm.h" + +#include +#include + +#include +#include + +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_image_bundle.h" +#include "lib/jxl/fields.h" // AllDefault +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { +namespace extras { +namespace { + +constexpr size_t kMaxHeaderSize = 200; + +Status EncodeHeader(const PackedImage& image, size_t bits_per_sample, + bool little_endian, char* header, int* chars_written) { + size_t num_channels = image.format.num_channels; + bool is_gray = num_channels <= 2; + bool has_alpha = num_channels == 2 || num_channels == 4; + if (has_alpha) { // PAM + if (bits_per_sample > 16) return JXL_FAILURE("PNM cannot have > 16 bits"); + const uint32_t max_val = (1U << bits_per_sample) - 1; + *chars_written = + snprintf(header, kMaxHeaderSize, + "P7\nWIDTH %" PRIuS "\nHEIGHT %" PRIuS + "\nDEPTH %u\nMAXVAL %u\nTUPLTYPE %s\nENDHDR\n", + image.xsize, image.ysize, is_gray ? 2 : 4, max_val, + is_gray ? "GRAYSCALE_ALPHA" : "RGB_ALPHA"); + JXL_RETURN_IF_ERROR(static_cast(*chars_written) < + kMaxHeaderSize); + } else if (bits_per_sample == 32) { // PFM + const char type = is_gray ? 'f' : 'F'; + const double scale = little_endian ? -1.0 : 1.0; + *chars_written = + snprintf(header, kMaxHeaderSize, "P%c\n%" PRIuS " %" PRIuS "\n%.1f\n", + type, image.xsize, image.ysize, scale); + JXL_RETURN_IF_ERROR(static_cast(*chars_written) < + kMaxHeaderSize); + } else { // PGM/PPM + if (bits_per_sample > 16) return JXL_FAILURE("PNM cannot have > 16 bits"); + const uint32_t max_val = (1U << bits_per_sample) - 1; + const char type = is_gray ? '5' : '6'; + *chars_written = + snprintf(header, kMaxHeaderSize, "P%c\n%" PRIuS " %" PRIuS "\n%u\n", + type, image.xsize, image.ysize, max_val); + JXL_RETURN_IF_ERROR(static_cast(*chars_written) < + kMaxHeaderSize); + } + return true; +} + +Status EncodeImagePNM(const PackedImage& image, size_t bits_per_sample, + std::vector* bytes) { + // Choose native for PFM; PGM/PPM require big-endian + bool is_little_endian = bits_per_sample > 16 && IsLittleEndian(); + char header[kMaxHeaderSize]; + int header_size = 0; + JXL_RETURN_IF_ERROR(EncodeHeader(image, bits_per_sample, is_little_endian, + header, &header_size)); + bytes->resize(static_cast(header_size) + image.pixels_size); + memcpy(bytes->data(), header, static_cast(header_size)); + const bool flipped_y = bits_per_sample == 32; // PFMs are flipped + const uint8_t* in = reinterpret_cast(image.pixels()); + uint8_t* out = bytes->data() + header_size; + for (size_t y = 0; y < image.ysize; ++y) { + size_t y_out = flipped_y ? image.ysize - 1 - y : y; + const uint8_t* row_in = &in[y * image.stride]; + uint8_t* row_out = &out[y_out * image.stride]; + memcpy(row_out, row_in, image.stride); + } + return true; +} + +class PNMEncoder : public Encoder { + public: + Status Encode(const PackedPixelFile& ppf, EncodedImage* encoded_image, + ThreadPool* pool = nullptr) const override { + JXL_RETURN_IF_ERROR(VerifyBasicInfo(ppf.info)); + if (!ppf.metadata.exif.empty() || !ppf.metadata.iptc.empty() || + !ppf.metadata.jumbf.empty() || !ppf.metadata.xmp.empty()) { + JXL_WARNING("PNM encoder ignoring metadata - use a different codec"); + } + encoded_image->icc = ppf.icc; + encoded_image->bitstreams.clear(); + encoded_image->bitstreams.reserve(ppf.frames.size()); + for (const auto& frame : ppf.frames) { + JXL_RETURN_IF_ERROR(VerifyPackedImage(frame.color, ppf.info)); + encoded_image->bitstreams.emplace_back(); + JXL_RETURN_IF_ERROR(EncodeImagePNM(frame.color, ppf.info.bits_per_sample, + &encoded_image->bitstreams.back())); + } + for (size_t i = 0; i < ppf.extra_channels_info.size(); ++i) { + const auto& ec_info = ppf.extra_channels_info[i].ec_info; + encoded_image->extra_channel_bitstreams.emplace_back(); + auto& ec_bitstreams = encoded_image->extra_channel_bitstreams.back(); + for (const auto& frame : ppf.frames) { + ec_bitstreams.emplace_back(); + JXL_RETURN_IF_ERROR(EncodeImagePNM(frame.extra_channels[i], + ec_info.bits_per_sample, + &ec_bitstreams.back())); + } + } + return true; + } +}; + +class PPMEncoder : public PNMEncoder { + public: + std::vector AcceptedFormats() const override { + std::vector formats; + for (const uint32_t num_channels : {1, 2, 3, 4}) { + for (const JxlDataType data_type : {JXL_TYPE_UINT8, JXL_TYPE_UINT16}) { + for (JxlEndianness endianness : {JXL_BIG_ENDIAN, JXL_LITTLE_ENDIAN}) { + formats.push_back(JxlPixelFormat{/*num_channels=*/num_channels, + /*data_type=*/data_type, + /*endianness=*/endianness, + /*align=*/0}); + } + } + } + return formats; + } +}; + +class PFMEncoder : public PNMEncoder { + public: + std::vector AcceptedFormats() const override { + std::vector formats; + for (const uint32_t num_channels : {1, 3}) { + for (const JxlDataType data_type : {JXL_TYPE_FLOAT16, JXL_TYPE_FLOAT}) { + for (JxlEndianness endianness : {JXL_BIG_ENDIAN, JXL_LITTLE_ENDIAN}) { + formats.push_back(JxlPixelFormat{/*num_channels=*/num_channels, + /*data_type=*/data_type, + /*endianness=*/endianness, + /*align=*/0}); + } + } + } + return formats; + } +}; + +class PGMEncoder : public PPMEncoder { + public: + std::vector AcceptedFormats() const override { + std::vector formats = PPMEncoder::AcceptedFormats(); + for (auto it = formats.begin(); it != formats.end();) { + if (it->num_channels > 2) { + it = formats.erase(it); + } else { + ++it; + } + } + return formats; + } +}; + +class PAMEncoder : public PPMEncoder { + public: + std::vector AcceptedFormats() const override { + std::vector formats = PPMEncoder::AcceptedFormats(); + for (auto it = formats.begin(); it != formats.end();) { + if (it->num_channels != 2 && it->num_channels != 4) { + it = formats.erase(it); + } else { + ++it; + } + } + return formats; + } +}; + +Span MakeSpan(const char* str) { + return Span(reinterpret_cast(str), + strlen(str)); +} + +} // namespace + +std::unique_ptr GetPPMEncoder() { + return jxl::make_unique(); +} + +std::unique_ptr GetPFMEncoder() { + return jxl::make_unique(); +} + +std::unique_ptr GetPGMEncoder() { + return jxl::make_unique(); +} + +std::unique_ptr GetPAMEncoder() { + return jxl::make_unique(); +} + +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/enc/pnm.h b/media/libjxl/src/lib/extras/enc/pnm.h new file mode 100644 index 000000000..403208cec --- /dev/null +++ b/media/libjxl/src/lib/extras/enc/pnm.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_ENC_PNM_H_ +#define LIB_EXTRAS_ENC_PNM_H_ + +// Encodes/decodes PBM/PGM/PPM/PFM pixels in memory. + +// TODO(janwas): workaround for incorrect Win64 codegen (cause unknown) +#include +#include + +#include "lib/extras/enc/encode.h" + +namespace jxl { +namespace extras { + +std::unique_ptr GetPAMEncoder(); +std::unique_ptr GetPGMEncoder(); +std::unique_ptr GetPPMEncoder(); +std::unique_ptr GetPFMEncoder(); + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_ENC_PNM_H_ diff --git a/media/libjxl/src/lib/extras/exif.cc b/media/libjxl/src/lib/extras/exif.cc new file mode 100644 index 000000000..7d926558c --- /dev/null +++ b/media/libjxl/src/lib/extras/exif.cc @@ -0,0 +1,55 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/exif.h" + +#include "lib/jxl/base/byte_order.h" + +namespace jxl { + +constexpr uint16_t kExifOrientationTag = 274; + +void ResetExifOrientation(std::vector& exif) { + if (exif.size() < 12) return; // not enough bytes for a valid exif blob + bool bigendian; + uint8_t* t = exif.data(); + if (LoadLE32(t) == 0x2A004D4D) { + bigendian = true; + } else if (LoadLE32(t) == 0x002A4949) { + bigendian = false; + } else { + return; // not a valid tiff header + } + t += 4; + uint32_t offset = (bigendian ? LoadBE32(t) : LoadLE32(t)); + if (exif.size() < 12 + offset + 2 || offset < 8) return; + t += offset - 4; + uint16_t nb_tags = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + while (nb_tags > 0) { + if (t + 12 >= exif.data() + exif.size()) return; + uint16_t tag = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + if (tag == kExifOrientationTag) { + uint16_t type = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + uint32_t count = (bigendian ? LoadBE32(t) : LoadLE32(t)); + t += 4; + if (type == 3 && count == 1) { + if (bigendian) { + StoreBE16(1, t); + } else { + StoreLE16(1, t); + } + } + return; + } else { + t += 10; + nb_tags--; + } + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/exif.h b/media/libjxl/src/lib/extras/exif.h new file mode 100644 index 000000000..f22b2ccef --- /dev/null +++ b/media/libjxl/src/lib/extras/exif.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_EXIF_H_ +#define LIB_EXTRAS_EXIF_H_ + +#include + +#include + +namespace jxl { + +// Sets the Exif orientation to the identity, to avoid repeated orientation +void ResetExifOrientation(std::vector& exif); + +} // namespace jxl + +#endif // LIB_EXTRAS_EXIF_H_ diff --git a/media/libjxl/src/lib/extras/hlg.cc b/media/libjxl/src/lib/extras/hlg.cc new file mode 100644 index 000000000..e39a0807f --- /dev/null +++ b/media/libjxl/src/lib/extras/hlg.cc @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/hlg.h" + +#include + +#include "lib/jxl/enc_color_management.h" + +namespace jxl { + +float GetHlgGamma(const float peak_luminance, const float surround_luminance) { + return 1.2f * std::pow(1.111f, std::log2(peak_luminance / 1000.f)) * + std::pow(0.98f, std::log2(surround_luminance / 5.f)); +} + +Status HlgOOTF(ImageBundle* ib, const float gamma, ThreadPool* pool) { + ColorEncoding linear_rec2020; + linear_rec2020.SetColorSpace(ColorSpace::kRGB); + linear_rec2020.primaries = Primaries::k2100; + linear_rec2020.white_point = WhitePoint::kD65; + linear_rec2020.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_RETURN_IF_ERROR(linear_rec2020.CreateICC()); + JXL_RETURN_IF_ERROR(ib->TransformTo(linear_rec2020, GetJxlCms(), pool)); + + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ib->ysize(), ThreadPool::NoInit, + [&](const int y, const int thread) { + float* const JXL_RESTRICT rows[3] = {ib->color()->PlaneRow(0, y), + ib->color()->PlaneRow(1, y), + ib->color()->PlaneRow(2, y)}; + for (size_t x = 0; x < ib->xsize(); ++x) { + float& red = rows[0][x]; + float& green = rows[1][x]; + float& blue = rows[2][x]; + const float luminance = + 0.2627f * red + 0.6780f * green + 0.0593f * blue; + const float ratio = std::pow(luminance, gamma - 1); + if (std::isfinite(ratio)) { + red *= ratio; + green *= ratio; + blue *= ratio; + } + } + }, + "HlgOOTF")); + return true; +} + +Status HlgInverseOOTF(ImageBundle* ib, const float gamma, ThreadPool* pool) { + return HlgOOTF(ib, 1.f / gamma, pool); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/hlg.h b/media/libjxl/src/lib/extras/hlg.h new file mode 100644 index 000000000..4cfec444f --- /dev/null +++ b/media/libjxl/src/lib/extras/hlg.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_HLG_H_ +#define LIB_EXTRAS_HLG_H_ + +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +float GetHlgGamma(float peak_luminance, float surround_luminance = 5.f); + +Status HlgOOTF(ImageBundle* ib, float gamma, ThreadPool* pool = nullptr); + +Status HlgInverseOOTF(ImageBundle* ib, float gamma, ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_EXTRAS_HLG_H_ diff --git a/media/libjxl/src/lib/extras/packed_image.h b/media/libjxl/src/lib/extras/packed_image.h new file mode 100644 index 000000000..129647210 --- /dev/null +++ b/media/libjxl/src/lib/extras/packed_image.h @@ -0,0 +1,150 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_PACKED_IMAGE_H_ +#define LIB_EXTRAS_PACKED_IMAGE_H_ + +// Helper class for storing external (int or float, interleaved) images. This is +// the common format used by other libraries and in the libjxl API. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "jxl/codestream_header.h" +#include "jxl/encode.h" +#include "jxl/types.h" +#include "lib/jxl/common.h" + +namespace jxl { +namespace extras { + +// Class representing an interleaved image with a bunch of channels. +class PackedImage { + public: + PackedImage(size_t xsize, size_t ysize, const JxlPixelFormat& format) + : PackedImage(xsize, ysize, format, CalcStride(format, xsize)) {} + + // The interleaved pixels as defined in the storage format. + void* pixels() const { return pixels_.get(); } + + // The image size in pixels. + size_t xsize; + size_t ysize; + + // The number of bytes per row. + size_t stride; + + // Pixel storage format and buffer size of the pixels_ pointer. + JxlPixelFormat format; + size_t pixels_size; + + size_t pixel_stride() const { + return (BitsPerChannel(format.data_type) * format.num_channels / + jxl::kBitsPerByte); + } + + static size_t BitsPerChannel(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + default: + JXL_ABORT("Unhandled JxlDataType"); + } + } + + private: + PackedImage(size_t xsize, size_t ysize, const JxlPixelFormat& format, + size_t stride) + : xsize(xsize), + ysize(ysize), + stride(stride), + format(format), + pixels_size(ysize * stride), + pixels_(malloc(std::max(1, pixels_size)), free) {} + + static size_t CalcStride(const JxlPixelFormat& format, size_t xsize) { + size_t stride = xsize * (BitsPerChannel(format.data_type) * + format.num_channels / jxl::kBitsPerByte); + if (format.align > 1) { + stride = jxl::DivCeil(stride, format.align) * format.align; + } + return stride; + } + + std::unique_ptr pixels_; +}; + +// Helper class representing a frame, as seen from the API. Animations will have +// multiple frames, but a single frame can have a color/grayscale channel and +// multiple extra channels. The order of the extra channels should be the same +// as all other frames in the same image. +class PackedFrame { + public: + template + explicit PackedFrame(Args&&... args) : color(std::forward(args)...) {} + + // The Frame metadata. + JxlFrameHeader frame_info = {}; + std::string name; + + // The pixel data for the color (or grayscale) channels. + PackedImage color; + // Extra channel image data. + std::vector extra_channels; +}; + +// Optional metadata associated with a file +class PackedMetadata { + public: + std::vector exif; + std::vector iptc; + std::vector jumbf; + std::vector xmp; +}; + +// Helper class representing a JXL image file as decoded to pixels from the API. +class PackedPixelFile { + public: + JxlBasicInfo info = {}; + + // The extra channel metadata information. + struct PackedExtraChannel { + JxlExtraChannelInfo ec_info; + size_t index; + std::string name; + }; + std::vector extra_channels_info; + + // Color information of the decoded pixels. + // If the icc is empty, the JxlColorEncoding should be used instead. + std::vector icc; + JxlColorEncoding color_encoding = {}; + // The icc profile of the original image. + std::vector orig_icc; + + std::unique_ptr preview_frame; + std::vector frames; + + PackedMetadata metadata; + PackedPixelFile() { JxlEncoderInitBasicInfo(&info); }; +}; + +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_PACKED_IMAGE_H_ diff --git a/media/libjxl/src/lib/extras/packed_image_convert.cc b/media/libjxl/src/lib/extras/packed_image_convert.cc new file mode 100644 index 000000000..dcdd12a67 --- /dev/null +++ b/media/libjxl/src/lib/extras/packed_image_convert.cc @@ -0,0 +1,302 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/packed_image_convert.h" + +#include + +#include "jxl/color_encoding.h" +#include "jxl/types.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_image_bundle.h" + +namespace jxl { +namespace extras { + +Status ConvertPackedFrameToImageBundle(const JxlBasicInfo& info, + const PackedFrame& frame, + const CodecInOut& io, ThreadPool* pool, + ImageBundle* bundle) { + JXL_ASSERT(frame.color.pixels() != nullptr); + const bool float_in = frame.color.format.data_type == JXL_TYPE_FLOAT16 || + frame.color.format.data_type == JXL_TYPE_FLOAT; + size_t frame_bits_per_sample = + float_in ? PackedImage::BitsPerChannel(frame.color.format.data_type) + : info.bits_per_sample; + JXL_ASSERT(frame_bits_per_sample != 0); + // It is ok for the frame.color.format.num_channels to not match the + // number of channels on the image. + JXL_ASSERT(1 <= frame.color.format.num_channels && + frame.color.format.num_channels <= 4); + + const Span span( + static_cast(frame.color.pixels()), + frame.color.pixels_size); + JXL_ASSERT(Rect(frame.frame_info.layer_info.crop_x0, + frame.frame_info.layer_info.crop_y0, + frame.frame_info.layer_info.xsize, + frame.frame_info.layer_info.ysize) + .IsInside(Rect(0, 0, info.xsize, info.ysize))); + if (info.have_animation) { + bundle->duration = frame.frame_info.duration; + bundle->blend = frame.frame_info.layer_info.blend_info.blendmode > 0; + bundle->use_for_next_frame = + frame.frame_info.layer_info.save_as_reference > 0; + bundle->origin.x0 = frame.frame_info.layer_info.crop_x0; + bundle->origin.y0 = frame.frame_info.layer_info.crop_y0; + } + bundle->name = frame.name; // frame.frame_info.name_length is ignored here. + JXL_ASSERT(io.metadata.m.color_encoding.IsGray() == + (frame.color.format.num_channels <= 2)); + + JXL_RETURN_IF_ERROR(ConvertFromExternal( + span, frame.color.xsize, frame.color.ysize, io.metadata.m.color_encoding, + frame.color.format.num_channels, + /*alpha_is_premultiplied=*/info.alpha_premultiplied, + frame_bits_per_sample, frame.color.format.endianness, pool, bundle, + /*float_in=*/float_in, /*align=*/0)); + + bundle->extra_channels().resize(io.metadata.m.extra_channel_info.size()); + for (size_t i = 0; i < frame.extra_channels.size(); i++) { + const auto& ppf_ec = frame.extra_channels[i]; + bundle->extra_channels()[i] = ImageF(ppf_ec.xsize, ppf_ec.ysize); + JXL_CHECK(BufferToImageF(ppf_ec.format, ppf_ec.xsize, ppf_ec.ysize, + ppf_ec.pixels(), ppf_ec.pixels_size, pool, + &bundle->extra_channels()[i])); + } + return true; +} + +Status ConvertPackedPixelFileToCodecInOut(const PackedPixelFile& ppf, + ThreadPool* pool, CodecInOut* io) { + const bool has_alpha = ppf.info.alpha_bits != 0; + JXL_ASSERT(!ppf.frames.empty()); + if (has_alpha) { + JXL_ASSERT(ppf.info.alpha_bits == ppf.info.bits_per_sample); + JXL_ASSERT(ppf.info.alpha_exponent_bits == + ppf.info.exponent_bits_per_sample); + } + + const bool is_gray = ppf.info.num_color_channels == 1; + JXL_ASSERT(ppf.info.num_color_channels == 1 || + ppf.info.num_color_channels == 3); + + // Convert the image metadata + io->SetSize(ppf.info.xsize, ppf.info.ysize); + io->metadata.m.bit_depth.bits_per_sample = ppf.info.bits_per_sample; + io->metadata.m.bit_depth.exponent_bits_per_sample = + ppf.info.exponent_bits_per_sample; + io->metadata.m.bit_depth.floating_point_sample = + ppf.info.exponent_bits_per_sample != 0; + io->metadata.m.modular_16_bit_buffer_sufficient = + ppf.info.exponent_bits_per_sample == 0 && ppf.info.bits_per_sample <= 12; + + io->metadata.m.SetAlphaBits(ppf.info.alpha_bits, + ppf.info.alpha_premultiplied); + + io->metadata.m.xyb_encoded = !ppf.info.uses_original_profile; + JXL_ASSERT(ppf.info.orientation > 0 && ppf.info.orientation <= 8); + io->metadata.m.orientation = ppf.info.orientation; + + // Convert animation metadata + JXL_ASSERT(ppf.frames.size() == 1 || ppf.info.have_animation); + io->metadata.m.have_animation = ppf.info.have_animation; + io->metadata.m.animation.tps_numerator = ppf.info.animation.tps_numerator; + io->metadata.m.animation.tps_denominator = ppf.info.animation.tps_denominator; + io->metadata.m.animation.num_loops = ppf.info.animation.num_loops; + + // Convert the color encoding. + if (!ppf.icc.empty()) { + PaddedBytes icc; + icc.append(ppf.icc); + if (!io->metadata.m.color_encoding.SetICC(std::move(icc))) { + fprintf(stderr, "Warning: error setting ICC profile, assuming SRGB\n"); + io->metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + } else { + if (io->metadata.m.color_encoding.IsGray() != is_gray) { + // E.g. JPG image has 3 channels, but gray ICC. + return JXL_FAILURE("Embedded ICC does not match image color type"); + } + } + } else { + JXL_RETURN_IF_ERROR(ConvertExternalToInternalColorEncoding( + ppf.color_encoding, &io->metadata.m.color_encoding)); + if (io->metadata.m.color_encoding.ICC().empty()) { + return JXL_FAILURE("Failed to serialize ICC"); + } + } + + // Convert the extra blobs + io->blobs.exif = ppf.metadata.exif; + io->blobs.iptc = ppf.metadata.iptc; + io->blobs.jumbf = ppf.metadata.jumbf; + io->blobs.xmp = ppf.metadata.xmp; + + // Append all other extra channels. + for (const PackedPixelFile::PackedExtraChannel& info : + ppf.extra_channels_info) { + ExtraChannelInfo out; + out.type = static_cast(info.ec_info.type); + out.bit_depth.bits_per_sample = info.ec_info.bits_per_sample; + out.bit_depth.exponent_bits_per_sample = + info.ec_info.exponent_bits_per_sample; + out.bit_depth.floating_point_sample = + info.ec_info.exponent_bits_per_sample != 0; + out.dim_shift = info.ec_info.dim_shift; + out.name = info.name; + out.alpha_associated = (info.ec_info.alpha_premultiplied != 0); + out.spot_color[0] = info.ec_info.spot_color[0]; + out.spot_color[1] = info.ec_info.spot_color[1]; + out.spot_color[2] = info.ec_info.spot_color[2]; + out.spot_color[3] = info.ec_info.spot_color[3]; + io->metadata.m.extra_channel_info.push_back(std::move(out)); + } + + // Convert the preview + if (ppf.preview_frame) { + size_t preview_xsize = ppf.preview_frame->color.xsize; + size_t preview_ysize = ppf.preview_frame->color.ysize; + io->metadata.m.have_preview = true; + JXL_RETURN_IF_ERROR( + io->metadata.m.preview_size.Set(preview_xsize, preview_ysize)); + JXL_RETURN_IF_ERROR(ConvertPackedFrameToImageBundle( + ppf.info, *ppf.preview_frame, *io, pool, &io->preview_frame)); + } + + // Convert the pixels + io->dec_pixels = 0; + io->frames.clear(); + for (const auto& frame : ppf.frames) { + ImageBundle bundle(&io->metadata.m); + JXL_RETURN_IF_ERROR( + ConvertPackedFrameToImageBundle(ppf.info, frame, *io, pool, &bundle)); + io->frames.push_back(std::move(bundle)); + io->dec_pixels += frame.color.xsize * frame.color.ysize; + } + + if (ppf.info.exponent_bits_per_sample == 0) { + // uint case. + io->metadata.m.bit_depth.bits_per_sample = io->Main().DetectRealBitdepth(); + } + if (ppf.info.intensity_target != 0) { + io->metadata.m.SetIntensityTarget(ppf.info.intensity_target); + } else { + SetIntensityTarget(io); + } + io->CheckMetadata(); + return true; +} + +// Allows converting from internal CodecInOut to external PackedPixelFile +Status ConvertCodecInOutToPackedPixelFile(const CodecInOut& io, + const JxlPixelFormat& pixel_format, + const ColorEncoding& c_desired, + ThreadPool* pool, + PackedPixelFile* ppf) { + const bool has_alpha = io.metadata.m.HasAlpha(); + bool alpha_premultiplied = false; + JXL_ASSERT(!io.frames.empty()); + + if (has_alpha) { + JXL_ASSERT(io.metadata.m.GetAlphaBits() == + io.metadata.m.bit_depth.bits_per_sample); + const auto* alpha_channel = io.metadata.m.Find(ExtraChannel::kAlpha); + JXL_ASSERT(alpha_channel->bit_depth.exponent_bits_per_sample == + io.metadata.m.bit_depth.exponent_bits_per_sample); + alpha_premultiplied = alpha_channel->alpha_associated; + } + + // Convert the image metadata + ppf->info.xsize = io.metadata.size.xsize(); + ppf->info.ysize = io.metadata.size.ysize(); + ppf->info.num_color_channels = io.metadata.m.color_encoding.Channels(); + ppf->info.bits_per_sample = io.metadata.m.bit_depth.bits_per_sample; + ppf->info.exponent_bits_per_sample = + io.metadata.m.bit_depth.exponent_bits_per_sample; + + ppf->info.alpha_bits = io.metadata.m.GetAlphaBits(); + ppf->info.alpha_premultiplied = alpha_premultiplied; + + ppf->info.uses_original_profile = !io.metadata.m.xyb_encoded; + JXL_ASSERT(0 < io.metadata.m.orientation && io.metadata.m.orientation <= 8); + ppf->info.orientation = + static_cast(io.metadata.m.orientation); + ppf->info.num_color_channels = io.metadata.m.color_encoding.Channels(); + + // Convert animation metadata + JXL_ASSERT(io.frames.size() == 1 || io.metadata.m.have_animation); + ppf->info.have_animation = io.metadata.m.have_animation; + ppf->info.animation.tps_numerator = io.metadata.m.animation.tps_numerator; + ppf->info.animation.tps_denominator = io.metadata.m.animation.tps_denominator; + ppf->info.animation.num_loops = io.metadata.m.animation.num_loops; + + // Convert the color encoding + ppf->icc.assign(c_desired.ICC().begin(), c_desired.ICC().end()); + ConvertInternalToExternalColorEncoding(c_desired, &ppf->color_encoding); + + // Convert the extra blobs + ppf->metadata.exif = io.blobs.exif; + ppf->metadata.iptc = io.blobs.iptc; + ppf->metadata.jumbf = io.blobs.jumbf; + ppf->metadata.xmp = io.blobs.xmp; + const bool float_out = pixel_format.data_type == JXL_TYPE_FLOAT || + pixel_format.data_type == JXL_TYPE_FLOAT16; + // Convert the pixels + ppf->frames.clear(); + for (const auto& frame : io.frames) { + JXL_ASSERT(frame.metadata()->bit_depth.bits_per_sample != 0); + // It is ok for the frame.color().kNumPlanes to not match the + // number of channels on the image. + const uint32_t num_channels = + frame.metadata()->color_encoding.Channels() + has_alpha; + JxlPixelFormat format{/*num_channels=*/num_channels, + /*data_type=*/pixel_format.data_type, + /*endianness=*/pixel_format.endianness, + /*align=*/pixel_format.align}; + + PackedFrame packed_frame(frame.oriented_xsize(), frame.oriented_ysize(), + format); + const size_t bits_per_sample = + float_out ? packed_frame.color.BitsPerChannel(pixel_format.data_type) + : ppf->info.bits_per_sample; + packed_frame.name = frame.name; + packed_frame.frame_info.name_length = frame.name.size(); + // Color transform + ImageBundle ib = frame.Copy(); + const ImageBundle* to_color_transform = &ib; + ImageMetadata metadata = io.metadata.m; + ImageBundle store(&metadata); + const ImageBundle* transformed; + // TODO(firsching): handle the transform here. + JXL_RETURN_IF_ERROR(TransformIfNeeded(*to_color_transform, c_desired, + GetJxlCms(), pool, &store, + &transformed)); + size_t stride = ib.oriented_xsize() * + (c_desired.Channels() * ppf->info.bits_per_sample) / + kBitsPerByte; + PaddedBytes pixels(stride * ib.oriented_ysize()); + + JXL_RETURN_IF_ERROR(ConvertToExternal( + *transformed, bits_per_sample, float_out, format.num_channels, + format.endianness, + /* stride_out=*/packed_frame.color.stride, pool, + packed_frame.color.pixels(), packed_frame.color.pixels_size, + /*out_callback=*/{}, frame.metadata()->GetOrientation())); + + // TODO(firsching): Convert the extra channels, beside one potential alpha + // channel. FIXME! + JXL_CHECK(frame.extra_channels().size() <= has_alpha); + ppf->frames.push_back(std::move(packed_frame)); + } + + return true; +} +} // namespace extras +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/packed_image_convert.h b/media/libjxl/src/lib/extras/packed_image_convert.h new file mode 100644 index 000000000..cada66044 --- /dev/null +++ b/media/libjxl/src/lib/extras/packed_image_convert.h @@ -0,0 +1,35 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_PACKED_IMAGE_CONVERT_H_ +#define LIB_EXTRAS_PACKED_IMAGE_CONVERT_H_ + +// Helper functions to convert from the external image types to the internal +// CodecInOut to help transitioning to the external types. + +#include "jxl/types.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace extras { + +// Converts an external PackedPixelFile to the internal CodecInOut for use with +// internal functions directly. +Status ConvertPackedPixelFileToCodecInOut(const PackedPixelFile& ppf, + ThreadPool* pool, CodecInOut* io); + +// Converts an internal CodecInOut for use with internal function to an external +// PackedPixelFile. +Status ConvertCodecInOutToPackedPixelFile(const CodecInOut& io, + const JxlPixelFormat& pixel_format, + const ColorEncoding& c_desired, + ThreadPool* pool, + PackedPixelFile* ppf); +} // namespace extras +} // namespace jxl + +#endif // LIB_EXTRAS_PACKED_IMAGE_CONVERT_H_ diff --git a/media/libjxl/src/lib/extras/render_hdr.cc b/media/libjxl/src/lib/extras/render_hdr.cc new file mode 100644 index 000000000..b247699cd --- /dev/null +++ b/media/libjxl/src/lib/extras/render_hdr.cc @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/render_hdr.h" + +#include "lib/extras/hlg.h" +#include "lib/extras/tone_mapping.h" +#include "lib/jxl/enc_color_management.h" + +namespace jxl { + +Status RenderHDR(CodecInOut* io, float display_nits, ThreadPool* pool) { + const ColorEncoding& original_color_encoding = io->metadata.m.color_encoding; + if (!(original_color_encoding.tf.IsPQ() || + original_color_encoding.tf.IsHLG())) { + // Nothing to do. + return true; + } + + if (original_color_encoding.tf.IsPQ()) { + JXL_RETURN_IF_ERROR(ToneMapTo({0, display_nits}, io, pool)); + JXL_RETURN_IF_ERROR(GamutMap(io, /*preserve_saturation=*/0.1, pool)); + } else { + const float intensity_target = io->metadata.m.IntensityTarget(); + const float gamma_hlg_to_display = GetHlgGamma(display_nits); + // If the image is already in display space, we need to account for the + // already-applied OOTF. + const float gamma_display_to_display = + gamma_hlg_to_display / GetHlgGamma(intensity_target); + // Ensures that conversions to linear in HlgOOTF below will not themselves + // include the OOTF. + io->metadata.m.SetIntensityTarget(300); + + bool need_gamut_mapping = false; + for (ImageBundle& ib : io->frames) { + const float gamma = ib.c_current().tf.IsHLG() ? gamma_hlg_to_display + : gamma_display_to_display; + if (gamma < 1) need_gamut_mapping = true; + JXL_RETURN_IF_ERROR(HlgOOTF(&ib, gamma, pool)); + } + io->metadata.m.SetIntensityTarget(display_nits); + + if (need_gamut_mapping) { + JXL_RETURN_IF_ERROR(GamutMap(io, /*preserve_saturation=*/0.1, pool)); + } + } + + ColorEncoding rec2020_pq; + rec2020_pq.SetColorSpace(ColorSpace::kRGB); + rec2020_pq.white_point = WhitePoint::kD65; + rec2020_pq.primaries = Primaries::k2100; + rec2020_pq.tf.SetTransferFunction(TransferFunction::kPQ); + JXL_RETURN_IF_ERROR(rec2020_pq.CreateICC()); + io->metadata.m.color_encoding = rec2020_pq; + return io->TransformTo(rec2020_pq, GetJxlCms(), pool); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/render_hdr.h b/media/libjxl/src/lib/extras/render_hdr.h new file mode 100644 index 000000000..95127e074 --- /dev/null +++ b/media/libjxl/src/lib/extras/render_hdr.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_RENDER_HDR_H_ +#define LIB_EXTRAS_RENDER_HDR_H_ + +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +// If `io` has an original color space using PQ or HLG, this renders it +// appropriately for a display with a peak luminance of `display_nits` and +// converts the result to a Rec. 2020 / PQ image. Otherwise, leaves the image as +// is. +// PQ images are tone-mapped using the method described in Rep. ITU-R BT.2408-5 +// annex 5, while HLG images are rendered using the HLG OOTF with a gamma +// appropriate for the given target luminance. +// With a sufficiently bright SDR display, converting the output of this +// function to an SDR colorspace may look decent. +Status RenderHDR(CodecInOut* io, float display_nits, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_EXTRAS_RENDER_HDR_H_ diff --git a/media/libjxl/src/lib/extras/time.cc b/media/libjxl/src/lib/extras/time.cc new file mode 100644 index 000000000..73d1b8f26 --- /dev/null +++ b/media/libjxl/src/lib/extras/time.cc @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/time.h" + +#include +#include +#include + +#include + +#include "lib/jxl/base/os_macros.h" // for JXL_OS_* + +#if JXL_OS_WIN +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#include +#endif // JXL_OS_WIN + +#if JXL_OS_MAC +#include +#include +#endif // JXL_OS_MAC + +#if JXL_OS_HAIKU +#include +#endif // JXL_OS_HAIKU + +namespace jxl { + +double Now() { +#if JXL_OS_WIN + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + LARGE_INTEGER freq; + (void)QueryPerformanceFrequency(&freq); + return double(counter.QuadPart) / freq.QuadPart; +#elif JXL_OS_MAC + const auto t = mach_absolute_time(); + // On OSX/iOS platform the elapsed time is cpu time unit + // We have to query the time base information to convert it back + // See https://developer.apple.com/library/mac/qa/qa1398/_index.html + static mach_timebase_info_data_t timebase; + if (timebase.denom == 0) { + (void)mach_timebase_info(&timebase); + } + return double(t) * timebase.numer / timebase.denom * 1E-9; +#elif JXL_OS_HAIKU + return double(system_time_nsecs()) * 1E-9; +#else + timespec t; + clock_gettime(CLOCK_MONOTONIC, &t); + return t.tv_sec + t.tv_nsec * 1E-9; +#endif +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/extras/time.h b/media/libjxl/src/lib/extras/time.h new file mode 100644 index 000000000..c71414b87 --- /dev/null +++ b/media/libjxl/src/lib/extras/time.h @@ -0,0 +1,19 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_TIME_H_ +#define LIB_EXTRAS_TIME_H_ + +// OS-specific function for timing. + +namespace jxl { + +// Returns current time [seconds] from a monotonic clock with unspecified +// starting point - only suitable for computing elapsed time. +double Now(); + +} // namespace jxl + +#endif // LIB_EXTRAS_TIME_H_ diff --git a/media/libjxl/src/lib/extras/tone_mapping.cc b/media/libjxl/src/lib/extras/tone_mapping.cc new file mode 100644 index 000000000..1ed1b2911 --- /dev/null +++ b/media/libjxl/src/lib/extras/tone_mapping.cc @@ -0,0 +1,131 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/tone_mapping.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/extras/tone_mapping.cc" +#include +#include + +#include "lib/jxl/dec_tone_mapping-inl.h" +#include "lib/jxl/enc_color_management.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +static constexpr float rec2020_luminances[3] = {0.2627f, 0.6780f, 0.0593f}; + +Status ToneMapFrame(const std::pair display_nits, + ImageBundle* const ib, ThreadPool* const pool) { + // Perform tone mapping as described in Report ITU-R BT.2390-8, section 5.4 + // (pp. 23-25). + // https://www.itu.int/pub/R-REP-BT.2390-8-2020 + + HWY_FULL(float) df; + using V = decltype(Zero(df)); + + ColorEncoding linear_rec2020; + linear_rec2020.SetColorSpace(ColorSpace::kRGB); + linear_rec2020.primaries = Primaries::k2100; + linear_rec2020.white_point = WhitePoint::kD65; + linear_rec2020.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_RETURN_IF_ERROR(linear_rec2020.CreateICC()); + JXL_RETURN_IF_ERROR(ib->TransformTo(linear_rec2020, GetJxlCms(), pool)); + + Rec2408ToneMapper tone_mapper( + {ib->metadata()->tone_mapping.min_nits, + ib->metadata()->IntensityTarget()}, + display_nits, rec2020_luminances); + + return RunOnPool( + pool, 0, ib->ysize(), ThreadPool::NoInit, + [&](const uint32_t y, size_t /* thread */) { + float* const JXL_RESTRICT row_r = ib->color()->PlaneRow(0, y); + float* const JXL_RESTRICT row_g = ib->color()->PlaneRow(1, y); + float* const JXL_RESTRICT row_b = ib->color()->PlaneRow(2, y); + for (size_t x = 0; x < ib->xsize(); x += Lanes(df)) { + V red = Load(df, row_r + x); + V green = Load(df, row_g + x); + V blue = Load(df, row_b + x); + tone_mapper.ToneMap(&red, &green, &blue); + Store(red, df, row_r + x); + Store(green, df, row_g + x); + Store(blue, df, row_b + x); + } + }, + "ToneMap"); +} + +Status GamutMapFrame(ImageBundle* const ib, float preserve_saturation, + ThreadPool* const pool) { + HWY_FULL(float) df; + using V = decltype(Zero(df)); + + ColorEncoding linear_rec2020; + linear_rec2020.SetColorSpace(ColorSpace::kRGB); + linear_rec2020.primaries = Primaries::k2100; + linear_rec2020.white_point = WhitePoint::kD65; + linear_rec2020.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_RETURN_IF_ERROR(linear_rec2020.CreateICC()); + JXL_RETURN_IF_ERROR(ib->TransformTo(linear_rec2020, GetJxlCms(), pool)); + + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ib->ysize(), ThreadPool::NoInit, + [&](const uint32_t y, size_t /* thread*/) { + float* const JXL_RESTRICT row_r = ib->color()->PlaneRow(0, y); + float* const JXL_RESTRICT row_g = ib->color()->PlaneRow(1, y); + float* const JXL_RESTRICT row_b = ib->color()->PlaneRow(2, y); + for (size_t x = 0; x < ib->xsize(); x += Lanes(df)) { + V red = Load(df, row_r + x); + V green = Load(df, row_g + x); + V blue = Load(df, row_b + x); + GamutMap(&red, &green, &blue, rec2020_luminances, + preserve_saturation); + Store(red, df, row_r + x); + Store(green, df, row_g + x); + Store(blue, df, row_b + x); + } + }, + "GamutMap")); + + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +namespace { +HWY_EXPORT(ToneMapFrame); +HWY_EXPORT(GamutMapFrame); +} // namespace + +Status ToneMapTo(const std::pair display_nits, + CodecInOut* const io, ThreadPool* const pool) { + const auto tone_map_frame = HWY_DYNAMIC_DISPATCH(ToneMapFrame); + for (ImageBundle& ib : io->frames) { + JXL_RETURN_IF_ERROR(tone_map_frame(display_nits, &ib, pool)); + } + io->metadata.m.SetIntensityTarget(display_nits.second); + return true; +} + +Status GamutMap(CodecInOut* const io, float preserve_saturation, + ThreadPool* const pool) { + const auto gamut_map_frame = HWY_DYNAMIC_DISPATCH(GamutMapFrame); + for (ImageBundle& ib : io->frames) { + JXL_RETURN_IF_ERROR(gamut_map_frame(&ib, preserve_saturation, pool)); + } + return true; +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/extras/tone_mapping.h b/media/libjxl/src/lib/extras/tone_mapping.h new file mode 100644 index 000000000..1f474101e --- /dev/null +++ b/media/libjxl/src/lib/extras/tone_mapping.h @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_EXTRAS_TONE_MAPPING_H_ +#define LIB_EXTRAS_TONE_MAPPING_H_ + +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +// Important: after calling this, the result will contain many out-of-gamut +// colors. It is very strongly recommended to call GamutMap afterwards to +// rectify this. +Status ToneMapTo(std::pair display_nits, CodecInOut* io, + ThreadPool* pool = nullptr); + +// `preserve_saturation` indicates to what extent to favor saturation over +// luminance when mapping out-of-gamut colors to Rec. 2020. 0 preserves +// luminance at the complete expense of saturation, while 1 gives the most +// saturated color with the same hue that Rec. 2020 can represent even if it +// means lowering the luminance. Values in between correspond to linear mixtures +// of those two extremes. +Status GamutMap(CodecInOut* io, float preserve_saturation, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_EXTRAS_TONE_MAPPING_H_ diff --git a/media/libjxl/src/lib/extras/tone_mapping_gbench.cc b/media/libjxl/src/lib/extras/tone_mapping_gbench.cc new file mode 100644 index 000000000..2f97b8866 --- /dev/null +++ b/media/libjxl/src/lib/extras/tone_mapping_gbench.cc @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/extras/codec.h" +#include "lib/extras/tone_mapping.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/testdata.h" + +namespace jxl { + +static void BM_ToneMapping(benchmark::State& state) { + CodecInOut image; + const PaddedBytes image_bytes = ReadTestData("jxl/flower/flower.png"); + JXL_CHECK(SetFromBytes(Span(image_bytes), &image)); + + // Convert to linear Rec. 2020 so that `ToneMapTo` doesn't have to and we + // mainly measure the tone mapping itself. + ColorEncoding linear_rec2020; + linear_rec2020.SetColorSpace(ColorSpace::kRGB); + linear_rec2020.primaries = Primaries::k2100; + linear_rec2020.white_point = WhitePoint::kD65; + linear_rec2020.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_CHECK(linear_rec2020.CreateICC()); + JXL_CHECK(image.TransformTo(linear_rec2020, GetJxlCms())); + + for (auto _ : state) { + state.PauseTiming(); + CodecInOut tone_mapping_input; + tone_mapping_input.SetFromImage(CopyImage(*image.Main().color()), + image.Main().c_current()); + tone_mapping_input.metadata.m.SetIntensityTarget( + image.metadata.m.IntensityTarget()); + state.ResumeTiming(); + + JXL_CHECK(ToneMapTo({0.1, 100}, &tone_mapping_input)); + } + + state.SetItemsProcessed(state.iterations() * image.xsize() * image.ysize()); +} +BENCHMARK(BM_ToneMapping); + +} // namespace jxl diff --git a/media/libjxl/src/lib/gbench_main.cc b/media/libjxl/src/lib/gbench_main.cc new file mode 100644 index 000000000..1cc177201 --- /dev/null +++ b/media/libjxl/src/lib/gbench_main.cc @@ -0,0 +1,8 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" + +BENCHMARK_MAIN(); diff --git a/media/libjxl/src/lib/include/jxl/butteraugli.h b/media/libjxl/src/lib/include/jxl/butteraugli.h new file mode 100644 index 000000000..ba69a2962 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/butteraugli.h @@ -0,0 +1,160 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_butteraugli + * @{ + * @file butteraugli.h + * @brief Butteraugli API for JPEG XL. + */ + +#ifndef JXL_BUTTERAUGLI_H_ +#define JXL_BUTTERAUGLI_H_ + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +#include "jxl/jxl_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" +#include "jxl/types.h" + +/** + * Opaque structure that holds a butteraugli API. + * + * Allocated and initialized with JxlButteraugliApiCreate(). + * Cleaned up and deallocated with JxlButteraugliApiDestroy(). + */ +typedef struct JxlButteraugliApiStruct JxlButteraugliApi; + +/** + * Opaque structure that holds intermediary butteraugli results. + * + * Allocated and initialized with JxlButteraugliCompute(). + * Cleaned up and deallocated with JxlButteraugliResultDestroy(). + */ +typedef struct JxlButteraugliResultStruct JxlButteraugliResult; + +/** + * Deinitializes and frees JxlButteraugliResult instance. + * + * @param result instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlButteraugliResultDestroy(JxlButteraugliResult* result); + +/** + * Creates an instance of JxlButteraugliApi and initializes it. + * + * @p memory_manager will be used for all the library dynamic allocations made + * from this instance. The parameter may be NULL, in which case the default + * allocator will be used. See jxl/memory_manager.h for details. + * + * @param memory_manager custom allocator function. It may be NULL. The memory + * manager will be copied internally. + * @return @c NULL if the instance can not be allocated or initialized + * @return pointer to initialized JxlEncoder otherwise + */ +JXL_EXPORT JxlButteraugliApi* JxlButteraugliApiCreate( + const JxlMemoryManager* memory_manager); + +/** + * Set the parallel runner for multithreading. + * + * @param api api instance. + * @param parallel_runner function pointer to runner for multithreading. A + * multithreaded runner should be set to reach fast performance. + * @param parallel_runner_opaque opaque pointer for parallel_runner. + */ +JXL_EXPORT void JxlButteraugliApiSetParallelRunner( + JxlButteraugliApi* api, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque); + +/** + * Set the hf_asymmetry option for butteraugli. + * + * @param api api instance. + * @param v new hf_asymmetry value. + */ +JXL_EXPORT void JxlButteraugliApiSetHFAsymmetry(JxlButteraugliApi* api, + float v); + +/** + * Set the intensity_target option for butteraugli. + * + * @param api api instance. + * @param v new intensity_target value. + */ +JXL_EXPORT void JxlButteraugliApiSetIntensityTarget(JxlButteraugliApi* api, + float v); + +/** + * Deinitializes and frees JxlButteraugliApi instance. + * + * @param api instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlButteraugliApiDestroy(JxlButteraugliApi* api); + +/** + * Computes intermediary butteraugli result between an original image and a + * distortion. + * + * @param api api instance for this computation. + * @param xsize width of the compared images. + * @param ysize height of the compared images. + * @param pixel_format_orig pixel format for original image. + * @param buffer_orig pixel data for original image. + * @param size_orig size of buffer_orig in bytes. + * @param pixel_format_dist pixel format for distortion. + * @param buffer_dist pixel data for distortion. + * @param size_dist size of buffer_dist in bytes. + * @return @c NULL if the results can not be computed or initialized. + * @return pointer to initialized and computed intermediary result. + */ +JXL_EXPORT JxlButteraugliResult* JxlButteraugliCompute( + const JxlButteraugliApi* api, uint32_t xsize, uint32_t ysize, + const JxlPixelFormat* pixel_format_orig, const void* buffer_orig, + size_t size_orig, const JxlPixelFormat* pixel_format_dist, + const void* buffer_dist, size_t size_dist); + +/** + * Computes butteraugli max distance based on an intermediary butteraugli + * result. + * + * @param result intermediary result instance. + * @return max distance. + */ +JXL_EXPORT float JxlButteraugliResultGetMaxDistance( + const JxlButteraugliResult* result); + +/** + * Computes a butteraugli distance based on an intermediary butteraugli result. + * + * @param result intermediary result instance. + * @param pnorm pnorm to calculate. + * @return distance using the given pnorm. + */ +JXL_EXPORT float JxlButteraugliResultGetDistance( + const JxlButteraugliResult* result, float pnorm); + +/** + * Get a pointer to the distmap in the result. + * + * @param result intermediary result instance. + * @param buffer will be set to the distmap. The distance value for (x,y) will + * be available at buffer + y * row_stride + x. + * @param row_stride will be set to the row stride of the distmap. + */ +JXL_EXPORT void JxlButteraugliResultGetDistmap( + const JxlButteraugliResult* result, const float** buffer, + uint32_t* row_stride); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_BUTTERAUGLI_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/butteraugli_cxx.h b/media/libjxl/src/lib/include/jxl/butteraugli_cxx.h new file mode 100644 index 000000000..55efd74d6 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/butteraugli_cxx.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/// @addtogroup libjxl_butteraugli +/// @{ +/// +/// @file butteraugli_cxx.h +/// @brief C++ header-only helper for @ref butteraugli.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_BUTTERAUGLI_CXX_H_ +#define JXL_BUTTERAUGLI_CXX_H_ + +#include + +#include "jxl/butteraugli.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error "This a C++ only header. Use jxl/butteraugli.h from C sources." +#endif + +/// Struct to call JxlButteraugliApiDestroy from the JxlButteraugliApiPtr +/// unique_ptr. +struct JxlButteraugliApiDestroyStruct { + /// Calls @ref JxlButteraugliApiDestroy() on the passed api. + void operator()(JxlButteraugliApi* api) { JxlButteraugliApiDestroy(api); } +}; + +/// std::unique_ptr<> type that calls JxlButteraugliApiDestroy() when releasing +/// the pointer. +/// +/// Use this helper type from C++ sources to ensure the api is destroyed and +/// their internal resources released. +typedef std::unique_ptr + JxlButteraugliApiPtr; + +/// Struct to call JxlButteraugliResultDestroy from the JxlButteraugliResultPtr +/// unique_ptr. +struct JxlButteraugliResultDestroyStruct { + /// Calls @ref JxlButteraugliResultDestroy() on the passed result object. + void operator()(JxlButteraugliResult* result) { + JxlButteraugliResultDestroy(result); + } +}; + +/// std::unique_ptr<> type that calls JxlButteraugliResultDestroy() when +/// releasing the pointer. +/// +/// Use this helper type from C++ sources to ensure the result object is +/// destroyed and their internal resources released. +typedef std::unique_ptr + JxlButteraugliResultPtr; + +#endif // JXL_BUTTERAUGLI_CXX_H_ + +/// @} diff --git a/media/libjxl/src/lib/include/jxl/cms_interface.h b/media/libjxl/src/lib/include/jxl/cms_interface.h new file mode 100644 index 000000000..fb852eeb1 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/cms_interface.h @@ -0,0 +1,232 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file cms_interface.h + * @brief Interface to allow the injection of different color management systems + * (CMSes, also called color management modules, or CMMs) in JPEG XL. + * + * A CMS is needed by the JPEG XL encoder and decoder to perform colorspace + * conversions. This defines an interface that can be implemented for different + * CMSes and then passed to the library. + */ + +#ifndef JXL_CMS_INTERFACE_H_ +#define JXL_CMS_INTERFACE_H_ + +#include "jxl/color_encoding.h" +#include "jxl/types.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Represents an input or output colorspace to a color transform, as a + * serialized ICC profile. */ +typedef struct { + /** The serialized ICC profile. This is guaranteed to be present and valid. */ + struct { + const uint8_t* data; + size_t size; + } icc; + + /** Structured representation of the colorspace, if applicable. If all fields + * are different from their "unknown" value, then this is equivalent to the + * ICC representation of the colorspace. If some are "unknown", those that are + * not are still valid and can still be used on their own if they are useful. + */ + JxlColorEncoding color_encoding; + + /** Number of components per pixel. This can be deduced from the other + * representations of the colorspace but is provided for convenience and + * validation. */ + size_t num_channels; +} JxlColorProfile; + +/** Allocates and returns the data needed for @p num_threads parallel transforms + * from the @p input colorspace to @p output, with up to @p pixels_per_thread + * pixels to transform per call to JxlCmsInterface::run. @p init_data comes + * directly from the JxlCmsInterface instance. Since @c run only receives the + * data returned by @c init, a reference to @p init_data should be kept there + * if access to it is desired in @c run. Likewise for JxlCmsInterface::destroy. + * + * The ICC data in @p input and @p output is guaranteed to outlive the @c init / + * @c run / @c destroy cycle. + * + * @param init_data JxlCmsInterface::init_data passed as-is. + * @param num_threads the maximum number of threads from which + * JxlCmsInterface::run will be called. + * @param pixels_per_thread the maximum number of pixels that each call to + * JxlCmsInterface::run will have to transform. + * @param input_profile the input colorspace for the transform. + * @param output_profile the colorspace to which JxlCmsInterface::run should + * convert the input data. + * @param intensity_target for colorspaces where luminance is relative + * (essentially: not PQ), indicates the luminance at which (1, 1, 1) will + * be displayed. This is useful for conversions between PQ and a relative + * luminance colorspace, in either direction: @p intensity_target cd/m² + * in PQ should map to and from (1, 1, 1) in the relative one.\n + * It is also used for conversions to and from HLG, as it is + * scene-referred while other colorspaces are assumed to be + * display-referred. That is, conversions from HLG should apply the OOTF + * for a peak display luminance of @p intensity_target, and conversions + * to HLG should undo it. The OOTF is a gamma function applied to the + * luminance channel (https://www.itu.int/rec/R-REC-BT.2100-2-201807-I + * page 7), with the gamma value computed as + * 1.2 * 1.111^log2(intensity_target / 1000) (footnote 2 page 8 + * of the same document). + * @return The data needed for the transform, or @c NULL in case of failure. + * This will be passed to the other functions as @c user_data. + */ +typedef void* (*jpegxl_cms_init_func)(void* init_data, size_t num_threads, + size_t pixels_per_thread, + const JxlColorProfile* input_profile, + const JxlColorProfile* output_profile, + float intensity_target); + +/** Returns a buffer that can be used by callers of the interface to store the + * input of the conversion or read its result, if they pass it as the input or + * output of the @c run function. + * @param user_data the data returned by @c init. + * @param thread the index of the thread for which to return a buffer. + * @return A buffer that can be used by the caller for passing to @c run. + */ +typedef float* (*jpegxl_cms_get_buffer_func)(void* user_data, size_t thread); + +/** Executes one transform and returns true on success or false on error. It + * must be possible to call this from different threads with different values + * for @p thread, all between 0 (inclusive) and the value of @p num_threads + * passed to @c init (exclusive). It is allowed to implement this by locking + * such that the transforms are essentially performed sequentially, if such a + * performance profile is acceptable. @p user_data is the data returned by + * @c init. + * The buffers each contain @p num_pixels × @c num_channels interleaved floating + * point (0..1) samples where @c num_channels is the number of color channels of + * their respective color profiles. It is guaranteed that the only case in which + * they might overlap is if the output has fewer channels than the input, in + * which case the pointers may be identical. + * For CMYK data, 0 represents the maximum amount of ink while 1 represents no + * ink. + * @param user_data the data returned by @c init. + * @param thread the index of the thread from which the function is being + * called. + * @param input_buffer the buffer containing the pixel data to be transformed. + * @param output_buffer the buffer receiving the transformed pixel data. + * @param num_pixels the number of pixels to transform from @p input to + * @p output. + * @return JXL_TRUE on success, JXL_FALSE on failure. + */ +typedef JXL_BOOL (*jpegxl_cms_run_func)(void* user_data, size_t thread, + const float* input_buffer, + float* output_buffer, + size_t num_pixels); + +/** Performs the necessary clean-up and frees the memory allocated for user + * data. + */ +typedef void (*jpegxl_cms_destroy_func)(void*); + +/** + * Interface for performing colorspace transforms. The @c init function can be + * called several times to instantiate several transforms, including before + * other transforms have been destroyed. + * + * The call sequence for a given colorspace transform could look like the + * following: + * @dot + * digraph calls { + * newrank = true + * node [shape = box, fontname = monospace] + * init [label = "user_data <- init(\l\ + * init_data = data,\l\ + * num_threads = 3,\l\ + * pixels_per_thread = 20,\l\ + * input = (sRGB, 3 channels),\l\ + * output = (Display-P3, 3 channels),\l\ + * intensity_target = 255\l\ + * )\l"] + * subgraph cluster_0 { + * color = lightgrey + * label = "thread 1" + * labeljust = "c" + * run_1_1 [label = "run(\l\ + * user_data,\l\ + * thread = 1,\l\ + * input = in[0],\l\ + * output = out[0],\l\ + * num_pixels = 20\l\ + * )\l"] + * run_1_2 [label = "run(\l\ + * user_data,\l\ + * thread = 1,\l\ + * input = in[3],\l\ + * output = out[3],\l\ + * num_pixels = 20\l\ + * )\l"] + * } + * subgraph cluster_1 { + * color = lightgrey + * label = "thread 2" + * labeljust = "l" + * run_2_1 [label = "run(\l\ + * user_data,\l\ + * thread = 2,\l\ + * input = in[1],\l\ + * output = out[1],\l\ + * num_pixels = 20\l\ + * )\l"] + * run_2_2 [label = "run(\l\ + * user_data,\l\ + * thread = 2,\l\ + * input = in[4],\l\ + * output = out[4],\l\ + * num_pixels = 13\l\ + * )\l"] + * } + * subgraph cluster_3 { + * color = lightgrey + * label = "thread 3" + * labeljust = "c" + * run_3_1 [label = "run(\l\ + * user_data,\l\ + * thread = 3,\l\ + * input = in[2],\l\ + * output = out[2],\l\ + * num_pixels = 20\l\ + * )\l"] + * } + * init -> {run_1_1; run_2_1; run_3_1; rank = same} + * run_1_1 -> run_1_2 + * run_2_1 -> run_2_2 + * {run_1_2; run_2_2, run_3_1} -> "destroy(user_data)" + * } + * @enddot + */ +typedef struct { + /** CMS-specific data that will be passed to @ref init. */ + void* init_data; + /** Prepares a colorspace transform as described in the documentation of @ref + * jpegxl_cms_init_func. */ + jpegxl_cms_init_func init; + /** Returns a buffer that can be used as input to @c run. */ + jpegxl_cms_get_buffer_func get_src_buf; + /** Returns a buffer that can be used as output from @c run. */ + jpegxl_cms_get_buffer_func get_dst_buf; + /** Executes the transform on a batch of pixels, per @ref jpegxl_cms_run_func. + */ + jpegxl_cms_run_func run; + /** Cleans up the transform. */ + jpegxl_cms_destroy_func destroy; +} JxlCmsInterface; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_CMS_INTERFACE_H_ */ + +/** @} */ diff --git a/media/libjxl/src/lib/include/jxl/codestream_header.h b/media/libjxl/src/lib/include/jxl/codestream_header.h new file mode 100644 index 000000000..d12657787 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/codestream_header.h @@ -0,0 +1,438 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file codestream_header.h + * @brief Definitions of structs and enums for the metadata from the JPEG XL + * codestream headers (signature, metadata, preview dimensions, ...), excluding + * color encoding which is in color_encoding.h. + */ + +#ifndef JXL_CODESTREAM_HEADER_H_ +#define JXL_CODESTREAM_HEADER_H_ + +#include +#include + +#include "jxl/types.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Image orientation metadata. + * Values 1..8 match the EXIF definitions. + * The name indicates the operation to perform to transform from the encoded + * image to the display image. + */ +typedef enum { + JXL_ORIENT_IDENTITY = 1, + JXL_ORIENT_FLIP_HORIZONTAL = 2, + JXL_ORIENT_ROTATE_180 = 3, + JXL_ORIENT_FLIP_VERTICAL = 4, + JXL_ORIENT_TRANSPOSE = 5, + JXL_ORIENT_ROTATE_90_CW = 6, + JXL_ORIENT_ANTI_TRANSPOSE = 7, + JXL_ORIENT_ROTATE_90_CCW = 8, +} JxlOrientation; + +/** Given type of an extra channel. + */ +typedef enum { + JXL_CHANNEL_ALPHA, + JXL_CHANNEL_DEPTH, + JXL_CHANNEL_SPOT_COLOR, + JXL_CHANNEL_SELECTION_MASK, + JXL_CHANNEL_BLACK, + JXL_CHANNEL_CFA, + JXL_CHANNEL_THERMAL, + JXL_CHANNEL_RESERVED0, + JXL_CHANNEL_RESERVED1, + JXL_CHANNEL_RESERVED2, + JXL_CHANNEL_RESERVED3, + JXL_CHANNEL_RESERVED4, + JXL_CHANNEL_RESERVED5, + JXL_CHANNEL_RESERVED6, + JXL_CHANNEL_RESERVED7, + JXL_CHANNEL_UNKNOWN, + JXL_CHANNEL_OPTIONAL +} JxlExtraChannelType; + +/** The codestream preview header */ +typedef struct { + /** Preview width in pixels */ + uint32_t xsize; + + /** Preview height in pixels */ + uint32_t ysize; +} JxlPreviewHeader; + +/** The intrinsic size header */ +typedef struct { + /** Intrinsic width in pixels */ + uint32_t xsize; + + /** Intrinsic height in pixels */ + uint32_t ysize; +} JxlIntrinsicSizeHeader; + +/** The codestream animation header, optionally present in the beginning of + * the codestream, and if it is it applies to all animation frames, unlike + * JxlFrameHeader which applies to an individual frame. + */ +typedef struct { + /** Numerator of ticks per second of a single animation frame time unit */ + uint32_t tps_numerator; + + /** Denominator of ticks per second of a single animation frame time unit */ + uint32_t tps_denominator; + + /** Amount of animation loops, or 0 to repeat infinitely */ + uint32_t num_loops; + + /** Whether animation time codes are present at animation frames in the + * codestream */ + JXL_BOOL have_timecodes; +} JxlAnimationHeader; + +/** Basic image information. This information is available from the file + * signature and first part of the codestream header. + */ +typedef struct { + /* TODO(lode): need additional fields for (transcoded) JPEG? For reusable + * fields orientation must be read from Exif APP1. For has_icc_profile: must + * look up where ICC profile is guaranteed to be in a JPEG file to be able to + * indicate this. */ + + /* TODO(lode): make struct packed, and/or make this opaque struct with getter + * functions (still separate struct from opaque decoder) */ + + /** Whether the codestream is embedded in the container format. If true, + * metadata information and extensions may be available in addition to the + * codestream. + */ + JXL_BOOL have_container; + + /** Width of the image in pixels, before applying orientation. + */ + uint32_t xsize; + + /** Height of the image in pixels, before applying orientation. + */ + uint32_t ysize; + + /** Original image color channel bit depth. + */ + uint32_t bits_per_sample; + + /** Original image color channel floating point exponent bits, or 0 if they + * are unsigned integer. For example, if the original data is half-precision + * (binary16) floating point, bits_per_sample is 16 and + * exponent_bits_per_sample is 5, and so on for other floating point + * precisions. + */ + uint32_t exponent_bits_per_sample; + + /** Upper bound on the intensity level present in the image in nits. For + * unsigned integer pixel encodings, this is the brightness of the largest + * representable value. The image does not necessarily contain a pixel + * actually this bright. An encoder is allowed to set 255 for SDR images + * without computing a histogram. + * Leaving this set to its default of 0 lets libjxl choose a sensible default + * value based on the color encoding. + */ + float intensity_target; + + /** Lower bound on the intensity level present in the image. This may be + * loose, i.e. lower than the actual darkest pixel. When tone mapping, a + * decoder will map [min_nits, intensity_target] to the display range. + */ + float min_nits; + + /** See the description of @see linear_below. + */ + JXL_BOOL relative_to_max_display; + + /** The tone mapping will leave unchanged (linear mapping) any pixels whose + * brightness is strictly below this. The interpretation depends on + * relative_to_max_display. If true, this is a ratio [0, 1] of the maximum + * display brightness [nits], otherwise an absolute brightness [nits]. + */ + float linear_below; + + /** Whether the data in the codestream is encoded in the original color + * profile that is attached to the codestream metadata header, or is + * encoded in an internally supported absolute color space (which the decoder + * can always convert to linear or non-linear sRGB or to XYB). If the original + * profile is used, the decoder outputs pixel data in the color space matching + * that profile, but doesn't convert it to any other color space. If the + * original profile is not used, the decoder only outputs the data as sRGB + * (linear if outputting to floating point, nonlinear with standard sRGB + * transfer function if outputting to unsigned integers) but will not convert + * it to to the original color profile. The decoder also does not convert to + * the target display color profile. To convert the pixel data produced by + * the decoder to the original color profile, one of the JxlDecoderGetColor* + * functions needs to be called with @ref JXL_COLOR_PROFILE_TARGET_DATA to get + * the color profile of the decoder output, and then an external CMS can be + * used for conversion. + * Note that for lossy compression, this should be set to false for most use + * cases, and if needed, the image should be converted to the original color + * profile after decoding, as described above. + */ + JXL_BOOL uses_original_profile; + + /** Indicates a preview image exists near the beginning of the codestream. + * The preview itself or its dimensions are not included in the basic info. + */ + JXL_BOOL have_preview; + + /** Indicates animation frames exist in the codestream. The animation + * information is not included in the basic info. + */ + JXL_BOOL have_animation; + + /** Image orientation, value 1-8 matching the values used by JEITA CP-3451C + * (Exif version 2.3). + */ + JxlOrientation orientation; + + /** Number of color channels encoded in the image, this is either 1 for + * grayscale data, or 3 for colored data. This count does not include + * the alpha channel or other extra channels. To check presence of an alpha + * channel, such as in the case of RGBA color, check alpha_bits != 0. + * If and only if this is 1, the JxlColorSpace in the JxlColorEncoding is + * JXL_COLOR_SPACE_GRAY. + */ + uint32_t num_color_channels; + + /** Number of additional image channels. This includes the main alpha channel, + * but can also include additional channels such as depth, additional alpha + * channels, spot colors, and so on. Information about the extra channels + * can be queried with JxlDecoderGetExtraChannelInfo. The main alpha channel, + * if it exists, also has its information available in the alpha_bits, + * alpha_exponent_bits and alpha_premultiplied fields in this JxlBasicInfo. + */ + uint32_t num_extra_channels; + + /** Bit depth of the encoded alpha channel, or 0 if there is no alpha channel. + * If present, matches the alpha_bits value of the JxlExtraChannelInfo + * associated with this alpha channel. + */ + uint32_t alpha_bits; + + /** Alpha channel floating point exponent bits, or 0 if they are unsigned. If + * present, matches the alpha_bits value of the JxlExtraChannelInfo associated + * with this alpha channel. integer. + */ + uint32_t alpha_exponent_bits; + + /** Whether the alpha channel is premultiplied. Only used if there is a main + * alpha channel. Matches the alpha_premultiplied value of the + * JxlExtraChannelInfo associated with this alpha channel. + */ + JXL_BOOL alpha_premultiplied; + + /** Dimensions of encoded preview image, only used if have_preview is + * JXL_TRUE. + */ + JxlPreviewHeader preview; + + /** Animation header with global animation properties for all frames, only + * used if have_animation is JXL_TRUE. + */ + JxlAnimationHeader animation; + + /** Intrinsic width of the image. + * The intrinsic size can be different from the actual size in pixels + * (as given by xsize and ysize) and it denotes the recommended dimensions + * for displaying the image, i.e. applications are advised to resample the + * decoded image to the intrinsic dimensions. + */ + uint32_t intrinsic_xsize; + + /** Intrinsic heigth of the image. + * The intrinsic size can be different from the actual size in pixels + * (as given by xsize and ysize) and it denotes the recommended dimensions + * for displaying the image, i.e. applications are advised to resample the + * decoded image to the intrinsic dimensions. + */ + uint32_t intrinsic_ysize; + + /** Padding for forwards-compatibility, in case more fields are exposed + * in a future version of the library. + */ + uint8_t padding[100]; +} JxlBasicInfo; + +/** Information for a single extra channel. + */ +typedef struct { + /** Given type of an extra channel. + */ + JxlExtraChannelType type; + + /** Total bits per sample for this channel. + */ + uint32_t bits_per_sample; + + /** Floating point exponent bits per channel, or 0 if they are unsigned + * integer. + */ + uint32_t exponent_bits_per_sample; + + /** The exponent the channel is downsampled by on each axis. + * TODO(lode): expand this comment to match the JPEG XL specification, + * specify how to upscale, how to round the size computation, and to which + * extra channels this field applies. + */ + uint32_t dim_shift; + + /** Length of the extra channel name in bytes, or 0 if no name. + * Excludes null termination character. + */ + uint32_t name_length; + + /** Whether alpha channel uses premultiplied alpha. Only applicable if + * type is JXL_CHANNEL_ALPHA. + */ + JXL_BOOL alpha_premultiplied; + + /** Spot color of the current spot channel in linear RGBA. Only applicable if + * type is JXL_CHANNEL_SPOT_COLOR. + */ + float spot_color[4]; + + /** Only applicable if type is JXL_CHANNEL_CFA. + * TODO(lode): add comment about the meaning of this field. + */ + uint32_t cfa_channel; +} JxlExtraChannelInfo; + +/* TODO(lode): add API to get the codestream header extensions. */ +/** Extensions in the codestream header. */ +typedef struct { + /** Extension bits. */ + uint64_t extensions; +} JxlHeaderExtensions; + +/** Frame blend modes. + * When decoding, if coalescing is enabled (default), this can be ignored. + */ +typedef enum { + JXL_BLEND_REPLACE = 0, + JXL_BLEND_ADD = 1, + JXL_BLEND_BLEND = 2, + JXL_BLEND_MULADD = 3, + JXL_BLEND_MUL = 4, +} JxlBlendMode; + +/** The information about blending the color channels or a single extra channel. + * When decoding, if coalescing is enabled (default), this can be ignored and + * the blend mode is considered to be JXL_BLEND_REPLACE. + * When encoding, these settings apply to the pixel data given to the encoder. + */ +typedef struct { + /** Blend mode. + */ + JxlBlendMode blendmode; + /** Reference frame ID to use as the 'bottom' layer (0-3). + */ + uint32_t source; + /** Which extra channel to use as the 'alpha' channel for blend modes + * JXL_BLEND_BLEND and JXL_BLEND_MULADD. + */ + uint32_t alpha; + /** Clamp values to [0,1] for the purpose of blending. + */ + JXL_BOOL clamp; +} JxlBlendInfo; + +/** The information about layers. + * When decoding, if coalescing is enabled (default), this can be ignored. + * When encoding, these settings apply to the pixel data given to the encoder, + * the encoder could choose an internal representation that differs. + */ +typedef struct { + /** Whether cropping is applied for this frame. When decoding, if false, + * crop_x0 and crop_y0 are set to zero, and xsize and ysize to the main + * image dimensions. When encoding and this is false, those fields are + * ignored. When decoding, if coalescing is enabled (default), this is always + * false, regardless of the internal encoding in the JPEG XL codestream. + */ + JXL_BOOL have_crop; + + /** Horizontal offset of the frame (can be negative). + */ + int32_t crop_x0; + + /** Vertical offset of the frame (can be negative). + */ + int32_t crop_y0; + + /** Width of the frame (number of columns). + */ + uint32_t xsize; + + /** Height of the frame (number of rows). + */ + uint32_t ysize; + + /** The blending info for the color channels. Blending info for extra channels + * has to be retrieved separately using JxlDecoderGetExtraChannelBlendInfo. + */ + JxlBlendInfo blend_info; + + /** After blending, save the frame as reference frame with this ID (0-3). + * Special case: if the frame duration is nonzero, ID 0 means "will not be + * referenced in the future". This value is not used for the last frame. + */ + uint32_t save_as_reference; +} JxlLayerInfo; + +/** The header of one displayed frame or non-coalesced layer. */ +typedef struct { + /** How long to wait after rendering in ticks. The duration in seconds of a + * tick is given by tps_numerator and tps_denominator in JxlAnimationHeader. + */ + uint32_t duration; + + /** SMPTE timecode of the current frame in form 0xHHMMSSFF, or 0. The bits are + * interpreted from most-significant to least-significant as hour, minute, + * second, and frame. If timecode is nonzero, it is strictly larger than that + * of a previous frame with nonzero duration. These values are only available + * if have_timecodes in JxlAnimationHeader is JXL_TRUE. + * This value is only used if have_timecodes in JxlAnimationHeader is + * JXL_TRUE. + */ + uint32_t timecode; + + /** Length of the frame name in bytes, or 0 if no name. + * Excludes null termination character. This value is set by the decoder. + * For the encoder, this value is ignored and @ref JxlEncoderSetFrameName is + * used instead to set the name and the length. + */ + uint32_t name_length; + + /** Indicates this is the last animation frame. This value is set by the + * decoder to indicate no further frames follow. For the encoder, it is not + * required to set this value and it is ignored, @ref JxlEncoderCloseFrames is + * used to indicate the last frame to the encoder instead. + */ + JXL_BOOL is_last; + + /** Information about the layer in case of no coalescing. + */ + JxlLayerInfo layer_info; +} JxlFrameHeader; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_CODESTREAM_HEADER_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/color_encoding.h b/media/libjxl/src/lib/include/jxl/color_encoding.h new file mode 100644 index 000000000..b16f6a01e --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/color_encoding.h @@ -0,0 +1,162 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file color_encoding.h + * @brief Color Encoding definitions used by JPEG XL. + * All CIE units are for the standard 1931 2 degree observer. + */ + +#ifndef JXL_COLOR_ENCODING_H_ +#define JXL_COLOR_ENCODING_H_ + +#include + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Color space of the image data. */ +typedef enum { + /** Tristimulus RGB */ + JXL_COLOR_SPACE_RGB, + /** Luminance based, the primaries in JxlColorEncoding must be ignored. This + * value implies that num_color_channels in JxlBasicInfo is 1, any other value + * implies num_color_channels is 3. */ + JXL_COLOR_SPACE_GRAY, + /** XYB (opsin) color space */ + JXL_COLOR_SPACE_XYB, + /** None of the other table entries describe the color space appropriately */ + JXL_COLOR_SPACE_UNKNOWN, +} JxlColorSpace; + +/** Built-in whitepoints for color encoding. When decoding, the numerical xy + * whitepoint value can be read from the JxlColorEncoding white_point field + * regardless of the enum value. When encoding, enum values except + * JXL_WHITE_POINT_CUSTOM override the numerical fields. Some enum values match + * a subset of CICP (Rec. ITU-T H.273 | ISO/IEC 23091-2:2019(E)), however the + * white point and RGB primaries are separate enums here. + */ +typedef enum { + /** CIE Standard Illuminant D65: 0.3127, 0.3290 */ + JXL_WHITE_POINT_D65 = 1, + /** White point must be read from the JxlColorEncoding white_point field, or + * as ICC profile. This enum value is not an exact match of the corresponding + * CICP value. */ + JXL_WHITE_POINT_CUSTOM = 2, + /** CIE Standard Illuminant E (equal-energy): 1/3, 1/3 */ + JXL_WHITE_POINT_E = 10, + /** DCI-P3 from SMPTE RP 431-2: 0.314, 0.351 */ + JXL_WHITE_POINT_DCI = 11, +} JxlWhitePoint; + +/** Built-in primaries for color encoding. When decoding, the primaries can be + * read from the JxlColorEncoding primaries_red_xy, primaries_green_xy and + * primaries_blue_xy fields regardless of the enum value. When encoding, the + * enum values except JXL_PRIMARIES_CUSTOM override the numerical fields. Some + * enum values match a subset of CICP (Rec. ITU-T H.273 | ISO/IEC + * 23091-2:2019(E)), however the white point and RGB primaries are separate + * enums here. + */ +typedef enum { + /** The CIE xy values of the red, green and blue primaries are: 0.639998686, + 0.330010138; 0.300003784, 0.600003357; 0.150002046, 0.059997204 */ + JXL_PRIMARIES_SRGB = 1, + /** Primaries must be read from the JxlColorEncoding primaries_red_xy, + * primaries_green_xy and primaries_blue_xy fields, or as ICC profile. This + * enum value is not an exact match of the corresponding CICP value. */ + JXL_PRIMARIES_CUSTOM = 2, + /** As specified in Rec. ITU-R BT.2100-1 */ + JXL_PRIMARIES_2100 = 9, + /** As specified in SMPTE RP 431-2 */ + JXL_PRIMARIES_P3 = 11, +} JxlPrimaries; + +/** Built-in transfer functions for color encoding. Enum values match a subset + * of CICP (Rec. ITU-T H.273 | ISO/IEC 23091-2:2019(E)) unless specified + * otherwise. */ +typedef enum { + /** As specified in SMPTE RP 431-2 */ + JXL_TRANSFER_FUNCTION_709 = 1, + /** None of the other table entries describe the transfer function. */ + JXL_TRANSFER_FUNCTION_UNKNOWN = 2, + /** The gamma exponent is 1 */ + JXL_TRANSFER_FUNCTION_LINEAR = 8, + /** As specified in IEC 61966-2-1 sRGB */ + JXL_TRANSFER_FUNCTION_SRGB = 13, + /** As specified in SMPTE ST 2084 */ + JXL_TRANSFER_FUNCTION_PQ = 16, + /** As specified in SMPTE ST 428-1 */ + JXL_TRANSFER_FUNCTION_DCI = 17, + /** As specified in Rec. ITU-R BT.2100-1 (HLG) */ + JXL_TRANSFER_FUNCTION_HLG = 18, + /** Transfer function follows power law given by the gamma value in + JxlColorEncoding. Not a CICP value. */ + JXL_TRANSFER_FUNCTION_GAMMA = 65535, +} JxlTransferFunction; + +/** Renderig intent for color encoding, as specified in ISO 15076-1:2010 */ +typedef enum { + /** vendor-specific */ + JXL_RENDERING_INTENT_PERCEPTUAL = 0, + /** media-relative */ + JXL_RENDERING_INTENT_RELATIVE, + /** vendor-specific */ + JXL_RENDERING_INTENT_SATURATION, + /** ICC-absolute */ + JXL_RENDERING_INTENT_ABSOLUTE, +} JxlRenderingIntent; + +/** Color encoding of the image as structured information. + */ +typedef struct { + /** Color space of the image data. + */ + JxlColorSpace color_space; + + /** Built-in white point. If this value is JXL_WHITE_POINT_CUSTOM, must + * use the numerical whitepoint values from white_point_xy. + */ + JxlWhitePoint white_point; + + /** Numerical whitepoint values in CIE xy space. */ + double white_point_xy[2]; + + /** Built-in RGB primaries. If this value is JXL_PRIMARIES_CUSTOM, must + * use the numerical primaries values below. This field and the custom values + * below are unused and must be ignored if the color space is + * JXL_COLOR_SPACE_GRAY or JXL_COLOR_SPACE_XYB. + */ + JxlPrimaries primaries; + + /** Numerical red primary values in CIE xy space. */ + double primaries_red_xy[2]; + + /** Numerical green primary values in CIE xy space. */ + double primaries_green_xy[2]; + + /** Numerical blue primary values in CIE xy space. */ + double primaries_blue_xy[2]; + + /** Transfer function if have_gamma is 0 */ + JxlTransferFunction transfer_function; + + /** Gamma value used when transfer_function is JXL_TRANSFER_FUNCTION_GAMMA + */ + double gamma; + + /** Rendering intent defined for the color profile. */ + JxlRenderingIntent rendering_intent; +} JxlColorEncoding; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_COLOR_ENCODING_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/decode.h b/media/libjxl/src/lib/include/jxl/decode.h new file mode 100644 index 000000000..66820bfcd --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/decode.h @@ -0,0 +1,1447 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_decoder + * @{ + * @file decode.h + * @brief Decoding API for JPEG XL. + */ + +#ifndef JXL_DECODE_H_ +#define JXL_DECODE_H_ + +#include +#include + +#include "jxl/codestream_header.h" +#include "jxl/color_encoding.h" +#include "jxl/jxl_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" +#include "jxl/types.h" +#include "jxl/version.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * Decoder library version. + * + * @return the decoder library version as an integer: + * MAJOR_VERSION * 1000000 + MINOR_VERSION * 1000 + PATCH_VERSION. For example, + * version 1.2.3 would return 1002003. + */ +JXL_EXPORT uint32_t JxlDecoderVersion(void); + +/** The result of @ref JxlSignatureCheck. + */ +typedef enum { + /** Not enough bytes were passed to determine if a valid signature was found. + */ + JXL_SIG_NOT_ENOUGH_BYTES = 0, + + /** No valid JPEG XL header was found. */ + JXL_SIG_INVALID = 1, + + /** A valid JPEG XL codestream signature was found, that is a JPEG XL image + * without container. + */ + JXL_SIG_CODESTREAM = 2, + + /** A valid container signature was found, that is a JPEG XL image embedded + * in a box format container. + */ + JXL_SIG_CONTAINER = 3, +} JxlSignature; + +/** + * JPEG XL signature identification. + * + * Checks if the passed buffer contains a valid JPEG XL signature. The passed @p + * buf of size + * @p size doesn't need to be a full image, only the beginning of the file. + * + * @return a flag indicating if a JPEG XL signature was found and what type. + * - @ref JXL_SIG_NOT_ENOUGH_BYTES if not enough bytes were passed to + * determine if a valid signature is there. + * - @ref JXL_SIG_INVALID if no valid signature found for JPEG XL decoding. + * - @ref JXL_SIG_CODESTREAM if a valid JPEG XL codestream signature was + * found. + * - @ref JXL_SIG_CONTAINER if a valid JPEG XL container signature was found. + */ +JXL_EXPORT JxlSignature JxlSignatureCheck(const uint8_t* buf, size_t len); + +/** + * Opaque structure that holds the JPEG XL decoder. + * + * Allocated and initialized with @ref JxlDecoderCreate(). + * Cleaned up and deallocated with @ref JxlDecoderDestroy(). + */ +typedef struct JxlDecoderStruct JxlDecoder; + +/** + * Creates an instance of @ref JxlDecoder and initializes it. + * + * @p memory_manager will be used for all the library dynamic allocations made + * from this instance. The parameter may be NULL, in which case the default + * allocator will be used. See jxl/memory_manager.h for details. + * + * @param memory_manager custom allocator function. It may be NULL. The memory + * manager will be copied internally. + * @return @c NULL if the instance can not be allocated or initialized + * @return pointer to initialized @ref JxlDecoder otherwise + */ +JXL_EXPORT JxlDecoder* JxlDecoderCreate(const JxlMemoryManager* memory_manager); + +/** + * Re-initializes a @ref JxlDecoder instance, so it can be re-used for decoding + * another image. All state and settings are reset as if the object was + * newly created with @ref JxlDecoderCreate, but the memory manager is kept. + * + * @param dec instance to be re-initialized. + */ +JXL_EXPORT void JxlDecoderReset(JxlDecoder* dec); + +/** + * Deinitializes and frees @ref JxlDecoder instance. + * + * @param dec instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlDecoderDestroy(JxlDecoder* dec); + +/** + * Return value for @ref JxlDecoderProcessInput. + * The values from @ref JXL_DEC_BASIC_INFO onwards are optional informative + * events that can be subscribed to, they are never returned if they + * have not been registered with @ref JxlDecoderSubscribeEvents. + */ +typedef enum { + /** Function call finished successfully, or decoding is finished and there is + * nothing more to be done. + * + * Note that @ref JxlDecoderProcessInput will return JXL_DEC_SUCCESS if all + * events that were registered with @ref JxlDecoderSubscribeEvents were + * processed, even before the end of the JPEG XL codestream. + * + * In this case, the return value @ref JxlDecoderReleaseInput will be the same + * as it was at the last signaled event. E.g. if JXL_DEC_FULL_IMAGE was + * subscribed to, then all bytes from the end of the JPEG XL codestream + * (including possible boxes needed for jpeg reconstruction) will be returned + * as unprocessed. + */ + JXL_DEC_SUCCESS = 0, + + /** An error occurred, for example invalid input file or out of memory. + * TODO(lode): add function to get error information from decoder. + */ + JXL_DEC_ERROR = 1, + + /** The decoder needs more input bytes to continue. Before the next @ref + * JxlDecoderProcessInput call, more input data must be set, by calling @ref + * JxlDecoderReleaseInput (if input was set previously) and then calling @ref + * JxlDecoderSetInput. @ref JxlDecoderReleaseInput returns how many bytes + * are not yet processed, before a next call to @ref JxlDecoderProcessInput + * all unprocessed bytes must be provided again (the address need not match, + * but the contents must), and more bytes must be concatenated after the + * unprocessed bytes. + * In most cases, @ref JxlDecoderReleaseInput will return no unprocessed bytes + * at this event, the only exceptions are if the previously set input ended + * within (a) the raw codestream signature, (b) the signature box, (c) a box + * header, or (d) the first 4 bytes of a brob, ftyp, or jxlp box. In any of + * these cases the number of unprocessed bytes is less than 20. + */ + JXL_DEC_NEED_MORE_INPUT = 2, + + /** The decoder is able to decode a preview image and requests setting a + * preview output buffer using @ref JxlDecoderSetPreviewOutBuffer. This occurs + * if @ref JXL_DEC_PREVIEW_IMAGE is requested and it is possible to decode a + * preview image from the codestream and the preview out buffer was not yet + * set. There is maximum one preview image in a codestream. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the frame header (including ToC) of the preview frame as + * unprocessed. + */ + JXL_DEC_NEED_PREVIEW_OUT_BUFFER = 3, + + /** The decoder is able to decode a DC image and requests setting a DC output + * buffer using @ref JxlDecoderSetDCOutBuffer. This occurs if @ref + * JXL_DEC_DC_IMAGE is requested and it is possible to decode a DC image from + * the codestream and the DC out buffer was not yet set. This event re-occurs + * for new frames if there are multiple animation frames. + * @deprecated The DC feature in this form will be removed. For progressive + * rendering, @ref JxlDecoderFlushImage should be used. + */ + JXL_DEC_NEED_DC_OUT_BUFFER = 4, + + /** The decoder requests an output buffer to store the full resolution image, + * which can be set with @ref JxlDecoderSetImageOutBuffer or with @ref + * JxlDecoderSetImageOutCallback. This event re-occurs for new frames if + * there are multiple animation frames and requires setting an output again. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the frame header (including ToC) as unprocessed. + */ + JXL_DEC_NEED_IMAGE_OUT_BUFFER = 5, + + /** The JPEG reconstruction buffer is too small for reconstructed JPEG + * codestream to fit. @ref JxlDecoderSetJPEGBuffer must be called again to + * make room for remaining bytes. This event may occur multiple times + * after @ref JXL_DEC_JPEG_RECONSTRUCTION. + */ + JXL_DEC_JPEG_NEED_MORE_OUTPUT = 6, + + /** The box contents output buffer is too small. @ref JxlDecoderSetBoxBuffer + * must be called again to make room for remaining bytes. This event may occur + * multiple times after @ref JXL_DEC_BOX. + */ + JXL_DEC_BOX_NEED_MORE_OUTPUT = 7, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": Basic information such as image dimensions and + * extra channels. This event occurs max once per image. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the basic info as unprocessed (including the last byte of basic info + * if it did not end on a byte boundary). + */ + JXL_DEC_BASIC_INFO = 0x40, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": User extensions of the codestream header. This + * event occurs max once per image and always later than @ref + * JXL_DEC_BASIC_INFO and earlier than any pixel data. + * + * @deprecated The decoder no longer returns this, the header extensions, + * if any, are available at the JXL_DEC_BASIC_INFO event. + */ + JXL_DEC_EXTENSIONS = 0x80, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": Color encoding or ICC profile from the + * codestream header. This event occurs max once per image and always later + * than @ref JXL_DEC_BASIC_INFO and earlier than any pixel data. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the image header (which is the start of the first frame) as + * unprocessed. + */ + JXL_DEC_COLOR_ENCODING = 0x100, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": Preview image, a small frame, decoded. This + * event can only happen if the image has a preview frame encoded. This event + * occurs max once for the codestream and always later than @ref + * JXL_DEC_COLOR_ENCODING and before @ref JXL_DEC_FRAME. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the preview frame as unprocessed. + */ + JXL_DEC_PREVIEW_IMAGE = 0x200, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": Beginning of a frame. @ref + * JxlDecoderGetFrameHeader can be used at this point. A note on frames: + * a JPEG XL image can have internal frames that are not intended to be + * displayed (e.g. used for compositing a final frame), but this only returns + * displayed frames, unless @ref JxlDecoderSetCoalescing was set to JXL_FALSE: + * in that case, the individual layers are returned, without blending. Note + * that even when coalescing is disabled, only frames of type kRegularFrame + * are returned; frames of type kReferenceOnly and kLfFrame are always for + * internal purposes only and cannot be accessed. A displayed frame either has + * an animation duration or is the only or last frame in the image. This event + * occurs max once per displayed frame, always later than @ref + * JXL_DEC_COLOR_ENCODING, and always earlier than any pixel data. While + * JPEG XL supports encoding a single frame as the composition of multiple + * internal sub-frames also called frames, this event is not indicated for the + * internal frames. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the frame header (including ToC) as unprocessed. + */ + JXL_DEC_FRAME = 0x400, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": DC image, 8x8 sub-sampled frame, decoded. It is + * not guaranteed that the decoder will always return DC separately, but when + * it does it will do so before outputting the full frame. @ref + * JxlDecoderSetDCOutBuffer must be used after getting the basic image + * information to be able to get the DC pixels, if not this return status only + * indicates we're past this point in the codestream. This event occurs max + * once per frame and always later than @ref JXL_DEC_FRAME and other header + * events and earlier than full resolution pixel data. + * + * @deprecated The DC feature in this form will be removed. For progressive + * rendering, @ref JxlDecoderFlushImage should be used. + */ + JXL_DEC_DC_IMAGE = 0x800, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": full frame (or layer, in case coalescing is + * disabled) is decoded. @ref JxlDecoderSetImageOutBuffer must be used after + * getting the basic image information to be able to get the image pixels, if + * not this return status only indicates we're past this point in the + * codestream. This event occurs max once per frame and always later than @ref + * JXL_DEC_DC_IMAGE. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the frame (or if @ref JXL_DEC_JPEG_RECONSTRUCTION is subscribed to, + * from the end of the last box that is needed for jpeg reconstruction) as + * unprocessed. + */ + JXL_DEC_FULL_IMAGE = 0x1000, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": JPEG reconstruction data decoded. @ref + * JxlDecoderSetJPEGBuffer may be used to set a JPEG reconstruction buffer + * after getting the JPEG reconstruction data. If a JPEG reconstruction buffer + * is set a byte stream identical to the JPEG codestream used to encode the + * image will be written to the JPEG reconstruction buffer instead of pixels + * to the image out buffer. This event occurs max once per image and always + * before @ref JXL_DEC_FULL_IMAGE. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the 'jbrd' box as unprocessed. + */ + JXL_DEC_JPEG_RECONSTRUCTION = 0x2000, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": The header of a box of the container format + * (BMFF) is decoded. The following API functions related to boxes can be used + * after this event: + * - @ref JxlDecoderSetBoxBuffer and @ref JxlDecoderReleaseBoxBuffer + * "JxlDecoderReleaseBoxBuffer": set and release a buffer to get the box + * data. + * - @ref JxlDecoderGetBoxType get the 4-character box typename. + * - @ref JxlDecoderGetBoxSizeRaw get the size of the box as it appears in + * the container file, not decompressed. + * - @ref JxlDecoderSetDecompressBoxes to configure whether to get the box + * data decompressed, or possibly compressed. + * + * Boxes can be compressed. This is so when their box type is + * "brob". In that case, they have an underlying decompressed box + * type and decompressed data. @ref JxlDecoderSetDecompressBoxes allows + * configuring which data to get. Decompressing requires + * Brotli. @ref JxlDecoderGetBoxType has a flag to get the compressed box + * type, which can be "brob", or the decompressed box type. If a box + * is not compressed (its compressed type is not "brob"), then + * the output decompressed box type and data is independent of what + * setting is configured. + * + * The buffer set with @ref JxlDecoderSetBoxBuffer must be set again for each + * next box to be obtained, or can be left unset to skip outputting this box. + * The output buffer contains the full box data when the next @ref JXL_DEC_BOX + * event or @ref JXL_DEC_SUCCESS occurs. @ref JXL_DEC_BOX occurs for all + * boxes, including non-metadata boxes such as the signature box or codestream + * boxes. To check whether the box is a metadata type for respectively EXIF, + * XMP or JUMBF, use @ref JxlDecoderGetBoxType and check for types "Exif", + * "xml " and "jumb" respectively. + * + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * start of the box header as unprocessed. + */ + JXL_DEC_BOX = 0x4000, + + /** Informative event by @ref JxlDecoderProcessInput + * "JxlDecoderProcessInput": a progressive step in decoding the frame is + * reached. When calling @ref JxlDecoderFlushImage at this point, the flushed + * image will correspond exactly to this point in decoding, and not yet + * contain partial results (such as partially more fine detail) of a next + * step. By default, this event will trigger maximum once per frame, when a + * 8x8th resolution (DC) image is ready (the image data is still returned at + * full resolution, giving upscaled DC). Use @ref + * JxlDecoderSetProgressiveDetail to configure more fine-grainedness. The + * event is not guaranteed to trigger, not all images have progressive steps + * or DC encoded. + * In this case, @ref JxlDecoderReleaseInput will return all bytes from the + * end of the section that was needed to produce this progressive event as + * unprocessed. + */ + JXL_DEC_FRAME_PROGRESSION = 0x8000, +} JxlDecoderStatus; + +/** Rewinds decoder to the beginning. The same input must be given again from + * the beginning of the file and the decoder will emit events from the beginning + * again. When rewinding (as opposed to @ref JxlDecoderReset), the decoder can + * keep state about the image, which it can use to skip to a requested frame + * more efficiently with @ref JxlDecoderSkipFrames. Settings such as parallel + * runner or subscribed events are kept. After rewind, @ref + * JxlDecoderSubscribeEvents can be used again, and it is feasible to leave out + * events that were already handled before, such as @ref JXL_DEC_BASIC_INFO + * and @ref JXL_DEC_COLOR_ENCODING, since they will provide the same information + * as before. + * The difference to @ref JxlDecoderReset is that some state is kept, namely + * settings set by a call to + * - @ref JxlDecoderSetCoalescing, + * - @ref JxlDecoderSetDesiredIntensityTarget, + * - @ref JxlDecoderSetDecompressBoxes, + * - @ref JxlDecoderSetKeepOrientation, + * - @ref JxlDecoderSetUnpremultiplyAlpha, + * - @ref JxlDecoderSetParallelRunner, + * - @ref JxlDecoderSetRenderSpotcolors, and + * - @ref JxlDecoderSubscribeEvents. + * + * @param dec decoder object + */ +JXL_EXPORT void JxlDecoderRewind(JxlDecoder* dec); + +/** Makes the decoder skip the next `amount` frames. It still needs to process + * the input, but will not output the frame events. It can be more efficient + * when skipping frames, and even more so when using this after @ref + * JxlDecoderRewind. If the decoder is already processing a frame (could + * have emitted @ref JXL_DEC_FRAME but not yet @ref JXL_DEC_FULL_IMAGE), it + * starts skipping from the next frame. If the amount is larger than the amount + * of frames remaining in the image, all remaining frames are skipped. Calling + * this function multiple times adds the amount to skip to the already existing + * amount. + * + * A frame here is defined as a frame that without skipping emits events such + * as @ref JXL_DEC_FRAME and @ref JXL_DEC_FULL_IMAGE, frames that are internal + * to the file format but are not rendered as part of an animation, or are not + * the final still frame of a still image, are not counted. + * + * @param dec decoder object + * @param amount the amount of frames to skip + */ +JXL_EXPORT void JxlDecoderSkipFrames(JxlDecoder* dec, size_t amount); + +/** + * Skips processing the current frame. Can be called after frame processing + * already started, signaled by a @ref JXL_DEC_NEED_IMAGE_OUT_BUFFER event, + * but before the corrsponding @ref JXL_DEC_FULL_IMAGE event. The next signaled + * event will be another @ref JXL_DEC_FRAME, or @ref JXL_DEC_SUCCESS if there + * are no more frames. If pixel data is required from the already processed part + * of the frame, @ref JxlDecoderFlushImage must be called before this. + * + * @param dec decoder object + * @return @ref JXL_DEC_SUCCESS if there is a frame to skip, and @ref + * JXL_DEC_ERROR if the function was not called during frame processing. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSkipCurrentFrame(JxlDecoder* dec); + +/** + * Get the default pixel format for this decoder. + * + * Requires that the decoder can produce JxlBasicInfo. + * + * @param dec @ref JxlDecoder to query when creating the recommended pixel + * format. + * @param format JxlPixelFormat to populate with the recommended settings for + * the data loaded into this decoder. + * @return @ref JXL_DEC_SUCCESS if no error, @ref JXL_DEC_NEED_MORE_INPUT if the + * basic info isn't yet available, and @ref JXL_DEC_ERROR otherwise. + * + * DEPRECATED: this function will be removed in the future. + */ +JXL_EXPORT JXL_DEPRECATED JxlDecoderStatus +JxlDecoderDefaultPixelFormat(const JxlDecoder* dec, JxlPixelFormat* format); + +/** + * Set the parallel runner for multithreading. May only be set before starting + * decoding. + * + * @param dec decoder object + * @param parallel_runner function pointer to runner for multithreading. It may + * be NULL to use the default, single-threaded, runner. A multithreaded + * runner should be set to reach fast performance. + * @param parallel_runner_opaque opaque pointer for parallel_runner. + * @return @ref JXL_DEC_SUCCESS if the runner was set, @ref JXL_DEC_ERROR + * otherwise (the previous runner remains set). + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetParallelRunner(JxlDecoder* dec, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque); + +/** + * Returns a hint indicating how many more bytes the decoder is expected to + * need to make @ref JxlDecoderGetBasicInfo available after the next @ref + * JxlDecoderProcessInput call. This is a suggested large enough value for + * the amount of bytes to provide in the next @ref JxlDecoderSetInput call, but + * it is not guaranteed to be an upper bound nor a lower bound. This number does + * not include bytes that have already been released from the input. Can be used + * before the first @ref JxlDecoderProcessInput call, and is correct the first + * time in most cases. If not, @ref JxlDecoderSizeHintBasicInfo can be called + * again to get an updated hint. + * + * @param dec decoder object + * @return the size hint in bytes if the basic info is not yet fully decoded. + * @return 0 when the basic info is already available. + */ +JXL_EXPORT size_t JxlDecoderSizeHintBasicInfo(const JxlDecoder* dec); + +/** Select for which informative events, i.e. @ref JXL_DEC_BASIC_INFO, etc., the + * decoder should return with a status. It is not required to subscribe to any + * events, data can still be requested from the decoder as soon as it available. + * By default, the decoder is subscribed to no events (events_wanted == 0), and + * the decoder will then only return when it cannot continue because it needs + * more input data or more output buffer. This function may only be be called + * before using @ref JxlDecoderProcessInput. + * + * @param dec decoder object + * @param events_wanted bitfield of desired events. + * @return @ref JXL_DEC_SUCCESS if no error, @ref JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSubscribeEvents(JxlDecoder* dec, + int events_wanted); + +/** Enables or disables preserving of as-in-bitstream pixeldata + * orientation. Some images are encoded with an Orientation tag + * indicating that the decoder must perform a rotation and/or + * mirroring to the encoded image data. + * + * - If skip_reorientation is JXL_FALSE (the default): the decoder + * will apply the transformation from the orientation setting, hence + * rendering the image according to its specified intent. When + * producing a JxlBasicInfo, the decoder will always set the + * orientation field to JXL_ORIENT_IDENTITY (matching the returned + * pixel data) and also align xsize and ysize so that they correspond + * to the width and the height of the returned pixel data. + * - If skip_reorientation is JXL_TRUE: the decoder will skip + * applying the transformation from the orientation setting, returning + * the image in the as-in-bitstream pixeldata orientation. + * This may be faster to decode since the decoder doesn't have to apply the + * transformation, but can cause wrong display of the image if the + * orientation tag is not correctly taken into account by the user. + * + * By default, this option is disabled, and the returned pixel data is + * re-oriented according to the image's Orientation setting. + * + * This function must be called at the beginning, before decoding is performed. + * + * @see JxlBasicInfo for the orientation field, and @ref JxlOrientation for the + * possible values. + * + * @param dec decoder object + * @param skip_reorientation JXL_TRUE to enable, JXL_FALSE to disable. + * @return @ref JXL_DEC_SUCCESS if no error, @ref JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetKeepOrientation(JxlDecoder* dec, JXL_BOOL skip_reorientation); + +/** + * Enables or disables preserving of associated alpha channels. If + * unpremul_alpha is set to JXL_FALSE then for associated alpha channel, the + * pixel data is returned with premultiplied colors. If it is set to JXL_TRUE, + * The colors will be unpremultiplied based on the alpha channel. This function + * has no effect if the image does not have an associated alpha channel. + * + * By default, this option is disabled, and the returned pixel data "as is". + * + * This function must be called at the beginning, before decoding is performed. + * + * @param dec decoder object + * @param unpremul_alpha JXL_TRUE to enable, JXL_FALSE to disable. + * @return @ref JXL_DEC_SUCCESS if no error, @ref JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetUnpremultiplyAlpha(JxlDecoder* dec, JXL_BOOL unpremul_alpha); + +/** Enables or disables rendering spot colors. By default, spot colors + * are rendered, which is OK for viewing the decoded image. If render_spotcolors + * is JXL_FALSE, then spot colors are not rendered, and have to be retrieved + * separately using @ref JxlDecoderSetExtraChannelBuffer. This is useful for + * e.g. printing applications. + * + * @param dec decoder object + * @param render_spotcolors JXL_TRUE to enable (default), JXL_FALSE to disable. + * @return @ref JXL_DEC_SUCCESS if no error, @ref JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetRenderSpotcolors(JxlDecoder* dec, JXL_BOOL render_spotcolors); + +/** Enables or disables coalescing of zero-duration frames. By default, frames + * are returned with coalescing enabled, i.e. all frames have the image + * dimensions, and are blended if needed. When coalescing is disabled, frames + * can have arbitrary dimensions, a non-zero crop offset, and blending is not + * performed. For display, coalescing is recommended. For loading a multi-layer + * still image as separate layers (as opposed to the merged image), coalescing + * has to be disabled. + * + * @param dec decoder object + * @param coalescing JXL_TRUE to enable coalescing (default), JXL_FALSE to + * disable it. + * @return @ref JXL_DEC_SUCCESS if no error, @ref JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetCoalescing(JxlDecoder* dec, + JXL_BOOL coalescing); + +/** + * Decodes JPEG XL file using the available bytes. Requires input has been + * set with @ref JxlDecoderSetInput. After @ref JxlDecoderProcessInput, input + * can optionally be released with @ref JxlDecoderReleaseInput and then set + * again to next bytes in the stream. @ref JxlDecoderReleaseInput returns how + * many bytes are not yet processed, before a next call to @ref + * JxlDecoderProcessInput all unprocessed bytes must be provided again (the + * address need not match, but the contents must), and more bytes may be + * concatenated after the unprocessed bytes. + * + * The returned status indicates whether the decoder needs more input bytes, or + * more output buffer for a certain type of output data. No matter what the + * returned status is (other than @ref JXL_DEC_ERROR), new information, such + * as @ref JxlDecoderGetBasicInfo, may have become available after this call. + * When the return value is not @ref JXL_DEC_ERROR or @ref JXL_DEC_SUCCESS, the + * decoding requires more @ref JxlDecoderProcessInput calls to continue. + * + * @param dec decoder object + * @return @ref JXL_DEC_SUCCESS when decoding finished and all events handled. + * If you still have more unprocessed input data anyway, then you can still + * continue by using @ref JxlDecoderSetInput and calling @ref + * JxlDecoderProcessInput again, similar to handling @ref + * JXL_DEC_NEED_MORE_INPUT. @ref JXL_DEC_SUCCESS can occur instead of @ref + * JXL_DEC_NEED_MORE_INPUT when, for example, the input data ended right at + * the boundary of a box of the container format, all essential codestream + * boxes were already decoded, but extra metadata boxes are still present in + * the next data. @ref JxlDecoderProcessInput cannot return success if all + * codestream boxes have not been seen yet. + * @return @ref JXL_DEC_ERROR when decoding failed, e.g. invalid codestream. + * TODO(lode): document the input data mechanism + * @return @ref JXL_DEC_NEED_MORE_INPUT when more input data is necessary. + * @return @ref JXL_DEC_BASIC_INFO when basic info such as image dimensions is + * available and this informative event is subscribed to. + * @return @ref JXL_DEC_COLOR_ENCODING when color profile information is + * available and this informative event is subscribed to. + * @return @ref JXL_DEC_PREVIEW_IMAGE when preview pixel information is + * available and output in the preview buffer. + * @return @ref JXL_DEC_DC_IMAGE when DC pixel information (8x8 downscaled + * version of the image) is available and output is in the DC buffer. + * @return @ref JXL_DEC_FULL_IMAGE when all pixel information at highest detail + * is available and has been output in the pixel buffer. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderProcessInput(JxlDecoder* dec); + +/** + * Sets input data for @ref JxlDecoderProcessInput. The data is owned by the + * caller and may be used by the decoder until @ref JxlDecoderReleaseInput is + * called or the decoder is destroyed or reset so must be kept alive until then. + * Cannot be called if @ref JxlDecoderSetInput was already called and @ref + * JxlDecoderReleaseInput was not yet called, and cannot be called after @ref + * JxlDecoderCloseInput indicating the end of input was called. + * + * @param dec decoder object + * @param data pointer to next bytes to read from + * @param size amount of bytes available starting from data + * @return @ref JXL_DEC_ERROR if input was already set without releasing or @ref + * JxlDecoderCloseInput was already called, @ref JXL_DEC_SUCCESS otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetInput(JxlDecoder* dec, + const uint8_t* data, + size_t size); + +/** + * Releases input which was provided with @ref JxlDecoderSetInput. Between @ref + * JxlDecoderProcessInput and @ref JxlDecoderReleaseInput, the user may not + * alter the data in the buffer. Calling @ref JxlDecoderReleaseInput is required + * whenever any input is already set and new input needs to be added with @ref + * JxlDecoderSetInput, but is not required before @ref JxlDecoderDestroy or @ref + * JxlDecoderReset. Calling @ref JxlDecoderReleaseInput when no input is set is + * not an error and returns 0. + * + * @param dec decoder object + * @return The amount of bytes the decoder has not yet processed that are still + * remaining in the data set by @ref JxlDecoderSetInput, or 0 if no input is + * set or @ref JxlDecoderReleaseInput was already called. For a next call + * to @ref JxlDecoderProcessInput, the buffer must start with these + * unprocessed bytes. From this value it is possible to infer the position + * of certain JPEG XL codestream elements (e.g. end of headers, frame + * start/end). See the documentation of individual values of @ref + * JxlDecoderStatus for more information. + */ +JXL_EXPORT size_t JxlDecoderReleaseInput(JxlDecoder* dec); + +/** + * Marks the input as finished, indicates that no more @ref JxlDecoderSetInput + * will be called. This function allows the decoder to determine correctly if it + * should return success, need more input or error in certain cases. For + * backwards compatibility with a previous version of the API, using this + * function is optional when not using the @ref JXL_DEC_BOX event (the decoder + * is able to determine the end of the image frames without marking the end), + * but using this function is required when using @ref JXL_DEC_BOX for getting + * metadata box contents. This function does not replace @ref + * JxlDecoderReleaseInput, that function should still be called if its return + * value is needed. + * + * @ref JxlDecoderCloseInput should be called as soon as all known input bytes + * are set (e.g. at the beginning when not streaming but setting all input + * at once), before the final @ref JxlDecoderProcessInput calls. + * + * @param dec decoder object + */ +JXL_EXPORT void JxlDecoderCloseInput(JxlDecoder* dec); + +/** + * Outputs the basic image information, such as image dimensions, bit depth and + * all other JxlBasicInfo fields, if available. + * + * @param dec decoder object + * @param info struct to copy the information into, or NULL to only check + * whether the information is available through the return value. + * @return @ref JXL_DEC_SUCCESS if the value is available, @ref + * JXL_DEC_NEED_MORE_INPUT if not yet available, @ref JXL_DEC_ERROR + * in case of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetBasicInfo(const JxlDecoder* dec, + JxlBasicInfo* info); + +/** + * Outputs information for extra channel at the given index. The index must be + * smaller than num_extra_channels in the associated JxlBasicInfo. + * + * @param dec decoder object + * @param index index of the extra channel to query. + * @param info struct to copy the information into, or NULL to only check + * whether the information is available through the return value. + * @return @ref JXL_DEC_SUCCESS if the value is available, @ref + * JXL_DEC_NEED_MORE_INPUT if not yet available, @ref JXL_DEC_ERROR + * in case of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetExtraChannelInfo( + const JxlDecoder* dec, size_t index, JxlExtraChannelInfo* info); + +/** + * Outputs name for extra channel at the given index in UTF-8. The index must be + * smaller than num_extra_channels in the associated JxlBasicInfo. The buffer + * for name must have at least name_length + 1 bytes allocated, gotten from + * the associated JxlExtraChannelInfo. + * + * @param dec decoder object + * @param index index of the extra channel to query. + * @param name buffer to copy the name into + * @param size size of the name buffer in bytes + * @return @ref JXL_DEC_SUCCESS if the value is available, @ref + * JXL_DEC_NEED_MORE_INPUT if not yet available, @ref JXL_DEC_ERROR + * in case of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetExtraChannelName(const JxlDecoder* dec, + size_t index, + char* name, + size_t size); + +/** Defines which color profile to get: the profile from the codestream + * metadata header, which represents the color profile of the original image, + * or the color profile from the pixel data produced by the decoder. Both are + * the same if the JxlBasicInfo has uses_original_profile set. + */ +typedef enum { + /** Get the color profile of the original image from the metadata. + */ + JXL_COLOR_PROFILE_TARGET_ORIGINAL = 0, + + /** Get the color profile of the pixel data the decoder outputs. */ + JXL_COLOR_PROFILE_TARGET_DATA = 1, +} JxlColorProfileTarget; + +/** + * Outputs the color profile as JPEG XL encoded structured data, if available. + * This is an alternative to an ICC Profile, which can represent a more limited + * amount of color spaces, but represents them exactly through enum values. + * + * It is often possible to use @ref JxlDecoderGetColorAsICCProfile as an + * alternative anyway. The following scenarios are possible: + * - The JPEG XL image has an attached ICC Profile, in that case, the encoded + * structured data is not available, this function will return an error + * status. @ref JxlDecoderGetColorAsICCProfile should be called instead. + * - The JPEG XL image has an encoded structured color profile, and it + * represents an RGB or grayscale color space. This function will return it. + * You can still use @ref JxlDecoderGetColorAsICCProfile as well as an + * alternative if desired, though depending on which RGB color space is + * represented, the ICC profile may be a close approximation. It is also not + * always feasible to deduce from an ICC profile which named color space it + * exactly represents, if any, as it can represent any arbitrary space. + * - The JPEG XL image has an encoded structured color profile, and it + * indicates an unknown or xyb color space. In that case, @ref + * JxlDecoderGetColorAsICCProfile is not available. + * + * When rendering an image on a system that supports ICC profiles, @ref + * JxlDecoderGetColorAsICCProfile should be used first. When rendering + * for a specific color space, possibly indicated in the JPEG XL + * image, @ref JxlDecoderGetColorAsEncodedProfile should be used first. + * + * @param dec decoder object + * @param unused_format deprecated, can be NULL + * @param target whether to get the original color profile from the metadata + * or the color profile of the decoded pixels. + * @param color_encoding struct to copy the information into, or NULL to only + * check whether the information is available through the return value. + * @return @ref JXL_DEC_SUCCESS if the data is available and returned, @ref + * JXL_DEC_NEED_MORE_INPUT if not yet available, @ref JXL_DEC_ERROR in + * case the encoded structured color profile does not exist in the + * codestream. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetColorAsEncodedProfile( + const JxlDecoder* dec, const JxlPixelFormat* unused_format, + JxlColorProfileTarget target, JxlColorEncoding* color_encoding); + +/** + * Outputs the size in bytes of the ICC profile returned by @ref + * JxlDecoderGetColorAsICCProfile, if available, or indicates there is none + * available. In most cases, the image will have an ICC profile available, but + * if it does not, @ref JxlDecoderGetColorAsEncodedProfile must be used instead. + * + * @see JxlDecoderGetColorAsEncodedProfile for more information. The ICC + * profile is either the exact ICC profile attached to the codestream metadata, + * or a close approximation generated from JPEG XL encoded structured data, + * depending of what is encoded in the codestream. + * + * @param dec decoder object + * @param unused_format deprecated, can be NULL + * @param target whether to get the original color profile from the metadata + * or the color profile of the decoded pixels. + * @param size variable to output the size into, or NULL to only check the + * return status. + * @return @ref JXL_DEC_SUCCESS if the ICC profile is available, @ref + * JXL_DEC_NEED_MORE_INPUT if the decoder has not yet received enough + * input data to determine whether an ICC profile is available or what its + * size is, @ref JXL_DEC_ERROR in case the ICC profile is not available and + * cannot be generated. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetICCProfileSize( + const JxlDecoder* dec, const JxlPixelFormat* unused_format, + JxlColorProfileTarget target, size_t* size); + +/** + * Outputs ICC profile if available. The profile is only available if @ref + * JxlDecoderGetICCProfileSize returns success. The output buffer must have + * at least as many bytes as given by @ref JxlDecoderGetICCProfileSize. + * + * @param dec decoder object + * @param unused_format deprecated, can be NULL + * @param target whether to get the original color profile from the metadata + * or the color profile of the decoded pixels. + * @param icc_profile buffer to copy the ICC profile into + * @param size size of the icc_profile buffer in bytes + * @return @ref JXL_DEC_SUCCESS if the profile was successfully returned is + * available, @ref JXL_DEC_NEED_MORE_INPUT if not yet available, @ref + * JXL_DEC_ERROR if the profile doesn't exist or the output size is not + * large enough. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetColorAsICCProfile( + const JxlDecoder* dec, const JxlPixelFormat* unused_format, + JxlColorProfileTarget target, uint8_t* icc_profile, size_t size); + +/** Sets the color profile to use for @ref JXL_COLOR_PROFILE_TARGET_DATA for the + * special case when the decoder has a choice. This only has effect for a JXL + * image where uses_original_profile is false. If uses_original_profile is true, + * this setting is ignored and the decoder uses a profile related to the image. + * No matter what, the @ref JXL_COLOR_PROFILE_TARGET_DATA must still be queried + * to know the actual data format of the decoded pixels after decoding. + * + * The JXL decoder has no color management system built in, but can convert XYB + * color to any of the ones supported by JxlColorEncoding. Note that if the + * requested color encoding has a narrower gamut, or the white points differ, + * then the resulting image can have significant color distortion. + * + * Can only be set after the @ref JXL_DEC_COLOR_ENCODING event occurred and + * before any other event occurred, and can affect the result of @ref + * JXL_COLOR_PROFILE_TARGET_DATA (but not of @ref + * JXL_COLOR_PROFILE_TARGET_ORIGINAL), so should be used after getting @ref + * JXL_COLOR_PROFILE_TARGET_ORIGINAL but before getting @ref + * JXL_COLOR_PROFILE_TARGET_DATA. The color_encoding must be grayscale if + * num_color_channels from the basic info is 1, RGB if num_color_channels from + * the basic info is 3. + * + * If @ref JxlDecoderSetPreferredColorProfile is not used, then for images for + * which uses_original_profile is false and with ICC color profile, the decoder + * will choose linear sRGB for color images, linear grayscale for grayscale + * images. This function only sets a preference, since for other images the + * decoder has no choice what color profile to use, it is determined by the + * image. + * + * @param dec decoder object + * @param color_encoding the default color encoding to set + * @return @ref JXL_DEC_SUCCESS if the preference was set successfully, @ref + * JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetPreferredColorProfile( + JxlDecoder* dec, const JxlColorEncoding* color_encoding); + +/** Requests that the decoder perform tone mapping to the peak display luminance + * passed as @c desired_intensity_target, if appropriate. + * @note This is provided for convenience and the exact tone mapping that is + * performed is not meant to be considered authoritative in any way. It may + * change from version to version. + * @param dec decoder object + * @param desired_intensity_target the intended target peak luminance + * @return @ref JXL_DEC_SUCCESS if the preference was set successfully, @ref + * JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetDesiredIntensityTarget( + JxlDecoder* dec, float desired_intensity_target); + +/** + * Returns the minimum size in bytes of the preview image output pixel buffer + * for the given format. This is the buffer for @ref + * JxlDecoderSetPreviewOutBuffer. Requires the preview header information is + * available in the decoder. + * + * @param dec decoder object + * @param format format of pixels + * @param size output value, buffer size in bytes + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * information not available yet. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderPreviewOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size); + +/** + * Sets the buffer to write the small resolution preview image + * to. The size of the buffer must be at least as large as given by @ref + * JxlDecoderPreviewOutBufferSize. The buffer follows the format described + * by JxlPixelFormat. The preview image dimensions are given by the + * JxlPreviewHeader. The buffer is owned by the caller. + * + * @param dec decoder object + * @param format format of pixels. Object owned by user and its contents are + * copied internally. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * size too small. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetPreviewOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size); + +/** + * Outputs the information from the frame, such as duration when have_animation. + * This function can be called when @ref JXL_DEC_FRAME occurred for the current + * frame, even when have_animation in the JxlBasicInfo is JXL_FALSE. + * + * @param dec decoder object + * @param header struct to copy the information into, or NULL to only check + * whether the information is available through the return value. + * @return @ref JXL_DEC_SUCCESS if the value is available, @ref + * JXL_DEC_NEED_MORE_INPUT if not yet available, @ref JXL_DEC_ERROR in + * case of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetFrameHeader(const JxlDecoder* dec, + JxlFrameHeader* header); + +/** + * Outputs name for the current frame. The buffer for name must have at least + * name_length + 1 bytes allocated, gotten from the associated JxlFrameHeader. + * + * @param dec decoder object + * @param name buffer to copy the name into + * @param size size of the name buffer in bytes, including zero termination + * character, so this must be at least JxlFrameHeader.name_length + 1. + * @return @ref JXL_DEC_SUCCESS if the value is available, @ref + * JXL_DEC_NEED_MORE_INPUT if not yet available, @ref JXL_DEC_ERROR in + * case of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetFrameName(const JxlDecoder* dec, + char* name, size_t size); + +/** + * Outputs the blend information for the current frame for a specific extra + * channel. This function can be called when @ref JXL_DEC_FRAME occurred for the + * current frame, even when have_animation in the JxlBasicInfo is JXL_FALSE. + * This information is only useful if coalescing is disabled; otherwise the + * decoder will have performed blending already. + * + * @param dec decoder object + * @param index the index of the extra channel + * @param blend_info struct to copy the information into + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetExtraChannelBlendInfo( + const JxlDecoder* dec, size_t index, JxlBlendInfo* blend_info); + +/** + * Returns the minimum size in bytes of the DC image output buffer + * for the given format. This is the buffer for @ref JxlDecoderSetDCOutBuffer. + * Requires the basic image information is available in the decoder. + * + * @param dec decoder object + * @param format format of pixels + * @param size output value, buffer size in bytes + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * information not available yet. + * + * @deprecated The DC feature in this form will be removed. Use @ref + * JxlDecoderFlushImage for progressive rendering. + */ +JXL_EXPORT JXL_DEPRECATED JxlDecoderStatus JxlDecoderDCOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size); + +/** + * Sets the buffer to write the lower resolution (8x8 sub-sampled) DC image + * to. The size of the buffer must be at least as large as given by @ref + * JxlDecoderDCOutBufferSize. The buffer follows the format described by + * JxlPixelFormat. The DC image has dimensions ceil(xsize / 8) * ceil(ysize / + * 8). The buffer is owned by the caller. + * + * @param dec decoder object + * @param format format of pixels. Object owned by user and its contents are + * copied internally. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * size too small. + * + * @deprecated The DC feature in this form will be removed. Use @ref + * JxlDecoderFlushImage for progressive rendering. + */ +JXL_EXPORT JXL_DEPRECATED JxlDecoderStatus JxlDecoderSetDCOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size); + +/** + * Returns the minimum size in bytes of the image output pixel buffer for the + * given format. This is the buffer for @ref JxlDecoderSetImageOutBuffer. + * Requires that the basic image information is available in the decoder in the + * case of coalescing enabled (default). In case coalescing is disabled, this + * can only be called after the @ref JXL_DEC_FRAME event occurs. In that case, + * it will return the size required to store the possibly cropped frame (which + * can be larger or smaller than the image dimensions). + * + * @param dec decoder object + * @param format format of the pixels. + * @param size output value, buffer size in bytes + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * information not available yet. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderImageOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size); + +/** + * Sets the buffer to write the full resolution image to. This can be set when + * the @ref JXL_DEC_FRAME event occurs, must be set when the @ref + * JXL_DEC_NEED_IMAGE_OUT_BUFFER event occurs, and applies only for the + * current frame. The size of the buffer must be at least as large as given + * by @ref JxlDecoderImageOutBufferSize. The buffer follows the format described + * by JxlPixelFormat. The buffer is owned by the caller. + * + * @param dec decoder object + * @param format format of the pixels. Object owned by user and its contents + * are copied internally. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * size too small. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetImageOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size); + +/** + * Function type for @ref JxlDecoderSetImageOutCallback. + * + * The callback may be called simultaneously by different threads when using a + * threaded parallel runner, on different pixels. + * + * @param opaque optional user data, as given to @ref + * JxlDecoderSetImageOutCallback. + * @param x horizontal position of leftmost pixel of the pixel data. + * @param y vertical position of the pixel data. + * @param num_pixels amount of pixels included in the pixel data, horizontally. + * This is not the same as xsize of the full image, it may be smaller. + * @param pixels pixel data as a horizontal stripe, in the format passed to @ref + * JxlDecoderSetImageOutCallback. The memory is not owned by the user, and + * is only valid during the time the callback is running. + */ +typedef void (*JxlImageOutCallback)(void* opaque, size_t x, size_t y, + size_t num_pixels, const void* pixels); + +/** + * Initialization callback for @ref JxlDecoderSetMultithreadedImageOutCallback. + * + * @param init_opaque optional user data, as given to @ref + * JxlDecoderSetMultithreadedImageOutCallback. + * @param num_threads maximum number of threads that will call the @c run + * callback concurrently. + * @param num_pixels_per_thread maximum number of pixels that will be passed in + * one call to @c run. + * @return a pointer to data that will be passed to the @c run callback, or + * @c NULL if initialization failed. + */ +typedef void* (*JxlImageOutInitCallback)(void* init_opaque, size_t num_threads, + size_t num_pixels_per_thread); + +/** + * Worker callback for @ref JxlDecoderSetMultithreadedImageOutCallback. + * + * @param run_opaque user data returned by the @c init callback. + * @param thread_id number in `[0, num_threads)` identifying the thread of the + * current invocation of the callback. + * @param x horizontal position of the first (leftmost) pixel of the pixel data. + * @param y vertical position of the pixel data. + * @param num_pixels number of pixels in the pixel data. May be less than the + * full @c xsize of the image, and will be at most equal to the @c + * num_pixels_per_thread that was passed to @c init. + * @param pixels pixel data as a horizontal stripe, in the format passed to @ref + * JxlDecoderSetMultithreadedImageOutCallback. The data pointed to + * remains owned by the caller and is only guaranteed to outlive the current + * callback invocation. + */ +typedef void (*JxlImageOutRunCallback)(void* run_opaque, size_t thread_id, + size_t x, size_t y, size_t num_pixels, + const void* pixels); + +/** + * Destruction callback for @ref JxlDecoderSetMultithreadedImageOutCallback, + * called after all invocations of the @c run callback to perform any + * appropriate clean-up of the @c run_opaque data returned by @c init. + * + * @param run_opaque user data returned by the @c init callback. + */ +typedef void (*JxlImageOutDestroyCallback)(void* run_opaque); + +/** + * Sets pixel output callback. This is an alternative to @ref + * JxlDecoderSetImageOutBuffer. This can be set when the @ref JXL_DEC_FRAME + * event occurs, must be set when the @ref JXL_DEC_NEED_IMAGE_OUT_BUFFER event + * occurs, and applies only for the current frame. Only one of @ref + * JxlDecoderSetImageOutBuffer or @ref JxlDecoderSetImageOutCallback may be used + * for the same frame, not both at the same time. + * + * The callback will be called multiple times, to receive the image + * data in small chunks. The callback receives a horizontal stripe of pixel + * data, 1 pixel high, xsize pixels wide, called a scanline. The xsize here is + * not the same as the full image width, the scanline may be a partial section, + * and xsize may differ between calls. The user can then process and/or copy the + * partial scanline to an image buffer. The callback may be called + * simultaneously by different threads when using a threaded parallel runner, on + * different pixels. + * + * If @ref JxlDecoderFlushImage is not used, then each pixel will be visited + * exactly once by the different callback calls, during processing with one or + * more @ref JxlDecoderProcessInput calls. These pixels are decoded to full + * detail, they are not part of a lower resolution or lower quality progressive + * pass, but the final pass. + * + * If @ref JxlDecoderFlushImage is used, then in addition each pixel will be + * visited zero or one times during the blocking @ref JxlDecoderFlushImage call. + * Pixels visited as a result of @ref JxlDecoderFlushImage may represent a lower + * resolution or lower quality intermediate progressive pass of the image. Any + * visited pixel will be of a quality at least as good or better than previous + * visits of this pixel. A pixel may be visited zero times if it cannot be + * decoded yet or if it was already decoded to full precision (this behavior is + * not guaranteed). + * + * @param dec decoder object + * @param format format of the pixels. Object owned by user; its contents are + * copied internally. + * @param callback the callback function receiving partial scanlines of pixel + * data. + * @param opaque optional user data, which will be passed on to the callback, + * may be NULL. + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such + * as @ref JxlDecoderSetImageOutBuffer already set. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetImageOutCallback(JxlDecoder* dec, const JxlPixelFormat* format, + JxlImageOutCallback callback, void* opaque); + +/** Similar to @ref JxlDecoderSetImageOutCallback except that the callback is + * allowed an initialization phase during which it is informed of how many + * threads will call it concurrently, and those calls are further informed of + * which thread they are occurring in. + * + * @param dec decoder object + * @param format format of the pixels. Object owned by user; its contents are + * copied internally. + * @param init_callback initialization callback. + * @param run_callback the callback function receiving partial scanlines of + * pixel data. + * @param destroy_callback clean-up callback invoked after all calls to @c + * run_callback. May be NULL if no clean-up is necessary. + * @param init_opaque optional user data passed to @c init_callback, may be NULL + * (unlike the return value from @c init_callback which may only be NULL if + * initialization failed). + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such + * as @ref JxlDecoderSetImageOutBuffer having already been called. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetMultithreadedImageOutCallback( + JxlDecoder* dec, const JxlPixelFormat* format, + JxlImageOutInitCallback init_callback, JxlImageOutRunCallback run_callback, + JxlImageOutDestroyCallback destroy_callback, void* init_opaque); + +/** + * Returns the minimum size in bytes of an extra channel pixel buffer for the + * given format. This is the buffer for @ref JxlDecoderSetExtraChannelBuffer. + * Requires the basic image information is available in the decoder. + * + * @param dec decoder object + * @param format format of the pixels. The num_channels value is ignored and is + * always treated to be 1. + * @param size output value, buffer size in bytes + * @param index which extra channel to get, matching the index used in @ref + * JxlDecoderGetExtraChannelInfo. Must be smaller than num_extra_channels in + * the associated JxlBasicInfo. + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * information not available yet or invalid index. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderExtraChannelBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size, + uint32_t index); + +/** + * Sets the buffer to write an extra channel to. This can be set when + * the @ref JXL_DEC_FRAME or @ref JXL_DEC_NEED_IMAGE_OUT_BUFFER event occurs, + * and applies only for the current frame. The size of the buffer must be at + * least as large as given by @ref JxlDecoderExtraChannelBufferSize. The buffer + * follows the format described by JxlPixelFormat, but where num_channels is 1. + * The buffer is owned by the caller. The amount of extra channels is given by + * the num_extra_channels field in the associated JxlBasicInfo, and the + * information of individual extra channels can be queried with @ref + * JxlDecoderGetExtraChannelInfo. To get multiple extra channels, this function + * must be called multiple times, once for each wanted index. Not all images + * have extra channels. The alpha channel is an extra channel and can be gotten + * as part of the color channels when using an RGBA pixel buffer with @ref + * JxlDecoderSetImageOutBuffer, but additionally also can be gotten + * separately as extra channel. The color channels themselves cannot be gotten + * this way. + * + * + * @param dec decoder object + * @param format format of the pixels. Object owned by user and its contents + * are copied internally. The num_channels value is ignored and is always + * treated to be 1. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @param index which extra channel to get, matching the index used in @ref + * JxlDecoderGetExtraChannelInfo. Must be smaller than num_extra_channels in + * the associated JxlBasicInfo. + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * size too small or invalid index. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetExtraChannelBuffer(JxlDecoder* dec, const JxlPixelFormat* format, + void* buffer, size_t size, uint32_t index); + +/** + * Sets output buffer for reconstructed JPEG codestream. + * + * The data is owned by the caller and may be used by the decoder until @ref + * JxlDecoderReleaseJPEGBuffer is called or the decoder is destroyed or + * reset so must be kept alive until then. + * + * If a JPEG buffer was set before and released with @ref + * JxlDecoderReleaseJPEGBuffer, bytes that the decoder has already output + * should not be included, only the remaining bytes output must be set. + * + * @param dec decoder object + * @param data pointer to next bytes to write to + * @param size amount of bytes available starting from data + * @return @ref JXL_DEC_ERROR if output buffer was already set and @ref + * JxlDecoderReleaseJPEGBuffer was not called on it, @ref JXL_DEC_SUCCESS + * otherwise + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetJPEGBuffer(JxlDecoder* dec, + uint8_t* data, size_t size); + +/** + * Releases buffer which was provided with @ref JxlDecoderSetJPEGBuffer. + * + * Calling @ref JxlDecoderReleaseJPEGBuffer is required whenever + * a buffer is already set and a new buffer needs to be added with @ref + * JxlDecoderSetJPEGBuffer, but is not required before @ref + * JxlDecoderDestroy or @ref JxlDecoderReset. + * + * Calling @ref JxlDecoderReleaseJPEGBuffer when no buffer is set is + * not an error and returns 0. + * + * @param dec decoder object + * @return the amount of bytes the decoder has not yet written to of the data + * set by @ref JxlDecoderSetJPEGBuffer, or 0 if no buffer is set or @ref + * JxlDecoderReleaseJPEGBuffer was already called. + */ +JXL_EXPORT size_t JxlDecoderReleaseJPEGBuffer(JxlDecoder* dec); + +/** + * Sets output buffer for box output codestream. + * + * The data is owned by the caller and may be used by the decoder until @ref + * JxlDecoderReleaseBoxBuffer is called or the decoder is destroyed or + * reset so must be kept alive until then. + * + * If for the current box a box buffer was set before and released with @ref + * JxlDecoderReleaseBoxBuffer, bytes that the decoder has already output + * should not be included, only the remaining bytes output must be set. + * + * The @ref JxlDecoderReleaseBoxBuffer must be used at the next @ref JXL_DEC_BOX + * event or final @ref JXL_DEC_SUCCESS event to compute the size of the output + * box bytes. + * + * @param dec decoder object + * @param data pointer to next bytes to write to + * @param size amount of bytes available starting from data + * @return @ref JXL_DEC_ERROR if output buffer was already set and @ref + * JxlDecoderReleaseBoxBuffer was not called on it, @ref JXL_DEC_SUCCESS + * otherwise + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetBoxBuffer(JxlDecoder* dec, + uint8_t* data, size_t size); + +/** + * Releases buffer which was provided with @ref JxlDecoderSetBoxBuffer. + * + * Calling @ref JxlDecoderReleaseBoxBuffer is required whenever + * a buffer is already set and a new buffer needs to be added with @ref + * JxlDecoderSetBoxBuffer, but is not required before @ref + * JxlDecoderDestroy or @ref JxlDecoderReset. + * + * Calling @ref JxlDecoderReleaseBoxBuffer when no buffer is set is + * not an error and returns 0. + * + * @param dec decoder object + * @return the amount of bytes the decoder has not yet written to of the data + * set by @ref JxlDecoderSetBoxBuffer, or 0 if no buffer is set or @ref + * JxlDecoderReleaseBoxBuffer was already called. + */ +JXL_EXPORT size_t JxlDecoderReleaseBoxBuffer(JxlDecoder* dec); + +/** + * Configures whether to get boxes in raw mode or in decompressed mode. In raw + * mode, boxes are output as their bytes appear in the container file, which may + * be decompressed, or compressed if their type is "brob". In decompressed mode, + * "brob" boxes are decompressed with Brotli before outputting them. The size of + * the decompressed stream is not known before the decompression has already + * finished. + * + * The default mode is raw. This setting can only be changed before decoding, or + * directly after a @ref JXL_DEC_BOX event, and is remembered until the decoder + * is reset or destroyed. + * + * Enabling decompressed mode requires Brotli support from the library. + * + * @param dec decoder object + * @param decompress JXL_TRUE to transparently decompress, JXL_FALSE to get + * boxes in raw mode. + * @return @ref JXL_DEC_ERROR if decompressed mode is set and Brotli is not + * available, @ref JXL_DEC_SUCCESS otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetDecompressBoxes(JxlDecoder* dec, + JXL_BOOL decompress); + +/** + * Outputs the type of the current box, after a @ref JXL_DEC_BOX event occured, + * as 4 characters without null termination character. In case of a compressed + * "brob" box, this will return "brob" if the decompressed argument is + * JXL_FALSE, or the underlying box type if the decompressed argument is + * JXL_TRUE. + * + * The following box types are currently described in ISO/IEC 18181-2: + * - "Exif": a box with EXIF metadata. Starts with a 4-byte tiff header offset + * (big-endian uint32) that indicates the start of the actual EXIF data + * (which starts with a tiff header). Usually the offset will be zero and the + * EXIF data starts immediately after the offset field. The Exif orientation + * should be ignored by applications; the JPEG XL codestream orientation + * takes precedence and libjxl will by default apply the correct orientation + * automatically (see @ref JxlDecoderSetKeepOrientation). + * - "xml ": a box with XML data, in particular XMP metadata. + * - "jumb": a JUMBF superbox (JPEG Universal Metadata Box Format, ISO/IEC + * 19566-5). + * - "JXL ": mandatory signature box, must come first, 12 bytes long including + * the box header + * - "ftyp": a second mandatory signature box, must come second, 20 bytes long + * including the box header + * - "jxll": a JXL level box. This indicates if the codestream is level 5 or + * level 10 compatible. If not present, it is level 5. Level 10 allows more + * features such as very high image resolution and bit-depths above 16 bits + * per channel. Added automatically by the encoder when + * JxlEncoderSetCodestreamLevel is used + * - "jxlc": a box with the image codestream, in case the codestream is not + * split across multiple boxes. The codestream contains the JPEG XL image + * itself, including the basic info such as image dimensions, ICC color + * profile, and all the pixel data of all the image frames. + * - "jxlp": a codestream box in case it is split across multiple boxes. + * The contents are the same as in case of a jxlc box, when concatenated. + * - "brob": a Brotli-compressed box, which otherwise represents an existing + * type of box such as Exif or "xml ". When @ref JxlDecoderSetDecompressBoxes + * is set to JXL_TRUE, these boxes will be transparently decompressed by the + * decoder. + * - "jxli": frame index box, can list the keyframes in case of a JPEG XL + * animation allowing the decoder to jump to individual frames more + * efficiently. + * - "jbrd": JPEG reconstruction box, contains the information required to + * byte-for-byte losslessly recontruct a JPEG-1 image. The JPEG DCT + * coefficients (pixel content) themselves as well as the ICC profile are + * encoded in the JXL codestream (jxlc or jxlp) itself. EXIF, XMP and JUMBF + * metadata is encoded in the corresponding boxes. The jbrd box itself + * contains information such as the remaining app markers of the JPEG-1 file + * and everything else required to fit the information together into the + * exact original JPEG file. + * + * Other application-specific boxes can exist. Their typename should not begin + * with "jxl" or "JXL" or conflict with other existing typenames. + * + * The signature, jxl* and jbrd boxes are processed by the decoder and would + * typically be ignored by applications. The typical way to use this function is + * to check if an encountered box contains metadata that the application is + * interested in (e.g. EXIF or XMP metadata), in order to conditionally set a + * box buffer. + * + * @param dec decoder object + * @param type buffer to copy the type into + * @param decompressed which box type to get: JXL_FALSE to get the raw box type, + * which can be "brob", JXL_TRUE, get the underlying box type. + * @return @ref JXL_DEC_SUCCESS if the value is available, @ref JXL_DEC_ERROR if + * not, for example the JXL file does not use the container format. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetBoxType(JxlDecoder* dec, + JxlBoxType type, + JXL_BOOL decompressed); + +/** + * Returns the size of a box as it appears in the container file, after the @ref + * JXL_DEC_BOX event. For a non-compressed box, this is the size of the + * contents, excluding the 4 bytes indicating the box type. For a compressed + * "brob" box, this is the size of the compressed box contents plus the + * additional 4 byte indicating the underlying box type, but excluding the 4 + * bytes indicating "brob". This function gives the size of the data that will + * be written in the output buffer when getting boxes in the default raw + * compressed mode. When @ref JxlDecoderSetDecompressBoxes is enabled, the + * return value of function does not change, and the decompressed size is not + * known before it has already been decompressed and output. + * + * @param dec decoder object + * @param size raw size of the box in bytes + * @return @ref JXL_DEC_ERROR if no box size is available, @ref JXL_DEC_SUCCESS + * otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetBoxSizeRaw(const JxlDecoder* dec, + uint64_t* size); + +/** + * Configures at which progressive steps in frame decoding these @ref + * JXL_DEC_FRAME_PROGRESSION event occurs. The default value for the level + * of detail if this function is never called is `kDC`. + * + * @param dec decoder object + * @param detail at which level of detail to trigger @ref + * JXL_DEC_FRAME_PROGRESSION + * @return @ref JXL_DEC_SUCCESS on success, @ref JXL_DEC_ERROR on error, such as + * an invalid value for the progressive detail. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetProgressiveDetail(JxlDecoder* dec, JxlProgressiveDetail detail); + +/** + * Returns the intended downsampling ratio for the progressive frame produced + * by @ref JxlDecoderFlushImage after the latest @ref JXL_DEC_FRAME_PROGRESSION + * event. + * + * @param dec decoder object + * @return The intended downsampling ratio, can be 1, 2, 4 or 8. + */ +JXL_EXPORT size_t JxlDecoderGetIntendedDownsamplingRatio(JxlDecoder* dec); + +/** + * Outputs progressive step towards the decoded image so far when only partial + * input was received. If the flush was successful, the buffer set with @ref + * JxlDecoderSetImageOutBuffer will contain partial image data. + * + * Can be called when @ref JxlDecoderProcessInput returns @ref + * JXL_DEC_NEED_MORE_INPUT, after the @ref JXL_DEC_FRAME event already occurred + * and before the @ref JXL_DEC_FULL_IMAGE event occurred for a frame. + * + * @param dec decoder object + * @return @ref JXL_DEC_SUCCESS if image data was flushed to the output buffer, + * or @ref JXL_DEC_ERROR when no flush was done, e.g. if not enough image + * data was available yet even for flush, or no output buffer was set yet. + * This error is not fatal, it only indicates no flushed image is available + * right now. Regular decoding can still be performed. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderFlushImage(JxlDecoder* dec); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_DECODE_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/decode_cxx.h b/media/libjxl/src/lib/include/jxl/decode_cxx.h new file mode 100644 index 000000000..ed5c39347 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/decode_cxx.h @@ -0,0 +1,57 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/// @addtogroup libjxl_decoder +/// @{ +/// +/// @file decode_cxx.h +/// @brief C++ header-only helper for @ref decode.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_DECODE_CXX_H_ +#define JXL_DECODE_CXX_H_ + +#include + +#include "jxl/decode.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error "This a C++ only header. Use jxl/decode.h from C sources." +#endif + +/// Struct to call JxlDecoderDestroy from the JxlDecoderPtr unique_ptr. +struct JxlDecoderDestroyStruct { + /// Calls @ref JxlDecoderDestroy() on the passed decoder. + void operator()(JxlDecoder* decoder) { JxlDecoderDestroy(decoder); } +}; + +/// std::unique_ptr<> type that calls JxlDecoderDestroy() when releasing the +/// decoder. +/// +/// Use this helper type from C++ sources to ensure the decoder is destroyed and +/// their internal resources released. +typedef std::unique_ptr JxlDecoderPtr; + +/// Creates an instance of JxlDecoder into a JxlDecoderPtr and initializes it. +/// +/// This function returns a unique_ptr that will call JxlDecoderDestroy() when +/// releasing the pointer. See @ref JxlDecoderCreate for details on the +/// instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @return a @c NULL JxlDecoderPtr if the instance can not be allocated or +/// initialized +/// @return initialized JxlDecoderPtr instance otherwise. +static inline JxlDecoderPtr JxlDecoderMake( + const JxlMemoryManager* memory_manager) { + return JxlDecoderPtr(JxlDecoderCreate(memory_manager)); +} + +#endif // JXL_DECODE_CXX_H_ + +/// @} diff --git a/media/libjxl/src/lib/include/jxl/encode.h b/media/libjxl/src/lib/include/jxl/encode.h new file mode 100644 index 000000000..4813e3b7c --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/encode.h @@ -0,0 +1,1151 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_encoder + * @{ + * @file encode.h + * @brief Encoding API for JPEG XL. + */ + +#ifndef JXL_ENCODE_H_ +#define JXL_ENCODE_H_ + +#include "jxl/cms_interface.h" +#include "jxl/codestream_header.h" +#include "jxl/jxl_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * Encoder library version. + * + * @return the encoder library version as an integer: + * MAJOR_VERSION * 1000000 + MINOR_VERSION * 1000 + PATCH_VERSION. For example, + * version 1.2.3 would return 1002003. + */ +JXL_EXPORT uint32_t JxlEncoderVersion(void); + +/** + * Opaque structure that holds the JPEG XL encoder. + * + * Allocated and initialized with JxlEncoderCreate(). + * Cleaned up and deallocated with JxlEncoderDestroy(). + */ +typedef struct JxlEncoderStruct JxlEncoder; + +/** + * Settings and metadata for a single image frame. This includes encoder options + * for a frame such as compression quality and speed. + * + * Allocated and initialized with JxlEncoderFrameSettingsCreate(). + * Cleaned up and deallocated when the encoder is destroyed with + * JxlEncoderDestroy(). + */ +typedef struct JxlEncoderFrameSettingsStruct JxlEncoderFrameSettings; + +/** DEPRECATED: Use JxlEncoderFrameSettings instead. + */ +typedef JxlEncoderFrameSettings JxlEncoderOptions; + +/** + * Return value for multiple encoder functions. + */ +typedef enum { + /** Function call finished successfully, or encoding is finished and there is + * nothing more to be done. + */ + JXL_ENC_SUCCESS = 0, + + /** An error occurred, for example out of memory. + */ + JXL_ENC_ERROR = 1, + + /** The encoder needs more output buffer to continue encoding. + */ + JXL_ENC_NEED_MORE_OUTPUT = 2, + + /** DEPRECATED: the encoder does not return this status and there is no need + * to handle or expect it. + * Instead, JXL_ENC_ERROR is returned with error condition + * JXL_ENC_ERR_NOT_SUPPORTED. + */ + JXL_ENC_NOT_SUPPORTED = 3, + +} JxlEncoderStatus; + +/** + * Error conditions: + * API usage errors have the 0x80 bit set to 1 + * Other errors have the 0x80 bit set to 0 + */ +typedef enum { + /** No error + */ + JXL_ENC_ERR_OK = 0, + + /** Generic encoder error due to unspecified cause + */ + JXL_ENC_ERR_GENERIC = 1, + + /** Out of memory + * TODO(jon): actually catch this and return this error + */ + JXL_ENC_ERR_OOM = 2, + + /** JPEG bitstream reconstruction data could not be + * represented (e.g. too much tail data) + */ + JXL_ENC_ERR_JBRD = 3, + + /** Input is invalid (e.g. corrupt JPEG file or ICC profile) + */ + JXL_ENC_ERR_BAD_INPUT = 4, + + /** The encoder doesn't (yet) support this. Either no version of libjxl + * supports this, and the API is used incorrectly, or the libjxl version + * should have been checked before trying to do this. + */ + JXL_ENC_ERR_NOT_SUPPORTED = 0x80, + + /** The encoder API is used in an incorrect way. + * In this case, a debug build of libjxl should output a specific error + * message. (if not, please open an issue about it) + */ + JXL_ENC_ERR_API_USAGE = 0x81, + +} JxlEncoderError; + +/** + * Id of encoder options for a frame. This includes options such as setting + * encoding effort/speed or overriding the use of certain coding tools, for this + * frame. This does not include non-frame related encoder options such as for + * boxes. + */ +typedef enum { + /** Sets encoder effort/speed level without affecting decoding speed. Valid + * values are, from faster to slower speed: 1:lightning 2:thunder 3:falcon + * 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise. + * Default: squirrel (7). + */ + JXL_ENC_FRAME_SETTING_EFFORT = 0, + + /** Sets the decoding speed tier for the provided options. Minimum is 0 + * (slowest to decode, best quality/density), and maximum is 4 (fastest to + * decode, at the cost of some quality/density). Default is 0. + */ + JXL_ENC_FRAME_SETTING_DECODING_SPEED = 1, + + /** Sets resampling option. If enabled, the image is downsampled before + * compression, and upsampled to original size in the decoder. Integer option, + * use -1 for the default behavior (resampling only applied for low quality), + * 1 for no downsampling (1x1), 2 for 2x2 downsampling, 4 for 4x4 + * downsampling, 8 for 8x8 downsampling. + */ + JXL_ENC_FRAME_SETTING_RESAMPLING = 2, + + /** Similar to JXL_ENC_FRAME_SETTING_RESAMPLING, but for extra channels. + * Integer option, use -1 for the default behavior (depends on encoder + * implementation), 1 for no downsampling (1x1), 2 for 2x2 downsampling, 4 for + * 4x4 downsampling, 8 for 8x8 downsampling. + */ + JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING = 3, + + /** Indicates the frame added with @ref JxlEncoderAddImageFrame is already + * downsampled by the downsampling factor set with @ref + * JXL_ENC_FRAME_SETTING_RESAMPLING. The input frame must then be given in the + * downsampled resolution, not the full image resolution. The downsampled + * resolution is given by ceil(xsize / resampling), ceil(ysize / resampling) + * with xsize and ysize the dimensions given in the basic info, and resampling + * the factor set with @ref JXL_ENC_FRAME_SETTING_RESAMPLING. + * Use 0 to disable, 1 to enable. Default value is 0. + */ + JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED = 4, + + /** Adds noise to the image emulating photographic film noise, the higher the + * given number, the grainier the image will be. As an example, a value of 100 + * gives low noise whereas a value of 3200 gives a lot of noise. The default + * value is 0. + */ + JXL_ENC_FRAME_SETTING_PHOTON_NOISE = 5, + + /** Enables adaptive noise generation. This setting is not recommended for + * use, please use JXL_ENC_FRAME_SETTING_PHOTON_NOISE instead. Use -1 for the + * default (encoder chooses), 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_NOISE = 6, + + /** Enables or disables dots generation. Use -1 for the default (encoder + * chooses), 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_DOTS = 7, + + /** Enables or disables patches generation. Use -1 for the default (encoder + * chooses), 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_PATCHES = 8, + + /** Edge preserving filter level, -1 to 3. Use -1 for the default (encoder + * chooses), 0 to 3 to set a strength. + */ + JXL_ENC_FRAME_SETTING_EPF = 9, + + /** Enables or disables the gaborish filter. Use -1 for the default (encoder + * chooses), 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_GABORISH = 10, + + /** Enables modular encoding. Use -1 for default (encoder + * chooses), 0 to enforce VarDCT mode (e.g. for photographic images), 1 to + * enforce modular mode (e.g. for lossless images). + */ + JXL_ENC_FRAME_SETTING_MODULAR = 11, + + /** Enables or disables preserving color of invisible pixels. Use -1 for the + * default (1 if lossless, 0 if lossy), 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE = 12, + + /** Determines the order in which 256x256 regions are stored in the codestream + * for progressive rendering. Use -1 for the encoder + * default, 0 for scanline order, 1 for center-first order. + */ + JXL_ENC_FRAME_SETTING_GROUP_ORDER = 13, + + /** Determines the horizontal position of center for the center-first group + * order. Use -1 to automatically use the middle of the image, 0..xsize to + * specifically set it. + */ + JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X = 14, + + /** Determines the center for the center-first group order. Use -1 to + * automatically use the middle of the image, 0..ysize to specifically set it. + */ + JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y = 15, + + /** Enables or disables progressive encoding for modular mode. Use -1 for the + * encoder default, 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_RESPONSIVE = 16, + + /** Set the progressive mode for the AC coefficients of VarDCT, using spectral + * progression from the DCT coefficients. Use -1 for the encoder default, 0 to + * disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC = 17, + + /** Set the progressive mode for the AC coefficients of VarDCT, using + * quantization of the least significant bits. Use -1 for the encoder default, + * 0 to disable, 1 to enable. + */ + JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC = 18, + + /** Set the progressive mode using lower-resolution DC images for VarDCT. Use + * -1 for the encoder default, 0 to disable, 1 to have an extra 64x64 lower + * resolution pass, 2 to have a 512x512 and 64x64 lower resolution pass. + */ + JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC = 19, + + /** Use Global channel palette if the amount of colors is smaller than this + * percentage of range. Use 0-100 to set an explicit percentage, -1 to use the + * encoder default. Used for modular encoding. + */ + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT = 20, + + /** Use Local (per-group) channel palette if the amount of colors is smaller + * than this percentage of range. Use 0-100 to set an explicit percentage, -1 + * to use the encoder default. Used for modular encoding. + */ + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT = 21, + + /** Use color palette if amount of colors is smaller than or equal to this + * amount, or -1 to use the encoder default. Used for modular encoding. + */ + JXL_ENC_FRAME_SETTING_PALETTE_COLORS = 22, + + /** Enables or disables delta palette. Use -1 for the default (encoder + * chooses), 0 to disable, 1 to enable. Used in modular mode. + */ + JXL_ENC_FRAME_SETTING_LOSSY_PALETTE = 23, + + /** Color transform for internal encoding: -1 = default, 0=XYB, 1=none (RGB), + * 2=YCbCr. The XYB setting performs the forward XYB transform. None and + * YCbCr both perform no transform, but YCbCr is used to indicate that the + * encoded data losslessly represents YCbCr values. + */ + JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM = 24, + + /** Reversible color transform for modular encoding: -1=default, 0-41=RCT + * index, e.g. index 0 = none, index 6 = YCoCg. + * If this option is set to a non-default value, the RCT will be globally + * applied to the whole frame. + * The default behavior is to try several RCTs locally per modular group, + * depending on the speed and distance setting. + */ + JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE = 25, + + /** Group size for modular encoding: -1=default, 0=128, 1=256, 2=512, 3=1024. + */ + JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE = 26, + + /** Predictor for modular encoding. -1 = default, 0=zero, 1=left, 2=top, + * 3=avg0, 4=select, 5=gradient, 6=weighted, 7=topright, 8=topleft, + * 9=leftleft, 10=avg1, 11=avg2, 12=avg3, 13=toptop predictive average 14=mix + * 5 and 6, 15=mix everything. + */ + JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR = 27, + + /** Fraction of pixels used to learn MA trees as a percentage. -1 = default, + * 0 = no MA and fast decode, 50 = default value, 100 = all, values above + * 100 are also permitted. Higher values use more encoder memory. + */ + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT = 28, + + /** Number of extra (previous-channel) MA tree properties to use. -1 = + * default, 0-11 = valid values. Recommended values are in the range 0 to 3, + * or 0 to amount of channels minus 1 (including all extra channels, and + * excluding color channels when using VarDCT mode). Higher value gives slower + * encoding and slower decoding. + */ + JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS = 29, + + /** Enable or disable CFL (chroma-from-luma) for lossless JPEG recompression. + * -1 = default, 0 = disable CFL, 1 = enable CFL. + */ + JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL = 30, + + /** Prepare the frame for indexing in the frame index box. + * 0 = ignore this frame (same as not setting a value), + * 1 = index this frame within the Frame Index Box. + * If any frames are indexed, the first frame needs to + * be indexed, too. If the first frame is not indexed, and + * a later frame is attempted to be indexed, JXL_ENC_ERROR will occur. + * If non-keyframes, i.e., frames with cropping, blending or patches are + * attempted to be indexed, JXL_ENC_ERROR will occur. + */ + JXL_ENC_FRAME_INDEX_BOX = 31, + + /** Sets brotli encode effort for use in JPEG recompression and compressed + * metadata boxes (brob). Can be -1 (default) or 0 (fastest) to 11 (slowest). + * Default is based on the general encode effort in case of JPEG + * recompression, and 4 for brob boxes. + */ + JXL_ENC_FRAME_SETTING_BROTLI_EFFORT = 32, + + /** Enum value not to be used as an option. This value is added to force the + * C compiler to have the enum to take a known size. + */ + JXL_ENC_FRAME_SETTING_FILL_ENUM = 65535, + +} JxlEncoderFrameSettingId; + +/** + * Creates an instance of JxlEncoder and initializes it. + * + * @p memory_manager will be used for all the library dynamic allocations made + * from this instance. The parameter may be NULL, in which case the default + * allocator will be used. See jpegxl/memory_manager.h for details. + * + * @param memory_manager custom allocator function. It may be NULL. The memory + * manager will be copied internally. + * @return @c NULL if the instance can not be allocated or initialized + * @return pointer to initialized JxlEncoder otherwise + */ +JXL_EXPORT JxlEncoder* JxlEncoderCreate(const JxlMemoryManager* memory_manager); + +/** + * Re-initializes a JxlEncoder instance, so it can be re-used for encoding + * another image. All state and settings are reset as if the object was + * newly created with JxlEncoderCreate, but the memory manager is kept. + * + * @param enc instance to be re-initialized. + */ +JXL_EXPORT void JxlEncoderReset(JxlEncoder* enc); + +/** + * Deinitializes and frees JxlEncoder instance. + * + * @param enc instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlEncoderDestroy(JxlEncoder* enc); + +/** + * Sets the color management system (CMS) that will be used for color conversion + * (if applicable) during encoding. May only be set before starting encoding. If + * left unset, the default CMS implementation will be used. + * + * @param enc encoder object. + * @param cms structure representing a CMS implementation. See JxlCmsInterface + * for more details. + */ +JXL_EXPORT void JxlEncoderSetCms(JxlEncoder* enc, JxlCmsInterface cms); + +/** + * Set the parallel runner for multithreading. May only be set before starting + * encoding. + * + * @param enc encoder object. + * @param parallel_runner function pointer to runner for multithreading. It may + * be NULL to use the default, single-threaded, runner. A multithreaded + * runner should be set to reach fast performance. + * @param parallel_runner_opaque opaque pointer for parallel_runner. + * @return JXL_ENC_SUCCESS if the runner was set, JXL_ENC_ERROR + * otherwise (the previous runner remains set). + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderSetParallelRunner(JxlEncoder* enc, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque); + +/** + * Get the (last) error code in case JXL_ENC_ERROR was returned. + * + * @param enc encoder object. + * @return the JxlEncoderError that caused the (last) JXL_ENC_ERROR to be + * returned. + */ +JXL_EXPORT JxlEncoderError JxlEncoderGetError(JxlEncoder* enc); + +/** + * Encodes JPEG XL file using the available bytes. @p *avail_out indicates how + * many output bytes are available, and @p *next_out points to the input bytes. + * *avail_out will be decremented by the amount of bytes that have been + * processed by the encoder and *next_out will be incremented by the same + * amount, so *next_out will now point at the amount of *avail_out unprocessed + * bytes. + * + * The returned status indicates whether the encoder needs more output bytes. + * When the return value is not JXL_ENC_ERROR or JXL_ENC_SUCCESS, the encoding + * requires more JxlEncoderProcessOutput calls to continue. + * + * This encodes the frames and/or boxes added so far. If the last frame or last + * box has been added, @ref JxlEncoderCloseInput, @ref JxlEncoderCloseFrames + * and/or @ref JxlEncoderCloseBoxes must be called before the next + * @ref JxlEncoderProcessOutput call, or the codestream won't be encoded + * correctly. + * + * @param enc encoder object. + * @param next_out pointer to next bytes to write to. + * @param avail_out amount of bytes available starting from *next_out. + * @return JXL_ENC_SUCCESS when encoding finished and all events handled. + * @return JXL_ENC_ERROR when encoding failed, e.g. invalid input. + * @return JXL_ENC_NEED_MORE_OUTPUT more output buffer is necessary. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderProcessOutput(JxlEncoder* enc, + uint8_t** next_out, + size_t* avail_out); + +/** + * Sets the frame information for this frame to the encoder. This includes + * animation information such as frame duration to store in the frame header. + * The frame header fields represent the frame as passed to the encoder, but not + * necessarily the exact values as they will be encoded file format: the encoder + * could change crop and blending options of a frame for more efficient encoding + * or introduce additional internal frames. Animation duration and time code + * information is not altered since those are immutable metadata of the frame. + * + * It is not required to use this function, however if have_animation is set + * to true in the basic info, then this function should be used to set the + * time duration of this individual frame. By default individual frames have a + * time duration of 0, making them form a composite still. See @ref + * JxlFrameHeader for more information. + * + * This information is stored in the JxlEncoderFrameSettings and so is used for + * any frame encoded with these JxlEncoderFrameSettings. It is ok to change + * between @ref JxlEncoderAddImageFrame calls, each added image frame will have + * the frame header that was set in the options at the time of calling + * JxlEncoderAddImageFrame. + * + * The is_last and name_length fields of the JxlFrameHeader are ignored, use + * @ref JxlEncoderCloseFrames to indicate last frame, and @ref + * JxlEncoderSetFrameName to indicate the name and its length instead. + * Calling this function will clear any name that was previously set with @ref + * JxlEncoderSetFrameName. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param frame_header frame header data to set. Object owned by the caller and + * does not need to be kept in memory, its information is copied internally. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderSetFrameHeader(JxlEncoderFrameSettings* frame_settings, + const JxlFrameHeader* frame_header); + +/** + * Sets blend info of an extra channel. The blend info of extra channels is set + * separately from that of the color channels, the color channels are set with + * @ref JxlEncoderSetFrameHeader. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param index index of the extra channel to use. + * @param blend_info blend info to set for the extra channel + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelBlendInfo( + JxlEncoderFrameSettings* frame_settings, size_t index, + const JxlBlendInfo* blend_info); + +/** + * Sets the name of the animation frame. This function is optional, frames are + * not required to have a name. This setting is a part of the frame header, and + * the same principles as for @ref JxlEncoderSetFrameHeader apply. The + * name_length field of JxlFrameHeader is ignored by the encoder, this function + * determines the name length instead as the length in bytes of the C string. + * + * The maximum possible name length is 1071 bytes (excluding terminating null + * character). + * + * Calling @ref JxlEncoderSetFrameHeader clears any name that was + * previously set. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param frame_name name of the next frame to be encoded, as a UTF-8 encoded C + * string (zero terminated). Owned by the caller, and copied internally. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetFrameName( + JxlEncoderFrameSettings* frame_settings, const char* frame_name); + +/** + * Sets the buffer to read JPEG encoded bytes from for the next frame to encode. + * + * If JxlEncoderSetBasicInfo has not yet been called, calling + * JxlEncoderAddJPEGFrame will implicitly call it with the parameters of the + * added JPEG frame. + * + * If JxlEncoderSetColorEncoding or JxlEncoderSetICCProfile has not yet been + * called, calling JxlEncoderAddJPEGFrame will implicitly call it with the + * parameters of the added JPEG frame. + * + * If the encoder is set to store JPEG reconstruction metadata using @ref + * JxlEncoderStoreJPEGMetadata and a single JPEG frame is added, it will be + * possible to losslessly reconstruct the JPEG codestream. + * + * If this is the last frame, @ref JxlEncoderCloseInput or @ref + * JxlEncoderCloseFrames must be called before the next + * @ref JxlEncoderProcessOutput call. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param buffer bytes to read JPEG from. Owned by the caller and its contents + * are copied internally. + * @param size size of buffer in bytes. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderAddJPEGFrame(const JxlEncoderFrameSettings* frame_settings, + const uint8_t* buffer, size_t size); + +/** + * Sets the buffer to read pixels from for the next image to encode. Must call + * JxlEncoderSetBasicInfo before JxlEncoderAddImageFrame. + * + * Currently only some data types for pixel formats are supported: + * - JXL_TYPE_UINT8, with range 0..255 + * - JXL_TYPE_UINT16, with range 0..65535 + * - JXL_TYPE_FLOAT16, with nominal range 0..1 + * - JXL_TYPE_FLOAT, with nominal range 0..1 + * + * Note: the sample data type in pixel_format is allowed to be different from + * what is described in the JxlBasicInfo. The type in pixel_format describes the + * format of the uncompressed pixel buffer. The bits_per_sample and + * exponent_bits_per_sample in the JxlBasicInfo describes what will actually be + * encoded in the JPEG XL codestream. For example, to encode a 12-bit image, you + * would set bits_per_sample to 12, and you could use e.g. JXL_TYPE_UINT16 + * (where the values are rescaled to 16-bit, i.e. multiplied by 65535/4095) or + * JXL_TYPE_FLOAT (where the values are rescaled to 0..1, i.e. multiplied + * by 1.f/4095.f). While it is allowed, it is obviously not recommended to use a + * pixel_format with lower precision than what is specified in the JxlBasicInfo. + * + * We support interleaved channels as described by the JxlPixelFormat: + * - single-channel data, e.g. grayscale + * - single-channel + alpha + * - trichromatic, e.g. RGB + * - trichromatic + alpha + * + * Extra channels not handled here need to be set by @ref + * JxlEncoderSetExtraChannelBuffer. + * If the image has alpha, and alpha is not passed here, it will implicitly be + * set to all-opaque (an alpha value of 1.0 everywhere). + * + * The pixels are assumed to be encoded in the original profile that is set with + * JxlEncoderSetColorEncoding or JxlEncoderSetICCProfile. If none of these + * functions were used, the pixels are assumed to be nonlinear sRGB for integer + * data types (JXL_TYPE_UINT8, JXL_TYPE_UINT16), and linear sRGB for floating + * point data types (JXL_TYPE_FLOAT16, JXL_TYPE_FLOAT). + * + * Sample values in floating-point pixel formats are allowed to be outside the + * nominal range, e.g. to represent out-of-sRGB-gamut colors in the + * uses_original_profile=false case. They are however not allowed to be NaN or + * +-infinity. + * + * If this is the last frame, @ref JxlEncoderCloseInput or @ref + * JxlEncoderCloseFrames must be called before the next + * @ref JxlEncoderProcessOutput call. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param pixel_format format for pixels. Object owned by the caller and its + * contents are copied internally. + * @param buffer buffer type to input the pixel data from. Owned by the caller + * and its contents are copied internally. + * @param size size of buffer in bytes. This size should match what is implied + * by the frame dimensions and the pixel format. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderAddImageFrame( + const JxlEncoderFrameSettings* frame_settings, + const JxlPixelFormat* pixel_format, const void* buffer, size_t size); + +/** + * Sets the buffer to read pixels from for an extra channel at a given index. + * The index must be smaller than the num_extra_channels in the associated + * JxlBasicInfo. Must call @ref JxlEncoderSetExtraChannelInfo before + * JxlEncoderSetExtraChannelBuffer. + * + * TODO(firsching): mention what data types in pixel formats are supported. + * + * It is required to call this function for every extra channel, except for the + * alpha channel if that was already set through @ref JxlEncoderAddImageFrame. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param pixel_format format for pixels. Object owned by the caller and its + * contents are copied internally. The num_channels value is ignored, since the + * number of channels for an extra channel is always assumed to be one. + * @param buffer buffer type to input the pixel data from. Owned by the caller + * and its contents are copied internally. + * @param size size of buffer in bytes. This size should match what is implied + * by the frame dimensions and the pixel format. + * @param index index of the extra channel to use. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelBuffer( + const JxlEncoderFrameSettings* frame_settings, + const JxlPixelFormat* pixel_format, const void* buffer, size_t size, + uint32_t index); + +/** Adds a metadata box to the file format. JxlEncoderProcessOutput must be used + * to effectively write the box to the output. @ref JxlEncoderUseBoxes must + * be enabled before using this function. + * + * Boxes allow inserting application-specific data and metadata (Exif, XML/XMP, + * JUMBF and user defined boxes). + * + * The box format follows ISO BMFF and shares features and box types with other + * image and video formats, including the Exif, XML and JUMBF boxes. The box + * format for JPEG XL is specified in ISO/IEC 18181-2. + * + * Boxes in general don't contain other boxes inside, except a JUMBF superbox. + * Boxes follow each other sequentially and are byte-aligned. If the container + * format is used, the JXL stream consists of concatenated boxes. + * It is also possible to use a direct codestream without boxes, but in that + * case metadata cannot be added. + * + * Each box generally has the following byte structure in the file: + * - 4 bytes: box size including box header (Big endian. If set to 0, an + * 8-byte 64-bit size follows instead). + * - 4 bytes: type, e.g. "JXL " for the signature box, "jxlc" for a codestream + * box. + * - N bytes: box contents. + * + * Only the box contents are provided to the contents argument of this function, + * the encoder encodes the size header itself. Most boxes are written + * automatically by the encoder as needed ("JXL ", "ftyp", "jxll", "jxlc", + * "jxlp", "jxli", "jbrd"), and this function only needs to be called to add + * optional metadata when encoding from pixels (using JxlEncoderAddImageFrame). + * When recompressing JPEG files (using JxlEncoderAddJPEGFrame), if the input + * JPEG contains EXIF, XMP or JUMBF metadata, the corresponding boxes are + * already added automatically. + * + * Box types are given by 4 characters. The following boxes can be added with + * this function: + * - "Exif": a box with EXIF metadata, can be added by libjxl users, or is + * automatically added when needed for JPEG reconstruction. The contents of + * this box must be prepended by a 4-byte tiff header offset, which may + * be 4 zero bytes in case the tiff header follows immediately. + * The EXIF metadata must be in sync with what is encoded in the JPEG XL + * codestream, specifically the image orientation. While this is not + * recommended in practice, in case of conflicting metadata, the JPEG XL + * codestream takes precedence. + * - "xml ": a box with XML data, in particular XMP metadata, can be added by + * libjxl users, or is automatically added when needed for JPEG reconstruction + * - "jumb": a JUMBF superbox, which can contain boxes with different types of + * metadata inside. This box type can be added by the encoder transparently, + * and other libraries to create and handle JUMBF content exist. + * - Application-specific boxes. Their typename should not begin with "jxl" or + * "JXL" or conflict with other existing typenames, and they should be + * registered with MP4RA (mp4ra.org). + * + * These boxes can be stored uncompressed or Brotli-compressed (using a "brob" + * box), depending on the compress_box parameter. + * + * @param enc encoder object. + * @param type the box type, e.g. "Exif" for EXIF metadata, "xml " for XMP or + * IPTC metadata, "jumb" for JUMBF metadata. + * @param contents the full contents of the box, for example EXIF + * data. ISO BMFF box header must not be included, only the contents. Owned by + * the caller and its contents are copied internally. + * @param size size of the box contents. + * @param compress_box Whether to compress this box as a "brob" box. Requires + * Brotli support. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error, such as when + * using this function without JxlEncoderUseContainer, or adding a box type + * that would result in an invalid file format. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderAddBox(JxlEncoder* enc, + const JxlBoxType type, + const uint8_t* contents, + size_t size, + JXL_BOOL compress_box); + +/** + * Indicates the intention to add metadata boxes. This allows @ref + * JxlEncoderAddBox to be used. When using this function, then it is required + * to use @ref JxlEncoderCloseBoxes at the end. + * + * By default the encoder assumes no metadata boxes will be added. + * + * This setting can only be set at the beginning, before encoding starts. + * + * @param enc encoder object. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderUseBoxes(JxlEncoder* enc); + +/** + * Declares that no further boxes will be added with @ref JxlEncoderAddBox. + * This function must be called after the last box is added so the encoder knows + * the stream will be finished. It is not necessary to use this function if + * @ref JxlEncoderUseBoxes is not used. Further frames may still be added. + * + * Must be called between JxlEncoderAddBox of the last box + * and the next call to JxlEncoderProcessOutput, or @ref JxlEncoderProcessOutput + * won't output the last box correctly. + * + * NOTE: if you don't need to close frames and boxes at separate times, you can + * use @ref JxlEncoderCloseInput instead to close both at once. + * + * @param enc encoder object. + */ +JXL_EXPORT void JxlEncoderCloseBoxes(JxlEncoder* enc); + +/** + * Declares that no frames will be added and @ref JxlEncoderAddImageFrame and + * @ref JxlEncoderAddJPEGFrame won't be called anymore. Further metadata boxes + * may still be added. This function or @ref JxlEncoderCloseInput must be called + * after adding the last frame and the next call to + * @ref JxlEncoderProcessOutput, or the frame won't be properly marked as last. + * + * NOTE: if you don't need to close frames and boxes at separate times, you can + * use @ref JxlEncoderCloseInput instead to close both at once. + * + * @param enc encoder object. + */ +JXL_EXPORT void JxlEncoderCloseFrames(JxlEncoder* enc); + +/** + * Closes any input to the encoder, equivalent to calling JxlEncoderCloseFrames + * as well as calling JxlEncoderCloseBoxes if needed. No further input of any + * kind may be given to the encoder, but further @ref JxlEncoderProcessOutput + * calls should be done to create the final output. + * + * The requirements of both @ref JxlEncoderCloseFrames and @ref + * JxlEncoderCloseBoxes apply to this function. Either this function or the + * other two must be called after the final frame and/or box, and the next + * @ref JxlEncoderProcessOutput call, or the codestream won't be encoded + * correctly. + * + * @param enc encoder object. + */ +JXL_EXPORT void JxlEncoderCloseInput(JxlEncoder* enc); + +/** + * Sets the original color encoding of the image encoded by this encoder. This + * is an alternative to JxlEncoderSetICCProfile and only one of these two must + * be used. This one sets the color encoding as a @ref JxlColorEncoding, while + * the other sets it as ICC binary data. + * Must be called after JxlEncoderSetBasicInfo. + * + * @param enc encoder object. + * @param color color encoding. Object owned by the caller and its contents are + * copied internally. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR or + * JXL_ENC_NOT_SUPPORTED otherwise + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderSetColorEncoding(JxlEncoder* enc, const JxlColorEncoding* color); + +/** + * Sets the original color encoding of the image encoded by this encoder as an + * ICC color profile. This is an alternative to JxlEncoderSetColorEncoding and + * only one of these two must be used. This one sets the color encoding as ICC + * binary data, while the other defines it as a @ref JxlColorEncoding. + * Must be called after JxlEncoderSetBasicInfo. + * + * @param enc encoder object. + * @param icc_profile bytes of the original ICC profile + * @param size size of the icc_profile buffer in bytes + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR or + * JXL_ENC_NOT_SUPPORTED otherwise + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetICCProfile(JxlEncoder* enc, + const uint8_t* icc_profile, + size_t size); + +/** + * Initializes a JxlBasicInfo struct to default values. + * For forwards-compatibility, this function has to be called before values + * are assigned to the struct fields. + * The default values correspond to an 8-bit RGB image, no alpha or any + * other extra channels. + * + * @param info global image metadata. Object owned by the caller. + */ +JXL_EXPORT void JxlEncoderInitBasicInfo(JxlBasicInfo* info); + +/** + * Initializes a JxlFrameHeader struct to default values. + * For forwards-compatibility, this function has to be called before values + * are assigned to the struct fields. + * The default values correspond to a frame with no animation duration and the + * 'replace' blend mode. After using this function, For animation duration must + * be set, for composite still blend settings must be set. + * + * @param frame_header frame metadata. Object owned by the caller. + */ +JXL_EXPORT void JxlEncoderInitFrameHeader(JxlFrameHeader* frame_header); + +/** + * Initializes a JxlBlendInfo struct to default values. + * For forwards-compatibility, this function has to be called before values + * are assigned to the struct fields. + * + * @param blend_info blending info. Object owned by the caller. + */ +JXL_EXPORT void JxlEncoderInitBlendInfo(JxlBlendInfo* blend_info); + +/** + * Sets the global metadata of the image encoded by this encoder. + * + * If the JxlBasicInfo contains information of extra channels beyond an alpha + * channel, then @ref JxlEncoderSetExtraChannelInfo must be called between + * JxlEncoderSetBasicInfo and @ref JxlEncoderAddImageFrame. In order to indicate + * extra channels, the value of `info.num_extra_channels` should be set to the + * number of extra channels, also counting the alpha channel if present. + * + * @param enc encoder object. + * @param info global image metadata. Object owned by the caller and its + * contents are copied internally. + * @return JXL_ENC_SUCCESS if the operation was successful, + * JXL_ENC_ERROR or JXL_ENC_NOT_SUPPORTED otherwise + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetBasicInfo(JxlEncoder* enc, + const JxlBasicInfo* info); + +/** + * Initializes a JxlExtraChannelInfo struct to default values. + * For forwards-compatibility, this function has to be called before values + * are assigned to the struct fields. + * The default values correspond to an 8-bit channel of the provided type. + * + * @param type type of the extra channel. + * @param info global extra channel metadata. Object owned by the caller and its + * contents are copied internally. + */ +JXL_EXPORT void JxlEncoderInitExtraChannelInfo(JxlExtraChannelType type, + JxlExtraChannelInfo* info); + +/** + * Sets information for the extra channel at the given index. The index + * must be smaller than num_extra_channels in the associated JxlBasicInfo. + * + * @param enc encoder object + * @param index index of the extra channel to set. + * @param info global extra channel metadata. Object owned by the caller and its + * contents are copied internally. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelInfo( + JxlEncoder* enc, size_t index, const JxlExtraChannelInfo* info); + +/** + * Sets the name for the extra channel at the given index in UTF-8. The index + * must be smaller than the num_extra_channels in the associated JxlBasicInfo. + * + * TODO(lode): remove size parameter for consistency with + * JxlEncoderSetFrameName + * + * @param enc encoder object + * @param index index of the extra channel to set. + * @param name buffer with the name of the extra channel. + * @param size size of the name buffer in bytes, not counting the terminating + * character. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelName(JxlEncoder* enc, + size_t index, + const char* name, + size_t size); + +/** + * Sets a frame-specific option of integer type to the encoder options. + * The JxlEncoderFrameSettingId argument determines which option is set. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param option ID of the option to set. + * @param value Integer value to set for this option. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR in + * case of an error, such as invalid or unknown option id, or invalid integer + * value for the given option. If an error is returned, the state of the + * JxlEncoderFrameSettings object is still valid and is the same as before this + * function was called. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderFrameSettingsSetOption( + JxlEncoderFrameSettings* frame_settings, JxlEncoderFrameSettingId option, + int64_t value); + +/** + * Sets a frame-specific option of float type to the encoder options. + * The JxlEncoderFrameSettingId argument determines which option is set. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param option ID of the option to set. + * @param value Float value to set for this option. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR in + * case of an error, such as invalid or unknown option id, or invalid integer + * value for the given option. If an error is returned, the state of the + * JxlEncoderFrameSettings object is still valid and is the same as before this + * function was called. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderFrameSettingsSetFloatOption( + JxlEncoderFrameSettings* frame_settings, JxlEncoderFrameSettingId option, + float value); + +/** Forces the encoder to use the box-based container format (BMFF) even + * when not necessary. + * + * When using @ref JxlEncoderUseBoxes, @ref JxlEncoderStoreJPEGMetadata or @ref + * JxlEncoderSetCodestreamLevel with level 10, the encoder will automatically + * also use the container format, it is not necessary to use + * JxlEncoderUseContainer for those use cases. + * + * By default this setting is disabled. + * + * This setting can only be set at the beginning, before encoding starts. + * + * @param enc encoder object. + * @param use_container true if the encoder should always output the JPEG XL + * container format, false to only output it when necessary. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderUseContainer(JxlEncoder* enc, + JXL_BOOL use_container); + +/** + * Configure the encoder to store JPEG reconstruction metadata in the JPEG XL + * container. + * + * If this is set to true and a single JPEG frame is added, it will be + * possible to losslessly reconstruct the JPEG codestream. + * + * This setting can only be set at the beginning, before encoding starts. + * + * @param enc encoder object. + * @param store_jpeg_metadata true if the encoder should store JPEG metadata. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderStoreJPEGMetadata(JxlEncoder* enc, JXL_BOOL store_jpeg_metadata); + +/** Sets the feature level of the JPEG XL codestream. Valid values are 5 and + * 10, or -1 (to choose automatically). Using the minimum required level, or + * level 5 in most cases, is recommended for compatibility with all decoders. + * + * Level 5: for end-user image delivery, this level is the most widely + * supported level by image decoders and the recommended level to use unless a + * level 10 feature is absolutely necessary. Supports a maximum resolution + * 268435456 pixels total with a maximum width or height of 262144 pixels, + * maximum 16-bit color channel depth, maximum 120 frames per second for + * animation, maximum ICC color profile size of 4 MiB, it allows all color + * models and extra channel types except CMYK and the JXL_CHANNEL_BLACK extra + * channel, and a maximum of 4 extra channels in addition to the 3 color + * channels. It also sets boundaries to certain internally used coding tools. + * + * Level 10: this level removes or increases the bounds of most of the level + * 5 limitations, allows CMYK color and up to 32 bits per color channel, but + * may be less widely supported. + * + * The default value is -1. This means the encoder will automatically choose + * between level 5 and level 10 based on what information is inside the @ref + * JxlBasicInfo structure. Do note that some level 10 features, particularly + * those used by animated JPEG XL codestreams, might require level 10, even + * though the @ref JxlBasicInfo only suggests level 5. In this case, the level + * must be explicitly set to 10, otherwise the encoder will return an error. + * The encoder will restrict internal encoding choices to those compatible with + * the level setting. + * + * This setting can only be set at the beginning, before encoding starts. + * + * @param enc encoder object. + * @param level the level value to set, must be -1, 5, or 10. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetCodestreamLevel(JxlEncoder* enc, + int level); + +/** Returns the codestream level required to support the currently configured + * settings and basic info. This function can only be used at the beginning, + * before encoding starts, but after setting basic info. + * + * This does not support per-frame settings, only global configuration, such as + * the image dimensions, that are known at the time of writing the header of + * the JPEG XL file. + * + * If this returns 5, nothing needs to be done and the codestream can be + * compatible with any decoder. If this returns 10, JxlEncoderSetCodestreamLevel + * has to be used to set the codestream level to 10, or the encoder can be + * configured differently to allow using the more compatible level 5. + * + * @param enc encoder object. + * @return -1 if no level can support the configuration (e.g. image dimensions + * larger than even level 10 supports), 5 if level 5 is supported, 10 if setting + * the codestream level to 10 is required. + * + */ +JXL_EXPORT int JxlEncoderGetRequiredCodestreamLevel(const JxlEncoder* enc); + +/** + * Enables lossless encoding. + * + * This is not an option like the others on itself, but rather while enabled it + * overrides a set of existing options (such as distance, modular mode and + * color transform) that enables bit-for-bit lossless encoding. + * + * When disabled, those options are not overridden, but since those options + * could still have been manually set to a combination that operates losslessly, + * using this function with lossless set to JXL_DEC_FALSE does not guarantee + * lossy encoding, though the default set of options is lossy. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param lossless whether to override options for lossless mode + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetFrameLossless( + JxlEncoderFrameSettings* frame_settings, JXL_BOOL lossless); + +/** DEPRECATED: use JxlEncoderSetFrameLossless instead. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderOptionsSetLossless(JxlEncoderFrameSettings*, JXL_BOOL); + +/** + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param effort the effort value to set. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + * + * DEPRECATED: use JxlEncoderFrameSettingsSetOption(frame_settings, + * JXL_ENC_FRAME_SETTING_EFFORT, effort) instead. + */ +JXL_EXPORT JXL_DEPRECATED JxlEncoderStatus +JxlEncoderOptionsSetEffort(JxlEncoderFrameSettings* frame_settings, int effort); + +/** + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param tier the decoding speed tier to set. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + * + * DEPRECATED: use JxlEncoderFrameSettingsSetOption(frame_settings, + * JXL_ENC_FRAME_SETTING_DECODING_SPEED, tier) instead. + */ +JXL_EXPORT JXL_DEPRECATED JxlEncoderStatus JxlEncoderOptionsSetDecodingSpeed( + JxlEncoderFrameSettings* frame_settings, int tier); + +/** + * Sets the distance level for lossy compression: target max butteraugli + * distance, lower = higher quality. Range: 0 .. 15. + * 0.0 = mathematically lossless (however, use JxlEncoderSetFrameLossless + * instead to use true lossless, as setting distance to 0 alone is not the only + * requirement). 1.0 = visually lossless. Recommended range: 0.5 .. 3.0. Default + * value: 1.0. + * + * @param frame_settings set of options and metadata for this frame. Also + * includes reference to the encoder object. + * @param distance the distance value to set. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetFrameDistance( + JxlEncoderFrameSettings* frame_settings, float distance); + +/** DEPRECATED: use JxlEncoderSetFrameDistance instead. + */ +JXL_EXPORT JXL_DEPRECATED JxlEncoderStatus +JxlEncoderOptionsSetDistance(JxlEncoderFrameSettings*, float); + +/** + * Create a new set of encoder options, with all values initially copied from + * the @p source options, or set to default if @p source is NULL. + * + * The returned pointer is an opaque struct tied to the encoder and it will be + * deallocated by the encoder when JxlEncoderDestroy() is called. For functions + * taking both a @ref JxlEncoder and a @ref JxlEncoderFrameSettings, only + * JxlEncoderFrameSettings created with this function for the same encoder + * instance can be used. + * + * @param enc encoder object. + * @param source source options to copy initial values from, or NULL to get + * defaults initialized to defaults. + * @return the opaque struct pointer identifying a new set of encoder options. + */ +JXL_EXPORT JxlEncoderFrameSettings* JxlEncoderFrameSettingsCreate( + JxlEncoder* enc, const JxlEncoderFrameSettings* source); + +/** DEPRECATED: use JxlEncoderFrameSettingsCreate instead. + */ +JXL_EXPORT JXL_DEPRECATED JxlEncoderFrameSettings* JxlEncoderOptionsCreate( + JxlEncoder*, const JxlEncoderFrameSettings*); + +/** + * Sets a color encoding to be sRGB. + * + * @param color_encoding color encoding instance. + * @param is_gray whether the color encoding should be gray scale or color. + */ +JXL_EXPORT void JxlColorEncodingSetToSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray); + +/** + * Sets a color encoding to be linear sRGB. + * + * @param color_encoding color encoding instance. + * @param is_gray whether the color encoding should be gray scale or color. + */ +JXL_EXPORT void JxlColorEncodingSetToLinearSRGB( + JxlColorEncoding* color_encoding, JXL_BOOL is_gray); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_ENCODE_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/encode_cxx.h b/media/libjxl/src/lib/include/jxl/encode_cxx.h new file mode 100644 index 000000000..494c03c7e --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/encode_cxx.h @@ -0,0 +1,57 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/// @addtogroup libjxl_encoder +///@{ +/// +/// @file encode_cxx.h +/// @brief C++ header-only helper for @ref encode.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_ENCODE_CXX_H_ +#define JXL_ENCODE_CXX_H_ + +#include + +#include "jxl/encode.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error "This a C++ only header. Use jxl/encode.h from C sources." +#endif + +/// Struct to call JxlEncoderDestroy from the JxlEncoderPtr unique_ptr. +struct JxlEncoderDestroyStruct { + /// Calls @ref JxlEncoderDestroy() on the passed encoder. + void operator()(JxlEncoder* encoder) { JxlEncoderDestroy(encoder); } +}; + +/// std::unique_ptr<> type that calls JxlEncoderDestroy() when releasing the +/// encoder. +/// +/// Use this helper type from C++ sources to ensure the encoder is destroyed and +/// their internal resources released. +typedef std::unique_ptr JxlEncoderPtr; + +/// Creates an instance of JxlEncoder into a JxlEncoderPtr and initializes it. +/// +/// This function returns a unique_ptr that will call JxlEncoderDestroy() when +/// releasing the pointer. See @ref JxlEncoderCreate for details on the +/// instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @return a @c NULL JxlEncoderPtr if the instance can not be allocated or +/// initialized +/// @return initialized JxlEncoderPtr instance otherwise. +static inline JxlEncoderPtr JxlEncoderMake( + const JxlMemoryManager* memory_manager) { + return JxlEncoderPtr(JxlEncoderCreate(memory_manager)); +} + +#endif // JXL_ENCODE_CXX_H_ + +/// @} diff --git a/media/libjxl/src/lib/include/jxl/memory_manager.h b/media/libjxl/src/lib/include/jxl/memory_manager.h new file mode 100644 index 000000000..52640a8be --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/memory_manager.h @@ -0,0 +1,72 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file memory_manager.h + * @brief Abstraction functions used by JPEG XL to allocate memory. + */ + +#ifndef JXL_MEMORY_MANAGER_H_ +#define JXL_MEMORY_MANAGER_H_ + +#include + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * Allocating function for a memory region of a given size. + * + * Allocates a contiguous memory region of size @p size bytes. The returned + * memory may not be aligned to a specific size or initialized at all. + * + * @param opaque custom memory manager handle provided by the caller. + * @param size in bytes of the requested memory region. + * @return @c NULL if the memory can not be allocated, + * @return pointer to the memory otherwise. + */ +typedef void* (*jpegxl_alloc_func)(void* opaque, size_t size); + +/** + * Deallocating function pointer type. + * + * This function @b MUST do nothing if @p address is @c NULL. + * + * @param opaque custom memory manager handle provided by the caller. + * @param address memory region pointer returned by ::jpegxl_alloc_func, or @c + * NULL. + */ +typedef void (*jpegxl_free_func)(void* opaque, void* address); + +/** + * Memory Manager struct. + * These functions, when provided by the caller, will be used to handle memory + * allocations. + */ +typedef struct JxlMemoryManagerStruct { + /** The opaque pointer that will be passed as the first parameter to all the + * functions in this struct. */ + void* opaque; + + /** Memory allocation function. This can be NULL if and only if also the + * free() member in this class is NULL. All dynamic memory will be allocated + * and freed with these functions if they are not NULL. */ + jpegxl_alloc_func alloc; + /** Free function matching the alloc() member. */ + jpegxl_free_func free; + + /* TODO(deymo): Add cache-aligned alloc/free functions here. */ +} JxlMemoryManager; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_MEMORY_MANAGER_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/parallel_runner.h b/media/libjxl/src/lib/include/jxl/parallel_runner.h new file mode 100644 index 000000000..45394e972 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/parallel_runner.h @@ -0,0 +1,156 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + */ +/** + * @file parallel_runner.h + */ + +/** API for running data operations in parallel in a multi-threaded environment. + * This module allows the JPEG XL caller to define their own way of creating and + * assigning threads. + * + * The JxlParallelRunner function type defines a parallel data processing + * runner that may be implemented by the caller to allow the library to process + * in multiple threads. The multi-threaded processing in this library only + * requires to run the same function over each number of a range, possibly + * running each call in a different thread. The JPEG XL caller is responsible + * for implementing this logic using the thread APIs available in their system. + * For convenience, a C++ implementation based on std::thread is provided in + * jpegxl/parallel_runner_thread.h (part of the jpegxl_threads library). + * + * Thread pools usually store small numbers of heterogeneous tasks in a queue. + * When tasks are identical or differ only by an integer input parameter, it is + * much faster to store just one function of an integer parameter and call it + * for each value. Conventional vector-of-tasks can be run in parallel using a + * lambda function adapter that simply calls task_funcs[task]. + * + * If no multi-threading is desired, a @c NULL value of JxlParallelRunner + * will use an internal implementation without multi-threading. + */ + +#ifndef JXL_PARALLEL_RUNNER_H_ +#define JXL_PARALLEL_RUNNER_H_ + +#include +#include + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Return code used in the JxlParallel* functions as return value. A value + * of 0 means success and any other value means error. The special value + * JXL_PARALLEL_RET_RUNNER_ERROR can be used by the runner to indicate any + * other error. + */ +typedef int JxlParallelRetCode; + +/** + * General error returned by the JxlParallelRunInit function to indicate + * an error. + */ +#define JXL_PARALLEL_RET_RUNNER_ERROR (-1) + +/** + * Parallel run initialization callback. See JxlParallelRunner for details. + * + * This function MUST be called by the JxlParallelRunner only once, on the + * same thread that called JxlParallelRunner, before any parallel execution. + * The purpose of this call is to provide the maximum number of threads that the + * JxlParallelRunner will use, which can be used by JPEG XL to allocate + * per-thread storage if needed. + * + * @param jpegxl_opaque the @p jpegxl_opaque handle provided to + * JxlParallelRunner() must be passed here. + * @param num_threads the maximum number of threads. This value must be + * positive. + * @return 0 if the initialization process was successful. + * @return an error code if there was an error, which should be returned by + * JxlParallelRunner(). + */ +typedef JxlParallelRetCode (*JxlParallelRunInit)(void* jpegxl_opaque, + size_t num_threads); + +/** + * Parallel run data processing callback. See JxlParallelRunner for details. + * + * This function MUST be called once for every number in the range [start_range, + * end_range) (including start_range but not including end_range) passing this + * number as the @p value. Calls for different value may be executed from + * different threads in parallel. + * + * @param jpegxl_opaque the @p jpegxl_opaque handle provided to + * JxlParallelRunner() must be passed here. + * @param value the number in the range [start_range, end_range) of the call. + * @param thread_id the thread number where this function is being called from. + * This must be lower than the @p num_threads value passed to + * JxlParallelRunInit. + */ +typedef void (*JxlParallelRunFunction)(void* jpegxl_opaque, uint32_t value, + size_t thread_id); + +/** + * JxlParallelRunner function type. A parallel runner implementation can be + * provided by a JPEG XL caller to allow running computations in multiple + * threads. This function must call the initialization function @p init in the + * same thread that called it and then call the passed @p func once for every + * number in the range [start_range, end_range) (including start_range but not + * including end_range) possibly from different multiple threads in parallel. + * + * The JxlParallelRunner function does not need to be re-entrant. This means + * that the same JxlParallelRunner function with the same runner_opaque + * provided parameter will not be called from the library from either @p init or + * @p func in the same decoder or encoder instance. However, a single decoding + * or encoding instance may call the provided JxlParallelRunner multiple + * times for different parts of the decoding or encoding process. + * + * @return 0 if the @p init call succeeded (returned 0) and no other error + * occurred in the runner code. + * @return JXL_PARALLEL_RET_RUNNER_ERROR if an error occurred in the runner + * code, for example, setting up the threads. + * @return the return value of @p init() if non-zero. + */ +typedef JxlParallelRetCode (*JxlParallelRunner)( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + +/* The following is an example of a JxlParallelRunner that doesn't use any + * multi-threading. Note that this implementation doesn't store any state + * between multiple calls of the ExampleSequentialRunner function, so the + * runner_opaque value is not used. + + JxlParallelRetCode ExampleSequentialRunner(void* runner_opaque, + void* jpegxl_opaque, + JxlParallelRunInit init, + JxlParallelRunFunction func, + uint32_t start_range, + uint32_t end_range) { + // We only use one thread (the currently running thread). + JxlParallelRetCode init_ret = (*init)(jpegxl_opaque, 1); + if (init_ret != 0) return init_ret; + + // In case of other initialization error (for example when initializing the + // threads) one can return JXL_PARALLEL_RET_RUNNER_ERROR. + + for (uint32_t i = start_range; i < end_range; i++) { + // Every call is in the thread number 0. These don't need to be in any + // order. + (*func)(jpegxl_opaque, i, 0); + } + return 0; + } + */ + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_PARALLEL_RUNNER_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/resizable_parallel_runner.h b/media/libjxl/src/lib/include/jxl/resizable_parallel_runner.h new file mode 100644 index 000000000..f6344bdfd --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/resizable_parallel_runner.h @@ -0,0 +1,79 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_threads + * @{ + * @file resizable_parallel_runner.h + * @brief implementation using std::thread of a resizeable ::JxlParallelRunner. + */ + +/** Implementation of JxlParallelRunner than can be used to enable + * multithreading when using the JPEG XL library. This uses std::thread + * internally and related synchronization functions. The number of threads + * created can be changed after creation of the thread pool; the threads + * (including the main thread) are re-used for every + * ResizableParallelRunner::Runner call. Only one concurrent + * JxlResizableParallelRunner call per instance is allowed at a time. + * + * This is a scalable, lower-overhead thread pool runner, especially suitable + * for data-parallel computations in the fork-join model, where clients need to + * know when all tasks have completed. + * + * Compared to the implementation in @ref thread_parallel_runner.h, this + * implementation is tuned for execution on lower-powered systems, including + * for example ARM CPUs with big.LITTLE computation models. + */ + +#ifndef JXL_RESIZABLE_PARALLEL_RUNNER_H_ +#define JXL_RESIZABLE_PARALLEL_RUNNER_H_ + +#include +#include +#include +#include + +#include "jxl/jxl_threads_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Parallel runner internally using std::thread. Use as JxlParallelRunner. + */ +JXL_THREADS_EXPORT JxlParallelRetCode JxlResizableParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + +/** Creates the runner for JxlResizableParallelRunner. Use as the opaque + * runner. The runner will execute tasks on the calling thread until + * @ref JxlResizableParallelRunnerSetThreads is called. + */ +JXL_THREADS_EXPORT void* JxlResizableParallelRunnerCreate( + const JxlMemoryManager* memory_manager); + +/** Changes the number of threads for JxlResizableParallelRunner. + */ +JXL_THREADS_EXPORT void JxlResizableParallelRunnerSetThreads( + void* runner_opaque, size_t num_threads); + +/** Suggests a number of threads to use for an image of given size. + */ +JXL_THREADS_EXPORT uint32_t +JxlResizableParallelRunnerSuggestThreads(uint64_t xsize, uint64_t ysize); + +/** Destroys the runner created by JxlResizableParallelRunnerCreate. + */ +JXL_THREADS_EXPORT void JxlResizableParallelRunnerDestroy(void* runner_opaque); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_RESIZABLE_PARALLEL_RUNNER_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/resizable_parallel_runner_cxx.h b/media/libjxl/src/lib/include/jxl/resizable_parallel_runner_cxx.h new file mode 100644 index 000000000..9a310c81a --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/resizable_parallel_runner_cxx.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/// @addtogroup libjxl_threads +/// @{ +/// +/// @file resizable_parallel_runner_cxx.h +/// @ingroup libjxl_threads +/// @brief C++ header-only helper for @ref resizable_parallel_runner.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_RESIZABLE_PARALLEL_RUNNER_CXX_H_ +#define JXL_RESIZABLE_PARALLEL_RUNNER_CXX_H_ + +#include + +#include "jxl/resizable_parallel_runner.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error \ + "This a C++ only header. Use jxl/jxl_resizable_parallel_runner.h from C" \ + "sources." +#endif + +/// Struct to call JxlResizableParallelRunnerDestroy from the +/// JxlResizableParallelRunnerPtr unique_ptr. +struct JxlResizableParallelRunnerDestroyStruct { + /// Calls @ref JxlResizableParallelRunnerDestroy() on the passed runner. + void operator()(void* runner) { JxlResizableParallelRunnerDestroy(runner); } +}; + +/// std::unique_ptr<> type that calls JxlResizableParallelRunnerDestroy() when +/// releasing the runner. +/// +/// Use this helper type from C++ sources to ensure the runner is destroyed and +/// their internal resources released. +typedef std::unique_ptr + JxlResizableParallelRunnerPtr; + +/// Creates an instance of JxlResizableParallelRunner into a +/// JxlResizableParallelRunnerPtr and initializes it. +/// +/// This function returns a unique_ptr that will call +/// JxlResizableParallelRunnerDestroy() when releasing the pointer. See @ref +/// JxlResizableParallelRunnerCreate for details on the instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @return a @c NULL JxlResizableParallelRunnerPtr if the instance can not be +/// allocated or initialized +/// @return initialized JxlResizableParallelRunnerPtr instance otherwise. +static inline JxlResizableParallelRunnerPtr JxlResizableParallelRunnerMake( + const JxlMemoryManager* memory_manager) { + return JxlResizableParallelRunnerPtr( + JxlResizableParallelRunnerCreate(memory_manager)); +} + +#endif // JXL_RESIZABLE_PARALLEL_RUNNER_CXX_H_ + +/// @} diff --git a/media/libjxl/src/lib/include/jxl/thread_parallel_runner.h b/media/libjxl/src/lib/include/jxl/thread_parallel_runner.h new file mode 100644 index 000000000..581ff7327 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/thread_parallel_runner.h @@ -0,0 +1,73 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_threads + * @{ + * @file thread_parallel_runner.h + * @brief implementation using std::thread of a ::JxlParallelRunner. + */ + +/** Implementation of JxlParallelRunner than can be used to enable + * multithreading when using the JPEG XL library. This uses std::thread + * internally and related synchronization functions. The number of threads + * created is fixed at construction time and the threads are re-used for every + * ThreadParallelRunner::Runner call. Only one concurrent + * JxlThreadParallelRunner call per instance is allowed at a time. + * + * This is a scalable, lower-overhead thread pool runner, especially suitable + * for data-parallel computations in the fork-join model, where clients need to + * know when all tasks have completed. + * + * This thread pool can efficiently load-balance millions of tasks using an + * atomic counter, thus avoiding per-task virtual or system calls. With 48 + * hyperthreads and 1M tasks that add to an atomic counter, overall runtime is + * 10-20x higher when using std::async, and ~200x for a queue-based thread + */ + +#ifndef JXL_THREAD_PARALLEL_RUNNER_H_ +#define JXL_THREAD_PARALLEL_RUNNER_H_ + +#include +#include +#include +#include + +#include "jxl/jxl_threads_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Parallel runner internally using std::thread. Use as JxlParallelRunner. + */ +JXL_THREADS_EXPORT JxlParallelRetCode JxlThreadParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + +/** Creates the runner for JxlThreadParallelRunner. Use as the opaque + * runner. + */ +JXL_THREADS_EXPORT void* JxlThreadParallelRunnerCreate( + const JxlMemoryManager* memory_manager, size_t num_worker_threads); + +/** Destroys the runner created by JxlThreadParallelRunnerCreate. + */ +JXL_THREADS_EXPORT void JxlThreadParallelRunnerDestroy(void* runner_opaque); + +/** Returns a default num_worker_threads value for + * JxlThreadParallelRunnerCreate. + */ +JXL_THREADS_EXPORT size_t JxlThreadParallelRunnerDefaultNumWorkerThreads(); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_THREAD_PARALLEL_RUNNER_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/include/jxl/thread_parallel_runner_cxx.h b/media/libjxl/src/lib/include/jxl/thread_parallel_runner_cxx.h new file mode 100644 index 000000000..a71d18c20 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/thread_parallel_runner_cxx.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/// @addtogroup libjxl_threads +/// @{ +/// +/// @file thread_parallel_runner_cxx.h +/// @brief C++ header-only helper for @ref thread_parallel_runner.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_THREAD_PARALLEL_RUNNER_CXX_H_ +#define JXL_THREAD_PARALLEL_RUNNER_CXX_H_ + +#include + +#include "jxl/thread_parallel_runner.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error \ + "This a C++ only header. Use jxl/jxl_thread_parallel_runner.h from C" \ + "sources." +#endif + +/// Struct to call JxlThreadParallelRunnerDestroy from the +/// JxlThreadParallelRunnerPtr unique_ptr. +struct JxlThreadParallelRunnerDestroyStruct { + /// Calls @ref JxlThreadParallelRunnerDestroy() on the passed runner. + void operator()(void* runner) { JxlThreadParallelRunnerDestroy(runner); } +}; + +/// std::unique_ptr<> type that calls JxlThreadParallelRunnerDestroy() when +/// releasing the runner. +/// +/// Use this helper type from C++ sources to ensure the runner is destroyed and +/// their internal resources released. +typedef std::unique_ptr + JxlThreadParallelRunnerPtr; + +/// Creates an instance of JxlThreadParallelRunner into a +/// JxlThreadParallelRunnerPtr and initializes it. +/// +/// This function returns a unique_ptr that will call +/// JxlThreadParallelRunnerDestroy() when releasing the pointer. See @ref +/// JxlThreadParallelRunnerCreate for details on the instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @param num_worker_threads the number of worker threads to create. +/// @return a @c NULL JxlThreadParallelRunnerPtr if the instance can not be +/// allocated or initialized +/// @return initialized JxlThreadParallelRunnerPtr instance otherwise. +static inline JxlThreadParallelRunnerPtr JxlThreadParallelRunnerMake( + const JxlMemoryManager* memory_manager, size_t num_worker_threads) { + return JxlThreadParallelRunnerPtr( + JxlThreadParallelRunnerCreate(memory_manager, num_worker_threads)); +} + +#endif // JXL_THREAD_PARALLEL_RUNNER_CXX_H_ + +/// @} diff --git a/media/libjxl/src/lib/include/jxl/types.h b/media/libjxl/src/lib/include/jxl/types.h new file mode 100644 index 000000000..1f8197864 --- /dev/null +++ b/media/libjxl/src/lib/include/jxl/types.h @@ -0,0 +1,151 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file types.h + * @brief Data types for the JPEG XL API, for both encoding and decoding. + */ + +#ifndef JXL_TYPES_H_ +#define JXL_TYPES_H_ + +#include +#include + +#include "jxl/jxl_export.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * A portable @c bool replacement. + * + * ::JXL_BOOL is a "documentation" type: actually it is @c int, but in API it + * denotes a type, whose only values are ::JXL_TRUE and ::JXL_FALSE. + */ +#define JXL_BOOL int +/** Portable @c true replacement. */ +#define JXL_TRUE 1 +/** Portable @c false replacement. */ +#define JXL_FALSE 0 + +/** Data type for the sample values per channel per pixel. + */ +typedef enum { + /** Use 32-bit single-precision floating point values, with range 0.0-1.0 + * (within gamut, may go outside this range for wide color gamut). Floating + * point output, either JXL_TYPE_FLOAT or JXL_TYPE_FLOAT16, is recommended + * for HDR and wide gamut images when color profile conversion is required. */ + JXL_TYPE_FLOAT = 0, + + /** Use type uint8_t. May clip wide color gamut data. + */ + JXL_TYPE_UINT8 = 2, + + /** Use type uint16_t. May clip wide color gamut data. + */ + JXL_TYPE_UINT16 = 3, + + /** Use 16-bit IEEE 754 half-precision floating point values */ + JXL_TYPE_FLOAT16 = 5, +} JxlDataType; + +/* DEPRECATED: bit-packed 1-bit data type. Use JXL_TYPE_UINT8 instead. + */ +static const int JXL_DEPRECATED JXL_TYPE_BOOLEAN = 1; + +/* DEPRECATED: uint32_t data type. Use JXL_TYPE_FLOAT instead. + */ +static const int JXL_DEPRECATED JXL_TYPE_UINT32 = 4; + +/** Ordering of multi-byte data. + */ +typedef enum { + /** Use the endianness of the system, either little endian or big endian, + * without forcing either specific endianness. Do not use if pixel data + * should be exported to a well defined format. + */ + JXL_NATIVE_ENDIAN = 0, + /** Force little endian */ + JXL_LITTLE_ENDIAN = 1, + /** Force big endian */ + JXL_BIG_ENDIAN = 2, +} JxlEndianness; + +/** Data type for the sample values per channel per pixel for the output buffer + * for pixels. This is not necessarily the same as the data type encoded in the + * codestream. The channels are interleaved per pixel. The pixels are + * organized row by row, left to right, top to bottom. + * TODO(lode): implement padding / alignment (row stride) + * TODO(lode): support different channel orders if needed (RGB, BGR, ...) + */ +typedef struct { + /** Amount of channels available in a pixel buffer. + * 1: single-channel data, e.g. grayscale or a single extra channel + * 2: single-channel + alpha + * 3: trichromatic, e.g. RGB + * 4: trichromatic + alpha + * TODO(lode): this needs finetuning. It is not yet defined how the user + * chooses output color space. CMYK+alpha needs 5 channels. + */ + uint32_t num_channels; + + /** Data type of each channel. + */ + JxlDataType data_type; + + /** Whether multi-byte data types are represented in big endian or little + * endian format. This applies to JXL_TYPE_UINT16, JXL_TYPE_UINT32 + * and JXL_TYPE_FLOAT. + */ + JxlEndianness endianness; + + /** Align scanlines to a multiple of align bytes, or 0 to require no + * alignment at all (which has the same effect as value 1) + */ + size_t align; +} JxlPixelFormat; + +/** Data type holding the 4-character type name of an ISOBMFF box. + */ +typedef char JxlBoxType[4]; + +/** Types of progressive detail. + * Setting a progressive detail with value N implies all progressive details + * with smaller or equal value. Currently only the following level of + * progressive detail is implemented: + * - kDC (which implies kFrames) + * - kLastPasses (which implies kDC and kFrames) + * - kPasses (which implies kLastPasses, kDC and kFrames) + */ +typedef enum { + // after completed kRegularFrames + kFrames = 0, + // after completed DC (1:8) + kDC = 1, + // after completed AC passes that are the last pass for their resolution + // target. + kLastPasses = 2, + // after completed AC passes that are not the last pass for their resolution + // target. + kPasses = 3, + // during DC frame when lower resolution are completed (1:32, 1:16) + kDCProgressive = 4, + // after completed groups + kDCGroups = 5, + // after completed groups + kGroups = 6, +} JxlProgressiveDetail; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_TYPES_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/jxl.cmake b/media/libjxl/src/lib/jxl.cmake new file mode 100644 index 000000000..72c07f48e --- /dev/null +++ b/media/libjxl/src/lib/jxl.cmake @@ -0,0 +1,643 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Lists all source files for the JPEG XL decoder library. These are also used +# by the encoder: the encoder uses both dec and enc ourse files, while the +# decoder uses only dec source files. +# TODO(lode): further prune these files and move to JPEGXL_INTERNAL_SOURCES_ENC: +# only those files that the decoder absolutely needs, and or not +# only for encoding, should be listed here. +set(JPEGXL_INTERNAL_SOURCES_DEC + jxl/ac_context.h + jxl/ac_strategy.cc + jxl/ac_strategy.h + jxl/alpha.cc + jxl/alpha.h + jxl/ans_common.cc + jxl/ans_common.h + jxl/ans_params.h + jxl/aux_out.cc + jxl/aux_out.h + jxl/aux_out_fwd.h + jxl/base/arch_macros.h + jxl/base/bits.h + jxl/base/byte_order.h + jxl/base/cache_aligned.cc + jxl/base/cache_aligned.h + jxl/base/compiler_specific.h + jxl/base/data_parallel.cc + jxl/base/data_parallel.h + jxl/base/file_io.h + jxl/base/iaca.h + jxl/base/os_macros.h + jxl/base/override.h + jxl/base/padded_bytes.cc + jxl/base/padded_bytes.h + jxl/base/printf_macros.h + jxl/base/profiler.h + jxl/base/random.cc + jxl/base/random.h + jxl/base/sanitizer_definitions.h + jxl/base/scope_guard.h + jxl/base/span.h + jxl/base/status.h + jxl/base/thread_pool_internal.h + jxl/blending.cc + jxl/blending.h + jxl/box_content_decoder.cc + jxl/box_content_decoder.h + jxl/chroma_from_luma.cc + jxl/chroma_from_luma.h + jxl/codec_in_out.h + jxl/coeff_order.cc + jxl/coeff_order.h + jxl/coeff_order_fwd.h + jxl/color_encoding_internal.cc + jxl/color_encoding_internal.h + jxl/color_management.cc + jxl/color_management.h + jxl/common.h + jxl/compressed_dc.cc + jxl/compressed_dc.h + jxl/convolve-inl.h + jxl/convolve.h + jxl/convolve_separable5.cc + jxl/convolve_separable7.cc + jxl/convolve_slow.cc + jxl/convolve_symmetric3.cc + jxl/convolve_symmetric5.cc + jxl/dct-inl.h + jxl/dct_block-inl.h + jxl/dct_scales.cc + jxl/dct_scales.h + jxl/dct_util.h + jxl/dec_ans.cc + jxl/dec_ans.h + jxl/dec_bit_reader.h + jxl/dec_cache.cc + jxl/dec_cache.h + jxl/dec_context_map.cc + jxl/dec_context_map.h + jxl/dec_external_image.cc + jxl/dec_external_image.h + jxl/dec_frame.cc + jxl/dec_frame.h + jxl/dec_group.cc + jxl/dec_group.h + jxl/dec_group_border.cc + jxl/dec_group_border.h + jxl/dec_huffman.cc + jxl/dec_huffman.h + jxl/dec_modular.cc + jxl/dec_modular.h + jxl/dec_noise.cc + jxl/dec_noise.h + jxl/dec_patch_dictionary.cc + jxl/dec_patch_dictionary.h + jxl/dec_tone_mapping-inl.h + jxl/dec_transforms-inl.h + jxl/dec_xyb-inl.h + jxl/dec_xyb.cc + jxl/dec_xyb.h + jxl/decode.cc + jxl/decode_to_jpeg.cc + jxl/decode_to_jpeg.h + jxl/enc_bit_writer.cc + jxl/enc_bit_writer.h + jxl/entropy_coder.cc + jxl/entropy_coder.h + jxl/epf.cc + jxl/epf.h + jxl/exif.h + jxl/fast_dct-inl.h + jxl/fast_dct.cc + jxl/fast_dct.h + jxl/fast_dct128-inl.h + jxl/fast_dct16-inl.h + jxl/fast_dct256-inl.h + jxl/fast_dct32-inl.h + jxl/fast_dct64-inl.h + jxl/fast_dct8-inl.h + jxl/fast_math-inl.h + jxl/field_encodings.h + jxl/fields.cc + jxl/fields.h + jxl/frame_header.cc + jxl/frame_header.h + jxl/gauss_blur.cc + jxl/gauss_blur.h + jxl/headers.cc + jxl/headers.h + jxl/huffman_table.cc + jxl/huffman_table.h + jxl/icc_codec.cc + jxl/icc_codec.h + jxl/icc_codec_common.cc + jxl/icc_codec_common.h + jxl/image.cc + jxl/image.h + jxl/image_bundle.cc + jxl/image_bundle.h + jxl/image_metadata.cc + jxl/image_metadata.h + jxl/image_ops.h + jxl/jpeg/dec_jpeg_data.cc + jxl/jpeg/dec_jpeg_data.h + jxl/jpeg/dec_jpeg_data_writer.cc + jxl/jpeg/dec_jpeg_data_writer.h + jxl/jpeg/dec_jpeg_output_chunk.h + jxl/jpeg/dec_jpeg_serialization_state.h + jxl/jpeg/jpeg_data.cc + jxl/jpeg/jpeg_data.h + jxl/jxl_inspection.h + jxl/lehmer_code.h + jxl/linalg.h + jxl/loop_filter.cc + jxl/loop_filter.h + jxl/luminance.cc + jxl/luminance.h + jxl/memory_manager_internal.cc + jxl/memory_manager_internal.h + jxl/modular/encoding/context_predict.h + jxl/modular/encoding/dec_ma.cc + jxl/modular/encoding/dec_ma.h + jxl/modular/encoding/encoding.cc + jxl/modular/encoding/encoding.h + jxl/modular/encoding/ma_common.h + jxl/modular/modular_image.cc + jxl/modular/modular_image.h + jxl/modular/options.h + jxl/modular/transform/palette.h + jxl/modular/transform/rct.cc + jxl/modular/transform/rct.h + jxl/modular/transform/squeeze.cc + jxl/modular/transform/squeeze.h + jxl/modular/transform/transform.cc + jxl/modular/transform/transform.h + jxl/noise.h + jxl/opsin_params.cc + jxl/opsin_params.h + jxl/passes_state.cc + jxl/passes_state.h + jxl/patch_dictionary_internal.h + jxl/quant_weights.cc + jxl/quant_weights.h + jxl/quantizer-inl.h + jxl/quantizer.cc + jxl/quantizer.h + jxl/rational_polynomial-inl.h + jxl/render_pipeline/low_memory_render_pipeline.cc + jxl/render_pipeline/low_memory_render_pipeline.h + jxl/render_pipeline/render_pipeline.cc + jxl/render_pipeline/render_pipeline.h + jxl/render_pipeline/render_pipeline_stage.h + jxl/render_pipeline/simple_render_pipeline.cc + jxl/render_pipeline/simple_render_pipeline.h + jxl/render_pipeline/stage_blending.cc + jxl/render_pipeline/stage_blending.h + jxl/render_pipeline/stage_chroma_upsampling.cc + jxl/render_pipeline/stage_chroma_upsampling.h + jxl/render_pipeline/stage_epf.cc + jxl/render_pipeline/stage_epf.h + jxl/render_pipeline/stage_from_linear.cc + jxl/render_pipeline/stage_from_linear.h + jxl/render_pipeline/stage_gaborish.cc + jxl/render_pipeline/stage_gaborish.h + jxl/render_pipeline/stage_noise.cc + jxl/render_pipeline/stage_noise.h + jxl/render_pipeline/stage_patches.cc + jxl/render_pipeline/stage_patches.h + jxl/render_pipeline/stage_splines.cc + jxl/render_pipeline/stage_splines.h + jxl/render_pipeline/stage_spot.cc + jxl/render_pipeline/stage_spot.h + jxl/render_pipeline/stage_to_linear.cc + jxl/render_pipeline/stage_to_linear.h + jxl/render_pipeline/stage_tone_mapping.cc + jxl/render_pipeline/stage_tone_mapping.h + jxl/render_pipeline/stage_upsampling.cc + jxl/render_pipeline/stage_upsampling.h + jxl/render_pipeline/stage_write.cc + jxl/render_pipeline/stage_write.h + jxl/render_pipeline/stage_xyb.cc + jxl/render_pipeline/stage_xyb.h + jxl/render_pipeline/stage_ycbcr.cc + jxl/render_pipeline/stage_ycbcr.h + jxl/render_pipeline/test_render_pipeline_stages.h + jxl/sanitizers.h + jxl/simd_util-inl.h + jxl/size_constraints.h + jxl/splines.cc + jxl/splines.h + jxl/toc.cc + jxl/toc.h + jxl/transfer_functions-inl.h + jxl/transpose-inl.h + jxl/xorshift128plus-inl.h +) + +# List of source files only needed by the encoder or by tools (including +# decoding tools), but not by the decoder library. +set(JPEGXL_INTERNAL_SOURCES_ENC + jxl/butteraugli/butteraugli.cc + jxl/butteraugli/butteraugli.h + jxl/butteraugli_wrapper.cc + jxl/enc_ac_strategy.cc + jxl/enc_ac_strategy.h + jxl/enc_adaptive_quantization.cc + jxl/enc_adaptive_quantization.h + jxl/enc_ans.cc + jxl/enc_ans.h + jxl/enc_ans_params.h + jxl/enc_ar_control_field.cc + jxl/enc_ar_control_field.h + jxl/enc_butteraugli_comparator.cc + jxl/enc_butteraugli_comparator.h + jxl/enc_butteraugli_pnorm.cc + jxl/enc_butteraugli_pnorm.h + jxl/enc_cache.cc + jxl/enc_cache.h + jxl/enc_chroma_from_luma.cc + jxl/enc_chroma_from_luma.h + jxl/enc_cluster.cc + jxl/enc_cluster.h + jxl/enc_coeff_order.cc + jxl/enc_coeff_order.h + jxl/enc_color_management.cc + jxl/enc_color_management.h + jxl/enc_comparator.cc + jxl/enc_comparator.h + jxl/enc_context_map.cc + jxl/enc_context_map.h + jxl/enc_detect_dots.cc + jxl/enc_detect_dots.h + jxl/enc_dot_dictionary.cc + jxl/enc_dot_dictionary.h + jxl/enc_entropy_coder.cc + jxl/enc_entropy_coder.h + jxl/enc_external_image.cc + jxl/enc_external_image.h + jxl/enc_file.cc + jxl/enc_file.h + jxl/enc_frame.cc + jxl/enc_frame.h + jxl/enc_gamma_correct.h + jxl/enc_group.cc + jxl/enc_group.h + jxl/enc_heuristics.cc + jxl/enc_heuristics.h + jxl/enc_huffman.cc + jxl/enc_huffman.h + jxl/enc_icc_codec.cc + jxl/enc_icc_codec.h + jxl/enc_image_bundle.cc + jxl/enc_image_bundle.h + jxl/enc_jxl_skcms.h + jxl/enc_modular.cc + jxl/enc_modular.h + jxl/enc_noise.cc + jxl/enc_noise.h + jxl/enc_params.h + jxl/enc_patch_dictionary.cc + jxl/enc_patch_dictionary.h + jxl/enc_photon_noise.cc + jxl/enc_photon_noise.h + jxl/enc_quant_weights.cc + jxl/enc_quant_weights.h + jxl/enc_splines.cc + jxl/enc_splines.h + jxl/enc_toc.cc + jxl/enc_toc.h + jxl/enc_transforms-inl.h + jxl/enc_transforms.cc + jxl/enc_transforms.h + jxl/enc_xyb.cc + jxl/enc_xyb.h + jxl/encode.cc + jxl/encode_internal.h + jxl/gaborish.cc + jxl/gaborish.h + jxl/huffman_tree.cc + jxl/huffman_tree.h + jxl/jpeg/enc_jpeg_data.cc + jxl/jpeg/enc_jpeg_data.h + jxl/jpeg/enc_jpeg_data_reader.cc + jxl/jpeg/enc_jpeg_data_reader.h + jxl/jpeg/enc_jpeg_huffman_decode.cc + jxl/jpeg/enc_jpeg_huffman_decode.h + jxl/linalg.cc + jxl/modular/encoding/enc_debug_tree.cc + jxl/modular/encoding/enc_debug_tree.h + jxl/modular/encoding/enc_encoding.cc + jxl/modular/encoding/enc_encoding.h + jxl/modular/encoding/enc_ma.cc + jxl/modular/encoding/enc_ma.h + jxl/modular/transform/enc_palette.cc + jxl/modular/transform/enc_palette.h + jxl/modular/transform/enc_rct.cc + jxl/modular/transform/enc_rct.h + jxl/modular/transform/enc_squeeze.cc + jxl/modular/transform/enc_squeeze.h + jxl/modular/transform/enc_transform.cc + jxl/modular/transform/enc_transform.h + jxl/optimize.cc + jxl/optimize.h + jxl/progressive_split.cc + jxl/progressive_split.h +) + +set(JPEGXL_DEC_INTERNAL_LIBS + brotlidec-static + brotlicommon-static + hwy + Threads::Threads + ${ATOMICS_LIBRARIES} +) + +if(JPEGXL_ENABLE_PROFILER) +list(APPEND JPEGXL_DEC_INTERNAL_LIBS jxl_profiler) +endif() + +set(JPEGXL_INTERNAL_LIBS + ${JPEGXL_DEC_INTERNAL_LIBS} + brotlienc-static +) + +# strips the -static suffix from all the elements in LIST +function(strip_static OUTPUT_VAR LIB_LIST) + foreach(lib IN LISTS ${LIB_LIST}) + string(REGEX REPLACE "-static$" "" lib "${lib}") + list(APPEND out_list "${lib}") + endforeach() + set(${OUTPUT_VAR} ${out_list} PARENT_SCOPE) +endfunction() + +if (JPEGXL_ENABLE_SKCMS) + list(APPEND JPEGXL_INTERNAL_FLAGS -DJPEGXL_ENABLE_SKCMS=1) + if (JPEGXL_BUNDLE_SKCMS) + list(APPEND JPEGXL_INTERNAL_FLAGS -DJPEGXL_BUNDLE_SKCMS=1) + # skcms objects are later added to JPEGXL_INTERNAL_OBJECTS + else () + list(APPEND JPEGXL_INTERNAL_LIBS skcms) + endif () +else () + list(APPEND JPEGXL_INTERNAL_LIBS lcms2) +endif () + +if (NOT JPEGXL_ENABLE_TRANSCODE_JPEG) + list(APPEND JPEGXL_INTERNAL_FLAGS -DJPEGXL_ENABLE_TRANSCODE_JPEG=0) +endif () + +set(OBJ_COMPILE_DEFINITIONS + JPEGXL_MAJOR_VERSION=${JPEGXL_MAJOR_VERSION} + JPEGXL_MINOR_VERSION=${JPEGXL_MINOR_VERSION} + JPEGXL_PATCH_VERSION=${JPEGXL_PATCH_VERSION} + # Used to determine if we are building the library when defined or just + # including the library when not defined. This is public so libjxl shared + # library gets this define too. + JXL_INTERNAL_LIBRARY_BUILD +) + +# Decoder-only object library +add_library(jxl_dec-obj OBJECT ${JPEGXL_INTERNAL_SOURCES_DEC}) +target_compile_options(jxl_dec-obj PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(jxl_dec-obj PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET jxl_dec-obj PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_dec-obj PUBLIC + "$" + "$" + "$>" + "$>" +) +target_compile_definitions(jxl_dec-obj PUBLIC + ${OBJ_COMPILE_DEFINITIONS} +) +if (JPEGXL_ENABLE_PROFILER) +target_link_libraries(jxl_dec-obj PUBLIC jxl_profiler) +endif() + +# Object library. This is used to hold the set of objects and properties. +add_library(jxl_enc-obj OBJECT ${JPEGXL_INTERNAL_SOURCES_ENC}) +target_compile_options(jxl_enc-obj PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(jxl_enc-obj PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET jxl_enc-obj PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_enc-obj PUBLIC + ${PROJECT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + $ + $ +) +target_compile_definitions(jxl_enc-obj PUBLIC + ${OBJ_COMPILE_DEFINITIONS} +) +if (JPEGXL_ENABLE_PROFILER) +target_link_libraries(jxl_enc-obj PUBLIC jxl_profiler) +endif() + +#TODO(lode): don't depend on CMS for the core library +if (JPEGXL_ENABLE_SKCMS) + target_include_directories(jxl_enc-obj PRIVATE + $ + ) +else () + target_include_directories(jxl_enc-obj PRIVATE + $ + ) +endif () + +# Generate version.h +configure_file("jxl/version.h.in" "include/jxl/version.h") + +# Headers for exporting/importing public headers +include(GenerateExportHeader) +set_target_properties(jxl_dec-obj PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 + DEFINE_SYMBOL JXL_INTERNAL_LIBRARY_BUILD +) +target_include_directories(jxl_dec-obj PUBLIC + ${CMAKE_CURRENT_BINARY_DIR}/include) + +set_target_properties(jxl_enc-obj PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 + DEFINE_SYMBOL JXL_INTERNAL_LIBRARY_BUILD +) +generate_export_header(jxl_enc-obj + BASE_NAME JXL + EXPORT_FILE_NAME include/jxl/jxl_export.h) +target_include_directories(jxl_enc-obj PUBLIC + ${CMAKE_CURRENT_BINARY_DIR}/include) + +# Private static library. This exposes all the internal functions and is used +# for tests. +add_library(jxl_dec-static STATIC + $ +) +target_link_libraries(jxl_dec-static + PUBLIC ${JPEGXL_COVERAGE_FLAGS} ${JPEGXL_DEC_INTERNAL_LIBS}) +target_include_directories(jxl_dec-static PUBLIC + "$" + "$" + "$") + +# The list of objects in the static and shared libraries. +set(JPEGXL_INTERNAL_OBJECTS + $ + $ +) +if (JPEGXL_ENABLE_SKCMS AND JPEGXL_BUNDLE_SKCMS) + list(APPEND JPEGXL_INTERNAL_OBJECTS $) +endif() + +# Private static library. This exposes all the internal functions and is used +# for tests. +# TODO(lode): once the source files are correctly split so that it is possible +# to do, remove $ here and depend on jxl_dec-static +add_library(jxl-static STATIC ${JPEGXL_INTERNAL_OBJECTS}) +target_link_libraries(jxl-static + PUBLIC ${JPEGXL_COVERAGE_FLAGS} ${JPEGXL_INTERNAL_LIBS}) +target_include_directories(jxl-static PUBLIC + "$" + "$" + "$") + +# JXL_EXPORT is defined to "__declspec(dllimport)" automatically by CMake +# in Windows builds when including headers from the C API and compiling from +# outside the jxl library. This is required when using the shared library, +# however in windows this causes the function to not be found when linking +# against the static library. This define JXL_EXPORT= here forces it to not +# use dllimport in tests and other tools that require the static library. +target_compile_definitions(jxl-static INTERFACE -DJXL_EXPORT=) +target_compile_definitions(jxl_dec-static INTERFACE -DJXL_EXPORT=) + +# TODO(deymo): Move TCMalloc linkage to the tools/ directory since the library +# shouldn't do any allocs anyway. +if(JPEGXL_ENABLE_TCMALLOC) + pkg_check_modules(TCMallocMinimal REQUIRED IMPORTED_TARGET + libtcmalloc_minimal) + # tcmalloc 2.8 has concurrency issues that makes it sometimes return nullptr + # for large allocs. See https://github.com/gperftools/gperftools/issues/1204 + # for details. + if(TCMallocMinimal_VERSION VERSION_EQUAL 2.8) + message(FATAL_ERROR + "tcmalloc version 2.8 has a concurrency bug. You have installed " + "version ${TCMallocMinimal_VERSION}, please either downgrade tcmalloc " + "to version 2.7, upgrade to 2.8.1 or newer or pass " + "-DJPEGXL_ENABLE_TCMALLOC=OFF to jpeg-xl cmake line. See the following " + "bug for details:\n" + " https://github.com/gperftools/gperftools/issues/1204\n") + endif() + target_link_libraries(jxl-static PUBLIC PkgConfig::TCMallocMinimal) +endif() # JPEGXL_ENABLE_TCMALLOC + +# Install the static library too, but as jxl.a file without the -static except +# in Windows. +if (NOT WIN32 OR MINGW) + set_target_properties(jxl-static PROPERTIES OUTPUT_NAME "jxl") + set_target_properties(jxl_dec-static PROPERTIES OUTPUT_NAME "jxl_dec") +endif() +install(TARGETS jxl-static DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install(TARGETS jxl_dec-static DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +if (BUILD_SHARED_LIBS) + +# Public shared library. +add_library(jxl SHARED ${JPEGXL_INTERNAL_OBJECTS}) +strip_static(JPEGXL_INTERNAL_SHARED_LIBS JPEGXL_INTERNAL_LIBS) +target_link_libraries(jxl PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +target_link_libraries(jxl PRIVATE ${JPEGXL_INTERNAL_SHARED_LIBS}) +# Shared library include path contains only the "include/" paths. +target_include_directories(jxl PUBLIC + "$" + "$") +set_target_properties(jxl PROPERTIES + VERSION ${JPEGXL_LIBRARY_VERSION} + SOVERSION ${JPEGXL_LIBRARY_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + +# Public shared decoder library. +add_library(jxl_dec SHARED $) +strip_static(JPEGXL_DEC_INTERNAL_SHARED_LIBS JPEGXL_DEC_INTERNAL_LIBS) +target_link_libraries(jxl_dec PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +target_link_libraries(jxl_dec PRIVATE ${JPEGXL_DEC_INTERNAL_SHARED_LIBS}) +# Shared library include path contains only the "include/" paths. +target_include_directories(jxl_dec PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") +set_target_properties(jxl_dec PROPERTIES + VERSION ${JPEGXL_LIBRARY_VERSION} + SOVERSION ${JPEGXL_LIBRARY_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + +# Check whether the linker support excluding libs +set(LINKER_EXCLUDE_LIBS_FLAG "-Wl,--exclude-libs=ALL") +include(CheckCSourceCompiles) +list(APPEND CMAKE_EXE_LINKER_FLAGS ${LINKER_EXCLUDE_LIBS_FLAG}) +check_c_source_compiles("int main(){return 0;}" LINKER_SUPPORT_EXCLUDE_LIBS) +list(REMOVE_ITEM CMAKE_EXE_LINKER_FLAGS ${LINKER_EXCLUDE_LIBS_FLAG}) + +# Add a jxl.version file as a version script to tag symbols with the +# appropriate version number. This script is also used to limit what's exposed +# in the shared library from the static dependencies bundled here. +foreach(target IN ITEMS jxl jxl_dec) + set_target_properties(${target} PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl.version) + if(APPLE) + set_property(TARGET ${target} APPEND_STRING PROPERTY + LINK_FLAGS "-Wl,-exported_symbols_list,${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl_osx.syms") + elseif(WIN32) + # Nothing needed here, we use __declspec(dllexport) (jxl_export.h) + else() + set_property(TARGET ${target} APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl.version") + endif() # APPLE + # This hides the default visibility symbols from static libraries bundled into + # the shared library. In particular this prevents exposing symbols from hwy + # and skcms in the shared library. + if(LINKER_SUPPORT_EXCLUDE_LIBS) + set_property(TARGET ${target} APPEND_STRING PROPERTY + LINK_FLAGS " ${LINKER_EXCLUDE_LIBS_FLAG}") + endif() +endforeach() + +# Only install libjxl shared library. The libjxl_dec is not installed since it +# contains symbols also in libjxl which would conflict if programs try to use +# both. +install(TARGETS jxl + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +else() +add_library(jxl ALIAS jxl-static) +add_library(jxl_dec ALIAS jxl_dec-static) +endif() # BUILD_SHARED_LIBS + +# Add a pkg-config file for libjxl. +set(JPEGXL_LIBRARY_REQUIRES + "libhwy libbrotlicommon libbrotlienc libbrotlidec") +if(NOT JPEGXL_ENABLE_SKCMS) + set(JPEGXL_LIBRARY_REQUIRES "${JPEGXL_LIBRARY_REQUIRES} lcms2") +endif() + +# Allow adding prefix if CMAKE_INSTALL_INCLUDEDIR not absolute. +if(IS_ABSOLUTE "${CMAKE_INSTALL_INCLUDEDIR}") + set(PKGCONFIG_TARGET_INCLUDES "${CMAKE_INSTALL_INCLUDEDIR}") +else() + set(PKGCONFIG_TARGET_INCLUDES "\${prefix}/${CMAKE_INSTALL_INCLUDEDIR}") +endif() +# Allow adding prefix if CMAKE_INSTALL_LIBDIR not absolute. +if(IS_ABSOLUTE "${CMAKE_INSTALL_LIBDIR}") + set(PKGCONFIG_TARGET_LIBS "${CMAKE_INSTALL_LIBDIR}") +else() + set(PKGCONFIG_TARGET_LIBS "\${exec_prefix}/${CMAKE_INSTALL_LIBDIR}") +endif() + +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/jxl/libjxl.pc.in" + "libjxl.pc" @ONLY) +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/libjxl.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") diff --git a/media/libjxl/src/lib/jxl/ac_context.h b/media/libjxl/src/lib/jxl/ac_context.h new file mode 100644 index 000000000..a2b9e046d --- /dev/null +++ b/media/libjxl/src/lib/jxl/ac_context.h @@ -0,0 +1,149 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AC_CONTEXT_H_ +#define LIB_JXL_AC_CONTEXT_H_ + +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" + +namespace jxl { + +// Block context used for scanning order, number of non-zeros, AC coefficients. +// Equal to the channel. +constexpr uint32_t kDCTOrderContextStart = 0; + +// The number of predicted nonzeros goes from 0 to 1008. We use +// ceil(log2(predicted+1)) as a context for the number of nonzeros, so from 0 to +// 10, inclusive. +constexpr uint32_t kNonZeroBuckets = 37; + +static const uint16_t kCoeffFreqContext[64] = { + 0xBAD, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, + 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, + 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, +}; + +static const uint16_t kCoeffNumNonzeroContext[64] = { + 0xBAD, 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, + 152, 152, 152, 152, 152, 152, 152, 152, 180, 180, 180, 180, 180, + 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, +}; + +// Supremum of ZeroDensityContext(x, y) + 1, when x + y < 64. +constexpr int kZeroDensityContextCount = 458; +// Supremum of ZeroDensityContext(x, y) + 1. +constexpr int kZeroDensityContextLimit = 474; + +/* This function is used for entropy-sources pre-clustering. + * + * Ideally, each combination of |nonzeros_left| and |k| should go to its own + * bucket; but it implies (64 * 63 / 2) == 2016 buckets. If there is other + * dimension (e.g. block context), then number of primary clusters becomes too + * big. + * + * To solve this problem, |nonzeros_left| and |k| values are clustered. It is + * known that their sum is at most 64, consequently, the total number buckets + * is at most A(64) * B(64). + */ +// TODO(user): investigate, why disabling pre-clustering makes entropy code +// less dense. Perhaps we would need to add HQ clustering algorithm that would +// be able to squeeze better by spending more CPU cycles. +static JXL_INLINE size_t ZeroDensityContext(size_t nonzeros_left, size_t k, + size_t covered_blocks, + size_t log2_covered_blocks, + size_t prev) { + JXL_DASSERT((1u << log2_covered_blocks) == covered_blocks); + nonzeros_left = (nonzeros_left + covered_blocks - 1) >> log2_covered_blocks; + k >>= log2_covered_blocks; + JXL_DASSERT(k > 0); + JXL_DASSERT(k < 64); + JXL_DASSERT(nonzeros_left > 0); + // Asserting nonzeros_left + k < 65 here causes crashes in debug mode with + // invalid input, since the (hot) decoding loop does not check this condition. + // As no out-of-bound memory reads are issued even if that condition is + // broken, we check this simpler condition which holds anyway. The decoder + // will still mark a file in which that condition happens as not valid at the + // end of the decoding loop, as `nzeros` will not be `0`. + JXL_DASSERT(nonzeros_left < 64); + return (kCoeffNumNonzeroContext[nonzeros_left] + kCoeffFreqContext[k]) * 2 + + prev; +} + +struct BlockCtxMap { + std::vector dc_thresholds[3]; + std::vector qf_thresholds; + std::vector ctx_map; + size_t num_ctxs, num_dc_ctxs; + + static constexpr uint8_t kDefaultCtxMap[] = { + // Default ctx map clusters all the large transforms together. + 0, 1, 2, 2, 3, 3, 4, 5, 6, 6, 6, 6, 6, // + 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14, // + 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14, // + }; + static_assert(3 * kNumOrders == + sizeof(kDefaultCtxMap) / sizeof *kDefaultCtxMap, + "Update default context map"); + + size_t Context(int dc_idx, uint32_t qf, size_t ord, size_t c) const { + size_t qf_idx = 0; + for (uint32_t t : qf_thresholds) { + if (qf > t) qf_idx++; + } + size_t idx = c < 2 ? c ^ 1 : 2; + idx = idx * kNumOrders + ord; + idx = idx * (qf_thresholds.size() + 1) + qf_idx; + idx = idx * num_dc_ctxs + dc_idx; + return ctx_map[idx]; + } + // Non-zero context is based on number of non-zeros and block context. + // For better clustering, contexts with same number of non-zeros are grouped. + constexpr uint32_t ZeroDensityContextsOffset(uint32_t block_ctx) const { + return num_ctxs * kNonZeroBuckets + kZeroDensityContextCount * block_ctx; + } + + // Context map for AC coefficients consists of 2 blocks: + // |num_ctxs x : context for number of non-zeros in the block + // kNonZeroBuckets| computed from block context and predicted + // value (based top and left values) + // |num_ctxs x : context for AC coefficient symbols, + // kZeroDensityContextCount| computed from block context, + // number of non-zeros left and + // index in scan order + constexpr uint32_t NumACContexts() const { + return num_ctxs * (kNonZeroBuckets + kZeroDensityContextCount); + } + + // Non-zero context is based on number of non-zeros and block context. + // For better clustering, contexts with same number of non-zeros are grouped. + inline uint32_t NonZeroContext(uint32_t non_zeros, uint32_t block_ctx) const { + uint32_t ctx; + if (non_zeros >= 64) non_zeros = 64; + if (non_zeros < 8) { + ctx = non_zeros; + } else { + ctx = 4 + non_zeros / 2; + } + return ctx * num_ctxs + block_ctx; + } + + BlockCtxMap() { + ctx_map.assign(std::begin(kDefaultCtxMap), std::end(kDefaultCtxMap)); + num_ctxs = *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; + num_dc_ctxs = 1; + } +}; + +} // namespace jxl + +#endif // LIB_JXL_AC_CONTEXT_H_ diff --git a/media/libjxl/src/lib/jxl/ac_strategy.cc b/media/libjxl/src/lib/jxl/ac_strategy.cc new file mode 100644 index 000000000..ada3bcb6f --- /dev/null +++ b/media/libjxl/src/lib/jxl/ac_strategy.cc @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ac_strategy.h" + +#include + +#include +#include // iota +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +// Tries to generalize zig-zag order to non-square blocks. Surprisingly, in +// square block frequency along the (i + j == const) diagonals is roughly the +// same. For historical reasons, consecutive diagonals are traversed +// in alternating directions - so called "zig-zag" (or "snake") order. +template +static void CoeffOrderAndLut(AcStrategy acs, coeff_order_t* out) { + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + + // CoefficientLayout ensures cx >= cy. + // We compute the zigzag order for a cx x cx block, then discard all the + // lines that are not multiple of the ratio between cx and cy. + size_t xs = cx / cy; + size_t xsm = xs - 1; + size_t xss = CeilLog2Nonzero(xs); + // First half of the block + size_t cur = cx * cy; + for (size_t i = 0; i < cx * kBlockDim; i++) { + for (size_t j = 0; j <= i; j++) { + size_t x = j; + size_t y = i - j; + if (i % 2) std::swap(x, y); + if ((y & xsm) != 0) continue; + y >>= xss; + size_t val = 0; + if (x < cx && y < cy) { + val = y * cx + x; + } else { + val = cur++; + } + if (is_lut) { + out[y * cx * kBlockDim + x] = val; + } else { + out[val] = y * cx * kBlockDim + x; + } + } + } + // Second half + for (size_t ip = cx * kBlockDim - 1; ip > 0; ip--) { + size_t i = ip - 1; + for (size_t j = 0; j <= i; j++) { + size_t x = cx * kBlockDim - 1 - (i - j); + size_t y = cx * kBlockDim - 1 - j; + if (i % 2) std::swap(x, y); + if ((y & xsm) != 0) continue; + y >>= xss; + size_t val = cur++; + if (is_lut) { + out[y * cx * kBlockDim + x] = val; + } else { + out[val] = y * cx * kBlockDim + x; + } + } + } +} + +void AcStrategy::ComputeNaturalCoeffOrder(coeff_order_t* order) const { + CoeffOrderAndLut(*this, order); +} +void AcStrategy::ComputeNaturalCoeffOrderLut(coeff_order_t* lut) const { + CoeffOrderAndLut(*this, lut); +} + +// These definitions are needed before C++17. +constexpr size_t AcStrategy::kMaxCoeffBlocks; +constexpr size_t AcStrategy::kMaxBlockDim; +constexpr size_t AcStrategy::kMaxCoeffArea; + +AcStrategyImage::AcStrategyImage(size_t xsize, size_t ysize) + : layers_(xsize, ysize) { + row_ = layers_.Row(0); + stride_ = layers_.PixelsPerRow(); +} + +size_t AcStrategyImage::CountBlocks(AcStrategy::Type type) const { + size_t ret = 0; + for (size_t y = 0; y < layers_.ysize(); y++) { + const uint8_t* JXL_RESTRICT row = layers_.ConstRow(y); + for (size_t x = 0; x < layers_.xsize(); x++) { + if (row[x] == ((static_cast(type) << 1) | 1)) ret++; + } + } + return ret; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/ac_strategy.h b/media/libjxl/src/lib/jxl/ac_strategy.h new file mode 100644 index 000000000..7d21167e6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ac_strategy.h @@ -0,0 +1,261 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AC_STRATEGY_H_ +#define LIB_JXL_AC_STRATEGY_H_ + +#include +#include + +#include // kMaxVectorSize + +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +// Defines the different kinds of transforms, and heuristics to choose between +// them. +// `AcStrategy` represents what transform should be used, and which sub-block of +// that transform we are currently in. Note that DCT4x4 is applied on all four +// 4x4 sub-blocks of an 8x8 block. +// `AcStrategyImage` defines which strategy should be used for each 8x8 block +// of the image. The highest 4 bits represent the strategy to be used, the +// lowest 4 represent the index of the block inside that strategy. + +namespace jxl { + +class AcStrategy { + public: + // Extremal values for the number of blocks/coefficients of a single strategy. + static constexpr size_t kMaxCoeffBlocks = 32; + static constexpr size_t kMaxBlockDim = kBlockDim * kMaxCoeffBlocks; + // Maximum number of coefficients in a block. Guaranteed to be a multiple of + // the vector size. + static constexpr size_t kMaxCoeffArea = kMaxBlockDim * kMaxBlockDim; + static_assert((kMaxCoeffArea * sizeof(float)) % hwy::kMaxVectorSize == 0, + "Coefficient area is not a multiple of vector size"); + + // Raw strategy types. + enum Type : uint32_t { + // Regular block size DCT + DCT = 0, + // Encode pixels without transforming + IDENTITY = 1, + // Use 2-by-2 DCT + DCT2X2 = 2, + // Use 4-by-4 DCT + DCT4X4 = 3, + // Use 16-by-16 DCT + DCT16X16 = 4, + // Use 32-by-32 DCT + DCT32X32 = 5, + // Use 16-by-8 DCT + DCT16X8 = 6, + // Use 8-by-16 DCT + DCT8X16 = 7, + // Use 32-by-8 DCT + DCT32X8 = 8, + // Use 8-by-32 DCT + DCT8X32 = 9, + // Use 32-by-16 DCT + DCT32X16 = 10, + // Use 16-by-32 DCT + DCT16X32 = 11, + // 4x8 and 8x4 DCT + DCT4X8 = 12, + DCT8X4 = 13, + // Corner-DCT. + AFV0 = 14, + AFV1 = 15, + AFV2 = 16, + AFV3 = 17, + // Larger DCTs + DCT64X64 = 18, + DCT64X32 = 19, + DCT32X64 = 20, + DCT128X128 = 21, + DCT128X64 = 22, + DCT64X128 = 23, + DCT256X256 = 24, + DCT256X128 = 25, + DCT128X256 = 26, + // Marker for num of valid strategies. + kNumValidStrategies + }; + + static constexpr uint32_t TypeBit(const Type type) { + return 1u << static_cast(type); + } + + // Returns true if this block is the first 8x8 block (i.e. top-left) of a + // possibly multi-block strategy. + JXL_INLINE bool IsFirstBlock() const { return is_first_; } + + JXL_INLINE bool IsMultiblock() const { + constexpr uint32_t bits = + TypeBit(Type::DCT16X16) | TypeBit(Type::DCT32X32) | + TypeBit(Type::DCT16X8) | TypeBit(Type::DCT8X16) | + TypeBit(Type::DCT32X8) | TypeBit(Type::DCT8X32) | + TypeBit(Type::DCT16X32) | TypeBit(Type::DCT32X16) | + TypeBit(Type::DCT32X64) | TypeBit(Type::DCT64X32) | + TypeBit(Type::DCT64X64) | TypeBit(DCT64X128) | TypeBit(DCT128X64) | + TypeBit(DCT128X128) | TypeBit(DCT128X256) | TypeBit(DCT256X128) | + TypeBit(DCT256X256); + JXL_DASSERT(Strategy() < kNumValidStrategies); + return ((1u << static_cast(Strategy())) & bits) != 0; + } + + // Returns the raw strategy value. Should only be used for tokenization. + JXL_INLINE uint8_t RawStrategy() const { + return static_cast(strategy_); + } + + JXL_INLINE Type Strategy() const { return strategy_; } + + // Inverse check + static JXL_INLINE constexpr bool IsRawStrategyValid(int raw_strategy) { + return raw_strategy < static_cast(kNumValidStrategies) && + raw_strategy >= 0; + } + static JXL_INLINE AcStrategy FromRawStrategy(uint8_t raw_strategy) { + return FromRawStrategy(static_cast(raw_strategy)); + } + static JXL_INLINE AcStrategy FromRawStrategy(Type raw_strategy) { + JXL_DASSERT(IsRawStrategyValid(static_cast(raw_strategy))); + return AcStrategy(raw_strategy, /*is_first=*/true); + } + + // "Natural order" means the order of increasing of "anisotropic" frequency of + // continuous version of DCT basis. + // Round-trip, for any given strategy s: + // X = NaturalCoeffOrder(s)[NaturalCoeffOrderLutN(s)[X]] + // X = NaturalCoeffOrderLut(s)[NaturalCoeffOrderN(s)[X]] + void ComputeNaturalCoeffOrder(coeff_order_t* order) const; + void ComputeNaturalCoeffOrderLut(coeff_order_t* lut) const; + + // Number of 8x8 blocks that this strategy will cover. 0 for non-top-left + // blocks inside a multi-block transform. + JXL_INLINE size_t covered_blocks_x() const { + static constexpr uint8_t kLut[] = {1, 1, 1, 1, 2, 4, 1, 2, 1, + 4, 2, 4, 1, 1, 1, 1, 1, 1, + 8, 4, 8, 16, 8, 16, 32, 16, 32}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + JXL_INLINE size_t covered_blocks_y() const { + static constexpr uint8_t kLut[] = {1, 1, 1, 1, 2, 4, 2, 1, 4, + 1, 4, 2, 1, 1, 1, 1, 1, 1, + 8, 8, 4, 16, 16, 8, 32, 32, 16}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + JXL_INLINE size_t log2_covered_blocks() const { + static constexpr uint8_t kLut[] = {0, 0, 0, 0, 2, 4, 1, 1, 2, + 2, 3, 3, 0, 0, 0, 0, 0, 0, + 6, 5, 5, 8, 7, 7, 10, 9, 9}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + private: + friend class AcStrategyRow; + JXL_INLINE AcStrategy(Type strategy, bool is_first) + : strategy_(strategy), is_first_(is_first) { + JXL_DASSERT(IsMultiblock() || is_first == true); + } + + Type strategy_; + bool is_first_; +}; + +// Class to use a certain row of the AC strategy. +class AcStrategyRow { + public: + explicit AcStrategyRow(const uint8_t* row) : row_(row) {} + AcStrategy operator[](size_t x) const { + return AcStrategy(static_cast(row_[x] >> 1), row_[x] & 1); + } + + private: + const uint8_t* JXL_RESTRICT row_; +}; + +class AcStrategyImage { + public: + AcStrategyImage() = default; + AcStrategyImage(size_t xsize, size_t ysize); + AcStrategyImage(AcStrategyImage&&) = default; + AcStrategyImage& operator=(AcStrategyImage&&) = default; + + void FillDCT8(const Rect& rect) { + FillPlane((static_cast(AcStrategy::Type::DCT) << 1) | 1, + &layers_, rect); + } + void FillDCT8() { FillDCT8(Rect(layers_)); } + + void FillInvalid() { FillImage(INVALID, &layers_); } + + void Set(size_t x, size_t y, AcStrategy::Type type) { +#if JXL_ENABLE_ASSERT + AcStrategy acs = AcStrategy::FromRawStrategy(type); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(y + acs.covered_blocks_y() <= layers_.ysize()); + JXL_ASSERT(x + acs.covered_blocks_x() <= layers_.xsize()); + JXL_CHECK(SetNoBoundsCheck(x, y, type, /*check=*/false)); + } + + Status SetNoBoundsCheck(size_t x, size_t y, AcStrategy::Type type, + bool check = true) { + AcStrategy acs = AcStrategy::FromRawStrategy(type); + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + size_t pos = (y + iy) * stride_ + x + ix; + if (check && row_[pos] != INVALID) { + return JXL_FAILURE("Invalid AC strategy: block overlap"); + } + row_[pos] = + (static_cast(type) << 1) | ((iy | ix) == 0 ? 1 : 0); + } + } + return true; + } + + bool IsValid(size_t x, size_t y) { return row_[y * stride_ + x] != INVALID; } + + AcStrategyRow ConstRow(size_t y, size_t x_prefix = 0) const { + return AcStrategyRow(layers_.ConstRow(y) + x_prefix); + } + + AcStrategyRow ConstRow(const Rect& rect, size_t y) const { + return ConstRow(rect.y0() + y, rect.x0()); + } + + size_t PixelsPerRow() const { return layers_.PixelsPerRow(); } + + size_t xsize() const { return layers_.xsize(); } + size_t ysize() const { return layers_.ysize(); } + + // Count the number of blocks of a given type. + size_t CountBlocks(AcStrategy::Type type) const; + + private: + ImageB layers_; + uint8_t* JXL_RESTRICT row_; + size_t stride_; + + // A value that does not represent a valid combined AC strategy + // value. Used as a sentinel. + static constexpr uint8_t INVALID = 0xFF; +}; + +} // namespace jxl + +#endif // LIB_JXL_AC_STRATEGY_H_ diff --git a/media/libjxl/src/lib/jxl/ac_strategy_test.cc b/media/libjxl/src/lib/jxl/ac_strategy_test.cc new file mode 100644 index 000000000..d366aa3f8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ac_strategy_test.cc @@ -0,0 +1,237 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ac_strategy.h" + +#include + +#include +#include +#include // HWY_ALIGN_MAX +#include +#include + +#include "lib/jxl/base/random.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms_testonly.h" +#include "lib/jxl/enc_transforms.h" + +namespace jxl { +namespace { + +// Test that DCT -> IDCT is a noop. +class AcStrategyRoundtrip : public ::hwy::TestWithParamTargetAndT { + protected: + void Run() { + const AcStrategy::Type type = static_cast(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + + auto mem = hwy::AllocateAligned(4 * AcStrategy::kMaxCoeffArea); + float* scratch_space = mem.get(); + float* coeffs = scratch_space + AcStrategy::kMaxCoeffArea; + float* idct = coeffs + AcStrategy::kMaxCoeffArea; + Rng rng(type * 65537 + 13); + + for (size_t j = 0; j < 64; j++) { + size_t i = (acs.log2_covered_blocks() + ? rng.UniformU(0, 64u << acs.log2_covered_blocks()) + : j); + float* input = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(input, AcStrategy::kMaxCoeffArea, 0); + input[i] = 0.2f; + TransformFromPixels(type, input, acs.covered_blocks_x() * 8, coeffs, + scratch_space); + ASSERT_NEAR(coeffs[0], 0.2 / (64 << acs.log2_covered_blocks()), 1e-6) + << " i = " << i; + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + for (size_t j = 0; j < 64u << acs.log2_covered_blocks(); j++) { + ASSERT_NEAR(idct[j], j == i ? 0.2f : 0, 2e-6) + << "j = " << j << " i = " << i << " acs " << type; + } + } + // Test DC. + std::fill_n(idct, AcStrategy::kMaxCoeffArea, 0); + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + float* dc = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2; + LowestFrequenciesFromDC(type, dc, acs.covered_blocks_x() * 8, coeffs); + DCFromLowestFrequencies(type, coeffs, idct, acs.covered_blocks_x() * 8); + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2; + for (size_t j = 0; j < 64u << acs.log2_covered_blocks(); j++) { + ASSERT_NEAR(idct[j], dc[j], 1e-6) + << "j = " << j << " x = " << x << " y = " << y << " acs " << type; + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyRoundtrip, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyRoundtrip, Test) { Run(); } + +// Test that DC(2x2) -> DCT coefficients -> IDCT -> downsampled IDCT is a noop. +class AcStrategyRoundtripDownsample + : public ::hwy::TestWithParamTargetAndT { + protected: + void Run() { + const AcStrategy::Type type = static_cast(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + + auto mem = hwy::AllocateAligned(4 * AcStrategy::kMaxCoeffArea); + float* scratch_space = mem.get(); + float* coeffs = scratch_space + AcStrategy::kMaxCoeffArea; + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0.0f); + float* idct = coeffs + AcStrategy::kMaxCoeffArea; + Rng rng(type * 65537 + 13); + + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + if (x > 4 || y > 4) { + if (rng.Bernoulli(0.9f)) continue; + } + float* dc = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2f; + LowestFrequenciesFromDC(type, dc, acs.covered_blocks_x() * 8, coeffs); + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0.0f); + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2f; + // Downsample + for (size_t dy = 0; dy < acs.covered_blocks_y(); dy++) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); dx++) { + float sum = 0; + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + sum += idct[(dy * 8 + iy) * 8 * acs.covered_blocks_x() + + dx * 8 + ix]; + } + } + sum /= 64.0f; + ASSERT_NEAR(sum, dc[dy * 8 * acs.covered_blocks_x() + dx], 1e-6) + << "acs " << type; + } + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyRoundtripDownsample, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyRoundtripDownsample, Test) { Run(); } + +// Test that IDCT(block with zeros in the non-topleft corner) -> downsampled +// IDCT is the same as IDCT -> DC(2x2) of the same block. +class AcStrategyDownsample : public ::hwy::TestWithParamTargetAndT { + protected: + void Run() { + const AcStrategy::Type type = static_cast(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + size_t cx = acs.covered_blocks_y(); + size_t cy = acs.covered_blocks_x(); + CoefficientLayout(&cy, &cx); + + auto mem = hwy::AllocateAligned(4 * AcStrategy::kMaxCoeffArea); + float* scratch_space = mem.get(); + float* idct = scratch_space + AcStrategy::kMaxCoeffArea; + float* idct_acs_downsampled = idct + AcStrategy::kMaxCoeffArea; + Rng rng(type * 65537 + 13); + + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + if (x > 4 || y > 4) { + if (rng.Bernoulli(0.9f)) continue; + } + float* coeffs = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0); + coeffs[y * cx * 8 + x] = 0.2f; + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0); + coeffs[y * cx * 8 + x] = 0.2f; + DCFromLowestFrequencies(type, coeffs, idct_acs_downsampled, + acs.covered_blocks_x() * 8); + // Downsample + for (size_t dy = 0; dy < acs.covered_blocks_y(); dy++) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); dx++) { + float sum = 0; + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + sum += idct[(dy * 8 + iy) * 8 * acs.covered_blocks_x() + + dx * 8 + ix]; + } + } + sum /= 64; + ASSERT_NEAR( + sum, idct_acs_downsampled[dy * 8 * acs.covered_blocks_x() + dx], + 1e-6) + << " acs " << type; + } + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyDownsample, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyDownsample, Test) { Run(); } + +class AcStrategyTargetTest : public ::hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(AcStrategyTargetTest); + +TEST_P(AcStrategyTargetTest, RoundtripAFVDCT) { + HWY_ALIGN_MAX float idct[16]; + for (size_t i = 0; i < 16; i++) { + HWY_ALIGN_MAX float pixels[16] = {}; + pixels[i] = 1; + HWY_ALIGN_MAX float coeffs[16] = {}; + + AFVDCT4x4(pixels, coeffs); + AFVIDCT4x4(coeffs, idct); + for (size_t j = 0; j < 16; j++) { + EXPECT_NEAR(idct[j], pixels[j], 1e-6); + } + } +} + +TEST_P(AcStrategyTargetTest, BenchmarkAFV) { + const AcStrategy::Type type = AcStrategy::Type::AFV0; + HWY_ALIGN_MAX float pixels[64] = {1}; + HWY_ALIGN_MAX float coeffs[64] = {}; + HWY_ALIGN_MAX float scratch_space[64] = {}; + for (size_t i = 0; i < 1 << 14; i++) { + TransformToPixels(type, coeffs, pixels, 8, scratch_space); + TransformFromPixels(type, pixels, 8, coeffs, scratch_space); + } + EXPECT_NEAR(pixels[0], 0.0, 1E-6); +} + +TEST_P(AcStrategyTargetTest, BenchmarkAFVDCT) { + HWY_ALIGN_MAX float pixels[64] = {1}; + HWY_ALIGN_MAX float coeffs[64] = {}; + for (size_t i = 0; i < 1 << 14; i++) { + AFVDCT4x4(pixels, coeffs); + AFVIDCT4x4(coeffs, pixels); + } + EXPECT_NEAR(pixels[0], 1.0, 1E-6); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/alpha.cc b/media/libjxl/src/lib/jxl/alpha.cc new file mode 100644 index 000000000..f0ab39ac0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/alpha.cc @@ -0,0 +1,120 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/alpha.h" + +#include + +#include + +namespace jxl { + +static float Clamp(float x) { return std::max(std::min(1.0f, x), 0.0f); } + +void PerformAlphaBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp) { + if (alpha_is_premultiplied) { + for (size_t x = 0; x < num_pixels; ++x) { + float fga = clamp ? Clamp(fg.a[x]) : fg.a[x]; + out.r[x] = (fg.r[x] + bg.r[x] * (1.f - fga)); + out.g[x] = (fg.g[x] + bg.g[x] * (1.f - fga)); + out.b[x] = (fg.b[x] + bg.b[x] * (1.f - fga)); + out.a[x] = (1.f - (1.f - fga) * (1.f - bg.a[x])); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + float fga = clamp ? Clamp(fg.a[x]) : fg.a[x]; + const float new_a = 1.f - (1.f - fga) * (1.f - bg.a[x]); + const float rnew_a = (new_a > 0 ? 1.f / new_a : 0.f); + out.r[x] = (fg.r[x] * fga + bg.r[x] * bg.a[x] * (1.f - fga)) * rnew_a; + out.g[x] = (fg.g[x] * fga + bg.g[x] * bg.a[x] * (1.f - fga)) * rnew_a; + out.b[x] = (fg.b[x] * fga + bg.b[x] * bg.a[x] * (1.f - fga)) * rnew_a; + out.a[x] = new_a; + } + } +} +void PerformAlphaBlending(const float* bg, const float* bga, const float* fg, + const float* fga, float* out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp) { + if (bg == bga && fg == fga) { + for (size_t x = 0; x < num_pixels; ++x) { + float fa = clamp ? fga[x] : std::min(std::max(0.0f, fga[x]), 1.0f); + out[x] = (1.f - (1.f - fa) * (1.f - bga[x])); + } + } else { + if (alpha_is_premultiplied) { + for (size_t x = 0; x < num_pixels; ++x) { + float fa = clamp ? fga[x] : Clamp(fga[x]); + out[x] = (fg[x] + bg[x] * (1.f - fa)); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + float fa = clamp ? fga[x] : Clamp(fga[x]); + const float new_a = 1.f - (1.f - fa) * (1.f - bga[x]); + const float rnew_a = (new_a > 0 ? 1.f / new_a : 0.f); + out[x] = (fg[x] * fa + bg[x] * bga[x] * (1.f - fa)) * rnew_a; + } + } + } +} + +void PerformAlphaWeightedAdd(const float* bg, const float* fg, const float* fga, + float* out, size_t num_pixels, bool clamp) { + if (fg == fga) { + memcpy(out, bg, num_pixels * sizeof(*out)); + } else { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] + fg[x] * Clamp(fga[x]); + } + } +} + +void PerformMulBlending(const float* bg, const float* fg, float* out, + size_t num_pixels, bool clamp) { + if (clamp) { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] * Clamp(fg[x]); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] * fg[x]; + } + } +} + +void PremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + const float multiplier = std::max(kSmallAlpha, a[x]); + r[x] *= multiplier; + g[x] *= multiplier; + b[x] *= multiplier; + } +} + +void UnpremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + const float multiplier = 1.f / std::max(kSmallAlpha, a[x]); + r[x] *= multiplier; + g[x] *= multiplier; + b[x] *= multiplier; + } +} + +void UnpremultiplyAlpha(float* JXL_RESTRICT rgba, size_t num_pixels) { + for (size_t x = 0, ix = 0; x < num_pixels; ++x, ix += 4) { + const float multiplier = 1.f / std::max(kSmallAlpha, rgba[ix + 3]); + rgba[ix] *= multiplier; + rgba[ix + 1] *= multiplier; + rgba[ix + 2] *= multiplier; + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/alpha.h b/media/libjxl/src/lib/jxl/alpha.h new file mode 100644 index 000000000..f49790b58 --- /dev/null +++ b/media/libjxl/src/lib/jxl/alpha.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ALPHA_H_ +#define LIB_JXL_ALPHA_H_ + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// A very small value to avoid divisions by zero when converting to +// unpremultiplied alpha. Page 21 of the technical introduction to OpenEXR +// (https://www.openexr.com/documentation/TechnicalIntroduction.pdf) recommends +// "a power of two" that is "less than half of the smallest positive 16-bit +// floating-point value". That smallest value happens to be the denormal number +// 2^-24, so 2^-26 should be a good choice. +static constexpr float kSmallAlpha = 1.f / (1u << 26u); + +struct AlphaBlendingInputLayer { + const float* r; + const float* g; + const float* b; + const float* a; +}; + +struct AlphaBlendingOutput { + float* r; + float* g; + float* b; + float* a; +}; + +// Note: The pointers in `out` are allowed to alias those in `bg` or `fg`. +// No pointer shall be null. +void PerformAlphaBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp); +// Single plane alpha blending +void PerformAlphaBlending(const float* bg, const float* bga, const float* fg, + const float* fga, float* out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp); + +void PerformAlphaWeightedAdd(const float* bg, const float* fg, const float* fga, + float* out, size_t num_pixels, bool clamp); + +void PerformMulBlending(const float* bg, const float* fg, float* out, + size_t num_pixels, bool clamp); + +void PremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels); +void UnpremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels); +void UnpremultiplyAlpha(float* JXL_RESTRICT rgba, size_t num_pixels); + +} // namespace jxl + +#endif // LIB_JXL_ALPHA_H_ diff --git a/media/libjxl/src/lib/jxl/alpha_test.cc b/media/libjxl/src/lib/jxl/alpha_test.cc new file mode 100644 index 000000000..c643fbdc5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/alpha_test.cc @@ -0,0 +1,133 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/alpha.h" + +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::FloatNear; + +TEST(AlphaTest, BlendingWithNonPremultiplied) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 15420.f / 65535; + const float fg_a2 = 2.0f; + float out_rgb[3]; + float out_a; + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/false, /*clamp=*/false); + EXPECT_THAT(out_rgb, + ElementsAre(FloatNear(77.2f, .05f), FloatNear(83.0f, .05f), + FloatNear(90.6f, .05f))); + EXPECT_NEAR(out_a, 3174.f / 4095, 1e-5); + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a2}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/false, /*clamp=*/true); + EXPECT_THAT(out_rgb, ElementsAre(FloatNear(fg_rgb[0], .05f), + FloatNear(fg_rgb[1], .05f), + FloatNear(fg_rgb[2], .05f))); + EXPECT_NEAR(out_a, 1.0f, 1e-5); +} + +TEST(AlphaTest, BlendingWithPremultiplied) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 15420.f / 65535; + const float fg_a2 = 2.0f; + float out_rgb[3]; + float out_a; + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/true, /*clamp=*/false); + EXPECT_THAT(out_rgb, + ElementsAre(FloatNear(101.5f, .05f), FloatNear(105.1f, .05f), + FloatNear(114.8f, .05f))); + EXPECT_NEAR(out_a, 3174.f / 4095, 1e-5); + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a2}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/true, /*clamp=*/true); + EXPECT_THAT(out_rgb, ElementsAre(FloatNear(fg_rgb[0], .05f), + FloatNear(fg_rgb[1], .05f), + FloatNear(fg_rgb[2], .05f))); + EXPECT_NEAR(out_a, 1.0f, 1e-5); +} + +TEST(AlphaTest, Mul) { + const float bg = 100; + const float fg = 25; + float out; + PerformMulBlending(&bg, &fg, &out, 1, /*clamp=*/false); + EXPECT_THAT(out, FloatNear(fg * bg, .05f)); + PerformMulBlending(&bg, &fg, &out, 1, /*clamp=*/true); + EXPECT_THAT(out, FloatNear(bg, .05f)); +} + +TEST(AlphaTest, PremultiplyAndUnpremultiply) { + const float alpha[] = {0.f, 63.f / 255, 127.f / 255, 1.f}; + float r[] = {120, 130, 140, 150}; + float g[] = {124, 134, 144, 154}; + float b[] = {127, 137, 147, 157}; + + PremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT( + r, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(130 * 63.f / 255, 1e-5f), + FloatNear(140 * 127.f / 255, 1e-5f), 150)); + EXPECT_THAT( + g, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(134 * 63.f / 255, 1e-5f), + FloatNear(144 * 127.f / 255, 1e-5f), 154)); + EXPECT_THAT( + b, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(137 * 63.f / 255, 1e-5f), + FloatNear(147 * 127.f / 255, 1e-5f), 157)); + + UnpremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(FloatNear(120, 1e-4f), FloatNear(130, 1e-4f), + FloatNear(140, 1e-4f), FloatNear(150, 1e-4f))); + EXPECT_THAT(g, ElementsAre(FloatNear(124, 1e-4f), FloatNear(134, 1e-4f), + FloatNear(144, 1e-4f), FloatNear(154, 1e-4f))); + EXPECT_THAT(b, ElementsAre(FloatNear(127, 1e-4f), FloatNear(137, 1e-4f), + FloatNear(147, 1e-4f), FloatNear(157, 1e-4f))); +} + +TEST(AlphaTest, UnpremultiplyAndPremultiply) { + const float alpha[] = {0.f, 63.f / 255, 127.f / 255, 1.f}; + float r[] = {50, 60, 70, 80}; + float g[] = {54, 64, 74, 84}; + float b[] = {57, 67, 77, 87}; + + UnpremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(_, FloatNear(60 * 255.f / 63, 1e-4f), + FloatNear(70 * 255.f / 127, 1e-4f), 80)); + EXPECT_THAT(g, ElementsAre(_, FloatNear(64 * 255.f / 63, 1e-4f), + FloatNear(74 * 255.f / 127, 1e-4f), 84)); + EXPECT_THAT(b, ElementsAre(_, FloatNear(67 * 255.f / 63, 1e-4f), + FloatNear(77 * 255.f / 127, 1e-4f), 87)); + + PremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(FloatNear(50, 1e-4f), FloatNear(60, 1e-4f), + FloatNear(70, 1e-4f), FloatNear(80, 1e-4f))); + EXPECT_THAT(g, ElementsAre(FloatNear(54, 1e-4f), FloatNear(64, 1e-4f), + FloatNear(74, 1e-4f), FloatNear(84, 1e-4f))); + EXPECT_THAT(b, ElementsAre(FloatNear(57, 1e-4f), FloatNear(67, 1e-4f), + FloatNear(77, 1e-4f), FloatNear(87, 1e-4f))); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/ans_common.cc b/media/libjxl/src/lib/jxl/ans_common.cc new file mode 100644 index 000000000..32a658fb4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ans_common.cc @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ans_common.h" + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +std::vector CreateFlatHistogram(int length, int total_count) { + JXL_ASSERT(length > 0); + JXL_ASSERT(length <= total_count); + const int count = total_count / length; + std::vector result(length, count); + const int rem_counts = total_count % length; + for (int i = 0; i < rem_counts; ++i) { + ++result[i]; + } + return result; +} + +// First, all trailing non-occuring symbols are removed from the distribution; +// if this leaves the distribution empty, a dummy symbol with max weight is +// added. This ensures that the resulting distribution sums to total table size. +// Then, `entry_size` is chosen to be the largest power of two so that +// `table_size` = ANS_TAB_SIZE/`entry_size` is at least as big as the +// distribution size. +// Note that each entry will only ever contain two different symbols, and +// consecutive ranges of offsets, which allows us to use a compact +// representation. +// Each entry is initialized with only the (symbol=i, offset) pairs; then +// positions for which the entry overflows (i.e. distribution[i] > entry_size) +// or is not full are computed, and put into a stack in increasing order. +// Missing symbols in the distribution are padded with 0 (because `table_size` +// >= number of symbols). The `cutoff` value for each entry is initialized to +// the number of occupied slots in that entry (i.e. `distributions[i]`). While +// the overflowing-symbol stack is not empty (which implies that the +// underflowing-symbol stack also is not), the top overfull and underfull +// positions are popped from the stack; the empty slots in the underfull entry +// are then filled with as many slots as needed from the overfull entry; such +// slots are placed after the slots in the overfull entry, and `offsets[1]` is +// computed accordingly. The formerly underfull entry is thus now neither +// underfull nor overfull, and represents exactly two symbols. The overfull +// entry might be either overfull or underfull, and is pushed into the +// corresponding stack. +void InitAliasTable(std::vector distribution, uint32_t range, + size_t log_alpha_size, AliasTable::Entry* JXL_RESTRICT a) { + while (!distribution.empty() && distribution.back() == 0) { + distribution.pop_back(); + } + // Ensure that a valid table is always returned, even for an empty + // alphabet. Otherwise, a specially-crafted stream might crash the + // decoder. + if (distribution.empty()) { + distribution.emplace_back(range); + } + const size_t table_size = 1 << log_alpha_size; +#if JXL_ENABLE_ASSERT + int sum = std::accumulate(distribution.begin(), distribution.end(), 0); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(static_cast(sum) == range); + // range must be a power of two + JXL_ASSERT((range & (range - 1)) == 0); + JXL_ASSERT(distribution.size() <= table_size); + JXL_ASSERT(table_size <= range); + const uint32_t entry_size = range >> log_alpha_size; // this is exact + // Special case for single-symbol distributions, that ensures that the state + // does not change when decoding from such a distribution. Note that, since we + // hardcode offset0 == 0, it is not straightforward (if at all possible) to + // fix the general case to produce this result. + for (size_t sym = 0; sym < distribution.size(); sym++) { + if (distribution[sym] == ANS_TAB_SIZE) { + for (size_t i = 0; i < table_size; i++) { + a[i].right_value = sym; + a[i].cutoff = 0; + a[i].offsets1 = entry_size * i; + a[i].freq0 = 0; + a[i].freq1_xor_freq0 = ANS_TAB_SIZE; + } + return; + } + } + std::vector underfull_posn; + std::vector overfull_posn; + std::vector cutoffs(1 << log_alpha_size); + // Initialize entries. + for (size_t i = 0; i < distribution.size(); i++) { + cutoffs[i] = distribution[i]; + if (cutoffs[i] > entry_size) { + overfull_posn.push_back(i); + } else if (cutoffs[i] < entry_size) { + underfull_posn.push_back(i); + } + } + for (uint32_t i = distribution.size(); i < table_size; i++) { + cutoffs[i] = 0; + underfull_posn.push_back(i); + } + // Reassign overflow/underflow values. + while (!overfull_posn.empty()) { + uint32_t overfull_i = overfull_posn.back(); + overfull_posn.pop_back(); + JXL_ASSERT(!underfull_posn.empty()); + uint32_t underfull_i = underfull_posn.back(); + underfull_posn.pop_back(); + uint32_t underfull_by = entry_size - cutoffs[underfull_i]; + cutoffs[overfull_i] -= underfull_by; + // overfull positions have their original symbols + a[underfull_i].right_value = overfull_i; + a[underfull_i].offsets1 = cutoffs[overfull_i]; + // Slots in the right part of entry underfull_i were taken from the end + // of the symbols in entry overfull_i. + if (cutoffs[overfull_i] < entry_size) { + underfull_posn.push_back(overfull_i); + } else if (cutoffs[overfull_i] > entry_size) { + overfull_posn.push_back(overfull_i); + } + } + for (uint32_t i = 0; i < table_size; i++) { + // cutoffs[i] is properly initialized but the clang-analyzer doesn't infer + // it since it is partially initialized across two for-loops. + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + if (cutoffs[i] == entry_size) { + a[i].right_value = i; + a[i].offsets1 = 0; + a[i].cutoff = 0; + } else { + // Note that, if cutoff is not equal to entry_size, + // a[i].offsets1 was initialized with (overfull cutoff) - + // (entry_size - a[i].cutoff). Thus, subtracting + // a[i].cutoff cannot make it negative. + a[i].offsets1 -= cutoffs[i]; + a[i].cutoff = cutoffs[i]; + } + const size_t freq0 = i < distribution.size() ? distribution[i] : 0; + const size_t i1 = a[i].right_value; + const size_t freq1 = i1 < distribution.size() ? distribution[i1] : 0; + a[i].freq0 = static_cast(freq0); + a[i].freq1_xor_freq0 = static_cast(freq1 ^ freq0); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/ans_common.h b/media/libjxl/src/lib/jxl/ans_common.h new file mode 100644 index 000000000..fb5058e31 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ans_common.h @@ -0,0 +1,143 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ANS_COMMON_H_ +#define LIB_JXL_ANS_COMMON_H_ + +#include + +#include +#include // Prefetch +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Returns the precision (number of bits) that should be used to store +// a histogram count such that Log2Floor(count) == logcount. +static JXL_INLINE uint32_t GetPopulationCountPrecision(uint32_t logcount, + uint32_t shift) { + int32_t r = std::min( + logcount, int(shift) - int((ANS_LOG_TAB_SIZE - logcount) >> 1)); + if (r < 0) return 0; + return r; +} + +// Returns a histogram where the counts are positive, differ by at most 1, +// and add up to total_count. The bigger counts (if any) are at the beginning +// of the histogram. +std::vector CreateFlatHistogram(int length, int total_count); + +// An alias table implements a mapping from the [0, ANS_TAB_SIZE) range into +// the [0, ANS_MAX_ALPHABET_SIZE) range, satisfying the following conditions: +// - each symbol occurs as many times as specified by any valid distribution +// of frequencies of the symbols. A valid distribution here is an array of +// ANS_MAX_ALPHABET_SIZE that contains numbers in the range [0, ANS_TAB_SIZE], +// and whose sum is ANS_TAB_SIZE. +// - lookups can be done in constant time, and also return how many smaller +// input values map into the same symbol, according to some well-defined order +// of input values. +// - the space used by the alias table is given by a small constant times the +// index of the largest symbol with nonzero probability in the distribution. +// Each of the entries in the table covers a range of `entry_size` values in the +// [0, ANS_TAB_SIZE) range; consecutive entries represent consecutive +// sub-ranges. In the range covered by entry `i`, the first `cutoff` values map +// to symbol `i`, while the others map to symbol `right_value`. +// +// TODO(veluca): consider making the order used for computing offsets easier to +// define - it is currently defined by the algorithm to compute the alias table. +// Beware of breaking the implicit assumption that symbols that come after the +// cutoff value should have an offset at least as big as the cutoff. + +struct AliasTable { + struct Symbol { + size_t value; + size_t offset; + size_t freq; + }; + +// Working set size matters here (~64 tables x 256 entries). +// offsets0 is always zero (beginning of [0] side among the same symbol). +// offsets1 is an offset of (pos >= cutoff) side decremented by cutoff. +#pragma pack(push, 1) + struct Entry { + uint8_t cutoff; // < kEntrySizeMinus1 when used by ANS. + uint8_t right_value; // < alphabet size. + uint16_t freq0; + + // Only used if `greater` (see Lookup) + uint16_t offsets1; // <= ANS_TAB_SIZE + uint16_t freq1_xor_freq0; // for branchless ternary in Lookup + }; +#pragma pack(pop) + + // Dividing `value` by `entry_size` determines `i`, the entry which is + // responsible for the input. If the remainder is below `cutoff`, then the + // mapped symbol is `i`; since `offsets[0]` stores the number of occurrences + // of `i` "before" the start of this entry, the offset of the input will be + // `offsets[0] + remainder`. If the remainder is above cutoff, the mapped + // symbol is `right_value`; since `offsets[1]` stores the number of + // occurrences of `right_value` "before" this entry, minus the `cutoff` value, + // the input offset is then `remainder + offsets[1]`. + static JXL_INLINE Symbol Lookup(const Entry* JXL_RESTRICT table, size_t value, + size_t log_entry_size, + size_t entry_size_minus_1) { + const size_t i = value >> log_entry_size; + const size_t pos = value & entry_size_minus_1; + +#if JXL_BYTE_ORDER_LITTLE + uint64_t entry; + memcpy(&entry, &table[i].cutoff, sizeof(entry)); + const size_t cutoff = entry & 0xFF; // = MOVZX + const size_t right_value = (entry >> 8) & 0xFF; // = MOVZX + const size_t freq0 = (entry >> 16) & 0xFFFF; +#else + // Generates multiple loads with complex addressing. + const size_t cutoff = table[i].cutoff; + const size_t right_value = table[i].right_value; + const size_t freq0 = table[i].freq0; +#endif + + const bool greater = pos >= cutoff; + +#if JXL_BYTE_ORDER_LITTLE + const uint64_t conditional = greater ? entry : 0; // = CMOV + const size_t offsets1_or_0 = (conditional >> 32) & 0xFFFF; + const size_t freq1_xor_freq0_or_0 = conditional >> 48; +#else + const size_t offsets1_or_0 = greater ? table[i].offsets1 : 0; + const size_t freq1_xor_freq0_or_0 = greater ? table[i].freq1_xor_freq0 : 0; +#endif + + // WARNING: moving this code may interfere with CMOV heuristics. + Symbol s; + s.value = greater ? right_value : i; + s.offset = offsets1_or_0 + pos; + s.freq = freq0 ^ freq1_xor_freq0_or_0; // = greater ? freq1 : freq0 + // XOR avoids implementation-defined conversion from unsigned to signed. + // Alternatives considered: BEXTR is 2 cycles on HSW, SET+shift causes + // spills, simple ternary has a long dependency chain. + + return s; + } + + static HWY_INLINE void Prefetch(const Entry* JXL_RESTRICT table, size_t value, + size_t log_entry_size) { + const size_t i = value >> log_entry_size; + hwy::Prefetch(table + i); + } +}; + +// Computes an alias table for a given distribution. +void InitAliasTable(std::vector distribution, uint32_t range, + size_t log_alpha_size, AliasTable::Entry* JXL_RESTRICT a); + +} // namespace jxl + +#endif // LIB_JXL_ANS_COMMON_H_ diff --git a/media/libjxl/src/lib/jxl/ans_common_test.cc b/media/libjxl/src/lib/jxl/ans_common_test.cc new file mode 100644 index 000000000..2c4ea8ea4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ans_common_test.cc @@ -0,0 +1,43 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ans_common.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/ans_params.h" + +namespace jxl { +namespace { + +void VerifyAliasDistribution(const std::vector& distribution, + uint32_t range) { + constexpr size_t log_alpha_size = 8; + AliasTable::Entry table[1 << log_alpha_size]; + InitAliasTable(distribution, range, log_alpha_size, table); + std::vector> offsets(distribution.size()); + for (uint32_t i = 0; i < range; i++) { + AliasTable::Symbol s = AliasTable::Lookup( + table, i, ANS_LOG_TAB_SIZE - 8, (1 << (ANS_LOG_TAB_SIZE - 8)) - 1); + offsets[s.value].push_back(s.offset); + } + for (uint32_t i = 0; i < distribution.size(); i++) { + ASSERT_EQ(static_cast(distribution[i]), offsets[i].size()); + std::sort(offsets[i].begin(), offsets[i].end()); + for (uint32_t j = 0; j < offsets[i].size(); j++) { + ASSERT_EQ(offsets[i][j], j); + } + } +} + +TEST(ANSCommonTest, AliasDistributionSmoke) { + VerifyAliasDistribution({ANS_TAB_SIZE / 2, ANS_TAB_SIZE / 2}, ANS_TAB_SIZE); + VerifyAliasDistribution({ANS_TAB_SIZE}, ANS_TAB_SIZE); + VerifyAliasDistribution({0, 0, 0, ANS_TAB_SIZE, 0}, ANS_TAB_SIZE); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/ans_params.h b/media/libjxl/src/lib/jxl/ans_params.h new file mode 100644 index 000000000..4bbc284c0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ans_params.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ANS_PARAMS_H_ +#define LIB_JXL_ANS_PARAMS_H_ + +// Common parameters that are needed for both the ANS entropy encoding and +// decoding methods. + +#include +#include + +namespace jxl { + +// TODO(veluca): decide if 12 is the best constant here (valid range is up to +// 16). This requires recomputing the Huffman tables in {enc,dec}_ans.cc +// 14 gives a 0.2% improvement at d1 and makes d8 slightly worse. This is +// likely not worth the increase in encoder complexity. +#define ANS_LOG_TAB_SIZE 12u +#define ANS_TAB_SIZE (1 << ANS_LOG_TAB_SIZE) +#define ANS_TAB_MASK (ANS_TAB_SIZE - 1) + +// Largest possible symbol to be encoded by either ANS or prefix coding. +#define PREFIX_MAX_ALPHABET_SIZE 4096 +#define ANS_MAX_ALPHABET_SIZE 256 + +// Max number of bits for prefix coding. +#define PREFIX_MAX_BITS 15 + +#define ANS_SIGNATURE 0x13 // Initial state, used as CRC. + +} // namespace jxl + +#endif // LIB_JXL_ANS_PARAMS_H_ diff --git a/media/libjxl/src/lib/jxl/ans_test.cc b/media/libjxl/src/lib/jxl/ans_test.cc new file mode 100644 index 000000000..ca9883d37 --- /dev/null +++ b/media/libjxl/src/lib/jxl/ans_test.cc @@ -0,0 +1,278 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { +namespace { + +void RoundtripTestcase(int n_histograms, int alphabet_size, + const std::vector& input_values) { + constexpr uint16_t kMagic1 = 0x9e33; + constexpr uint16_t kMagic2 = 0x8b04; + + BitWriter writer; + // Space for magic bytes. + BitWriter::Allotment allotment_magic1(&writer, 16); + writer.Write(16, kMagic1); + ReclaimAndCharge(&writer, &allotment_magic1, 0, nullptr); + + std::vector context_map; + EntropyEncodingData codes; + std::vector> input_values_vec; + input_values_vec.push_back(input_values); + + BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec, + &codes, &context_map, &writer, 0, nullptr); + WriteTokens(input_values_vec[0], codes, context_map, &writer, 0, nullptr); + + // Magic bytes + padding + BitWriter::Allotment allotment_magic2(&writer, 24); + writer.Write(16, kMagic2); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment_magic2, 0, nullptr); + + // We do not truncate the output. Reading past the end reads out zeroes + // anyway. + BitReader br(writer.GetSpan()); + + ASSERT_EQ(br.ReadBits(16), kMagic1); + + std::vector dec_context_map; + ANSCode decoded_codes; + ASSERT_TRUE( + DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map)); + ASSERT_EQ(dec_context_map, context_map); + ANSSymbolReader reader(&decoded_codes, &br); + + for (const Token& symbol : input_values) { + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value); + } + ASSERT_TRUE(reader.CheckANSFinalState()); + + ASSERT_EQ(br.ReadBits(16), kMagic2); + EXPECT_TRUE(br.Close()); +} + +TEST(ANSTest, EmptyRoundtrip) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector()); +} + +TEST(ANSTest, SingleSymbolRoundtrip) { + for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}}); + } + for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, + std::vector(1024, {0, i})); + } +} + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +constexpr size_t kReps = 3; +#else +constexpr size_t kReps = 10; +#endif + +void RoundtripRandomStream(int alphabet_size, size_t reps = kReps, + size_t num = 1 << 18) { + constexpr int kNumHistograms = 3; + Rng rng(0); + for (size_t i = 0; i < reps; i++) { + std::vector symbols; + for (size_t j = 0; j < num; j++) { + int context = rng.UniformI(0, kNumHistograms); + int value = rng.UniformU(0, alphabet_size); + symbols.emplace_back(context, value); + } + RoundtripTestcase(kNumHistograms, alphabet_size, symbols); + } +} + +void RoundtripRandomUnbalancedStream(int alphabet_size) { + constexpr int kNumHistograms = 3; + constexpr int kPrecision = 1 << 10; + Rng rng(0); + for (size_t i = 0; i < kReps; i++) { + std::vector distributions[kNumHistograms] = {}; + for (int j = 0; j < kNumHistograms; j++) { + distributions[j].resize(kPrecision); + int symbol = 0; + int remaining = 1; + for (int k = 0; k < kPrecision; k++) { + if (remaining == 0) { + if (symbol < alphabet_size - 1) symbol++; + // There is no meaning behind this distribution: it's anything that + // will create a nonuniform distribution and won't have too few + // symbols usually. Also we want different distributions we get to be + // sufficiently dissimilar. + remaining = rng.UniformU(0, kPrecision - k + 1); + } + distributions[j][k] = symbol; + remaining--; + } + } + std::vector symbols; + for (int j = 0; j < 1 << 18; j++) { + int context = rng.UniformI(0, kNumHistograms); + int value = rng.UniformU(0, kPrecision); + symbols.emplace_back(context, value); + } + RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols); + } +} + +TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); } + +TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); } + +TEST(ANSTest, RandomStreamRoundtripBig) { + RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE); +} + +TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) { + RoundtripRandomUnbalancedStream(3); +} + +TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) { + RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE); +} + +TEST(ANSTest, UintConfigRoundtrip) { + for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) { + std::vector uint_config, uint_config_dec; + for (size_t i = 0; i < log_alpha_size; i++) { + for (size_t j = 0; j <= i; j++) { + for (size_t k = 0; k <= i - j; k++) { + uint_config.emplace_back(i, j, k); + } + } + } + uint_config.emplace_back(log_alpha_size, 0, 0); + uint_config_dec.resize(uint_config.size()); + BitWriter writer; + BitWriter::Allotment allotment(&writer, 10 * uint_config.size()); + EncodeUintConfigs(uint_config, &writer, log_alpha_size); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br)); + EXPECT_TRUE(br.Close()); + for (size_t i = 0; i < uint_config.size(); i++) { + EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token); + EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token); + EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token); + } + } +} + +void TestCheckpointing(bool ans, bool lz77) { + std::vector> input_values(1); + for (size_t i = 0; i < 1024; i++) { + input_values[0].push_back(Token(0, i % 4)); + } + // up to lz77 window size. + for (size_t i = 0; i < (1 << 20) - 1022; i++) { + input_values[0].push_back(Token(0, (i % 5) + 4)); + } + // Ensure that when the window wraps around, new values are different. + input_values[0].push_back(Token(0, 0)); + for (size_t i = 0; i < 1024; i++) { + input_values[0].push_back(Token(0, i % 4)); + } + + std::vector context_map; + EntropyEncodingData codes; + HistogramParams params; + params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77 + : HistogramParams::LZ77Method::kNone; + params.force_huffman = !ans; + + BitWriter writer; + { + auto input_values_copy = input_values; + BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map, + &writer, 0, nullptr); + WriteTokens(input_values_copy[0], codes, context_map, &writer, 0, nullptr); + writer.ZeroPadToByte(); + } + + // We do not truncate the output. Reading past the end reads out zeroes + // anyway. + BitReader br(writer.GetSpan()); + Status status = true; + { + BitReaderScopedCloser bc(&br, &status); + + std::vector dec_context_map; + ANSCode decoded_codes; + ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map)); + ASSERT_EQ(dec_context_map, context_map); + ANSSymbolReader reader(&decoded_codes, &br); + + ANSSymbolReader::Checkpoint checkpoint; + size_t br_pos = 0; + constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2; + for (size_t i = 0; i < input_values[0].size(); i++) { + if (i % kInterval == 0 && i > 0) { + reader.Restore(checkpoint); + ASSERT_TRUE(br.Close()); + br = BitReader(writer.GetSpan()); + br.SkipBits(br_pos); + for (size_t j = i - kInterval; j < i; j++) { + Token symbol = input_values[0][j]; + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value) << "j = " << j; + } + } + if (i % kInterval == 0) { + reader.Save(&checkpoint); + br_pos = br.TotalBitsConsumed(); + } + Token symbol = input_values[0][i]; + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value) << "i = " << i; + } + ASSERT_TRUE(reader.CheckANSFinalState()); + } + EXPECT_TRUE(status); +} + +TEST(ANSTest, TestCheckpointingANS) { + TestCheckpointing(/*ans=*/true, /*lz77=*/false); +} + +TEST(ANSTest, TestCheckpointingPrefix) { + TestCheckpointing(/*ans=*/false, /*lz77=*/false); +} + +TEST(ANSTest, TestCheckpointingANSLZ77) { + TestCheckpointing(/*ans=*/true, /*lz77=*/true); +} + +TEST(ANSTest, TestCheckpointingPrefixLZ77) { + TestCheckpointing(/*ans=*/false, /*lz77=*/true); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/aux_out.cc b/media/libjxl/src/lib/jxl/aux_out.cc new file mode 100644 index 000000000..d8ee9460f --- /dev/null +++ b/media/libjxl/src/lib/jxl/aux_out.cc @@ -0,0 +1,87 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/aux_out.h" + +#include + +#include // accumulate + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +void AuxOut::Print(size_t num_inputs) const { + if (num_inputs == 0) return; + + LayerTotals all_layers; + for (size_t i = 0; i < layers.size(); ++i) { + all_layers.Assimilate(layers[i]); + } + + printf("Average butteraugli iters: %10.2f\n", + num_butteraugli_iters * 1.0 / num_inputs); + if (min_quant_rescale != 1.0 || max_quant_rescale != 1.0) { + printf("quant rescale range: %f .. %f\n", min_quant_rescale, + max_quant_rescale); + printf("bitrate error range: %.3f%% .. %.3f%%\n", + 100.0f * min_bitrate_error, 100.0f * max_bitrate_error); + } + + for (size_t i = 0; i < layers.size(); ++i) { + if (layers[i].total_bits != 0) { + printf("Total layer bits %-10s\t", LayerName(i)); + printf("%10f%%", 100.0 * layers[i].total_bits / all_layers.total_bits); + layers[i].Print(num_inputs); + } + } + printf("Total image size "); + all_layers.Print(num_inputs); + + const uint32_t dc_pred_total = + std::accumulate(dc_pred_usage.begin(), dc_pred_usage.end(), 0u); + const uint32_t dc_pred_total_xb = + std::accumulate(dc_pred_usage_xb.begin(), dc_pred_usage_xb.end(), 0u); + if (dc_pred_total + dc_pred_total_xb != 0) { + printf("\nDC pred Y XB:\n"); + for (size_t i = 0; i < dc_pred_usage.size(); ++i) { + printf(" %6u (%5.2f%%) %6u (%5.2f%%)\n", dc_pred_usage[i], + 100.0 * dc_pred_usage[i] / dc_pred_total, dc_pred_usage_xb[i], + 100.0 * dc_pred_usage_xb[i] / dc_pred_total_xb); + } + } + + size_t total_blocks = 0; + size_t total_positions = 0; + if (total_blocks != 0 && total_positions != 0) { + printf("\n\t\t Blocks\t\tPositions\t\t\tBlocks/Position\n"); + printf(" Total:\t\t %7" PRIuS "\t\t %7" PRIuS " \t\t\t%10f%%\n\n", + total_blocks, total_positions, + 100.0 * total_blocks / total_positions); + } +} + +void ReclaimAndCharge(BitWriter* JXL_RESTRICT writer, + BitWriter::Allotment* JXL_RESTRICT allotment, + size_t layer, AuxOut* JXL_RESTRICT aux_out) { + size_t used_bits, unused_bits; + allotment->PrivateReclaim(writer, &used_bits, &unused_bits); + +#if 0 + printf("Layer %s bits: max %" PRIuS " used %" PRIuS " unused %" PRIuS "\n", LayerName(layer), + allotment->MaxBits(), used_bits, unused_bits); +#endif + + // This may be a nested call with aux_out == null. Whenever we know that + // aux_out is null, we can call ReclaimUnused directly. + if (aux_out != nullptr) { + aux_out->layers[layer].total_bits += used_bits; + aux_out->layers[layer].histogram_bits += allotment->HistogramBits(); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/aux_out.h b/media/libjxl/src/lib/jxl/aux_out.h new file mode 100644 index 000000000..707660312 --- /dev/null +++ b/media/libjxl/src/lib/jxl/aux_out.h @@ -0,0 +1,309 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AUX_OUT_H_ +#define LIB_JXL_AUX_OUT_H_ + +// Optional output information for debugging and analyzing size usage. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jxl_inspection.h" + +namespace jxl { + +// For LayerName and AuxOut::layers[] index. Order does not matter. +enum { + kLayerHeader = 0, + kLayerTOC, + kLayerDictionary, + kLayerSplines, + kLayerNoise, + kLayerQuant, + kLayerModularTree, + kLayerModularGlobal, + kLayerDC, + kLayerModularDcGroup, + kLayerControlFields, + kLayerOrder, + kLayerAC, + kLayerACTokens, + kLayerModularAcGroup, + kNumImageLayers +}; + +static inline const char* LayerName(size_t layer) { + switch (layer) { + case kLayerHeader: + return "Headers"; + case kLayerTOC: + return "TOC"; + case kLayerDictionary: + return "Patches"; + case kLayerSplines: + return "Splines"; + case kLayerNoise: + return "Noise"; + case kLayerQuant: + return "Quantizer"; + case kLayerModularTree: + return "ModularTree"; + case kLayerModularGlobal: + return "ModularGlobal"; + case kLayerDC: + return "DC"; + case kLayerModularDcGroup: + return "ModularDcGroup"; + case kLayerControlFields: + return "ControlFields"; + case kLayerOrder: + return "CoeffOrder"; + case kLayerAC: + return "ACHistograms"; + case kLayerACTokens: + return "ACTokens"; + case kLayerModularAcGroup: + return "ModularAcGroup"; + default: + JXL_ABORT("Invalid layer %d\n", static_cast(layer)); + } +} + +// Statistics gathered during compression or decompression. +struct AuxOut { + private: + struct LayerTotals { + void Assimilate(const LayerTotals& victim) { + num_clustered_histograms += victim.num_clustered_histograms; + histogram_bits += victim.histogram_bits; + extra_bits += victim.extra_bits; + total_bits += victim.total_bits; + clustered_entropy += victim.clustered_entropy; + } + void Print(size_t num_inputs) const { + printf("%10" PRId64, static_cast(total_bits)); + if (histogram_bits != 0) { + printf(" [c/i:%6.2f | hst:%8" PRId64 " | ex:%8" PRId64 + " | h+c+e:%12.3f", + num_clustered_histograms * 1.0 / num_inputs, + static_cast(histogram_bits >> 3), + static_cast(extra_bits >> 3), + (histogram_bits + clustered_entropy + extra_bits) / 8.0); + printf("]"); + } + printf("\n"); + } + size_t num_clustered_histograms = 0; + size_t extra_bits = 0; + + // Set via BitsWritten below + size_t histogram_bits = 0; + size_t total_bits = 0; + + double clustered_entropy = 0.0; + }; + + public: + AuxOut() = default; + AuxOut(const AuxOut&) = default; + + void Assimilate(const AuxOut& victim) { + for (size_t i = 0; i < layers.size(); ++i) { + layers[i].Assimilate(victim.layers[i]); + } + num_blocks += victim.num_blocks; + num_small_blocks += victim.num_small_blocks; + num_dct4x8_blocks += victim.num_dct4x8_blocks; + num_afv_blocks += victim.num_afv_blocks; + num_dct8_blocks += victim.num_dct8_blocks; + num_dct8x16_blocks += victim.num_dct8x16_blocks; + num_dct8x32_blocks += victim.num_dct8x32_blocks; + num_dct16_blocks += victim.num_dct16_blocks; + num_dct16x32_blocks += victim.num_dct16x32_blocks; + num_dct32_blocks += victim.num_dct32_blocks; + num_dct32x64_blocks += victim.num_dct32x64_blocks; + num_dct64_blocks += victim.num_dct64_blocks; + num_butteraugli_iters += victim.num_butteraugli_iters; + for (size_t i = 0; i < dc_pred_usage.size(); ++i) { + dc_pred_usage[i] += victim.dc_pred_usage[i]; + dc_pred_usage_xb[i] += victim.dc_pred_usage_xb[i]; + } + max_quant_rescale = std::max(max_quant_rescale, victim.max_quant_rescale); + min_quant_rescale = std::min(min_quant_rescale, victim.min_quant_rescale); + max_bitrate_error = std::max(max_bitrate_error, victim.max_bitrate_error); + min_bitrate_error = std::min(min_bitrate_error, victim.min_bitrate_error); + } + + void Print(size_t num_inputs) const; + + size_t TotalBits() const { + size_t total = 0; + for (const auto& layer : layers) { + total += layer.total_bits; + } + return total; + } + + template + void DumpImage(const char* label, const Image3& image) const { + if (!dump_image) return; + if (debug_prefix.empty()) return; + std::ostringstream pathname; + pathname << debug_prefix << label << ".png"; + CodecInOut io; + // Always save to 16-bit png. + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(ConvertToFloat(image), io.metadata.m.color_encoding); + (void)dump_image(io, pathname.str()); + } + template + void DumpImage(const char* label, const Plane& image) { + DumpImage(label, + Image3(CopyImage(image), CopyImage(image), CopyImage(image))); + } + + template + void DumpXybImage(const char* label, const Image3& image) const { + if (!dump_image) return; + if (debug_prefix.empty()) return; + std::ostringstream pathname; + pathname << debug_prefix << label << ".png"; + + Image3F linear(image.xsize(), image.ysize()); + OpsinParams opsin_params; + opsin_params.Init(kDefaultIntensityTarget); + OpsinToLinear(image, Rect(linear), nullptr, &linear, opsin_params); + + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(std::move(linear), io.metadata.m.color_encoding); + + (void)dump_image(io, pathname.str()); + } + + // Normalizes all the channels to range 0-1, creating a false-color image + // which allows seeing the information from non-RGB channels in an RGB debug + // image. + template + void DumpImageNormalized(const char* label, const Image3& image) const { + std::array min; + std::array max; + Image3MinMax(image, &min, &max); + Image3B normalized(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + float mul = min[c] == max[c] ? 0 : (255.0f / (max[c] - min[c])); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row_in = image.ConstPlaneRow(c, y); + uint8_t* JXL_RESTRICT row_out = normalized.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = static_cast((row_in[x] - min[c]) * mul); + } + } + } + DumpImage(label, normalized); + } + + template + void DumpPlaneNormalized(const char* label, const Plane& image) const { + T min; + T max; + ImageMinMax(image, &min, &max); + Image3B normalized(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + float mul = min == max ? 0 : (255.0f / (max - min)); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row_in = image.ConstRow(y); + uint8_t* JXL_RESTRICT row_out = normalized.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = static_cast((row_in[x] - min) * mul); + } + } + } + DumpImage(label, normalized); + } + + void SetInspectorImage3F(const jxl::InspectorImage3F& inspector) { + inspector_image3f_ = inspector; + } + + // Allows hooking intermediate data inspection into various places of the + // processing pipeline. Returns true iff processing should proceed. + bool InspectImage3F(const char* label, const Image3F& image) { + if (inspector_image3f_ != nullptr) { + return inspector_image3f_(label, image); + } + return true; + } + + std::array layers; + size_t num_blocks = 0; + + // Number of blocks that use larger DCT (set by ac_strategy). + size_t num_small_blocks = 0; + size_t num_dct4x8_blocks = 0; + size_t num_afv_blocks = 0; + size_t num_dct8_blocks = 0; + size_t num_dct8x16_blocks = 0; + size_t num_dct8x32_blocks = 0; + size_t num_dct16_blocks = 0; + size_t num_dct16x32_blocks = 0; + size_t num_dct32_blocks = 0; + size_t num_dct32x64_blocks = 0; + size_t num_dct64_blocks = 0; + + std::array dc_pred_usage = {{0}}; + std::array dc_pred_usage_xb = {{0}}; + + int num_butteraugli_iters = 0; + + float max_quant_rescale = 1.0f; + float min_quant_rescale = 1.0f; + float min_bitrate_error = 0.0f; + float max_bitrate_error = 0.0f; + + // If not empty, additional debugging information (e.g. debug images) is + // saved in files with this prefix. + std::string debug_prefix; + + // By how much the decoded image was downsampled relative to the encoded + // image. + size_t downsampling = 1; + + jxl::InspectorImage3F inspector_image3f_; + + std::function dump_image = + nullptr; +}; + +// Used to skip image creation if they won't be written to debug directory. +static inline bool WantDebugOutput(const AuxOut* aux_out) { + // Need valid pointer and filename. + return aux_out != nullptr && !aux_out->debug_prefix.empty(); +} + +} // namespace jxl + +#endif // LIB_JXL_AUX_OUT_H_ diff --git a/media/libjxl/src/lib/jxl/aux_out_fwd.h b/media/libjxl/src/lib/jxl/aux_out_fwd.h new file mode 100644 index 000000000..29b31ad87 --- /dev/null +++ b/media/libjxl/src/lib/jxl/aux_out_fwd.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AUX_OUT_FWD_H_ +#define LIB_JXL_AUX_OUT_FWD_H_ + +#include + +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +struct AuxOut; + +// Helper function that ensures the `bits_written` are charged to `layer` in +// `aux_out`. Example usage: +// BitWriter::Allotment allotment(&writer, max_bits); +// writer.Write(..); writer.Write(..); +// ReclaimAndCharge(&writer, &allotment, layer, aux_out); +void ReclaimAndCharge(BitWriter* JXL_RESTRICT writer, + BitWriter::Allotment* JXL_RESTRICT allotment, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +} // namespace jxl + +#endif // LIB_JXL_AUX_OUT_FWD_H_ diff --git a/media/libjxl/src/lib/jxl/base/arch_macros.h b/media/libjxl/src/lib/jxl/base/arch_macros.h new file mode 100644 index 000000000..a98301915 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/arch_macros.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_ARCH_MACROS_H_ +#define LIB_JXL_BASE_ARCH_MACROS_H_ + +// Defines the JXL_ARCH_* macros. + +namespace jxl { + +#if defined(__x86_64__) || defined(_M_X64) +#define JXL_ARCH_X64 1 +#else +#define JXL_ARCH_X64 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) +#define JXL_ARCH_PPC 1 +#else +#define JXL_ARCH_PPC 0 +#endif + +#if defined(__aarch64__) || defined(__arm__) +#define JXL_ARCH_ARM 1 +#else +#define JXL_ARCH_ARM 0 +#endif + +} // namespace jxl + +#endif // LIB_JXL_BASE_ARCH_MACROS_H_ diff --git a/media/libjxl/src/lib/jxl/base/bits.h b/media/libjxl/src/lib/jxl/base/bits.h new file mode 100644 index 000000000..9f86118e7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/bits.h @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_BITS_H_ +#define LIB_JXL_BASE_BITS_H_ + +// Specialized instructions for processing register-sized bit arrays. + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +#if JXL_COMPILER_MSVC +#include +#endif + +#include +#include + +namespace jxl { + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +template +constexpr bool IsSigned() { + return T(0) > T(-1); +} + +// Undefined results for x == 0. +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(SizeTag<4> /* tag */, const uint32_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC + unsigned long index; + _BitScanReverse(&index, x); + return 31 - index; +#else + return static_cast(__builtin_clz(x)); +#endif +} +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(SizeTag<8> /* tag */, const uint64_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC +#if JXL_ARCH_X64 + unsigned long index; + _BitScanReverse64(&index, x); + return 63 - index; +#else // JXL_ARCH_X64 + // _BitScanReverse64 not available + uint32_t msb = static_cast(x >> 32u); + unsigned long index; + if (msb == 0) { + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // JXL_ARCH_X64 +#else + return static_cast(__builtin_clzll(x)); +#endif +} +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(const T x) { + static_assert(!IsSigned(), "Num0BitsAboveMS1Bit_Nonzero: use unsigned"); + return Num0BitsAboveMS1Bit_Nonzero(SizeTag(), x); +} + +// Undefined results for x == 0. +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsBelowLS1Bit_Nonzero(SizeTag<4> /* tag */, const uint32_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC + unsigned long index; + _BitScanForward(&index, x); + return index; +#else + return static_cast(__builtin_ctz(x)); +#endif +} +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsBelowLS1Bit_Nonzero(SizeTag<8> /* tag */, const uint64_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC +#if JXL_ARCH_X64 + unsigned long index; + _BitScanForward64(&index, x); + return index; +#else // JXL_ARCH_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + unsigned long index; + if (lsb == 0) { + uint32_t msb = static_cast(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // JXL_ARCH_X64 +#else + return static_cast(__builtin_ctzll(x)); +#endif +} +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsBelowLS1Bit_Nonzero(T x) { + static_assert(!IsSigned(), "Num0BitsBelowLS1Bit_Nonzero: use unsigned"); + return Num0BitsBelowLS1Bit_Nonzero(SizeTag(), x); +} + +// Returns bit width for x == 0. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsAboveMS1Bit(const T x) { + return (x == 0) ? sizeof(T) * 8 : Num0BitsAboveMS1Bit_Nonzero(x); +} + +// Returns bit width for x == 0. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsBelowLS1Bit(const T x) { + return (x == 0) ? sizeof(T) * 8 : Num0BitsBelowLS1Bit_Nonzero(x); +} + +// Returns base-2 logarithm, rounded down. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t FloorLog2Nonzero(const T x) { + return (sizeof(T) * 8 - 1) ^ Num0BitsAboveMS1Bit_Nonzero(x); +} + +// Returns base-2 logarithm, rounded up. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t CeilLog2Nonzero(const T x) { + const size_t floor_log2 = FloorLog2Nonzero(x); + if ((x & (x - 1)) == 0) return floor_log2; // power of two + return floor_log2 + 1; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_BITS_H_ diff --git a/media/libjxl/src/lib/jxl/base/byte_order.h b/media/libjxl/src/lib/jxl/base/byte_order.h new file mode 100644 index 000000000..ccf1a5e71 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/byte_order.h @@ -0,0 +1,241 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_BYTE_ORDER_H_ +#define LIB_JXL_BASE_BYTE_ORDER_H_ + +#include +#include // memcpy + +#include "lib/jxl/base/compiler_specific.h" + +#if JXL_COMPILER_MSVC +#include // _byteswap_* +#endif + +#if (defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) +#define JXL_BYTE_ORDER_LITTLE 1 +#else +// This means that we don't know that the byte order is little endian, in +// this case we use endian-neutral code that works for both little- and +// big-endian. +#define JXL_BYTE_ORDER_LITTLE 0 +#endif + +// Returns whether the system is little-endian (least-significant byte first). +#if JXL_BYTE_ORDER_LITTLE +static constexpr bool IsLittleEndian() { return true; } +#else +static inline bool IsLittleEndian() { + const uint32_t multibyte = 1; + uint8_t byte; + memcpy(&byte, &multibyte, 1); + return byte == 1; +} +#endif + +#if JXL_COMPILER_MSVC +#define JXL_BSWAP32(x) _byteswap_ulong(x) +#define JXL_BSWAP64(x) _byteswap_uint64(x) +#else +#define JXL_BSWAP32(x) __builtin_bswap32(x) +#define JXL_BSWAP64(x) __builtin_bswap64(x) +#endif + +static JXL_INLINE uint32_t LoadBE16(const uint8_t* p) { + const uint32_t byte1 = p[0]; + const uint32_t byte0 = p[1]; + return (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadLE16(const uint8_t* p) { + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + return (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadBE32(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint32_t big; + memcpy(&big, p, 4); + return JXL_BSWAP32(big); +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint32_t byte3 = p[0]; + const uint32_t byte2 = p[1]; + const uint32_t byte1 = p[2]; + const uint32_t byte0 = p[3]; + return (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE uint64_t LoadBE64(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint64_t big; + memcpy(&big, p, 8); + return JXL_BSWAP64(big); +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint64_t byte7 = p[0]; + const uint64_t byte6 = p[1]; + const uint64_t byte5 = p[2]; + const uint64_t byte4 = p[3]; + const uint64_t byte3 = p[4]; + const uint64_t byte2 = p[5]; + const uint64_t byte1 = p[6]; + const uint64_t byte0 = p[7]; + return (byte7 << 56ull) | (byte6 << 48ull) | (byte5 << 40ull) | + (byte4 << 32ull) | (byte3 << 24ull) | (byte2 << 16ull) | + (byte1 << 8ull) | byte0; +#endif +} + +static JXL_INLINE uint32_t LoadLE32(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint32_t little; + memcpy(&little, p, 4); + return little; +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + const uint32_t byte2 = p[2]; + const uint32_t byte3 = p[3]; + return (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE uint64_t LoadLE64(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint64_t little; + memcpy(&little, p, 8); + return little; +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint64_t byte0 = p[0]; + const uint64_t byte1 = p[1]; + const uint64_t byte2 = p[2]; + const uint64_t byte3 = p[3]; + const uint64_t byte4 = p[4]; + const uint64_t byte5 = p[5]; + const uint64_t byte6 = p[6]; + const uint64_t byte7 = p[7]; + return (byte7 << 56) | (byte6 << 48) | (byte5 << 40) | (byte4 << 32) | + (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE void StoreBE16(const uint32_t native, uint8_t* p) { + p[0] = (native >> 8) & 0xFF; + p[1] = native & 0xFF; +} + +static JXL_INLINE void StoreLE16(const uint32_t native, uint8_t* p) { + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +} + +static JXL_INLINE void StoreBE32(const uint32_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint32_t big = JXL_BSWAP32(native); + memcpy(p, &big, 4); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[0] = native >> 24; + p[1] = (native >> 16) & 0xFF; + p[2] = (native >> 8) & 0xFF; + p[3] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreBE64(const uint64_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint64_t big = JXL_BSWAP64(native); + memcpy(p, &big, 8); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[0] = native >> 56ull; + p[1] = (native >> 48ull) & 0xFF; + p[2] = (native >> 40ull) & 0xFF; + p[3] = (native >> 32ull) & 0xFF; + p[4] = (native >> 24ull) & 0xFF; + p[5] = (native >> 16ull) & 0xFF; + p[6] = (native >> 8ull) & 0xFF; + p[7] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreLE32(const uint32_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint32_t little = native; + memcpy(p, &little, 4); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[3] = native >> 24; + p[2] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreLE64(const uint64_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint64_t little = native; + memcpy(p, &little, 8); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[7] = native >> 56; + p[6] = (native >> 48) & 0xFF; + p[5] = (native >> 40) & 0xFF; + p[4] = (native >> 32) & 0xFF; + p[3] = (native >> 24) & 0xFF; + p[2] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +#endif +} + +// Big/Little Endian order. +struct OrderBE {}; +struct OrderLE {}; + +// Wrappers for calling from generic code. +static JXL_INLINE void Store16(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE16(native, p); +} + +static JXL_INLINE void Store16(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE16(native, p); +} + +static JXL_INLINE void Store32(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE32(native, p); +} + +static JXL_INLINE void Store32(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE32(native, p); +} + +static JXL_INLINE uint32_t Load16(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE16(p); +} + +static JXL_INLINE uint32_t Load16(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE16(p); +} + +static JXL_INLINE uint32_t Load32(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE32(p); +} + +static JXL_INLINE uint32_t Load32(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE32(p); +} + +#endif // LIB_JXL_BASE_BYTE_ORDER_H_ diff --git a/media/libjxl/src/lib/jxl/base/cache_aligned.cc b/media/libjxl/src/lib/jxl/base/cache_aligned.cc new file mode 100644 index 000000000..9a9cc585a --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/cache_aligned.cc @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/cache_aligned.h" + +#include +#include + +// Disabled: slower than malloc + alignment. +#define JXL_USE_MMAP 0 + +#if JXL_USE_MMAP +#include +#endif + +#include // std::max +#include +#include // kMaxVectorSize +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace { + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t allocated_size; + uint8_t left_padding[hwy::kMaxVectorSize]; +}; +#pragma pack(pop) + +std::atomic num_allocations{0}; +std::atomic bytes_in_use{0}; +std::atomic max_bytes_in_use{0}; + +} // namespace + +// Avoids linker errors in pre-C++17 builds. +constexpr size_t CacheAligned::kPointerSize; +constexpr size_t CacheAligned::kCacheLineSize; +constexpr size_t CacheAligned::kAlignment; +constexpr size_t CacheAligned::kAlias; + +void CacheAligned::PrintStats() { + fprintf( + stderr, "Allocations: %" PRIuS " (max bytes in use: %E)\n", + static_cast(num_allocations.load(std::memory_order_relaxed)), + static_cast(max_bytes_in_use.load(std::memory_order_relaxed))); +} + +size_t CacheAligned::NextOffset() { + static std::atomic next{0}; + constexpr uint32_t kGroups = CacheAligned::kAlias / CacheAligned::kAlignment; + const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + return CacheAligned::kAlignment * group; +} + +void* CacheAligned::Allocate(const size_t payload_size, size_t offset) { + JXL_ASSERT(payload_size <= std::numeric_limits::max() / 2); + JXL_ASSERT((offset % kAlignment == 0) && offset <= kAlias); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset | |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + // SVE/RVV vectors can be large, so we cannot rely on them (including the + // padding at the end of AllocationHeader) to fit in kAlignment. + offset = hwy::RoundUpTo(sizeof(AllocationHeader), kAlignment); + } + +#if JXL_USE_MMAP + const size_t allocated_size = offset + payload_size; + const int flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE; + void* allocated = + mmap(nullptr, allocated_size, PROT_READ | PROT_WRITE, flags, -1, 0); + if (allocated == MAP_FAILED) return nullptr; + const uintptr_t aligned = reinterpret_cast(allocated); +#else + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated = malloc(allocated_size); + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); +#endif + +#if 0 + // No effect. + uintptr_t page_aligned = reinterpret_cast(allocated); + page_aligned &= ~(4096 - 1); + if (madvise(reinterpret_cast(page_aligned), allocated_size, + MADV_WILLNEED) != 0) { + JXL_NOTIFY_ERROR("madvise failed"); + } +#elif 0 + // INCREASES both first and subsequent decode times. + if (mlock(allocated, allocated_size) != 0) { + JXL_NOTIFY_ERROR("mlock failed"); + } +#endif + + // Update statistics (#allocations and max bytes in use) + num_allocations.fetch_add(1, std::memory_order_relaxed); + const uint64_t prev_bytes = + bytes_in_use.fetch_add(allocated_size, std::memory_order_acq_rel); + uint64_t expected_max = max_bytes_in_use.load(std::memory_order_acquire); + for (;;) { + const uint64_t desired = + std::max(expected_max, prev_bytes + allocated_size); + if (max_bytes_in_use.compare_exchange_strong(expected_max, desired, + std::memory_order_acq_rel)) { + break; + } + } + + const uintptr_t payload = aligned + offset; // still aligned + + // Stash `allocated` and payload_size inside header for use by Free(). + AllocationHeader* header = reinterpret_cast(payload) - 1; + header->allocated = allocated; + header->allocated_size = allocated_size; + + return JXL_ASSUME_ALIGNED(reinterpret_cast(payload), 64); +} + +void CacheAligned::Free(const void* aligned_pointer) { + if (aligned_pointer == nullptr) { + return; + } + const uintptr_t payload = reinterpret_cast(aligned_pointer); + JXL_ASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + // Subtract (2's complement negation). + bytes_in_use.fetch_add(~header->allocated_size + 1, + std::memory_order_acq_rel); + +#if JXL_USE_MMAP + munmap(header->allocated, header->allocated_size); +#else + free(header->allocated); +#endif +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/base/cache_aligned.h b/media/libjxl/src/lib/jxl/base/cache_aligned.h new file mode 100644 index 000000000..e57df1483 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/cache_aligned.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_CACHE_ALIGNED_H_ +#define LIB_JXL_BASE_CACHE_ALIGNED_H_ + +// Memory allocator with support for alignment + misalignment. + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Functions that depend on the cache line size. +class CacheAligned { + public: + static void PrintStats(); + + static constexpr size_t kPointerSize = sizeof(void*); + static constexpr size_t kCacheLineSize = 64; + // To avoid RFOs, match L2 fill size (pairs of lines). + static constexpr size_t kAlignment = 2 * kCacheLineSize; + // Minimum multiple for which cache set conflicts and/or loads blocked by + // preceding stores can occur. + static constexpr size_t kAlias = 2048; + + // Returns a 'random' (cyclical) offset suitable for Allocate. + static size_t NextOffset(); + + // Returns null or memory whose address is congruent to `offset` (mod kAlias). + // This reduces cache conflicts and load/store stalls, especially with large + // allocations that would otherwise have similar alignments. At least + // `payload_size` (which can be zero) bytes will be accessible. + static void* Allocate(size_t payload_size, size_t offset); + + static void* Allocate(const size_t payload_size) { + return Allocate(payload_size, NextOffset()); + } + + static void Free(const void* aligned_pointer); +}; + +// Avoids the need for a function pointer (deleter) in CacheAlignedUniquePtr. +struct CacheAlignedDeleter { + void operator()(uint8_t* aligned_pointer) const { + return CacheAligned::Free(aligned_pointer); + } +}; + +using CacheAlignedUniquePtr = std::unique_ptr; + +// Does not invoke constructors. +static inline CacheAlignedUniquePtr AllocateArray(const size_t bytes) { + return CacheAlignedUniquePtr( + static_cast(CacheAligned::Allocate(bytes)), + CacheAlignedDeleter()); +} + +static inline CacheAlignedUniquePtr AllocateArray(const size_t bytes, + const size_t offset) { + return CacheAlignedUniquePtr( + static_cast(CacheAligned::Allocate(bytes, offset)), + CacheAlignedDeleter()); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_CACHE_ALIGNED_H_ diff --git a/media/libjxl/src/lib/jxl/base/compiler_specific.h b/media/libjxl/src/lib/jxl/base/compiler_specific.h new file mode 100644 index 000000000..7aa8b9915 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/compiler_specific.h @@ -0,0 +1,156 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_COMPILER_SPECIFIC_H_ +#define LIB_JXL_BASE_COMPILER_SPECIFIC_H_ + +// Macros for compiler version + nonstandard keywords, e.g. __builtin_expect. + +#include + +#include "lib/jxl/base/sanitizer_definitions.h" + +// #if is shorter and safer than #ifdef. *_VERSION are zero if not detected, +// otherwise 100 * major + minor version. Note that other packages check for +// #ifdef COMPILER_MSVC, so we cannot use that same name. + +#ifdef _MSC_VER +#define JXL_COMPILER_MSVC _MSC_VER +#else +#define JXL_COMPILER_MSVC 0 +#endif + +#ifdef __GNUC__ +#define JXL_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define JXL_COMPILER_GCC 0 +#endif + +#ifdef __clang__ +#define JXL_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +// Clang pretends to be GCC for compatibility. +#undef JXL_COMPILER_GCC +#define JXL_COMPILER_GCC 0 +#else +#define JXL_COMPILER_CLANG 0 +#endif + +#if JXL_COMPILER_MSVC +#define JXL_RESTRICT __restrict +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_RESTRICT __restrict__ +#else +#define JXL_RESTRICT +#endif + +#if JXL_COMPILER_MSVC +#define JXL_INLINE __forceinline +#define JXL_NOINLINE __declspec(noinline) +#else +#define JXL_INLINE inline __attribute__((always_inline)) +#define JXL_NOINLINE __attribute__((noinline)) +#endif + +#if JXL_COMPILER_MSVC +#define JXL_NORETURN __declspec(noreturn) +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_NORETURN __attribute__((noreturn)) +#else +#define JXL_NORETURN +#endif + +#if JXL_COMPILER_MSVC +#define JXL_UNREACHABLE __assume(false) +#elif JXL_COMPILER_CLANG || JXL_COMPILER_GCC >= 405 +#define JXL_UNREACHABLE __builtin_unreachable() +#else +#define JXL_UNREACHABLE +#endif + +#if JXL_COMPILER_MSVC +#define JXL_MAYBE_UNUSED +#else +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define JXL_MAYBE_UNUSED __attribute__((unused)) +#endif + +// MSAN execution won't hurt if some code it not inlined, but this can greatly +// improve compilation time. Unfortunately this macro can not be used just +// everywhere - inside header files it leads to "multiple definition" error; +// though it would be better not to have JXL_INLINE in header overall. +#if JXL_MEMORY_SANITIZER || JXL_ADDRESS_SANITIZER || JXL_THREAD_SANITIZER +#define JXL_MAYBE_INLINE JXL_MAYBE_UNUSED +#else +#define JXL_MAYBE_INLINE JXL_INLINE +#endif + +#if JXL_COMPILER_MSVC +// Unsupported, __assume is not the same. +#define JXL_LIKELY(expr) expr +#define JXL_UNLIKELY(expr) expr +#else +#define JXL_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define JXL_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* JXL_RESTRICT aligned = (float*)JXL_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if JXL_COMPILER_CLANG +// Early versions of Clang did not support __builtin_assume_aligned. +#define JXL_HAS_ASSUME_ALIGNED __has_builtin(__builtin_assume_aligned) +#elif JXL_COMPILER_GCC +#define JXL_HAS_ASSUME_ALIGNED 1 +#else +#define JXL_HAS_ASSUME_ALIGNED 0 +#endif + +#if JXL_HAS_ASSUME_ALIGNED +#define JXL_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define JXL_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +#ifdef __has_attribute +#define JXL_HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define JXL_HAVE_ATTRIBUTE(x) 0 +#endif + +// Raises warnings if the function return value is unused. Should appear as the +// first part of a function definition/declaration. +#if JXL_HAVE_ATTRIBUTE(nodiscard) +#define JXL_MUST_USE_RESULT [[nodiscard]] +#elif JXL_COMPILER_CLANG && JXL_HAVE_ATTRIBUTE(warn_unused_result) +#define JXL_MUST_USE_RESULT __attribute__((warn_unused_result)) +#else +#define JXL_MUST_USE_RESULT +#endif + +// Disable certain -fsanitize flags for functions that are expected to include +// things like unsigned integer overflow. For example use in the function +// declaration JXL_NO_SANITIZE("unsigned-integer-overflow") to silence unsigned +// integer overflow ubsan messages. +#if JXL_COMPILER_CLANG && JXL_HAVE_ATTRIBUTE(no_sanitize) +#define JXL_NO_SANITIZE(X) __attribute__((no_sanitize(X))) +#else +#define JXL_NO_SANITIZE(X) +#endif + +#if JXL_HAVE_ATTRIBUTE(__format__) +#define JXL_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define JXL_FORMAT(idx_fmt, idx_arg) +#endif + +#if JXL_COMPILER_MSVC +using ssize_t = intptr_t; +#endif + +#endif // LIB_JXL_BASE_COMPILER_SPECIFIC_H_ diff --git a/media/libjxl/src/lib/jxl/base/data_parallel.cc b/media/libjxl/src/lib/jxl/base/data_parallel.cc new file mode 100644 index 000000000..20a911255 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/data_parallel.cc @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/data_parallel.h" + +namespace jxl { + +// static +JxlParallelRetCode ThreadPool::SequentialRunnerStatic( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + JxlParallelRetCode init_ret = (*init)(jpegxl_opaque, 1); + if (init_ret != 0) return init_ret; + + for (uint32_t i = start_range; i < end_range; i++) { + (*func)(jpegxl_opaque, i, 0); + } + return 0; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/base/data_parallel.h b/media/libjxl/src/lib/jxl/base/data_parallel.h new file mode 100644 index 000000000..666925aea --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/data_parallel.h @@ -0,0 +1,120 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_DATA_PARALLEL_H_ +#define LIB_JXL_BASE_DATA_PARALLEL_H_ + +// Portable, low-overhead C++11 ThreadPool alternative to OpenMP for +// data-parallel computations. + +#include +#include + +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#if JXL_COMPILER_MSVC +// suppress warnings about the const & applied to function types +#pragma warning(disable : 4180) +#endif + +namespace jxl { + +class ThreadPool { + public: + ThreadPool(JxlParallelRunner runner, void* runner_opaque) + : runner_(runner ? runner : &ThreadPool::SequentialRunnerStatic), + runner_opaque_(runner ? runner_opaque : static_cast(this)) {} + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + JxlParallelRunner runner() const { return runner_; } + void* runner_opaque() const { return runner_opaque_; } + + // Runs init_func(num_threads) followed by data_func(task, thread) on worker + // thread(s) for every task in [begin, end). init_func() must return a Status + // indicating whether the initialization succeeded. + // "thread" is an integer smaller than num_threads. + // Not thread-safe - no two calls to Run may overlap. + // Subsequent calls will reuse the same threads. + // + // Precondition: begin <= end. + template + Status Run(uint32_t begin, uint32_t end, const InitFunc& init_func, + const DataFunc& data_func, const char* caller = "") { + JXL_ASSERT(begin <= end); + if (begin == end) return true; + RunCallState call_state(init_func, data_func); + // The runner_ uses the C convention and returns 0 in case of error, so we + // convert it to a Status. + return (*runner_)(runner_opaque_, static_cast(&call_state), + &call_state.CallInitFunc, &call_state.CallDataFunc, begin, + end) == 0; + } + + // Use this as init_func when no initialization is needed. + static Status NoInit(size_t num_threads) { return true; } + + private: + // class holding the state of a Run() call to pass to the runner_ as an + // opaque_jpegxl pointer. + template + class RunCallState final { + public: + RunCallState(const InitFunc& init_func, const DataFunc& data_func) + : init_func_(init_func), data_func_(data_func) {} + + // JxlParallelRunInit interface. + static int CallInitFunc(void* jpegxl_opaque, size_t num_threads) { + const auto* self = + static_cast*>(jpegxl_opaque); + // Returns -1 when the internal init function returns false Status to + // indicate an error. + return self->init_func_(num_threads) ? 0 : -1; + } + + // JxlParallelRunFunction interface. + static void CallDataFunc(void* jpegxl_opaque, uint32_t value, + size_t thread_id) { + const auto* self = + static_cast*>(jpegxl_opaque); + return self->data_func_(value, thread_id); + } + + private: + const InitFunc& init_func_; + const DataFunc& data_func_; + }; + + // Default JxlParallelRunner used when no runner is provided by the + // caller. This runner doesn't use any threading and thread_id is always 0. + static JxlParallelRetCode SequentialRunnerStatic( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + + // The caller supplied runner function and its opaque void*. + const JxlParallelRunner runner_; + void* const runner_opaque_; +}; + +template +Status RunOnPool(ThreadPool* pool, const uint32_t begin, const uint32_t end, + const InitFunc& init_func, const DataFunc& data_func, + const char* caller) { + if (pool == nullptr) { + ThreadPool default_pool(nullptr, nullptr); + return default_pool.Run(begin, end, init_func, data_func, caller); + } else { + return pool->Run(begin, end, init_func, data_func, caller); + } +} + +} // namespace jxl +#if JXL_COMPILER_MSVC +#pragma warning(default : 4180) +#endif + +#endif // LIB_JXL_BASE_DATA_PARALLEL_H_ diff --git a/media/libjxl/src/lib/jxl/base/file_io.h b/media/libjxl/src/lib/jxl/base/file_io.h new file mode 100644 index 000000000..8c7777c94 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/file_io.h @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_FILE_IO_H_ +#define LIB_JXL_BASE_FILE_IO_H_ + +// Helper functions for reading/writing files. + +#include +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Returns extension including the dot, or empty string if none. Assumes +// filename is not a hidden file (e.g. ".bashrc"). May be called with a pathname +// if the filename contains a dot and/or no other path component does. +static inline std::string Extension(const std::string& filename) { + const size_t pos = filename.rfind('.'); + if (pos == std::string::npos) return std::string(); + return filename.substr(pos); +} + +// RAII, ensures files are closed even when returning early. +class FileWrapper { + public: + FileWrapper(const FileWrapper& other) = delete; + FileWrapper& operator=(const FileWrapper& other) = delete; + + explicit FileWrapper(const std::string& pathname, const char* mode) + : file_(pathname == "-" ? (mode[0] == 'r' ? stdin : stdout) + : fopen(pathname.c_str(), mode)), + close_on_delete_(pathname != "-") { +#ifdef _WIN32 + struct __stat64 s = {}; + const int err = _stat64(pathname.c_str(), &s); + const bool is_file = (s.st_mode & S_IFREG) != 0; +#else + struct stat s = {}; + const int err = stat(pathname.c_str(), &s); + const bool is_file = S_ISREG(s.st_mode); +#endif + if (err == 0 && is_file) { + size_ = s.st_size; + } + } + + ~FileWrapper() { + if (file_ != nullptr && close_on_delete_) { + const int err = fclose(file_); + JXL_CHECK(err == 0); + } + } + + // We intend to use FileWrapper as a replacement of FILE. + // NOLINTNEXTLINE(google-explicit-constructor) + operator FILE*() const { return file_; } + + int64_t size() { return size_; } + + private: + FILE* const file_; + bool close_on_delete_ = true; + int64_t size_ = -1; +}; + +template +static inline Status ReadFile(const std::string& pathname, + ContainerType* JXL_RESTRICT bytes) { + FileWrapper f(pathname, "rb"); + if (f == nullptr) + return JXL_FAILURE("Failed to open file for reading: %s", pathname.c_str()); + + // Get size of file in bytes + const int64_t size = f.size(); + if (size < 0) { + // Size is unknown, loop reading chunks until EOF. + bytes->clear(); + std::list> chunks; + + size_t total_size = 0; + while (true) { + std::vector chunk(16 * 1024); + const size_t bytes_read = fread(chunk.data(), 1, chunk.size(), f); + if (ferror(f) || bytes_read > chunk.size()) { + return JXL_FAILURE("Error reading %s", pathname.c_str()); + } + + chunk.resize(bytes_read); + total_size += bytes_read; + if (bytes_read != 0) { + chunks.emplace_back(std::move(chunk)); + } + if (feof(f)) { + break; + } + } + bytes->resize(total_size); + size_t pos = 0; + for (const auto& chunk : chunks) { + // Needed in case ContainerType is std::string, whose data() is const. + char* bytes_writable = reinterpret_cast(&(*bytes)[0]); + memcpy(bytes_writable + pos, chunk.data(), chunk.size()); + pos += chunk.size(); + } + } else { + // Size is known, read the file directly. + bytes->resize(static_cast(size)); + size_t pos = 0; + while (pos < bytes->size()) { + // Needed in case ContainerType is std::string, whose data() is const. + char* bytes_writable = reinterpret_cast(&(*bytes)[0]); + const size_t bytes_read = + fread(bytes_writable + pos, 1, bytes->size() - pos, f); + if (bytes_read == 0) return JXL_FAILURE("Failed to read"); + pos += bytes_read; + } + JXL_ASSERT(pos == bytes->size()); + } + return true; +} + +template +static inline Status WriteFile(const ContainerType& bytes, + const std::string& pathname) { + FileWrapper f(pathname, "wb"); + if (f == nullptr) return JXL_FAILURE("Failed to open file for writing"); + + size_t pos = 0; + while (pos < bytes.size()) { + const size_t bytes_written = + fwrite(bytes.data() + pos, 1, bytes.size() - pos, f); + if (bytes_written == 0) return JXL_FAILURE("Failed to write"); + pos += bytes_written; + } + JXL_ASSERT(pos == bytes.size()); + + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_FILE_IO_H_ diff --git a/media/libjxl/src/lib/jxl/base/iaca.h b/media/libjxl/src/lib/jxl/base/iaca.h new file mode 100644 index 000000000..e5732dae5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/iaca.h @@ -0,0 +1,65 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_IACA_H_ +#define LIB_JXL_BASE_IACA_H_ + +#include "lib/jxl/base/compiler_specific.h" + +// IACA (Intel's Code Analyzer) analyzes instruction latencies, but only for +// code between special markers. These functions embed such markers in an +// executable, but only for reading via IACA - they deliberately trigger a +// crash if executed to ensure they are removed in normal builds. + +#ifndef JXL_IACA_ENABLED +#define JXL_IACA_ENABLED 0 +#endif + +namespace jxl { + +// Call before the region of interest. +static JXL_INLINE void BeginIACA() { +#if JXL_IACA_ENABLED && (JXL_COMPILER_GCC || JXL_COMPILER_CLANG) + asm volatile( + // UD2 "instruction" raises an invalid opcode exception. + ".byte 0x0F, 0x0B\n\t" + // Magic sequence recognized by IACA (MOV + addr32 fs:NOP). This actually + // clobbers EBX, but we don't care because the code won't be run, and we + // want IACA to observe the same code the compiler would have generated + // without this marker. + "movl $111, %%ebx\n\t" + ".byte 0x64, 0x67, 0x90\n\t" + : + : + // (Allegedly) clobbering memory may prevent reordering. + : "memory"); +#endif +} + +// Call after the region of interest. +static JXL_INLINE void EndIACA() { +#if JXL_IACA_ENABLED && (JXL_COMPILER_GCC || JXL_COMPILER_CLANG) + asm volatile( + // See above. + "movl $222, %%ebx\n\t" + ".byte 0x64, 0x67, 0x90\n\t" + // UD2 + ".byte 0x0F, 0x0B\n\t" + : + : + // (Allegedly) clobbering memory may prevent reordering. + : "memory"); +#endif +} + +// Add to a scope to mark a region. +struct ScopeIACA { + JXL_INLINE ScopeIACA() { BeginIACA(); } + JXL_INLINE ~ScopeIACA() { EndIACA(); } +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_IACA_H_ diff --git a/media/libjxl/src/lib/jxl/base/os_macros.h b/media/libjxl/src/lib/jxl/base/os_macros.h new file mode 100644 index 000000000..84d0b82bf --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/os_macros.h @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_OS_MACROS_H_ +#define LIB_JXL_BASE_OS_MACROS_H_ + +// Defines the JXL_OS_* macros. + +#if defined(_WIN32) || defined(_WIN64) +#define JXL_OS_WIN 1 +#else +#define JXL_OS_WIN 0 +#endif + +#ifdef __linux__ +#define JXL_OS_LINUX 1 +#else +#define JXL_OS_LINUX 0 +#endif + +#ifdef __APPLE__ +#define JXL_OS_MAC 1 +#else +#define JXL_OS_MAC 0 +#endif + +#define JXL_OS_IOS 0 +#ifdef __APPLE__ +#include +#if TARGET_OS_IPHONE +#undef JXL_OS_IOS +#define JXL_OS_IOS 1 +#endif +#endif + +#ifdef __FreeBSD__ +#define JXL_OS_FREEBSD 1 +#else +#define JXL_OS_FREEBSD 0 +#endif + +#ifdef __HAIKU__ +#define JXL_OS_HAIKU 1 +#else +#define JXL_OS_HAIKU 0 +#endif + +#endif // LIB_JXL_BASE_OS_MACROS_H_ diff --git a/media/libjxl/src/lib/jxl/base/override.h b/media/libjxl/src/lib/jxl/base/override.h new file mode 100644 index 000000000..1f8b65797 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/override.h @@ -0,0 +1,29 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_OVERRIDE_H_ +#define LIB_JXL_BASE_OVERRIDE_H_ + +// 'Trool' for command line arguments: force enable/disable, or use default. + +namespace jxl { + +// No effect if kDefault, otherwise forces a feature (typically a FrameHeader +// flag) on or off. +enum class Override : int { kOn = 1, kOff = 0, kDefault = -1 }; + +static inline Override OverrideFromBool(bool flag) { + return flag ? Override::kOn : Override::kOff; +} + +static inline bool ApplyOverride(Override o, bool default_condition) { + if (o == Override::kOn) return true; + if (o == Override::kOff) return false; + return default_condition; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_OVERRIDE_H_ diff --git a/media/libjxl/src/lib/jxl/base/padded_bytes.cc b/media/libjxl/src/lib/jxl/base/padded_bytes.cc new file mode 100644 index 000000000..11e4bff6f --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/padded_bytes.cc @@ -0,0 +1,63 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/padded_bytes.h" + +namespace jxl { + +void PaddedBytes::IncreaseCapacityTo(size_t capacity) { + JXL_ASSERT(capacity > capacity_); + + size_t new_capacity = std::max(capacity, 3 * capacity_ / 2); + new_capacity = std::max(64, new_capacity); + + // BitWriter writes up to 7 bytes past the end. + CacheAlignedUniquePtr new_data = AllocateArray(new_capacity + 8); + if (new_data == nullptr) { + // Allocation failed, discard all data to ensure this is noticed. + size_ = capacity_ = 0; + return; + } + + if (data_ == nullptr) { + // First allocation: ensure first byte is initialized (won't be copied). + new_data[0] = 0; + } else { + // Subsequent resize: copy existing data to new location. + memcpy(new_data.get(), data_.get(), size_); + // Ensure that the first new byte is initialized, to allow write_bits to + // safely append to the newly-resized PaddedBytes. + new_data[size_] = 0; + } + + capacity_ = new_capacity; + std::swap(new_data, data_); +} + +void PaddedBytes::assign(const uint8_t* new_begin, const uint8_t* new_end) { + JXL_DASSERT(new_begin <= new_end); + const size_t new_size = static_cast(new_end - new_begin); + + // memcpy requires non-overlapping ranges, and resizing might invalidate the + // new range. Neither happens if the new range is completely to the left or + // right of the _allocated_ range (irrespective of size_). + const uint8_t* allocated_end = begin() + capacity_; + const bool outside = new_end <= begin() || new_begin >= allocated_end; + if (outside) { + resize(new_size); // grow or shrink + memcpy(data(), new_begin, new_size); + return; + } + + // There is overlap. The new size cannot be larger because we own the memory + // and the new range cannot include anything outside the allocated range. + JXL_ASSERT(new_size <= capacity_); + + // memmove allows overlap and capacity_ is sufficient. + memmove(data(), new_begin, new_size); + size_ = new_size; // shrink +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/base/padded_bytes.h b/media/libjxl/src/lib/jxl/base/padded_bytes.h new file mode 100644 index 000000000..4534ddf86 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/padded_bytes.h @@ -0,0 +1,197 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_PADDED_BYTES_H_ +#define LIB_JXL_BASE_PADDED_BYTES_H_ + +// std::vector replacement with padding to reduce bounds checks in WriteBits + +#include +#include +#include // memcpy + +#include // max +#include +#include // swap + +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Provides a subset of the std::vector interface with some differences: +// - allows BitWriter to write 64 bits at a time without bounds checking; +// - ONLY zero-initializes the first byte (required by BitWriter); +// - ensures cache-line alignment. +class PaddedBytes { + public: + // Required for output params. + PaddedBytes() : size_(0), capacity_(0) {} + + explicit PaddedBytes(size_t size) : size_(size), capacity_(0) { + if (size != 0) IncreaseCapacityTo(size); + } + + PaddedBytes(size_t size, uint8_t value) : size_(size), capacity_(0) { + if (size != 0) { + IncreaseCapacityTo(size); + } + if (size_ != 0) { + memset(data(), value, size); + } + } + + PaddedBytes(const PaddedBytes& other) : size_(other.size_), capacity_(0) { + if (size_ != 0) IncreaseCapacityTo(size_); + if (data() != nullptr) memcpy(data(), other.data(), size_); + } + PaddedBytes& operator=(const PaddedBytes& other) { + // Self-assignment is safe. + resize(other.size()); + if (data() != nullptr) memmove(data(), other.data(), size_); + return *this; + } + + // default is not OK - need to set other.size_ to 0! + PaddedBytes(PaddedBytes&& other) noexcept + : size_(other.size_), + capacity_(other.capacity_), + data_(std::move(other.data_)) { + other.size_ = other.capacity_ = 0; + } + PaddedBytes& operator=(PaddedBytes&& other) noexcept { + size_ = other.size_; + capacity_ = other.capacity_; + data_ = std::move(other.data_); + + if (&other != this) { + other.size_ = other.capacity_ = 0; + } + return *this; + } + + void swap(PaddedBytes& other) { + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + std::swap(data_, other.data_); + } + + void reserve(size_t capacity) { + if (capacity > capacity_) IncreaseCapacityTo(capacity); + } + + // NOTE: unlike vector, this does not initialize the new data! + // However, we guarantee that write_bits can safely append after + // the resize, as we zero-initialize the first new byte of data. + // If size < capacity(), does not invalidate the memory. + void resize(size_t size) { + if (size > capacity_) IncreaseCapacityTo(size); + size_ = (data() == nullptr) ? 0 : size; + } + + // resize(size) plus explicit initialization of the new data with `value`. + void resize(size_t size, uint8_t value) { + size_t old_size = size_; + resize(size); + if (size_ > old_size) { + memset(data() + old_size, value, size_ - old_size); + } + } + + // Amortized constant complexity due to exponential growth. + void push_back(uint8_t x) { + if (size_ == capacity_) { + IncreaseCapacityTo(capacity_ + 1); + if (data() == nullptr) return; + } + + data_[size_++] = x; + } + + size_t size() const { return size_; } + size_t capacity() const { return capacity_; } + + uint8_t* data() { return data_.get(); } + const uint8_t* data() const { return data_.get(); } + + // std::vector operations implemented in terms of the public interface above. + + void clear() { resize(0); } + bool empty() const { return size() == 0; } + + void assign(std::initializer_list il) { + resize(il.size()); + memcpy(data(), il.begin(), il.size()); + } + + // Replaces data() with [new_begin, new_end); potentially reallocates. + void assign(const uint8_t* new_begin, const uint8_t* new_end); + + uint8_t* begin() { return data(); } + const uint8_t* begin() const { return data(); } + uint8_t* end() { return begin() + size(); } + const uint8_t* end() const { return begin() + size(); } + + uint8_t& operator[](const size_t i) { + BoundsCheck(i); + return data()[i]; + } + const uint8_t& operator[](const size_t i) const { + BoundsCheck(i); + return data()[i]; + } + + uint8_t& back() { + JXL_ASSERT(size() != 0); + return data()[size() - 1]; + } + const uint8_t& back() const { + JXL_ASSERT(size() != 0); + return data()[size() - 1]; + } + + template + void append(const T& other) { + append(reinterpret_cast(other.data()), + reinterpret_cast(other.data()) + other.size()); + } + + void append(const uint8_t* begin, const uint8_t* end) { + if (end - begin > 0) { + size_t old_size = size(); + resize(size() + (end - begin)); + memcpy(data() + old_size, begin, end - begin); + } + } + + private: + void BoundsCheck(size_t i) const { + // <= is safe due to padding and required by BitWriter. + JXL_ASSERT(i <= size()); + } + + // Copies existing data to newly allocated "data_". If allocation fails, + // data() == nullptr and size_ = capacity_ = 0. + // The new capacity will be at least 1.5 times the old capacity. This ensures + // that we avoid quadratic behaviour. + void IncreaseCapacityTo(size_t capacity); + + size_t size_; + size_t capacity_; + CacheAlignedUniquePtr data_; +}; + +template +static inline void Append(const T& s, PaddedBytes* out, + size_t* JXL_RESTRICT byte_pos) { + memcpy(out->data() + *byte_pos, s.data(), s.size()); + *byte_pos += s.size(); + JXL_CHECK(*byte_pos <= out->size()); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_PADDED_BYTES_H_ diff --git a/media/libjxl/src/lib/jxl/base/printf_macros.h b/media/libjxl/src/lib/jxl/base/printf_macros.h new file mode 100644 index 000000000..3215052af --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/printf_macros.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_PRINTF_MACROS_H_ +#define LIB_JXL_BASE_PRINTF_MACROS_H_ + +// Format string macros. These should be included after any other system +// library since those may unconditionally define these, depending on the +// platform. + +// PRIuS and PRIdS macros to print size_t and ssize_t respectively. +#if !defined(PRIdS) +#if defined(_WIN64) +#define PRIdS "lld" +#elif defined(_WIN32) +#define PRIdS "d" +#else +#define PRIdS "zd" +#endif +#endif // PRIdS + +#if !defined(PRIuS) +#if defined(_WIN64) +#define PRIuS "llu" +#elif defined(_WIN32) +#define PRIuS "u" +#else +#define PRIuS "zu" +#endif +#endif // PRIuS + +#endif // LIB_JXL_BASE_PRINTF_MACROS_H_ diff --git a/media/libjxl/src/lib/jxl/base/profiler.h b/media/libjxl/src/lib/jxl/base/profiler.h new file mode 100644 index 000000000..13f95d2b7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/profiler.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_PROFILER_H_ +#define LIB_JXL_BASE_PROFILER_H_ + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// To use the profiler you must set the JPEGXL_ENABLE_PROFILER CMake flag, which +// defines PROFILER_ENABLED and links against the libjxl_profiler library. + +// If zero, this file has no effect and no measurements will be recorded. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif // PROFILER_ENABLED + +#if PROFILER_ENABLED + +#include "lib/profiler/profiler.h" + +#else // !PROFILER_ENABLED + +#define PROFILER_ZONE(name) +#define PROFILER_FUNC +#define PROFILER_PRINT_RESULTS() + +#endif // PROFILER_ENABLED + +#endif // LIB_JXL_BASE_PROFILER_H_ diff --git a/media/libjxl/src/lib/jxl/base/random.cc b/media/libjxl/src/lib/jxl/base/random.cc new file mode 100644 index 000000000..0fbe75806 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/random.cc @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/random.h" + +#include "lib/jxl/fast_math-inl.h" + +namespace jxl { + +Rng::GeometricDistribution::GeometricDistribution(float p) + : inv_log_1mp(1.0 / FastLog2f(1 - p)) {} + +uint32_t Rng::Geometric(const GeometricDistribution& dist) { + float f = UniformF(0, 1); + float log = FastLog2f(1 - f) * dist.inv_log_1mp; + return static_cast(log); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/base/random.h b/media/libjxl/src/lib/jxl/base/random.h new file mode 100644 index 000000000..663b88c95 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/random.h @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_RANDOM_ +#define LIB_JXL_BASE_RANDOM_ + +// Random number generator + distributions. +// We don't use because the implementation (and thus results) differs +// between libstdc++ and libc++. + +#include +#include + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { +struct Rng { + explicit Rng(size_t seed) + : s{static_cast(0x94D049BB133111EBull), + static_cast(0xBF58476D1CE4E5B9ull) + seed} {} + + // Xorshift128+ adapted from xorshift128+-inl.h + uint64_t operator()() { + uint64_t s1 = s[0]; + const uint64_t s0 = s[1]; + const uint64_t bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return bits; + } + + // Uniformly distributed int64_t in [begin, end), under the assumption that + // `end-begin` is significantly smaller than 1<<64, otherwise there is some + // bias. + int64_t UniformI(int64_t begin, int64_t end) { + JXL_DASSERT(end > begin); + return static_cast((*this)() % + static_cast(end - begin)) + + begin; + } + + // Same as UniformI, but for uint64_t. + uint64_t UniformU(uint64_t begin, uint64_t end) { + JXL_DASSERT(end > begin); + return (*this)() % (end - begin) + begin; + } + + // Uniformly distributed float in [begin, end) range. Note: only 23 bits of + // randomness. + float UniformF(float begin, float end) { + float f; + // Bits of a random [1, 2) float. + uint32_t u = ((*this)() >> (64 - 23)) | 0x3F800000; + static_assert(sizeof(f) == sizeof(u), + "Float and U32 must have the same size"); + memcpy(&f, &u, sizeof(f)); + // Note: (end-begin) * f + (2*begin-end) may fail to return a number >= + // begin. + return (end - begin) * (f - 1.0f) + begin; + } + + // Bernoulli trial + bool Bernoulli(float p) { return UniformF(0, 1) < p; } + + // State for geometric distributions. + struct GeometricDistribution { + explicit GeometricDistribution(float p); + + private: + float inv_log_1mp; + friend struct Rng; + }; + + uint32_t Geometric(const GeometricDistribution& dist); + + template + void Shuffle(T* t, size_t n) { + for (size_t i = 0; i + 1 < n; i++) { + size_t a = UniformU(i, n); + std::swap(t[a], t[i]); + } + } + + private: + uint64_t s[2]; +}; + +} // namespace jxl +#endif // LIB_JXL_BASE_RANDOM_ diff --git a/media/libjxl/src/lib/jxl/base/sanitizer_definitions.h b/media/libjxl/src/lib/jxl/base/sanitizer_definitions.h new file mode 100644 index 000000000..b52c538bc --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/sanitizer_definitions.h @@ -0,0 +1,44 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_SANITIZER_DEFINITIONS_H_ +#define LIB_JXL_BASE_SANITIZER_DEFINITIONS_H_ + +#ifdef MEMORY_SANITIZER +#define JXL_MEMORY_SANITIZER 1 +#elif defined(__has_feature) +#if __has_feature(memory_sanitizer) +#define JXL_MEMORY_SANITIZER 1 +#else +#define JXL_MEMORY_SANITIZER 0 +#endif +#else +#define JXL_MEMORY_SANITIZER 0 +#endif + +#ifdef ADDRESS_SANITIZER +#define JXL_ADDRESS_SANITIZER 1 +#elif defined(__has_feature) +#if __has_feature(address_sanitizer) +#define JXL_ADDRESS_SANITIZER 1 +#else +#define JXL_ADDRESS_SANITIZER 0 +#endif +#else +#define JXL_ADDRESS_SANITIZER 0 +#endif + +#ifdef THREAD_SANITIZER +#define JXL_THREAD_SANITIZER 1 +#elif defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define JXL_THREAD_SANITIZER 1 +#else +#define JXL_THREAD_SANITIZER 0 +#endif +#else +#define JXL_THREAD_SANITIZER 0 +#endif +#endif // LIB_JXL_BASE_SANITIZER_DEFINITIONS_H diff --git a/media/libjxl/src/lib/jxl/base/scope_guard.h b/media/libjxl/src/lib/jxl/base/scope_guard.h new file mode 100644 index 000000000..a18a44cb7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/scope_guard.h @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_SCOPE_GUARD_H_ +#define LIB_JXL_BASE_SCOPE_GUARD_H_ + +#include + +namespace jxl { + +template +class ScopeGuard { + public: + // Discourage unnecessary moves / copies. + ScopeGuard(const ScopeGuard &) = delete; + ScopeGuard &operator=(const ScopeGuard &) = delete; + ScopeGuard &operator=(ScopeGuard &&) = delete; + + // Pre-C++17 does not guarantee RVO -> require move constructor. + ScopeGuard(ScopeGuard &&other) : callback_(std::move(other.callback_)) { + other.armed_ = false; + } + + template + explicit ScopeGuard(CallbackParam &&callback) + : callback_(std::forward(callback)), armed_(true) {} + + ~ScopeGuard() { + if (armed_) callback_(); + } + + void Disarm() { armed_ = false; } + + private: + Callback callback_; + bool armed_; +}; + +template +ScopeGuard MakeScopeGuard(Callback &&callback) { + return ScopeGuard{std::forward(callback)}; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_SCOPE_GUARD_H_ diff --git a/media/libjxl/src/lib/jxl/base/span.h b/media/libjxl/src/lib/jxl/base/span.h new file mode 100644 index 000000000..41c3623a4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/span.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_SPAN_H_ +#define LIB_JXL_BASE_SPAN_H_ + +// Span (array view) is a non-owning container that provides cheap "cut" +// operations and could be used as "ArrayLike" data source for PaddedBytes. + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +template +class Span { + public: + constexpr Span() noexcept : Span(nullptr, 0) {} + + constexpr Span(T* array, size_t length) noexcept + : ptr_(array), len_(length) {} + + template + explicit constexpr Span(T (&a)[N]) noexcept : Span(a, N) {} + + template + explicit constexpr Span(const ArrayLike& other) noexcept + : Span(reinterpret_cast(other.data()), other.size()) { + static_assert(sizeof(*other.data()) == sizeof(T), + "Incompatible type of source."); + } + + constexpr T* data() const noexcept { return ptr_; } + + constexpr size_t size() const noexcept { return len_; } + + constexpr bool empty() const noexcept { return len_ == 0; } + + constexpr T& operator[](size_t i) const noexcept { + // MSVC 2015 accepts this as constexpr, but not ptr_[i] + return *(data() + i); + } + + void remove_prefix(size_t n) noexcept { + JXL_ASSERT(size() >= n); + ptr_ += n; + len_ -= n; + } + + private: + T* ptr_; + size_t len_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_SPAN_H_ diff --git a/media/libjxl/src/lib/jxl/base/status.h b/media/libjxl/src/lib/jxl/base/status.h new file mode 100644 index 000000000..682f44001 --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/status.h @@ -0,0 +1,324 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_STATUS_H_ +#define LIB_JXL_BASE_STATUS_H_ + +// Error handling: Status return type + helper macros. + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/sanitizer_definitions.h" + +#if JXL_ADDRESS_SANITIZER || JXL_MEMORY_SANITIZER || JXL_THREAD_SANITIZER +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif // defined(*_SANITIZER) + +namespace jxl { + +// Uncomment to abort when JXL_FAILURE or JXL_STATUS with a fatal error is +// reached: +// #define JXL_CRASH_ON_ERROR + +#ifndef JXL_ENABLE_ASSERT +#define JXL_ENABLE_ASSERT 1 +#endif + +#ifndef JXL_ENABLE_CHECK +#define JXL_ENABLE_CHECK 1 +#endif + +// Pass -DJXL_DEBUG_ON_ERROR at compile time to print debug messages when a +// function returns JXL_FAILURE or calls JXL_NOTIFY_ERROR. Note that this is +// irrelevant if you also pass -DJXL_CRASH_ON_ERROR. +#if defined(JXL_DEBUG_ON_ERROR) || defined(JXL_CRASH_ON_ERROR) +#undef JXL_DEBUG_ON_ERROR +#define JXL_DEBUG_ON_ERROR 1 +#else // JXL_DEBUG_ON_ERROR || JXL_CRASH_ON_ERROR +#ifdef NDEBUG +#define JXL_DEBUG_ON_ERROR 0 +#else // NDEBUG +#define JXL_DEBUG_ON_ERROR 1 +#endif // NDEBUG +#endif // JXL_DEBUG_ON_ERROR || JXL_CRASH_ON_ERROR + +// Pass -DJXL_DEBUG_ON_ALL_ERROR at compile time to print debug messages on +// all error (fatal and non-fatal) status. This implies JXL_DEBUG_ON_ERROR. +#if defined(JXL_DEBUG_ON_ALL_ERROR) +#undef JXL_DEBUG_ON_ALL_ERROR +#define JXL_DEBUG_ON_ALL_ERROR 1 +// JXL_DEBUG_ON_ALL_ERROR implies JXL_DEBUG_ON_ERROR too. +#undef JXL_DEBUG_ON_ERROR +#define JXL_DEBUG_ON_ERROR 1 +#else // JXL_DEBUG_ON_ALL_ERROR +#define JXL_DEBUG_ON_ALL_ERROR 0 +#endif // JXL_DEBUG_ON_ALL_ERROR + +// The Verbose level for the library +#ifndef JXL_DEBUG_V_LEVEL +#define JXL_DEBUG_V_LEVEL 0 +#endif // JXL_DEBUG_V_LEVEL + +// Pass -DJXL_DEBUG_ON_ABORT=0 to disable the debug messages on JXL_ASSERT, +// JXL_CHECK and JXL_ABORT. +#ifndef JXL_DEBUG_ON_ABORT +#define JXL_DEBUG_ON_ABORT 1 +#endif // JXL_DEBUG_ON_ABORT + +// Print a debug message on standard error. You should use the JXL_DEBUG macro +// instead of calling Debug directly. This function returns false, so it can be +// used as a return value in JXL_FAILURE. +JXL_FORMAT(1, 2) +inline JXL_NOINLINE bool Debug(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + return false; +} + +// Print a debug message on standard error if "enabled" is true. "enabled" is +// normally a macro that evaluates to 0 or 1 at compile time, so the Debug +// function is never called and optimized out in release builds. Note that the +// arguments are compiled but not evaluated when enabled is false. The format +// string must be a explicit string in the call, for example: +// JXL_DEBUG(JXL_DEBUG_MYMODULE, "my module message: %d", some_var); +// Add a header at the top of your module's .cc or .h file (depending on whether +// you have JXL_DEBUG calls from the .h as well) like this: +// #ifndef JXL_DEBUG_MYMODULE +// #define JXL_DEBUG_MYMODULE 0 +// #endif JXL_DEBUG_MYMODULE +#define JXL_DEBUG(enabled, format, ...) \ + do { \ + if (enabled) { \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + } \ + } while (0) + +// JXL_DEBUG version that prints the debug message if the global verbose level +// defined at compile time by JXL_DEBUG_V_LEVEL is greater or equal than the +// passed level. +#define JXL_DEBUG_V(level, format, ...) \ + JXL_DEBUG(level <= JXL_DEBUG_V_LEVEL, format, ##__VA_ARGS__) + +// Warnings (via JXL_WARNING) are enabled by default in debug builds (opt and +// debug). +#ifdef JXL_DEBUG_WARNING +#undef JXL_DEBUG_WARNING +#define JXL_DEBUG_WARNING 1 +#else // JXL_DEBUG_WARNING +#ifdef NDEBUG +#define JXL_DEBUG_WARNING 0 +#else // JXL_DEBUG_WARNING +#define JXL_DEBUG_WARNING 1 +#endif // NDEBUG +#endif // JXL_DEBUG_WARNING +#define JXL_WARNING(format, ...) \ + JXL_DEBUG(JXL_DEBUG_WARNING, format, ##__VA_ARGS__) + +// Exits the program after printing a stack trace when possible. +JXL_NORETURN inline JXL_NOINLINE bool Abort() { +#if JXL_ADDRESS_SANITIZER || JXL_MEMORY_SANITIZER || JXL_THREAD_SANITIZER + // If compiled with any sanitizer print a stack trace. This call doesn't crash + // the program, instead the trap below will crash it also allowing gdb to + // break there. + __sanitizer_print_stack_trace(); +#endif // *_SANITIZER) + +#if JXL_COMPILER_MSVC + __debugbreak(); + abort(); +#else + __builtin_trap(); +#endif +} + +// Exits the program after printing file/line plus a formatted string. +#define JXL_ABORT(format, ...) \ + ((JXL_DEBUG_ON_ABORT) && ::jxl::Debug(("%s:%d: JXL_ABORT: " format "\n"), \ + __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort()) + +// Does not guarantee running the code, use only for debug mode checks. +#if JXL_ENABLE_ASSERT +#define JXL_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(JXL_DEBUG_ON_ABORT, "JXL_ASSERT: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_ASSERT(condition) \ + do { \ + } while (0) +#endif + +// Define JXL_IS_DEBUG_BUILD that denotes asan, msan and other debug builds, +// but not opt or release. +#ifndef JXL_IS_DEBUG_BUILD +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) || \ + defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) || \ + defined(__clang_analyzer__) +#define JXL_IS_DEBUG_BUILD 1 +#else +#define JXL_IS_DEBUG_BUILD 0 +#endif +#endif // JXL_IS_DEBUG_BUILD + +// Same as above, but only runs in debug builds (builds where NDEBUG is not +// defined). This is useful for slower asserts that we want to run more rarely +// than usual. These will run on asan, msan and other debug builds, but not in +// opt or release. +#if JXL_IS_DEBUG_BUILD +#define JXL_DASSERT(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(JXL_DEBUG_ON_ABORT, "JXL_DASSERT: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_DASSERT(condition) \ + do { \ + } while (0) +#endif + +// Always runs the condition, so can be used for non-debug calls. +#if JXL_ENABLE_CHECK +#define JXL_CHECK(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(JXL_DEBUG_ON_ABORT, "JXL_CHECK: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_CHECK(condition) \ + do { \ + (void)(condition); \ + } while (0) +#endif + +// A jxl::Status value from a StatusCode or Status which prints a debug message +// when enabled. +#define JXL_STATUS(status, format, ...) \ + ::jxl::StatusMessage(::jxl::Status(status), "%s:%d: " format "\n", __FILE__, \ + __LINE__, ##__VA_ARGS__) + +// Notify of an error but discard the resulting Status value. This is only +// useful for debug builds or when building with JXL_CRASH_ON_ERROR. +#define JXL_NOTIFY_ERROR(format, ...) \ + (void)JXL_STATUS(::jxl::StatusCode::kGenericError, "JXL_ERROR: " format, \ + ##__VA_ARGS__) + +// An error Status with a message. The JXL_STATUS() macro will return a Status +// object with a kGenericError code, but the comma operator helps with +// clang-tidy inference and potentially with optimizations. +#define JXL_FAILURE(format, ...) \ + ((void)JXL_STATUS(::jxl::StatusCode::kGenericError, "JXL_FAILURE: " format, \ + ##__VA_ARGS__), \ + ::jxl::Status(::jxl::StatusCode::kGenericError)) + +// Always evaluates the status exactly once, so can be used for non-debug calls. +// Returns from the current context if the passed Status expression is an error +// (fatal or non-fatal). The return value is the passed Status. +#define JXL_RETURN_IF_ERROR(status) \ + do { \ + ::jxl::Status jxl_return_if_error_status = (status); \ + if (!jxl_return_if_error_status) { \ + (void)::jxl::StatusMessage( \ + jxl_return_if_error_status, \ + "%s:%d: JXL_RETURN_IF_ERROR code=%d: %s\n", __FILE__, __LINE__, \ + static_cast(jxl_return_if_error_status.code()), #status); \ + return jxl_return_if_error_status; \ + } \ + } while (0) + +// As above, but without calling StatusMessage. Intended for bundles (see +// fields.h), which have numerous call sites (-> relevant for code size) and do +// not want to generate excessive messages when decoding partial headers. +#define JXL_QUIET_RETURN_IF_ERROR(status) \ + do { \ + ::jxl::Status jxl_return_if_error_status = (status); \ + if (!jxl_return_if_error_status) { \ + return jxl_return_if_error_status; \ + } \ + } while (0) + +enum class StatusCode : int32_t { + // Non-fatal errors (negative values). + kNotEnoughBytes = -1, + + // The only non-error status code. + kOk = 0, + + // Fatal-errors (positive values) + kGenericError = 1, +}; + +// Drop-in replacement for bool that raises compiler warnings if not used +// after being returned from a function. Example: +// Status LoadFile(...) { return true; } is more compact than +// bool JXL_MUST_USE_RESULT LoadFile(...) { return true; } +// In case of error, the status can carry an extra error code in its value which +// is split between fatal and non-fatal error codes. +class JXL_MUST_USE_RESULT Status { + public: + // We want implicit constructor from bool to allow returning "true" or "false" + // on a function when using Status. "true" means kOk while "false" means a + // generic fatal error. + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Status(bool ok) + : code_(ok ? StatusCode::kOk : StatusCode::kGenericError) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Status(StatusCode code) : code_(code) {} + + // We also want implicit cast to bool to check for return values of functions. + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr operator bool() const { return code_ == StatusCode::kOk; } + + constexpr StatusCode code() const { return code_; } + + // Returns whether the status code is a fatal error. + constexpr bool IsFatalError() const { + return static_cast(code_) > 0; + } + + private: + StatusCode code_; +}; + +// Helper function to create a Status and print the debug message or abort when +// needed. +inline JXL_FORMAT(2, 3) Status + StatusMessage(const Status status, const char* format, ...) { + // This block will be optimized out when JXL_DEBUG_ON_ERROR and + // JXL_DEBUG_ON_ALL_ERROR are both disabled. + if ((JXL_DEBUG_ON_ERROR && status.IsFatalError()) || + (JXL_DEBUG_ON_ALL_ERROR && !status)) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + } +#ifdef JXL_CRASH_ON_ERROR + // JXL_CRASH_ON_ERROR means to Abort() only on non-fatal errors. + if (status.IsFatalError()) { + Abort(); + } +#endif // JXL_CRASH_ON_ERROR + return status; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_STATUS_H_ diff --git a/media/libjxl/src/lib/jxl/base/thread_pool_internal.h b/media/libjxl/src/lib/jxl/base/thread_pool_internal.h new file mode 100644 index 000000000..6e23a335a --- /dev/null +++ b/media/libjxl/src/lib/jxl/base/thread_pool_internal.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_THREAD_POOL_INTERNAL_H_ +#define LIB_JXL_BASE_THREAD_POOL_INTERNAL_H_ + +#include + +#include + +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/threads/thread_parallel_runner_internal.h" + +namespace jxl { + +// Helper class to pass an internal ThreadPool-like object using threads. This +// is only suitable for tests or tools that access the internal API of JPEG XL. +// In other cases the caller will provide a JxlParallelRunner() for handling +// this. This class uses jpegxl::ThreadParallelRunner (from jpegxl_threads +// library). For interface details check jpegxl::ThreadParallelRunner. +class ThreadPoolInternal : public ThreadPool { + public: + // Starts the given number of worker threads and blocks until they are ready. + // "num_worker_threads" defaults to one per hyperthread. If zero, all tasks + // run on the main thread. + explicit ThreadPoolInternal( + int num_worker_threads = std::thread::hardware_concurrency()) + : ThreadPool(&jpegxl::ThreadParallelRunner::Runner, + static_cast(&runner_)), + runner_(num_worker_threads) {} + + ThreadPoolInternal(const ThreadPoolInternal&) = delete; + ThreadPoolInternal& operator&(const ThreadPoolInternal&) = delete; + + size_t NumThreads() const { return runner_.NumThreads(); } + size_t NumWorkerThreads() const { return runner_.NumWorkerThreads(); } + + template + void RunOnEachThread(const Func& func) { + runner_.RunOnEachThread(func); + } + + private: + jpegxl::ThreadParallelRunner runner_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_THREAD_POOL_INTERNAL_H_ diff --git a/media/libjxl/src/lib/jxl/bit_reader_test.cc b/media/libjxl/src/lib/jxl/bit_reader_test.cc new file mode 100644 index 000000000..dbe93d40a --- /dev/null +++ b/media/libjxl/src/lib/jxl/bit_reader_test.cc @@ -0,0 +1,263 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { +namespace { + +TEST(BitReaderTest, ExtendsWithZeroes) { + for (size_t size = 4; size < 32; ++size) { + std::vector data(size, 0xff); + + for (size_t n_bytes = 0; n_bytes < size; n_bytes++) { + BitReader br(Span(data.data(), n_bytes)); + // Read all the bits + for (size_t i = 0; i < n_bytes * kBitsPerByte; i++) { + ASSERT_EQ(br.ReadBits(1), 1u) << "n_bytes=" << n_bytes << " i=" << i; + } + + // PEEK more than the declared size - all will be zero. Cannot consume. + for (size_t i = 0; i < BitReader::kMaxBitsPerCall; i++) { + ASSERT_EQ(br.PeekBits(i), 0u) + << "size=" << size << "n_bytes=" << n_bytes << " i=" << i; + } + + EXPECT_TRUE(br.Close()); + } + } +} + +struct Symbol { + uint32_t num_bits; + uint32_t value; +}; + +// Reading from output gives the same values. +TEST(BitReaderTest, TestRoundTrip) { + ThreadPoolInternal pool(8); + EXPECT_TRUE(RunOnPool( + &pool, 0, 1000, ThreadPool::NoInit, + [](const uint32_t task, size_t /* thread */) { + constexpr size_t kMaxBits = 8000; + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + + std::vector symbols; + symbols.reserve(1000); + + Rng rng(55537 + 129 * task); + + for (;;) { + const uint32_t num_bits = rng.UniformU(1, 33); + if (writer.BitsWritten() + num_bits > kMaxBits) break; + const uint32_t value = rng.UniformU(0, 1ULL << num_bits); + symbols.push_back({num_bits, value}); + writer.Write(num_bits, value); + } + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + for (const Symbol& s : symbols) { + EXPECT_EQ(s.value, reader.ReadBits(s.num_bits)); + } + EXPECT_TRUE(reader.Close()); + }, + "TestTBitReaderRoundTrip")); +} + +// SkipBits is the same as reading that many bits. +TEST(BitReaderTest, TestSkip) { + ThreadPoolInternal pool(8); + EXPECT_TRUE(RunOnPool( + &pool, 0, 96, ThreadPool::NoInit, + [](const uint32_t task, size_t /* thread */) { + constexpr size_t kSize = 100; + + for (size_t skip = 0; skip < 128; ++skip) { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kSize * kBitsPerByte); + // Start with "task" 1-bits. + for (size_t i = 0; i < task; ++i) { + writer.Write(1, 1); + } + + // Write 0-bits that we will skip over + for (size_t i = 0; i < skip; ++i) { + writer.Write(1, 0); + } + + // Write terminator bits '101' + writer.Write(3, 5); + EXPECT_EQ(task + skip + 3, writer.BitsWritten()); + writer.ZeroPadToByte(); + AuxOut aux_out; + ReclaimAndCharge(&writer, &allotment, 0, &aux_out); + EXPECT_LT(aux_out.layers[0].total_bits, kSize * 8); + + BitReader reader1(writer.GetSpan()); + BitReader reader2(writer.GetSpan()); + // Verify initial 1-bits + for (size_t i = 0; i < task; ++i) { + EXPECT_EQ(1u, reader1.ReadBits(1)); + EXPECT_EQ(1u, reader2.ReadBits(1)); + } + + // SkipBits or manually read "skip" bits + reader1.SkipBits(skip); + for (size_t i = 0; i < skip; ++i) { + EXPECT_EQ(0u, reader2.ReadBits(1)) + << " skip=" << skip << " i=" << i; + } + EXPECT_EQ(reader1.TotalBitsConsumed(), reader2.TotalBitsConsumed()); + + // Ensure both readers see the terminator bits. + EXPECT_EQ(5u, reader1.ReadBits(3)); + EXPECT_EQ(5u, reader2.ReadBits(3)); + + EXPECT_TRUE(reader1.Close()); + EXPECT_TRUE(reader2.Close()); + } + }, + "TestSkip")); +} + +// Verifies byte order and different groupings of bits. +TEST(BitReaderTest, TestOrder) { + constexpr size_t kMaxBits = 16; + + // u(1) - bits written into LSBs of first byte + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + for (size_t i = 0; i < 5; ++i) { + writer.Write(1, 1); + } + for (size_t i = 0; i < 5; ++i) { + writer.Write(1, 0); + } + for (size_t i = 0; i < 6; ++i) { + writer.Write(1, 1); + } + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0x1Fu, reader.ReadFixedBits<8>()); + EXPECT_EQ(0xFCu, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // u(8) - get bytes in the same order + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(8, 0xF8); + writer.Write(8, 0x3F); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0xF8u, reader.ReadFixedBits<8>()); + EXPECT_EQ(0x3Fu, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // u(16) - little-endian bytes + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(16, 0xF83F); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0x3Fu, reader.ReadFixedBits<8>()); + EXPECT_EQ(0xF8u, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // Non-byte-aligned, mixed sizes + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(1, 1); + writer.Write(3, 6); + writer.Write(8, 0xDB); + writer.Write(4, 8); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0xBDu, reader.ReadFixedBits<8>()); + EXPECT_EQ(0x8Du, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } +} + +TEST(BitReaderTest, TotalCountersTest) { + uint8_t buf[8] = {1, 2, 3, 4}; + BitReader reader(Span(buf, sizeof(buf))); + + EXPECT_EQ(sizeof(buf), reader.TotalBytes()); + EXPECT_EQ(0u, reader.TotalBitsConsumed()); + reader.ReadFixedBits<1>(); + EXPECT_EQ(1u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<10>(); + EXPECT_EQ(11u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<4>(); + EXPECT_EQ(15u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<1>(); + EXPECT_EQ(16u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<16>(); + EXPECT_EQ(32u, reader.TotalBitsConsumed()); + + EXPECT_TRUE(reader.Close()); +} + +TEST(BitReaderTest, MoveTest) { + uint8_t buf[8] = {1, 2, 3, 4}; + BitReader reader2; + { + BitReader reader1(Span(buf, sizeof(buf))); + + EXPECT_EQ(0u, reader1.TotalBitsConsumed()); + reader1.ReadFixedBits<16>(); + EXPECT_EQ(16u, reader1.TotalBitsConsumed()); + + reader2 = std::move(reader1); + // From this point reader1 is invalid, but can continue to access reader2 + // and we don't need to call Close() on reader1. + } + + EXPECT_EQ(16u, reader2.TotalBitsConsumed()); + EXPECT_EQ(3U, reader2.ReadFixedBits<8>()); + EXPECT_EQ(24u, reader2.TotalBitsConsumed()); + + EXPECT_TRUE(reader2.Close()); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/bits_test.cc b/media/libjxl/src/lib/jxl/bits_test.cc new file mode 100644 index 000000000..699090c8b --- /dev/null +++ b/media/libjxl/src/lib/jxl/bits_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/bits.h" + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(BitsTest, TestNumZeroBits) { + // Zero input is well-defined. + EXPECT_EQ(32u, Num0BitsAboveMS1Bit(0u)); + EXPECT_EQ(64u, Num0BitsAboveMS1Bit(0ull)); + EXPECT_EQ(32u, Num0BitsBelowLS1Bit(0u)); + EXPECT_EQ(64u, Num0BitsBelowLS1Bit(0ull)); + + EXPECT_EQ(31u, Num0BitsAboveMS1Bit(1u)); + EXPECT_EQ(30u, Num0BitsAboveMS1Bit(2u)); + EXPECT_EQ(63u, Num0BitsAboveMS1Bit(1ull)); + EXPECT_EQ(62u, Num0BitsAboveMS1Bit(2ull)); + + EXPECT_EQ(0u, Num0BitsBelowLS1Bit(1u)); + EXPECT_EQ(0u, Num0BitsBelowLS1Bit(1ull)); + EXPECT_EQ(1u, Num0BitsBelowLS1Bit(2u)); + EXPECT_EQ(1u, Num0BitsBelowLS1Bit(2ull)); + + EXPECT_EQ(0u, Num0BitsAboveMS1Bit(0x80000000u)); + EXPECT_EQ(0u, Num0BitsAboveMS1Bit(0x8000000000000000ull)); + EXPECT_EQ(31u, Num0BitsBelowLS1Bit(0x80000000u)); + EXPECT_EQ(63u, Num0BitsBelowLS1Bit(0x8000000000000000ull)); +} + +TEST(BitsTest, TestFloorLog2) { + // for input = [1, 7] + const size_t expected[7] = {0, 1, 1, 2, 2, 2, 2}; + for (uint32_t i = 1; i <= 7; ++i) { + EXPECT_EQ(expected[i - 1], FloorLog2Nonzero(i)) << " " << i; + EXPECT_EQ(expected[i - 1], FloorLog2Nonzero(uint64_t(i))) << " " << i; + } + + EXPECT_EQ(11u, FloorLog2Nonzero(0x00000fffu)); // 4095 + EXPECT_EQ(12u, FloorLog2Nonzero(0x00001000u)); // 4096 + EXPECT_EQ(12u, FloorLog2Nonzero(0x00001001u)); // 4097 + + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000000u)); + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000001u)); + EXPECT_EQ(31u, FloorLog2Nonzero(0xFFFFFFFFu)); + + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000000ull)); + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000001ull)); + EXPECT_EQ(31u, FloorLog2Nonzero(0xFFFFFFFFull)); + + EXPECT_EQ(63u, FloorLog2Nonzero(0x8000000000000000ull)); + EXPECT_EQ(63u, FloorLog2Nonzero(0x8000000000000001ull)); + EXPECT_EQ(63u, FloorLog2Nonzero(0xFFFFFFFFFFFFFFFFull)); +} + +TEST(BitsTest, TestCeilLog2) { + // for input = [1, 7] + const size_t expected[7] = {0, 1, 2, 2, 3, 3, 3}; + for (uint32_t i = 1; i <= 7; ++i) { + EXPECT_EQ(expected[i - 1], CeilLog2Nonzero(i)) << " " << i; + EXPECT_EQ(expected[i - 1], CeilLog2Nonzero(uint64_t(i))) << " " << i; + } + + EXPECT_EQ(12u, CeilLog2Nonzero(0x00000fffu)); // 4095 + EXPECT_EQ(12u, CeilLog2Nonzero(0x00001000u)); // 4096 + EXPECT_EQ(13u, CeilLog2Nonzero(0x00001001u)); // 4097 + + EXPECT_EQ(31u, CeilLog2Nonzero(0x80000000u)); + EXPECT_EQ(32u, CeilLog2Nonzero(0x80000001u)); + EXPECT_EQ(32u, CeilLog2Nonzero(0xFFFFFFFFu)); + + EXPECT_EQ(31u, CeilLog2Nonzero(0x80000000ull)); + EXPECT_EQ(32u, CeilLog2Nonzero(0x80000001ull)); + EXPECT_EQ(32u, CeilLog2Nonzero(0xFFFFFFFFull)); + + EXPECT_EQ(63u, CeilLog2Nonzero(0x8000000000000000ull)); + EXPECT_EQ(64u, CeilLog2Nonzero(0x8000000000000001ull)); + EXPECT_EQ(64u, CeilLog2Nonzero(0xFFFFFFFFFFFFFFFFull)); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/blending.cc b/media/libjxl/src/lib/jxl/blending.cc new file mode 100644 index 000000000..ab37fdabb --- /dev/null +++ b/media/libjxl/src/lib/jxl/blending.cc @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/blending.h" + +#include "lib/jxl/alpha.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +bool NeedsBlending(PassesDecoderState* dec_state) { + const PassesSharedState& state = *dec_state->shared; + if (!(state.frame_header.frame_type == FrameType::kRegularFrame || + state.frame_header.frame_type == FrameType::kSkipProgressive)) { + return false; + } + const auto& info = state.frame_header.blending_info; + bool replace_all = (info.mode == BlendMode::kReplace); + for (const auto& ec_i : state.frame_header.extra_channel_blending_info) { + if (ec_i.mode != BlendMode::kReplace) { + replace_all = false; + } + } + // Replace the full frame: nothing to do. + if (!state.frame_header.custom_size_or_origin && replace_all) { + return false; + } + return true; +} + +void PerformBlending(const float* const* bg, const float* const* fg, + float* const* out, size_t x0, size_t xsize, + const PatchBlending& color_blending, + const PatchBlending* ec_blending, + const std::vector& extra_channel_info) { + bool has_alpha = false; + size_t num_ec = extra_channel_info.size(); + for (size_t i = 0; i < num_ec; i++) { + if (extra_channel_info[i].type == jxl::ExtraChannel::kAlpha) { + has_alpha = true; + break; + } + } + ImageF tmp(xsize, 3 + num_ec); + // Blend extra channels first so that we use the pre-blending alpha. + for (size_t i = 0; i < num_ec; i++) { + if (ec_blending[i].mode == PatchBlendMode::kAdd) { + for (size_t x = 0; x < xsize; x++) { + tmp.Row(3 + i)[x] = bg[3 + i][x + x0] + fg[3 + i][x + x0]; + } + } else if (ec_blending[i].mode == PatchBlendMode::kBlendAbove) { + size_t alpha = ec_blending[i].alpha_channel; + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending(bg[3 + i] + x0, bg[3 + alpha] + x0, fg[3 + i] + x0, + fg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + is_premultiplied, ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kBlendBelow) { + size_t alpha = ec_blending[i].alpha_channel; + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending(fg[3 + i] + x0, fg[3 + alpha] + x0, bg[3 + i] + x0, + bg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + is_premultiplied, ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kAlphaWeightedAddAbove) { + size_t alpha = ec_blending[i].alpha_channel; + PerformAlphaWeightedAdd(bg[3 + i] + x0, fg[3 + i] + x0, + fg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kAlphaWeightedAddBelow) { + size_t alpha = ec_blending[i].alpha_channel; + PerformAlphaWeightedAdd(fg[3 + i] + x0, bg[3 + i] + x0, + bg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kMul) { + PerformMulBlending(bg[3 + i] + x0, fg[3 + i] + x0, tmp.Row(3 + i), xsize, + ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kReplace) { + memcpy(tmp.Row(3 + i), fg[3 + i] + x0, xsize * sizeof(**fg)); + } else if (ec_blending[i].mode == PatchBlendMode::kNone) { + if (xsize) memcpy(tmp.Row(3 + i), bg[3 + i] + x0, xsize * sizeof(**fg)); + } else { + JXL_ABORT("Unreachable"); + } + } + size_t alpha = color_blending.alpha_channel; + + if (color_blending.mode == PatchBlendMode::kAdd || + (color_blending.mode == PatchBlendMode::kAlphaWeightedAddAbove && + !has_alpha) || + (color_blending.mode == PatchBlendMode::kAlphaWeightedAddBelow && + !has_alpha)) { + for (int p = 0; p < 3; p++) { + float* out = tmp.Row(p); + for (size_t x = 0; x < xsize; x++) { + out[x] = bg[p][x + x0] + fg[p][x + x0]; + } + } + } else if (color_blending.mode == PatchBlendMode::kBlendAbove + // blend without alpha is just replace + && has_alpha) { + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending( + {bg[0] + x0, bg[1] + x0, bg[2] + x0, bg[3 + alpha] + x0}, + {fg[0] + x0, fg[1] + x0, fg[2] + x0, fg[3 + alpha] + x0}, + {tmp.Row(0), tmp.Row(1), tmp.Row(2), tmp.Row(3 + alpha)}, xsize, + is_premultiplied, color_blending.clamp); + } else if (color_blending.mode == PatchBlendMode::kBlendBelow + // blend without alpha is just replace + && has_alpha) { + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending( + {fg[0] + x0, fg[1] + x0, fg[2] + x0, fg[3 + alpha] + x0}, + {bg[0] + x0, bg[1] + x0, bg[2] + x0, bg[3 + alpha] + x0}, + {tmp.Row(0), tmp.Row(1), tmp.Row(2), tmp.Row(3 + alpha)}, xsize, + is_premultiplied, color_blending.clamp); + } else if (color_blending.mode == PatchBlendMode::kAlphaWeightedAddAbove) { + JXL_DASSERT(has_alpha); + for (size_t c = 0; c < 3; c++) { + PerformAlphaWeightedAdd(bg[c] + x0, fg[c] + x0, fg[3 + alpha] + x0, + tmp.Row(c), xsize, color_blending.clamp); + } + } else if (color_blending.mode == PatchBlendMode::kAlphaWeightedAddBelow) { + JXL_DASSERT(has_alpha); + for (size_t c = 0; c < 3; c++) { + PerformAlphaWeightedAdd(fg[c] + x0, bg[c] + x0, bg[3 + alpha] + x0, + tmp.Row(c), xsize, color_blending.clamp); + } + } else if (color_blending.mode == PatchBlendMode::kMul) { + for (int p = 0; p < 3; p++) { + PerformMulBlending(bg[p] + x0, fg[p] + x0, tmp.Row(p), xsize, + color_blending.clamp); + } + } else if (color_blending.mode == PatchBlendMode::kReplace || + color_blending.mode == PatchBlendMode::kBlendAbove || + color_blending.mode == PatchBlendMode::kBlendBelow) { // kReplace + for (size_t p = 0; p < 3; p++) { + memcpy(tmp.Row(p), fg[p] + x0, xsize * sizeof(**fg)); + } + } else if (color_blending.mode == PatchBlendMode::kNone) { + for (size_t p = 0; p < 3; p++) { + memcpy(tmp.Row(p), bg[p] + x0, xsize * sizeof(**fg)); + } + } else { + JXL_ABORT("Unreachable"); + } + for (size_t i = 0; i < 3 + num_ec; i++) { + if (xsize != 0) memcpy(out[i] + x0, tmp.Row(i), xsize * sizeof(**out)); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/blending.h b/media/libjxl/src/lib/jxl/blending.h new file mode 100644 index 000000000..7eab7d50c --- /dev/null +++ b/media/libjxl/src/lib/jxl/blending.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BLENDING_H_ +#define LIB_JXL_BLENDING_H_ +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +bool NeedsBlending(PassesDecoderState* dec_state); + +void PerformBlending(const float* const* bg, const float* const* fg, + float* const* out, size_t x0, size_t xsize, + const PatchBlending& color_blending, + const PatchBlending* ec_blending, + const std::vector& extra_channel_info); + +} // namespace jxl + +#endif // LIB_JXL_BLENDING_H_ diff --git a/media/libjxl/src/lib/jxl/blending_test.cc b/media/libjxl/src/lib/jxl/blending_test.cc new file mode 100644 index 000000000..e032b99a4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/blending_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/codec.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +using ::testing::SizeIs; + +TEST(BlendingTest, Crops) { + ThreadPool* pool = nullptr; + + const PaddedBytes compressed = + ReadTestData("jxl/blending/cropped_traffic_light.jxl"); + CodecInOut decoded; + ASSERT_TRUE(test::DecodeFile({}, compressed, &decoded, pool)); + ASSERT_THAT(decoded.frames, SizeIs(4)); + + int i = 0; + for (const ImageBundle& ib : decoded.frames) { + std::ostringstream filename; + filename << "jxl/blending/cropped_traffic_light_frame-" << i << ".png"; + const PaddedBytes compressed_frame = ReadTestData(filename.str()); + CodecInOut frame; + ASSERT_TRUE(SetFromBytes(Span(compressed_frame), &frame)); + EXPECT_TRUE(SamePixels(ib.color(), *frame.Main().color())); + ++i; + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/box_content_decoder.cc b/media/libjxl/src/lib/jxl/box_content_decoder.cc new file mode 100644 index 000000000..c4cba3a31 --- /dev/null +++ b/media/libjxl/src/lib/jxl/box_content_decoder.cc @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/box_content_decoder.h" + +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +JxlBoxContentDecoder::JxlBoxContentDecoder() {} + +JxlBoxContentDecoder::~JxlBoxContentDecoder() { + if (brotli_dec) { + BrotliDecoderDestroyInstance(brotli_dec); + } +} + +void JxlBoxContentDecoder::StartBox(bool brob_decode, bool box_until_eof, + size_t contents_size) { + if (brotli_dec) { + BrotliDecoderDestroyInstance(brotli_dec); + brotli_dec = nullptr; + } + header_done_ = false; + brob_decode_ = brob_decode; + box_until_eof_ = box_until_eof; + remaining_ = box_until_eof ? 0 : contents_size; + pos_ = 0; +} + +JxlDecoderStatus JxlBoxContentDecoder::Process(const uint8_t* next_in, + size_t avail_in, size_t box_pos, + uint8_t** next_out, + size_t* avail_out) { + next_in += pos_ - box_pos; + avail_in -= pos_ - box_pos; + + if (brob_decode_) { + if (!header_done_) { + if (avail_in < 4) return JXL_DEC_NEED_MORE_INPUT; + if (!box_until_eof_) { + if (remaining_ < 4) return JXL_DEC_ERROR; + remaining_ -= 4; + } + next_in += 4; + avail_in -= 4; + pos_ += 4; + header_done_ = true; + } + + if (!brotli_dec) { + brotli_dec = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + } + + const uint8_t* next_in_before = next_in; + uint8_t* next_out_before = *next_out; + msan::MemoryIsInitialized(next_in, avail_in); + BrotliDecoderResult res = BrotliDecoderDecompressStream( + brotli_dec, &avail_in, &next_in, avail_out, next_out, nullptr); + size_t consumed = next_in - next_in_before; + size_t produced = *next_out - next_out_before; + if (res == BROTLI_DECODER_RESULT_ERROR) { + return JXL_DEC_ERROR; + } + msan::UnpoisonMemory(next_out_before, produced); + pos_ += consumed; + if (!box_until_eof_) remaining_ -= consumed; + if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT) { + return JXL_DEC_NEED_MORE_INPUT; + } + if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + return JXL_DEC_BOX_NEED_MORE_OUTPUT; + } + if (res == BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_DEC_SUCCESS; + } + // unknown Brotli result + return JXL_DEC_ERROR; + } else { + // remaining box bytes as seen from dec->file_pos + size_t can_read = avail_in; + if (!box_until_eof_) can_read = std::min(can_read, remaining_); + size_t to_write = std::min(can_read, *avail_out); + memcpy(*next_out, next_in, to_write); + + *next_out += to_write; + *avail_out -= to_write; + if (!box_until_eof_) remaining_ -= to_write; + pos_ += to_write; + + if (to_write < can_read) return JXL_DEC_BOX_NEED_MORE_OUTPUT; + + if (!box_until_eof_ && remaining_ > 0) return JXL_DEC_NEED_MORE_INPUT; + + return JXL_DEC_SUCCESS; + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/box_content_decoder.h b/media/libjxl/src/lib/jxl/box_content_decoder.h new file mode 100644 index 000000000..41d878c13 --- /dev/null +++ b/media/libjxl/src/lib/jxl/box_content_decoder.h @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BOX_CONTENT_DECODER_H_ +#define LIB_JXL_BOX_CONTENT_DECODER_H_ + +#include +#include +#include + +#include +#include + +#include "jxl/decode.h" + +namespace jxl { + +/** Outputs the contents of a box in a streaming fashion, either directly, or + * optionally decoding with Brotli, in case of a brob box. The input must be + * the contents of a box, excluding the box header. + */ +class JxlBoxContentDecoder { + public: + JxlBoxContentDecoder(); + ~JxlBoxContentDecoder(); + + void StartBox(bool brob_decode, bool box_until_eof, size_t contents_size); + + // Outputs decoded bytes from the box, decoding with brotli if needed. + // box_pos is the position in the box content which next_in points to. + // Returns success, whether more input or output bytes are needed, or error. + JxlDecoderStatus Process(const uint8_t* next_in, size_t avail_in, + size_t box_pos, uint8_t** next_out, + size_t* avail_out); + + private: + BrotliDecoderState* brotli_dec; + + bool header_done_; + bool brob_decode_; + bool box_until_eof_; + size_t remaining_; + size_t pos_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BOX_CONTENT_DECODER_H_ diff --git a/media/libjxl/src/lib/jxl/butteraugli/butteraugli.cc b/media/libjxl/src/lib/jxl/butteraugli/butteraugli.cc new file mode 100644 index 000000000..ee1a53013 --- /dev/null +++ b/media/libjxl/src/lib/jxl/butteraugli/butteraugli.cc @@ -0,0 +1,1988 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Author: Jyrki Alakuijala (jyrki.alakuijala@gmail.com) +// +// The physical architecture of butteraugli is based on the following naming +// convention: +// * Opsin - dynamics of the photosensitive chemicals in the retina +// with their immediate electrical processing +// * Xyb - hybrid opponent/trichromatic color space +// x is roughly red-subtract-green. +// y is yellow. +// b is blue. +// Xyb values are computed from Opsin mixing, not directly from rgb. +// * Mask - for visual masking +// * Hf - color modeling for spatially high-frequency features +// * Lf - color modeling for spatially low-frequency features +// * Diffmap - to cluster and build an image of error between the images +// * Blur - to hold the smoothing code + +#include "lib/jxl/butteraugli/butteraugli.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#if PROFILER_ENABLED +#include +#endif // PROFILER_ENABLED + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/butteraugli/butteraugli.cc" +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/gauss_blur.h" +#include "lib/jxl/image_ops.h" + +#ifndef JXL_BUTTERAUGLI_ONCE +#define JXL_BUTTERAUGLI_ONCE + +namespace jxl { + +std::vector ComputeKernel(float sigma) { + const float m = 2.25; // Accuracy increases when m is increased. + const double scaler = -1.0 / (2.0 * sigma * sigma); + const int diff = std::max(1, m * std::fabs(sigma)); + std::vector kernel(2 * diff + 1); + for (int i = -diff; i <= diff; ++i) { + kernel[i + diff] = std::exp(scaler * i * i); + } + return kernel; +} + +void ConvolveBorderColumn(const ImageF& in, const std::vector& kernel, + const size_t x, float* BUTTERAUGLI_RESTRICT row_out) { + const size_t offset = kernel.size() / 2; + int minx = x < offset ? 0 : x - offset; + int maxx = std::min(in.xsize() - 1, x + offset); + float weight = 0.0f; + for (int j = minx; j <= maxx; ++j) { + weight += kernel[j - x + offset]; + } + float scale = 1.0f / weight; + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y); + float sum = 0.0f; + for (int j = minx; j <= maxx; ++j) { + sum += row_in[j] * kernel[j - x + offset]; + } + row_out[y] = sum * scale; + } +} + +// Computes a horizontal convolution and transposes the result. +void ConvolutionWithTranspose(const ImageF& in, + const std::vector& kernel, + ImageF* BUTTERAUGLI_RESTRICT out) { + PROFILER_FUNC; + JXL_CHECK(out->xsize() == in.ysize()); + JXL_CHECK(out->ysize() == in.xsize()); + const size_t len = kernel.size(); + const size_t offset = len / 2; + float weight_no_border = 0.0f; + for (size_t j = 0; j < len; ++j) { + weight_no_border += kernel[j]; + } + const float scale_no_border = 1.0f / weight_no_border; + const size_t border1 = std::min(in.xsize(), offset); + const size_t border2 = in.xsize() > offset ? in.xsize() - offset : 0; + std::vector scaled_kernel(len / 2 + 1); + for (size_t i = 0; i <= len / 2; ++i) { + scaled_kernel[i] = kernel[i] * scale_no_border; + } + + // middle + switch (len) { + case 7: { + PROFILER_ZONE("conv7"); + const float sk0 = scaled_kernel[0]; + const float sk1 = scaled_kernel[1]; + const float sk2 = scaled_kernel[2]; + const float sk3 = scaled_kernel[3]; + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + const float sum0 = (row_in[0] + row_in[6]) * sk0; + const float sum1 = (row_in[1] + row_in[5]) * sk1; + const float sum2 = (row_in[2] + row_in[4]) * sk2; + const float sum = (row_in[3]) * sk3 + sum0 + sum1 + sum2; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum; + } + } + } break; + case 13: { + PROFILER_ZONE("conv15"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[12]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[11]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[10]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[9]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[8]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[7]) * scaled_kernel[5]; + const float sum = (row_in[6]) * scaled_kernel[6]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 15: { + PROFILER_ZONE("conv15"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[14]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[13]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[12]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[11]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[10]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[9]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[8]) * scaled_kernel[6]; + const float sum = (row_in[7]) * scaled_kernel[7]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 33: { + PROFILER_ZONE("conv33"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[32]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[31]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[30]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[29]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[28]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[27]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[26]) * scaled_kernel[6]; + sum3 += (row_in[7] + row_in[25]) * scaled_kernel[7]; + sum0 += (row_in[8] + row_in[24]) * scaled_kernel[8]; + sum1 += (row_in[9] + row_in[23]) * scaled_kernel[9]; + sum2 += (row_in[10] + row_in[22]) * scaled_kernel[10]; + sum3 += (row_in[11] + row_in[21]) * scaled_kernel[11]; + sum0 += (row_in[12] + row_in[20]) * scaled_kernel[12]; + sum1 += (row_in[13] + row_in[19]) * scaled_kernel[13]; + sum2 += (row_in[14] + row_in[18]) * scaled_kernel[14]; + sum3 += (row_in[15] + row_in[17]) * scaled_kernel[15]; + const float sum = (row_in[16]) * scaled_kernel[16]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + default: + printf("Warning: Unexpected kernel size! %" PRIuS "\n", len); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y); + for (size_t x = border1; x < border2; ++x) { + const int d = x - offset; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + float sum = 0.0f; + size_t j; + for (j = 0; j <= len / 2; ++j) { + sum += row_in[d + j] * scaled_kernel[j]; + } + for (; j < len; ++j) { + sum += row_in[d + j] * scaled_kernel[len - 1 - j]; + } + row_out[y] = sum; + } + } + } + // left border + for (size_t x = 0; x < border1; ++x) { + ConvolveBorderColumn(in, kernel, x, out->Row(x)); + } + + // right border + for (size_t x = border2; x < in.xsize(); ++x) { + ConvolveBorderColumn(in, kernel, x, out->Row(x)); + } +} + +// A blur somewhat similar to a 2D Gaussian blur. +// See: https://en.wikipedia.org/wiki/Gaussian_blur +// +// This is a bottleneck because the sigma can be quite large (>7). We can use +// gauss_blur.cc (runtime independent of sigma, closer to a 4*sigma truncated +// Gaussian and our 2.25 in ComputeKernel), but its boundary conditions are +// zero-valued. This leads to noticeable differences at the edges of diffmaps. +// We retain a special case for 5x5 kernels (even faster than gauss_blur), +// optionally use gauss_blur followed by fixup of the borders for large images, +// or fall back to the previous truncated FIR followed by a transpose. +void Blur(const ImageF& in, float sigma, const ButteraugliParams& params, + BlurTemp* temp, ImageF* out) { + std::vector kernel = ComputeKernel(sigma); + // Separable5 does an in-place convolution, so this fast path is not safe if + // in aliases out. + if (kernel.size() == 5 && &in != out) { + float sum_weights = 0.0f; + for (const float w : kernel) { + sum_weights += w; + } + const float scale = 1.0f / sum_weights; + const float w0 = kernel[2] * scale; + const float w1 = kernel[1] * scale; + const float w2 = kernel[0] * scale; + const WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + }; + Separable5(in, Rect(in), weights, /*pool=*/nullptr, out); + return; + } + + ImageF* JXL_RESTRICT temp_t = temp->GetTransposed(in); + ConvolutionWithTranspose(in, kernel, temp_t); + ConvolutionWithTranspose(*temp_t, kernel, out); +} + +// Allows PaddedMaltaUnit to call either function via overloading. +struct MaltaTagLF {}; +struct MaltaTag {}; + +} // namespace jxl + +#endif // JXL_BUTTERAUGLI_ONCE + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::MulSub; +using hwy::HWY_NAMESPACE::Neg; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +template +HWY_INLINE V MaximumClamp(D d, V v, double kMaxVal) { + static const double kMul = 0.724216145665; + const V mul = Set(d, kMul); + const V maxval = Set(d, kMaxVal); + // If greater than maxval or less than -maxval, replace with if_*. + const V if_pos = MulAdd(Sub(v, maxval), mul, maxval); + const V if_neg = MulSub(Add(v, maxval), mul, maxval); + const V pos_or_v = IfThenElse(Ge(v, maxval), if_pos, v); + return IfThenElse(Lt(v, Neg(maxval)), if_neg, pos_or_v); +} + +// Make area around zero less important (remove it). +template +HWY_INLINE V RemoveRangeAroundZero(const D d, const double kw, const V x) { + const auto w = Set(d, kw); + return IfThenElse(Gt(x, w), Sub(x, w), + IfThenElseZero(Lt(x, Neg(w)), Add(x, w))); +} + +// Make area around zero more important (2x it until the limit). +template +HWY_INLINE V AmplifyRangeAroundZero(const D d, const double kw, const V x) { + const auto w = Set(d, kw); + return IfThenElse(Gt(x, w), Add(x, w), + IfThenElse(Lt(x, Neg(w)), Sub(x, w), Add(x, x))); +} + +// XybLowFreqToVals converts from low-frequency XYB space to the 'vals' space. +// Vals space can be converted to L2-norm space (Euclidean and normalized) +// through visual masking. +template +HWY_INLINE void XybLowFreqToVals(const D d, const V& x, const V& y, + const V& b_arg, V* HWY_RESTRICT valx, + V* HWY_RESTRICT valy, V* HWY_RESTRICT valb) { + static const double xmuli = 32.2217497012; + static const double ymuli = 13.7697791434; + static const double bmuli = 47.504615728; + static const double y_to_b_muli = -0.362267051518; + const V xmul = Set(d, xmuli); + const V ymul = Set(d, ymuli); + const V bmul = Set(d, bmuli); + const V y_to_b_mul = Set(d, y_to_b_muli); + const V b = MulAdd(y_to_b_mul, y, b_arg); + *valb = Mul(b, bmul); + *valx = Mul(x, xmul); + *valy = Mul(y, ymul); +} + +void SuppressXByY(const ImageF& in_x, const ImageF& in_y, const double yw, + ImageF* HWY_RESTRICT out) { + JXL_DASSERT(SameSize(in_x, in_y) && SameSize(in_x, *out)); + const size_t xsize = in_x.xsize(); + const size_t ysize = in_x.ysize(); + + const HWY_FULL(float) d; + static const double s = 0.653020556257; + const auto sv = Set(d, s); + const auto one_minus_s = Set(d, 1.0 - s); + const auto ywv = Set(d, yw); + + for (size_t y = 0; y < ysize; ++y) { + const float* HWY_RESTRICT row_x = in_x.ConstRow(y); + const float* HWY_RESTRICT row_y = in_y.ConstRow(y); + float* HWY_RESTRICT row_out = out->Row(y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto vx = Load(d, row_x + x); + const auto vy = Load(d, row_y + x); + const auto scaler = + MulAdd(Div(ywv, MulAdd(vy, vy, ywv)), one_minus_s, sv); + Store(Mul(scaler, vx), d, row_out + x); + } + } +} + +static void SeparateFrequencies(size_t xsize, size_t ysize, + const ButteraugliParams& params, + BlurTemp* blur_temp, const Image3F& xyb, + PsychoImage& ps) { + PROFILER_FUNC; + const HWY_FULL(float) d; + + // Extract lf ... + static const double kSigmaLf = 7.15593339443; + static const double kSigmaHf = 3.22489901262; + static const double kSigmaUhf = 1.56416327805; + ps.mf = Image3F(xsize, ysize); + ps.hf[0] = ImageF(xsize, ysize); + ps.hf[1] = ImageF(xsize, ysize); + ps.lf = Image3F(xyb.xsize(), xyb.ysize()); + ps.mf = Image3F(xyb.xsize(), xyb.ysize()); + for (int i = 0; i < 3; ++i) { + Blur(xyb.Plane(i), kSigmaLf, params, blur_temp, &ps.lf.Plane(i)); + + // ... and keep everything else in mf. + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_xyb = xyb.PlaneRow(i, y); + const float* BUTTERAUGLI_RESTRICT row_lf = ps.lf.ConstPlaneRow(i, y); + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(i, y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto mf = Sub(Load(d, row_xyb + x), Load(d, row_lf + x)); + Store(mf, d, row_mf + x); + } + } + if (i == 2) { + Blur(ps.mf.Plane(i), kSigmaHf, params, blur_temp, &ps.mf.Plane(i)); + break; + } + // Divide mf into mf and hf. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(i, y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[i].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + Store(Load(d, row_mf + x), d, row_hf + x); + } + } + Blur(ps.mf.Plane(i), kSigmaHf, params, blur_temp, &ps.mf.Plane(i)); + static const double kRemoveMfRange = 0.29; + static const double kAddMfRange = 0.1; + if (i == 0) { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[0].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto mf = Load(d, row_mf + x); + auto hf = Sub(Load(d, row_hf + x), mf); + mf = RemoveRangeAroundZero(d, kRemoveMfRange, mf); + Store(mf, d, row_mf + x); + Store(hf, d, row_hf + x); + } + } + } else { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[1].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto mf = Load(d, row_mf + x); + auto hf = Sub(Load(d, row_hf + x), mf); + + mf = AmplifyRangeAroundZero(d, kAddMfRange, mf); + Store(mf, d, row_mf + x); + Store(hf, d, row_hf + x); + } + } + } + } + + // Temporarily used as output of SuppressXByY + ps.uhf[0] = ImageF(xsize, ysize); + ps.uhf[1] = ImageF(xsize, ysize); + + // Suppress red-green by intensity change in the high freq channels. + static const double suppress = 46.0; + SuppressXByY(ps.hf[0], ps.hf[1], suppress, &ps.uhf[0]); + // hf is the SuppressXByY output, uhf will be written below. + ps.hf[0].Swap(ps.uhf[0]); + + for (int i = 0; i < 2; ++i) { + // Divide hf into hf and uhf. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = ps.uhf[i].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[i].Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_uhf[x] = row_hf[x]; + } + } + Blur(ps.hf[i], kSigmaUhf, params, blur_temp, &ps.hf[i]); + static const double kRemoveHfRange = 1.5; + static const double kAddHfRange = 0.132; + static const double kRemoveUhfRange = 0.04; + static const double kMaxclampHf = 28.4691806922; + static const double kMaxclampUhf = 5.19175294647; + static double kMulYHf = 2.155; + static double kMulYUhf = 2.69313763794; + if (i == 0) { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = ps.uhf[0].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[0].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto hf = Load(d, row_hf + x); + auto uhf = Sub(Load(d, row_uhf + x), hf); + hf = RemoveRangeAroundZero(d, kRemoveHfRange, hf); + uhf = RemoveRangeAroundZero(d, kRemoveUhfRange, uhf); + Store(hf, d, row_hf + x); + Store(uhf, d, row_uhf + x); + } + } + } else { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = ps.uhf[1].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[1].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto hf = Load(d, row_hf + x); + hf = MaximumClamp(d, hf, kMaxclampHf); + + auto uhf = Sub(Load(d, row_uhf + x), hf); + uhf = MaximumClamp(d, uhf, kMaxclampUhf); + uhf = Mul(uhf, Set(d, kMulYUhf)); + Store(uhf, d, row_uhf + x); + + hf = Mul(hf, Set(d, kMulYHf)); + hf = AmplifyRangeAroundZero(d, kAddHfRange, hf); + Store(hf, d, row_hf + x); + } + } + } + } + // Modify range around zero code only concerns the high frequency + // planes and only the X and Y channels. + // Convert low freq xyb to vals space so that we can do a simple squared sum + // diff on the low frequencies later. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_x = ps.lf.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_y = ps.lf.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_b = ps.lf.PlaneRow(2, y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto valx = Undefined(d); + auto valy = Undefined(d); + auto valb = Undefined(d); + XybLowFreqToVals(d, Load(d, row_x + x), Load(d, row_y + x), + Load(d, row_b + x), &valx, &valy, &valb); + Store(valx, d, row_x + x); + Store(valy, d, row_y + x); + Store(valb, d, row_b + x); + } + } +} + +namespace { +template +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d) { + return Add(Add(a, b), Add(c, d)); +} +template +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d, V e) { + return Sum(a, b, c, Add(d, e)); +} +template +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d, V e, V f, V g) { + return Sum(a, b, c, Sum(d, e, f, g)); +} +template +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d, V e, V f, V g, V h, V i) { + return Add(Add(Sum(a, b, c, d), Sum(e, f, g, h)), i); +} +} // namespace + +template +Vec MaltaUnit(MaltaTagLF /*tag*/, const D df, + const float* BUTTERAUGLI_RESTRICT d, const intptr_t xs) { + const intptr_t xs3 = 3 * xs; + + const auto center = LoadU(df, d); + + // x grows, y constant + const auto sum_yconst = Sum(LoadU(df, d - 4), LoadU(df, d - 2), center, + LoadU(df, d + 2), LoadU(df, d + 4)); + // Will return this, sum of all line kernels + auto retval = Mul(sum_yconst, sum_yconst); + { + // y grows, x constant + auto sum = Sum(LoadU(df, d - xs3 - xs), LoadU(df, d - xs - xs), center, + LoadU(df, d + xs + xs), LoadU(df, d + xs3 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // both grow + auto sum = Sum(LoadU(df, d - xs3 - 3), LoadU(df, d - xs - xs - 2), center, + LoadU(df, d + xs + xs + 2), LoadU(df, d + xs3 + 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows, x shrinks + auto sum = Sum(LoadU(df, d - xs3 + 3), LoadU(df, d - xs - xs + 2), center, + LoadU(df, d + xs + xs - 2), LoadU(df, d + xs3 - 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x shrinks 1 -> -1 + auto sum = + Sum(LoadU(df, d - xs3 - xs + 1), LoadU(df, d - xs - xs + 1), center, + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 + xs - 1)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x grows -1 -> 1 + auto sum = + Sum(LoadU(df, d - xs3 - xs - 1), LoadU(df, d - xs - xs - 1), center, + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + xs + 1)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y grows -1 to 1 + auto sum = Sum(LoadU(df, d - 4 - xs), LoadU(df, d - 2 - xs), center, + LoadU(df, d + 2 + xs), LoadU(df, d + 4 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y shrinks 1 to -1 + auto sum = Sum(LoadU(df, d - 4 + xs), LoadU(df, d - 2 + xs), center, + LoadU(df, d + 2 - xs), LoadU(df, d + 4 - xs)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1__*______ + 2___*_____ + 3_________ + 4____0____ + 5_________ + 6_____*___ + 7______*__ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 - 2), LoadU(df, d - xs - xs - 1), center, + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1______*__ + 2_____*___ + 3_________ + 4____0____ + 5_________ + 6___*_____ + 7__*______ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 + 2), LoadU(df, d - xs - xs + 1), center, + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 - 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_*_______ + 3__*______ + 4____0____ + 5______*__ + 6_______*_ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs - 3), LoadU(df, d - xs - 2), center, + LoadU(df, d + xs + 2), LoadU(df, d + xs + xs + 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_______*_ + 3______*__ + 4____0____ + 5__*______ + 6_*_______ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs + 3), LoadU(df, d - xs + 2), center, + LoadU(df, d + xs - 2), LoadU(df, d + xs + xs - 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2________* + 3______*__ + 4____0____ + 5__*______ + 6*________ + 7_________ + 8_________ */ + + auto sum = Sum(LoadU(df, d + xs + xs - 4), LoadU(df, d + xs - 2), center, + LoadU(df, d - xs + 2), LoadU(df, d - xs - xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2*________ + 3__*______ + 4____0____ + 5______*__ + 6________* + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs - 4), LoadU(df, d - xs - 2), center, + LoadU(df, d + xs + 2), LoadU(df, d + xs + xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0__*______ + 1_________ + 2___*_____ + 3_________ + 4____0____ + 5_________ + 6_____*___ + 7_________ + 8______*__ */ + auto sum = + Sum(LoadU(df, d - xs3 - xs - 2), LoadU(df, d - xs - xs - 1), center, + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + xs + 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0______*__ + 1_________ + 2_____*___ + 3_________ + 4____0____ + 5_________ + 6___*_____ + 7_________ + 8__*______ */ + auto sum = + Sum(LoadU(df, d - xs3 - xs + 2), LoadU(df, d - xs - xs + 1), center, + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 + xs - 2)); + retval = MulAdd(sum, sum, retval); + } + return retval; +} + +template +Vec MaltaUnit(MaltaTag /*tag*/, const D df, + const float* BUTTERAUGLI_RESTRICT d, const intptr_t xs) { + const intptr_t xs3 = 3 * xs; + + const auto center = LoadU(df, d); + + // x grows, y constant + const auto sum_yconst = + Sum(LoadU(df, d - 4), LoadU(df, d - 3), LoadU(df, d - 2), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + 2), + LoadU(df, d + 3), LoadU(df, d + 4)); + // Will return this, sum of all line kernels + auto retval = Mul(sum_yconst, sum_yconst); + + { + // y grows, x constant + auto sum = Sum(LoadU(df, d - xs3 - xs), LoadU(df, d - xs3), + LoadU(df, d - xs - xs), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs), + LoadU(df, d + xs3), LoadU(df, d + xs3 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // both grow + auto sum = Sum(LoadU(df, d - xs3 - 3), LoadU(df, d - xs - xs - 2), + LoadU(df, d - xs - 1), center, LoadU(df, d + xs + 1), + LoadU(df, d + xs + xs + 2), LoadU(df, d + xs3 + 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows, x shrinks + auto sum = Sum(LoadU(df, d - xs3 + 3), LoadU(df, d - xs - xs + 2), + LoadU(df, d - xs + 1), center, LoadU(df, d + xs - 1), + LoadU(df, d + xs + xs - 2), LoadU(df, d + xs3 - 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x shrinks 1 -> -1 + auto sum = Sum(LoadU(df, d - xs3 - xs + 1), LoadU(df, d - xs3 + 1), + LoadU(df, d - xs - xs + 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs - 1), + LoadU(df, d + xs3 - 1), LoadU(df, d + xs3 + xs - 1)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x grows -1 -> 1 + auto sum = Sum(LoadU(df, d - xs3 - xs - 1), LoadU(df, d - xs3 - 1), + LoadU(df, d - xs - xs - 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs + 1), + LoadU(df, d + xs3 + 1), LoadU(df, d + xs3 + xs + 1)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y grows -1 to 1 + auto sum = + Sum(LoadU(df, d - 4 - xs), LoadU(df, d - 3 - xs), LoadU(df, d - 2 - xs), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + 2 + xs), + LoadU(df, d + 3 + xs), LoadU(df, d + 4 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y shrinks 1 to -1 + auto sum = + Sum(LoadU(df, d - 4 + xs), LoadU(df, d - 3 + xs), LoadU(df, d - 2 + xs), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + 2 - xs), + LoadU(df, d + 3 - xs), LoadU(df, d + 4 - xs)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1__*______ + 2___*_____ + 3___*_____ + 4____0____ + 5_____*___ + 6_____*___ + 7______*__ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 - 2), LoadU(df, d - xs - xs - 1), + LoadU(df, d - xs - 1), center, LoadU(df, d + xs + 1), + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1______*__ + 2_____*___ + 3_____*___ + 4____0____ + 5___*_____ + 6___*_____ + 7__*______ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 + 2), LoadU(df, d - xs - xs + 1), + LoadU(df, d - xs + 1), center, LoadU(df, d + xs - 1), + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 - 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_*_______ + 3__**_____ + 4____0____ + 5_____**__ + 6_______*_ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs - 3), LoadU(df, d - xs - 2), + LoadU(df, d - xs - 1), center, LoadU(df, d + xs + 1), + LoadU(df, d + xs + 2), LoadU(df, d + xs + xs + 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_______*_ + 3_____**__ + 4____0____ + 5__**_____ + 6_*_______ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs + 3), LoadU(df, d - xs + 2), + LoadU(df, d - xs + 1), center, LoadU(df, d + xs - 1), + LoadU(df, d + xs - 2), LoadU(df, d + xs + xs - 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_________ + 3______*** + 4___*0*___ + 5***______ + 6_________ + 7_________ + 8_________ */ + + auto sum = + Sum(LoadU(df, d + xs - 4), LoadU(df, d + xs - 3), LoadU(df, d + xs - 2), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d - xs + 2), + LoadU(df, d - xs + 3), LoadU(df, d - xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_________ + 3***______ + 4___*0*___ + 5______*** + 6_________ + 7_________ + 8_________ */ + auto sum = + Sum(LoadU(df, d - xs - 4), LoadU(df, d - xs - 3), LoadU(df, d - xs - 2), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + xs + 2), + LoadU(df, d + xs + 3), LoadU(df, d + xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0___*_____ + 1___*_____ + 2___*_____ + 3____*____ + 4____0____ + 5____*____ + 6_____*___ + 7_____*___ + 8_____*___ */ + auto sum = Sum(LoadU(df, d - xs3 - xs - 1), LoadU(df, d - xs3 - 1), + LoadU(df, d - xs - xs - 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs + 1), + LoadU(df, d + xs3 + 1), LoadU(df, d + xs3 + xs + 1)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_____*___ + 1_____*___ + 2____ *___ + 3____*____ + 4____0____ + 5____*____ + 6___*_____ + 7___*_____ + 8___*_____ */ + auto sum = Sum(LoadU(df, d - xs3 - xs + 1), LoadU(df, d - xs3 + 1), + LoadU(df, d - xs - xs + 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs - 1), + LoadU(df, d + xs3 - 1), LoadU(df, d + xs3 + xs - 1)); + retval = MulAdd(sum, sum, retval); + } + return retval; +} + +// Returns MaltaUnit. Avoids bounds-checks when x0 and y0 are known +// to be far enough from the image borders. "diffs" is a packed image. +template +static BUTTERAUGLI_INLINE float PaddedMaltaUnit(const ImageF& diffs, + const size_t x0, + const size_t y0) { + const float* BUTTERAUGLI_RESTRICT d = diffs.ConstRow(y0) + x0; + const HWY_CAPPED(float, 1) df; + if ((x0 >= 4 && y0 >= 4 && x0 < (diffs.xsize() - 4) && + y0 < (diffs.ysize() - 4))) { + return GetLane(MaltaUnit(Tag(), df, d, diffs.PixelsPerRow())); + } + + PROFILER_ZONE("Padded Malta"); + float borderimage[12 * 9]; // round up to 4 + for (int dy = 0; dy < 9; ++dy) { + int y = y0 + dy - 4; + if (y < 0 || static_cast(y) >= diffs.ysize()) { + for (int dx = 0; dx < 12; ++dx) { + borderimage[dy * 12 + dx] = 0.0f; + } + continue; + } + + const float* row_diffs = diffs.ConstRow(y); + for (int dx = 0; dx < 9; ++dx) { + int x = x0 + dx - 4; + if (x < 0 || static_cast(x) >= diffs.xsize()) { + borderimage[dy * 12 + dx] = 0.0f; + } else { + borderimage[dy * 12 + dx] = row_diffs[x]; + } + } + std::fill(borderimage + dy * 12 + 9, borderimage + dy * 12 + 12, 0.0f); + } + return GetLane(MaltaUnit(Tag(), df, &borderimage[4 * 12 + 4], 12)); +} + +template +static void MaltaDiffMapT(const Tag tag, const ImageF& lum0, const ImageF& lum1, + const double w_0gt1, const double w_0lt1, + const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + JXL_DASSERT(SameSize(lum0, lum1) && SameSize(lum0, *diffs)); + const size_t xsize_ = lum0.xsize(); + const size_t ysize_ = lum0.ysize(); + + const float kWeight0 = 0.5; + const float kWeight1 = 0.33; + + const double w_pre0gt1 = mulli * std::sqrt(kWeight0 * w_0gt1) / (len * 2 + 1); + const double w_pre0lt1 = mulli * std::sqrt(kWeight1 * w_0lt1) / (len * 2 + 1); + const float norm2_0gt1 = w_pre0gt1 * norm1; + const float norm2_0lt1 = w_pre0lt1 * norm1; + + for (size_t y = 0; y < ysize_; ++y) { + const float* HWY_RESTRICT row0 = lum0.ConstRow(y); + const float* HWY_RESTRICT row1 = lum1.ConstRow(y); + float* HWY_RESTRICT row_diffs = diffs->Row(y); + for (size_t x = 0; x < xsize_; ++x) { + const float absval = 0.5f * (std::abs(row0[x]) + std::abs(row1[x])); + const float diff = row0[x] - row1[x]; + const float scaler = norm2_0gt1 / (static_cast(norm1) + absval); + + // Primary symmetric quadratic objective. + row_diffs[x] = scaler * diff; + + const float scaler2 = norm2_0lt1 / (static_cast(norm1) + absval); + const double fabs0 = std::fabs(row0[x]); + + // Secondary half-open quadratic objectives. + const double too_small = 0.55 * fabs0; + const double too_big = 1.05 * fabs0; + + if (row0[x] < 0) { + if (row1[x] > -too_small) { + double impact = scaler2 * (row1[x] + too_small); + row_diffs[x] -= impact; + } else if (row1[x] < -too_big) { + double impact = scaler2 * (-row1[x] - too_big); + row_diffs[x] += impact; + } + } else { + if (row1[x] < too_small) { + double impact = scaler2 * (too_small - row1[x]); + row_diffs[x] += impact; + } else if (row1[x] > too_big) { + double impact = scaler2 * (row1[x] - too_big); + row_diffs[x] -= impact; + } + } + } + } + + size_t y0 = 0; + // Top + for (; y0 < 4; ++y0) { + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->PlaneRow(c, y0); + for (size_t x0 = 0; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + } + + const HWY_FULL(float) df; + const size_t aligned_x = std::max(size_t(4), Lanes(df)); + const intptr_t stride = diffs->PixelsPerRow(); + + // Middle + for (; y0 < ysize_ - 4; ++y0) { + const float* BUTTERAUGLI_RESTRICT row_in = diffs->ConstRow(y0); + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->PlaneRow(c, y0); + size_t x0 = 0; + for (; x0 < aligned_x; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + for (; x0 + Lanes(df) + 4 <= xsize_; x0 += Lanes(df)) { + auto diff = Load(df, row_diff + x0); + diff = Add(diff, MaltaUnit(Tag(), df, row_in + x0, stride)); + Store(diff, df, row_diff + x0); + } + + for (; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + } + + // Bottom + for (; y0 < ysize_; ++y0) { + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->PlaneRow(c, y0); + for (size_t x0 = 0; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + } +} + +// Need non-template wrapper functions for HWY_EXPORT. +void MaltaDiffMap(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + MaltaDiffMapT(MaltaTag(), lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, + diffs, block_diff_ac, c); +} + +void MaltaDiffMapLF(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + MaltaDiffMapT(MaltaTagLF(), lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, + diffs, block_diff_ac, c); +} + +void DiffPrecompute(const ImageF& xyb, float mul, float bias_arg, ImageF* out) { + PROFILER_FUNC; + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + const float bias = mul * bias_arg; + const float sqrt_bias = sqrt(bias); + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = xyb.Row(y); + float* BUTTERAUGLI_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + // kBias makes sqrt behave more linearly. + row_out[x] = sqrt(mul * std::abs(row_in[x]) + bias) - sqrt_bias; + } + } +} + +// std::log(80.0) / std::log(255.0); +constexpr float kIntensityTargetNormalizationHack = 0.79079917404f; +static const float kInternalGoodQualityThreshold = + 17.8f * kIntensityTargetNormalizationHack; +static const float kGlobalScale = 1.0 / kInternalGoodQualityThreshold; + +void StoreMin3(const float v, float& min0, float& min1, float& min2) { + if (v < min2) { + if (v < min0) { + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min2 = min1; + min1 = v; + } else { + min2 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas area generally smooth, don't do masking. +void FuzzyErosion(const ImageF& from, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + static const int kStep = 3; + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + float min0 = from.Row(y)[x]; + float min1 = 2 * min0; + float min2 = min1; + if (x >= kStep) { + float v = from.Row(y)[x - kStep]; + StoreMin3(v, min0, min1, min2); + if (y >= kStep) { + float v = from.Row(y - kStep)[x - kStep]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - kStep) { + float v = from.Row(y + kStep)[x - kStep]; + StoreMin3(v, min0, min1, min2); + } + } + if (x < xsize - kStep) { + float v = from.Row(y)[x + kStep]; + StoreMin3(v, min0, min1, min2); + if (y >= kStep) { + float v = from.Row(y - kStep)[x + kStep]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - kStep) { + float v = from.Row(y + kStep)[x + kStep]; + StoreMin3(v, min0, min1, min2); + } + } + if (y >= kStep) { + float v = from.Row(y - kStep)[x]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - kStep) { + float v = from.Row(y + kStep)[x]; + StoreMin3(v, min0, min1, min2); + } + to->Row(y)[x] = (0.45f * min0 + 0.3f * min1 + 0.25f * min2); + } + } +} + +// Compute values of local frequency and dc masking based on the activity +// in the two images. img_diff_ac may be null. +void Mask(const ImageF& mask0, const ImageF& mask1, + const ButteraugliParams& params, BlurTemp* blur_temp, + ImageF* BUTTERAUGLI_RESTRICT mask, + ImageF* BUTTERAUGLI_RESTRICT diff_ac) { + // Only X and Y components are involved in masking. B's influence + // is considered less important in the high frequency area, and we + // don't model masking from lower frequency signals. + PROFILER_FUNC; + const size_t xsize = mask0.xsize(); + const size_t ysize = mask0.ysize(); + *mask = ImageF(xsize, ysize); + static const float kMul = 6.19424080439; + static const float kBias = 12.61050594197; + static const float kRadius = 2.7; + ImageF diff0(xsize, ysize); + ImageF diff1(xsize, ysize); + ImageF blurred0(xsize, ysize); + ImageF blurred1(xsize, ysize); + DiffPrecompute(mask0, kMul, kBias, &diff0); + DiffPrecompute(mask1, kMul, kBias, &diff1); + Blur(diff0, kRadius, params, blur_temp, &blurred0); + FuzzyErosion(blurred0, &diff0); + Blur(diff1, kRadius, params, blur_temp, &blurred1); + FuzzyErosion(blurred1, &diff1); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + mask->Row(y)[x] = diff0.Row(y)[x]; + if (diff_ac != nullptr) { + static const float kMaskToErrorMul = 10.0; + float diff = blurred0.Row(y)[x] - blurred1.Row(y)[x]; + diff_ac->Row(y)[x] += kMaskToErrorMul * diff * diff; + } + } + } +} + +// `diff_ac` may be null. +void MaskPsychoImage(const PsychoImage& pi0, const PsychoImage& pi1, + const size_t xsize, const size_t ysize, + const ButteraugliParams& params, Image3F* temp, + BlurTemp* blur_temp, ImageF* BUTTERAUGLI_RESTRICT mask, + ImageF* BUTTERAUGLI_RESTRICT diff_ac) { + ImageF mask0(xsize, ysize); + ImageF mask1(xsize, ysize); + static const float muls[3] = { + 2.5f, + 0.4f, + 0.4f, + }; + // Silly and unoptimized approach here. TODO(jyrki): rework this. + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_y_hf0 = pi0.hf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_hf1 = pi1.hf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_uhf0 = pi0.uhf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_uhf1 = pi1.uhf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_hf0 = pi0.hf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_hf1 = pi1.hf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_uhf0 = pi0.uhf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_uhf1 = pi1.uhf[0].Row(y); + float* BUTTERAUGLI_RESTRICT row0 = mask0.Row(y); + float* BUTTERAUGLI_RESTRICT row1 = mask1.Row(y); + for (size_t x = 0; x < xsize; ++x) { + float xdiff0 = (row_x_uhf0[x] + row_x_hf0[x]) * muls[0]; + float xdiff1 = (row_x_uhf1[x] + row_x_hf1[x]) * muls[0]; + float ydiff0 = row_y_uhf0[x] * muls[1] + row_y_hf0[x] * muls[2]; + float ydiff1 = row_y_uhf1[x] * muls[1] + row_y_hf1[x] * muls[2]; + row0[x] = xdiff0 * xdiff0 + ydiff0 * ydiff0; + row0[x] = sqrt(row0[x]); + row1[x] = xdiff1 * xdiff1 + ydiff1 * ydiff1; + row1[x] = sqrt(row1[x]); + } + } + Mask(mask0, mask1, params, blur_temp, mask, diff_ac); +} + +double MaskY(double delta) { + static const double offset = 0.829591754942; + static const double scaler = 0.451936922203; + static const double mul = 2.5485944793; + const double c = mul / ((scaler * delta) + offset); + const double retval = kGlobalScale * (1.0 + c); + return retval * retval; +} + +double MaskDcY(double delta) { + static const double offset = 0.20025578522; + static const double scaler = 3.87449418804; + static const double mul = 0.505054525019; + const double c = mul / ((scaler * delta) + offset); + const double retval = kGlobalScale * (1.0 + c); + return retval * retval; +} + +inline float MaskColor(const float color[3], const float mask) { + return color[0] * mask + color[1] * mask + color[2] * mask; +} + +// Diffmap := sqrt of sum{diff images by multiplied by X and Y/B masks} +void CombineChannelsToDiffmap(const ImageF& mask, const Image3F& block_diff_dc, + const Image3F& block_diff_ac, float xmul, + ImageF* result) { + PROFILER_FUNC; + JXL_CHECK(SameSize(mask, *result)); + size_t xsize = mask.xsize(); + size_t ysize = mask.ysize(); + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_out = result->Row(y); + for (size_t x = 0; x < xsize; ++x) { + float val = mask.Row(y)[x]; + float maskval = MaskY(val); + float dc_maskval = MaskDcY(val); + float diff_dc[3]; + float diff_ac[3]; + for (int i = 0; i < 3; ++i) { + diff_dc[i] = block_diff_dc.PlaneRow(i, y)[x]; + diff_ac[i] = block_diff_ac.PlaneRow(i, y)[x]; + } + diff_ac[0] *= xmul; + diff_dc[0] *= xmul; + row_out[x] = + sqrt(MaskColor(diff_dc, dc_maskval) + MaskColor(diff_ac, maskval)); + } + } +} + +// Adds weighted L2 difference between i0 and i1 to diffmap. +static void L2Diff(const ImageF& i0, const ImageF& i1, const float w, + Image3F* BUTTERAUGLI_RESTRICT diffmap, size_t c) { + if (w == 0) return; + + const HWY_FULL(float) d; + const auto weight = Set(d, w); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.ConstRow(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->PlaneRow(c, y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto diff = Sub(Load(d, row0 + x), Load(d, row1 + x)); + const auto diff2 = Mul(diff, diff); + const auto prev = Load(d, row_diff + x); + Store(MulAdd(diff2, weight, prev), d, row_diff + x); + } + } +} + +// Initializes diffmap to the weighted L2 difference between i0 and i1. +static void SetL2Diff(const ImageF& i0, const ImageF& i1, const float w, + Image3F* BUTTERAUGLI_RESTRICT diffmap, size_t c) { + if (w == 0) return; + + const HWY_FULL(float) d; + const auto weight = Set(d, w); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.ConstRow(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->PlaneRow(c, y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto diff = Sub(Load(d, row0 + x), Load(d, row1 + x)); + const auto diff2 = Mul(diff, diff); + Store(Mul(diff2, weight), d, row_diff + x); + } + } +} + +// i0 is the original image. +// i1 is the deformed copy. +static void L2DiffAsymmetric(const ImageF& i0, const ImageF& i1, float w_0gt1, + float w_0lt1, + Image3F* BUTTERAUGLI_RESTRICT diffmap, size_t c) { + if (w_0gt1 == 0 && w_0lt1 == 0) { + return; + } + + const HWY_FULL(float) d; + const auto vw_0gt1 = Set(d, w_0gt1 * 0.8); + const auto vw_0lt1 = Set(d, w_0lt1 * 0.8); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.Row(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.Row(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->PlaneRow(c, y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto val0 = Load(d, row0 + x); + const auto val1 = Load(d, row1 + x); + + // Primary symmetric quadratic objective. + const auto diff = Sub(val0, val1); + auto total = MulAdd(Mul(diff, diff), vw_0gt1, Load(d, row_diff + x)); + + // Secondary half-open quadratic objectives. + const auto fabs0 = Abs(val0); + const auto too_small = Mul(Set(d, 0.4), fabs0); + const auto too_big = fabs0; + + const auto if_neg = IfThenElse( + Gt(val1, Neg(too_small)), Add(val1, too_small), + IfThenElseZero(Lt(val1, Neg(too_big)), Sub(Neg(val1), too_big))); + const auto if_pos = + IfThenElse(Lt(val1, too_small), Sub(too_small, val1), + IfThenElseZero(Gt(val1, too_big), Sub(val1, too_big))); + const auto v = IfThenElse(Lt(val0, Zero(d)), if_neg, if_pos); + total = MulAdd(vw_0lt1, Mul(v, v), total); + Store(total, d, row_diff + x); + } + } +} + +// A simple HDR compatible gamma function. +template +V Gamma(const DF df, V v) { + // ln(2) constant folded in because we want std::log but have FastLog2f. + const auto kRetMul = Set(df, 19.245013259874995f * 0.693147180559945f); + const auto kRetAdd = Set(df, -23.16046239805755); + // This should happen rarely, but may lead to a NaN in log, which is + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + v = ZeroIfNegative(v); + + const auto biased = Add(v, Set(df, 9.9710635769299145)); + const auto log = FastLog2f(df, biased); + // We could fold this into a custom Log2 polynomial, but there would be + // relatively little gain. + return MulAdd(kRetMul, log, kRetAdd); +} + +template +BUTTERAUGLI_INLINE void OpsinAbsorbance(const DF df, const V& in0, const V& in1, + const V& in2, V* JXL_RESTRICT out0, + V* JXL_RESTRICT out1, + V* JXL_RESTRICT out2) { + // https://en.wikipedia.org/wiki/Photopsin absorbance modeling. + static const double mixi0 = 0.29956550340058319; + static const double mixi1 = 0.63373087833825936; + static const double mixi2 = 0.077705617820981968; + static const double mixi3 = 1.7557483643287353; + static const double mixi4 = 0.22158691104574774; + static const double mixi5 = 0.69391388044116142; + static const double mixi6 = 0.0987313588422; + static const double mixi7 = 1.7557483643287353; + static const double mixi8 = 0.02; + static const double mixi9 = 0.02; + static const double mixi10 = 0.20480129041026129; + static const double mixi11 = 12.226454707163354; + + const V mix0 = Set(df, mixi0); + const V mix1 = Set(df, mixi1); + const V mix2 = Set(df, mixi2); + const V mix3 = Set(df, mixi3); + const V mix4 = Set(df, mixi4); + const V mix5 = Set(df, mixi5); + const V mix6 = Set(df, mixi6); + const V mix7 = Set(df, mixi7); + const V mix8 = Set(df, mixi8); + const V mix9 = Set(df, mixi9); + const V mix10 = Set(df, mixi10); + const V mix11 = Set(df, mixi11); + + *out0 = MulAdd(mix0, in0, MulAdd(mix1, in1, MulAdd(mix2, in2, mix3))); + *out1 = MulAdd(mix4, in0, MulAdd(mix5, in1, MulAdd(mix6, in2, mix7))); + *out2 = MulAdd(mix8, in0, MulAdd(mix9, in1, MulAdd(mix10, in2, mix11))); + + if (Clamp) { + *out0 = Max(*out0, mix3); + *out1 = Max(*out1, mix7); + *out2 = Max(*out2, mix11); + } +} + +// `blurred` is a temporary image used inside this function and not returned. +Image3F OpsinDynamicsImage(const Image3F& rgb, const ButteraugliParams& params, + Image3F* blurred, BlurTemp* blur_temp) { + PROFILER_FUNC; + Image3F xyb(rgb.xsize(), rgb.ysize()); + const double kSigma = 1.2; + Blur(rgb.Plane(0), kSigma, params, blur_temp, &blurred->Plane(0)); + Blur(rgb.Plane(1), kSigma, params, blur_temp, &blurred->Plane(1)); + Blur(rgb.Plane(2), kSigma, params, blur_temp, &blurred->Plane(2)); + const HWY_FULL(float) df; + const auto intensity_target_multiplier = Set(df, params.intensity_target); + for (size_t y = 0; y < rgb.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_r = rgb.ConstPlaneRow(0, y); + const float* BUTTERAUGLI_RESTRICT row_g = rgb.ConstPlaneRow(1, y); + const float* BUTTERAUGLI_RESTRICT row_b = rgb.ConstPlaneRow(2, y); + const float* BUTTERAUGLI_RESTRICT row_blurred_r = + blurred->ConstPlaneRow(0, y); + const float* BUTTERAUGLI_RESTRICT row_blurred_g = + blurred->ConstPlaneRow(1, y); + const float* BUTTERAUGLI_RESTRICT row_blurred_b = + blurred->ConstPlaneRow(2, y); + float* BUTTERAUGLI_RESTRICT row_out_x = xyb.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_out_y = xyb.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_out_b = xyb.PlaneRow(2, y); + const auto min = Set(df, 1e-4f); + for (size_t x = 0; x < rgb.xsize(); x += Lanes(df)) { + auto sensitivity0 = Undefined(df); + auto sensitivity1 = Undefined(df); + auto sensitivity2 = Undefined(df); + { + // Calculate sensitivity based on the smoothed image gamma derivative. + auto pre_mixed0 = Undefined(df); + auto pre_mixed1 = Undefined(df); + auto pre_mixed2 = Undefined(df); + OpsinAbsorbance( + df, Mul(Load(df, row_blurred_r + x), intensity_target_multiplier), + Mul(Load(df, row_blurred_g + x), intensity_target_multiplier), + Mul(Load(df, row_blurred_b + x), intensity_target_multiplier), + &pre_mixed0, &pre_mixed1, &pre_mixed2); + pre_mixed0 = Max(pre_mixed0, min); + pre_mixed1 = Max(pre_mixed1, min); + pre_mixed2 = Max(pre_mixed2, min); + sensitivity0 = Div(Gamma(df, pre_mixed0), pre_mixed0); + sensitivity1 = Div(Gamma(df, pre_mixed1), pre_mixed1); + sensitivity2 = Div(Gamma(df, pre_mixed2), pre_mixed2); + sensitivity0 = Max(sensitivity0, min); + sensitivity1 = Max(sensitivity1, min); + sensitivity2 = Max(sensitivity2, min); + } + auto cur_mixed0 = Undefined(df); + auto cur_mixed1 = Undefined(df); + auto cur_mixed2 = Undefined(df); + OpsinAbsorbance( + df, Mul(Load(df, row_r + x), intensity_target_multiplier), + Mul(Load(df, row_g + x), intensity_target_multiplier), + Mul(Load(df, row_b + x), intensity_target_multiplier), &cur_mixed0, + &cur_mixed1, &cur_mixed2); + cur_mixed0 = Mul(cur_mixed0, sensitivity0); + cur_mixed1 = Mul(cur_mixed1, sensitivity1); + cur_mixed2 = Mul(cur_mixed2, sensitivity2); + // This is a kludge. The negative values should be zeroed away before + // blurring. Ideally there would be no negative values in the first place. + const auto min01 = Set(df, 1.7557483643287353f); + const auto min2 = Set(df, 12.226454707163354f); + cur_mixed0 = Max(cur_mixed0, min01); + cur_mixed1 = Max(cur_mixed1, min01); + cur_mixed2 = Max(cur_mixed2, min2); + + Store(Sub(cur_mixed0, cur_mixed1), df, row_out_x + x); + Store(Add(cur_mixed0, cur_mixed1), df, row_out_y + x); + Store(cur_mixed2, df, row_out_b + x); + } + } + return xyb; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(SeparateFrequencies); // Local function. +HWY_EXPORT(MaskPsychoImage); // Local function. +HWY_EXPORT(L2DiffAsymmetric); // Local function. +HWY_EXPORT(L2Diff); // Local function. +HWY_EXPORT(SetL2Diff); // Local function. +HWY_EXPORT(CombineChannelsToDiffmap); // Local function. +HWY_EXPORT(MaltaDiffMap); // Local function. +HWY_EXPORT(MaltaDiffMapLF); // Local function. +HWY_EXPORT(OpsinDynamicsImage); // Local function. + +#if BUTTERAUGLI_ENABLE_CHECKS + +static inline bool IsNan(const float x) { + uint32_t bits; + memcpy(&bits, &x, sizeof(bits)); + const uint32_t bitmask_exp = 0x7F800000; + return (bits & bitmask_exp) == bitmask_exp && (bits & 0x7FFFFF); +} + +static inline bool IsNan(const double x) { + uint64_t bits; + memcpy(&bits, &x, sizeof(bits)); + return (0x7ff0000000000001ULL <= bits && bits <= 0x7fffffffffffffffULL) || + (0xfff0000000000001ULL <= bits && bits <= 0xffffffffffffffffULL); +} + +static inline void CheckImage(const ImageF& image, const char* name) { + PROFILER_FUNC; + for (size_t y = 0; y < image.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row = image.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + if (IsNan(row[x])) { + printf("NAN: Image %s @ %" PRIuS ",%" PRIuS " (of %" PRIuS ",%" PRIuS + ")\n", + name, x, y, image.xsize(), image.ysize()); + exit(1); + } + } + } +} + +#define CHECK_NAN(x, str) \ + do { \ + if (IsNan(x)) { \ + printf("%d: %s\n", __LINE__, str); \ + abort(); \ + } \ + } while (0) + +#define CHECK_IMAGE(image, name) CheckImage(image, name) + +#else // BUTTERAUGLI_ENABLE_CHECKS + +#define CHECK_NAN(x, str) +#define CHECK_IMAGE(image, name) + +#endif // BUTTERAUGLI_ENABLE_CHECKS + +// Calculate a 2x2 subsampled image for purposes of recursive butteraugli at +// multiresolution. +static Image3F SubSample2x(const Image3F& in) { + size_t xs = (in.xsize() + 1) / 2; + size_t ys = (in.ysize() + 1) / 2; + Image3F retval(xs, ys); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ys; ++y) { + for (size_t x = 0; x < xs; ++x) { + retval.PlaneRow(c, y)[x] = 0; + } + } + } + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < in.ysize(); ++y) { + for (size_t x = 0; x < in.xsize(); ++x) { + retval.PlaneRow(c, y / 2)[x / 2] += 0.25f * in.PlaneRow(c, y)[x]; + } + } + if ((in.xsize() & 1) != 0) { + for (size_t y = 0; y < retval.ysize(); ++y) { + size_t last_column = retval.xsize() - 1; + retval.PlaneRow(c, y)[last_column] *= 2.0f; + } + } + if ((in.ysize() & 1) != 0) { + for (size_t x = 0; x < retval.xsize(); ++x) { + size_t last_row = retval.ysize() - 1; + retval.PlaneRow(c, last_row)[x] *= 2.0f; + } + } + } + return retval; +} + +// Supersample src by 2x and add it to dest. +static void AddSupersampled2x(const ImageF& src, float w, ImageF& dest) { + for (size_t y = 0; y < dest.ysize(); ++y) { + for (size_t x = 0; x < dest.xsize(); ++x) { + // There will be less errors from the more averaged images. + // We take it into account to some extent using a scaler. + static const double kHeuristicMixingValue = 0.3; + dest.Row(y)[x] *= 1.0 - kHeuristicMixingValue * w; + dest.Row(y)[x] += w * src.Row(y / 2)[x / 2]; + } + } +} + +Image3F* ButteraugliComparator::Temp() const { + bool was_in_use = temp_in_use_.test_and_set(std::memory_order_acq_rel); + JXL_ASSERT(!was_in_use); + (void)was_in_use; + return &temp_; +} + +void ButteraugliComparator::ReleaseTemp() const { temp_in_use_.clear(); } + +ButteraugliComparator::ButteraugliComparator(const Image3F& rgb0, + const ButteraugliParams& params) + : xsize_(rgb0.xsize()), + ysize_(rgb0.ysize()), + params_(params), + temp_(xsize_, ysize_) { + if (xsize_ < 8 || ysize_ < 8) { + return; + } + + Image3F xyb0 = HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage)(rgb0, params, Temp(), + &blur_temp_); + ReleaseTemp(); + HWY_DYNAMIC_DISPATCH(SeparateFrequencies) + (xsize_, ysize_, params_, &blur_temp_, xyb0, pi0_); + + // Awful recursive construction of samples of different resolution. + // This is an after-thought and possibly somewhat parallel in + // functionality with the PsychoImage multi-resolution approach. + sub_.reset(new ButteraugliComparator(SubSample2x(rgb0), params)); +} + +void ButteraugliComparator::Mask(ImageF* BUTTERAUGLI_RESTRICT mask) const { + HWY_DYNAMIC_DISPATCH(MaskPsychoImage) + (pi0_, pi0_, xsize_, ysize_, params_, Temp(), &blur_temp_, mask, nullptr); + ReleaseTemp(); +} + +void ButteraugliComparator::Diffmap(const Image3F& rgb1, ImageF& result) const { + PROFILER_FUNC; + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&result); + return; + } + const Image3F xyb1 = HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage)( + rgb1, params_, Temp(), &blur_temp_); + ReleaseTemp(); + DiffmapOpsinDynamicsImage(xyb1, result); + if (sub_) { + if (sub_->xsize_ < 8 || sub_->ysize_ < 8) { + return; + } + const Image3F sub_xyb = HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage)( + SubSample2x(rgb1), params_, sub_->Temp(), &sub_->blur_temp_); + sub_->ReleaseTemp(); + ImageF subresult; + sub_->DiffmapOpsinDynamicsImage(sub_xyb, subresult); + AddSupersampled2x(subresult, 0.5, result); + } +} + +void ButteraugliComparator::DiffmapOpsinDynamicsImage(const Image3F& xyb1, + ImageF& result) const { + PROFILER_FUNC; + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&result); + return; + } + PsychoImage pi1; + HWY_DYNAMIC_DISPATCH(SeparateFrequencies) + (xsize_, ysize_, params_, &blur_temp_, xyb1, pi1); + result = ImageF(xsize_, ysize_); + DiffmapPsychoImage(pi1, result); +} + +namespace { + +void MaltaDiffMap(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + PROFILER_FUNC; + const double len = 3.75; + static const double mulli = 0.39905817637; + HWY_DYNAMIC_DISPATCH(MaltaDiffMap) + (lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, diffs, block_diff_ac, c); +} + +void MaltaDiffMapLF(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + PROFILER_FUNC; + const double len = 3.75; + static const double mulli = 0.611612573796; + HWY_DYNAMIC_DISPATCH(MaltaDiffMapLF) + (lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, diffs, block_diff_ac, c); +} + +} // namespace + +void ButteraugliComparator::DiffmapPsychoImage(const PsychoImage& pi1, + ImageF& diffmap) const { + PROFILER_FUNC; + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&diffmap); + return; + } + + const float hf_asymmetry_ = params_.hf_asymmetry; + const float xmul_ = params_.xmul; + + ImageF diffs(xsize_, ysize_); + Image3F block_diff_ac(xsize_, ysize_); + ZeroFillImage(&block_diff_ac); + static const double wUhfMalta = 1.10039032555; + static const double norm1Uhf = 71.7800275169; + MaltaDiffMap(pi0_.uhf[1], pi1.uhf[1], wUhfMalta * hf_asymmetry_, + wUhfMalta / hf_asymmetry_, norm1Uhf, &diffs, &block_diff_ac, 1); + + static const double wUhfMaltaX = 173.5; + static const double norm1UhfX = 5.0; + MaltaDiffMap(pi0_.uhf[0], pi1.uhf[0], wUhfMaltaX * hf_asymmetry_, + wUhfMaltaX / hf_asymmetry_, norm1UhfX, &diffs, &block_diff_ac, + 0); + + static const double wHfMalta = 18.7237414387; + static const double norm1Hf = 4498534.45232; + MaltaDiffMapLF(pi0_.hf[1], pi1.hf[1], wHfMalta * std::sqrt(hf_asymmetry_), + wHfMalta / std::sqrt(hf_asymmetry_), norm1Hf, &diffs, + &block_diff_ac, 1); + + static const double wHfMaltaX = 6923.99476109; + static const double norm1HfX = 8051.15833247; + MaltaDiffMapLF(pi0_.hf[0], pi1.hf[0], wHfMaltaX * std::sqrt(hf_asymmetry_), + wHfMaltaX / std::sqrt(hf_asymmetry_), norm1HfX, &diffs, + &block_diff_ac, 0); + + static const double wMfMalta = 37.0819870399; + static const double norm1Mf = 130262059.556; + MaltaDiffMapLF(pi0_.mf.Plane(1), pi1.mf.Plane(1), wMfMalta, wMfMalta, norm1Mf, + &diffs, &block_diff_ac, 1); + + static const double wMfMaltaX = 8246.75321353; + static const double norm1MfX = 1009002.70582; + MaltaDiffMapLF(pi0_.mf.Plane(0), pi1.mf.Plane(0), wMfMaltaX, wMfMaltaX, + norm1MfX, &diffs, &block_diff_ac, 0); + + static const double wmul[9] = { + 400.0, 1.50815703118, 0, + 2150.0, 10.6195433239, 16.2176043152, + 29.2353797994, 0.844626970982, 0.703646627719, + }; + Image3F block_diff_dc(xsize_, ysize_); + for (size_t c = 0; c < 3; ++c) { + if (c < 2) { // No blue channel error accumulated at HF. + HWY_DYNAMIC_DISPATCH(L2DiffAsymmetric) + (pi0_.hf[c], pi1.hf[c], wmul[c] * hf_asymmetry_, wmul[c] / hf_asymmetry_, + &block_diff_ac, c); + } + HWY_DYNAMIC_DISPATCH(L2Diff) + (pi0_.mf.Plane(c), pi1.mf.Plane(c), wmul[3 + c], &block_diff_ac, c); + HWY_DYNAMIC_DISPATCH(SetL2Diff) + (pi0_.lf.Plane(c), pi1.lf.Plane(c), wmul[6 + c], &block_diff_dc, c); + } + + ImageF mask; + HWY_DYNAMIC_DISPATCH(MaskPsychoImage) + (pi0_, pi1, xsize_, ysize_, params_, Temp(), &blur_temp_, &mask, + &block_diff_ac.Plane(1)); + ReleaseTemp(); + + HWY_DYNAMIC_DISPATCH(CombineChannelsToDiffmap) + (mask, block_diff_dc, block_diff_ac, xmul_, &diffmap); +} + +double ButteraugliScoreFromDiffmap(const ImageF& diffmap, + const ButteraugliParams* params) { + PROFILER_FUNC; + float retval = 0.0f; + for (size_t y = 0; y < diffmap.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row = diffmap.ConstRow(y); + for (size_t x = 0; x < diffmap.xsize(); ++x) { + retval = std::max(retval, row[x]); + } + } + return retval; +} + +bool ButteraugliDiffmap(const Image3F& rgb0, const Image3F& rgb1, + double hf_asymmetry, double xmul, ImageF& diffmap) { + ButteraugliParams params; + params.hf_asymmetry = hf_asymmetry; + params.xmul = xmul; + return ButteraugliDiffmap(rgb0, rgb1, params, diffmap); +} + +bool ButteraugliDiffmap(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap) { + PROFILER_FUNC; + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + if (xsize < 1 || ysize < 1) { + return JXL_FAILURE("Zero-sized image"); + } + if (!SameSize(rgb0, rgb1)) { + return JXL_FAILURE("Size mismatch"); + } + static const int kMax = 8; + if (xsize < kMax || ysize < kMax) { + // Butteraugli values for small (where xsize or ysize is smaller + // than 8 pixels) images are non-sensical, but most likely it is + // less disruptive to try to compute something than just give up. + // Temporarily extend the borders of the image to fit 8 x 8 size. + size_t xborder = xsize < kMax ? (kMax - xsize) / 2 : 0; + size_t yborder = ysize < kMax ? (kMax - ysize) / 2 : 0; + size_t xscaled = std::max(kMax, xsize); + size_t yscaled = std::max(kMax, ysize); + Image3F scaled0(xscaled, yscaled); + Image3F scaled1(xscaled, yscaled); + for (int i = 0; i < 3; ++i) { + for (size_t y = 0; y < yscaled; ++y) { + for (size_t x = 0; x < xscaled; ++x) { + size_t x2 = + std::min(xsize - 1, x > xborder ? x - xborder : 0); + size_t y2 = + std::min(ysize - 1, y > yborder ? y - yborder : 0); + scaled0.PlaneRow(i, y)[x] = rgb0.PlaneRow(i, y2)[x2]; + scaled1.PlaneRow(i, y)[x] = rgb1.PlaneRow(i, y2)[x2]; + } + } + } + ImageF diffmap_scaled; + const bool ok = + ButteraugliDiffmap(scaled0, scaled1, params, diffmap_scaled); + diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + diffmap.Row(y)[x] = diffmap_scaled.Row(y + yborder)[x + xborder]; + } + } + return ok; + } + ButteraugliComparator butteraugli(rgb0, params); + butteraugli.Diffmap(rgb1, diffmap); + return true; +} + +bool ButteraugliInterface(const Image3F& rgb0, const Image3F& rgb1, + float hf_asymmetry, float xmul, ImageF& diffmap, + double& diffvalue) { + ButteraugliParams params; + params.hf_asymmetry = hf_asymmetry; + params.xmul = xmul; + return ButteraugliInterface(rgb0, rgb1, params, diffmap, diffvalue); +} + +bool ButteraugliInterface(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap, + double& diffvalue) { +#if PROFILER_ENABLED + auto trace_start = std::chrono::steady_clock::now(); +#endif + if (!ButteraugliDiffmap(rgb0, rgb1, params, diffmap)) { + return false; + } +#if PROFILER_ENABLED + auto trace_end = std::chrono::steady_clock::now(); + std::chrono::duration elapsed = trace_end - trace_start; + const size_t mp = rgb0.xsize() * rgb0.ysize(); + printf("diff MP/s %f\n", mp / elapsed.count() * 1E-6); +#endif + diffvalue = ButteraugliScoreFromDiffmap(diffmap, ¶ms); + return true; +} + +double ButteraugliFuzzyClass(double score) { + static const double fuzzy_width_up = 4.8; + static const double fuzzy_width_down = 4.8; + static const double m0 = 2.0; + static const double scaler = 0.7777; + double val; + if (score < 1.0) { + // val in [scaler .. 2.0] + val = m0 / (1.0 + exp((score - 1.0) * fuzzy_width_down)); + val -= 1.0; // from [1 .. 2] to [0 .. 1] + val *= 2.0 - scaler; // from [0 .. 1] to [0 .. 2.0 - scaler] + val += scaler; // from [0 .. 2.0 - scaler] to [scaler .. 2.0] + } else { + // val in [0 .. scaler] + val = m0 / (1.0 + exp((score - 1.0) * fuzzy_width_up)); + val *= scaler; + } + return val; +} + +// #define PRINT_OUT_NORMALIZATION + +double ButteraugliFuzzyInverse(double seek) { + double pos = 0; + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (double range = 1.0; range >= 1e-10; range *= 0.5) { + double cur = ButteraugliFuzzyClass(pos); + if (cur < seek) { + pos -= range; + } else { + pos += range; + } + } +#ifdef PRINT_OUT_NORMALIZATION + if (seek == 1.0) { + fprintf(stderr, "Fuzzy inverse %g\n", pos); + } +#endif + return pos; +} + +#ifdef PRINT_OUT_NORMALIZATION +static double print_out_normalization = ButteraugliFuzzyInverse(1.0); +#endif + +namespace { + +void ScoreToRgb(double score, double good_threshold, double bad_threshold, + float rgb[3]) { + double heatmap[12][3] = { + {0, 0, 0}, {0, 0, 1}, + {0, 1, 1}, {0, 1, 0}, // Good level + {1, 1, 0}, {1, 0, 0}, // Bad level + {1, 0, 1}, {0.5, 0.5, 1.0}, + {1.0, 0.5, 0.5}, // Pastel colors for the very bad quality range. + {1.0, 1.0, 0.5}, {1, 1, 1}, + {1, 1, 1}, // Last color repeated to have a solid range of white. + }; + if (score < good_threshold) { + score = (score / good_threshold) * 0.3; + } else if (score < bad_threshold) { + score = 0.3 + + (score - good_threshold) / (bad_threshold - good_threshold) * 0.15; + } else { + score = 0.45 + (score - bad_threshold) / (bad_threshold * 12) * 0.5; + } + static const int kTableSize = sizeof(heatmap) / sizeof(heatmap[0]); + score = std::min(std::max(score * (kTableSize - 1), 0.0), + kTableSize - 2); + int ix = static_cast(score); + ix = std::min(std::max(0, ix), kTableSize - 2); // Handle NaN + double mix = score - ix; + for (int i = 0; i < 3; ++i) { + double v = mix * heatmap[ix + 1][i] + (1 - mix) * heatmap[ix][i]; + rgb[i] = pow(v, 0.5); + } +} + +} // namespace + +Image3F CreateHeatMapImage(const ImageF& distmap, double good_threshold, + double bad_threshold) { + Image3F heatmap(distmap.xsize(), distmap.ysize()); + for (size_t y = 0; y < distmap.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_distmap = distmap.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_h0 = heatmap.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_h1 = heatmap.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_h2 = heatmap.PlaneRow(2, y); + for (size_t x = 0; x < distmap.xsize(); ++x) { + const float d = row_distmap[x]; + float rgb[3]; + ScoreToRgb(d, good_threshold, bad_threshold, rgb); + row_h0[x] = rgb[0]; + row_h1[x] = rgb[1]; + row_h2[x] = rgb[2]; + } + } + return heatmap; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/butteraugli/butteraugli.h b/media/libjxl/src/lib/jxl/butteraugli/butteraugli.h new file mode 100644 index 000000000..652b9528c --- /dev/null +++ b/media/libjxl/src/lib/jxl/butteraugli/butteraugli.h @@ -0,0 +1,209 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Author: Jyrki Alakuijala (jyrki.alakuijala@gmail.com) + +#ifndef LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ +#define LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +#define BUTTERAUGLI_ENABLE_CHECKS 0 +#define BUTTERAUGLI_RESTRICT JXL_RESTRICT + +// This is the main interface to butteraugli image similarity +// analysis function. + +namespace jxl { + +struct ButteraugliParams { + // Multiplier for penalizing new HF artifacts more than blurring away + // features. 1.0=neutral. + float hf_asymmetry = 1.0f; + + // Multiplier for the psychovisual difference in the X channel. + float xmul = 1.0f; + + // Number of nits that correspond to 1.0f input values. + float intensity_target = 80.0f; +}; + +// ButteraugliInterface defines the public interface for butteraugli. +// +// It calculates the difference between rgb0 and rgb1. +// +// rgb0 and rgb1 contain the images. rgb0[c][px] and rgb1[c][px] contains +// the red image for c == 0, green for c == 1, blue for c == 2. Location index +// px is calculated as y * xsize + x. +// +// Value of pixels of images rgb0 and rgb1 need to be represented as raw +// intensity. Most image formats store gamma corrected intensity in pixel +// values. This gamma correction has to be removed, by applying the following +// function to values in the 0-1 range: +// butteraugli_val = pow(input_val, gamma); +// A typical value of gamma is 2.2. It is usually stored in the image header. +// Take care not to confuse that value with its inverse. The gamma value should +// be always greater than one. +// Butteraugli does not work as intended if the caller does not perform +// gamma correction. +// +// hf_asymmetry is a multiplier for penalizing new HF artifacts more than +// blurring away features (1.0 -> neutral). +// +// diffmap will contain an image of the size xsize * ysize, containing +// localized differences for values px (indexed with the px the same as rgb0 +// and rgb1). diffvalue will give a global score of similarity. +// +// A diffvalue smaller than kButteraugliGood indicates that images can be +// observed as the same image. +// diffvalue larger than kButteraugliBad indicates that a difference between +// the images can be observed. +// A diffvalue between kButteraugliGood and kButteraugliBad indicates that +// a subtle difference can be observed between the images. +// +// Returns true on success. +bool ButteraugliInterface(const Image3F &rgb0, const Image3F &rgb1, + const ButteraugliParams ¶ms, ImageF &diffmap, + double &diffvalue); + +// Deprecated (calls the previous function) +bool ButteraugliInterface(const Image3F &rgb0, const Image3F &rgb1, + float hf_asymmetry, float xmul, ImageF &diffmap, + double &diffvalue); + +// Converts the butteraugli score into fuzzy class values that are continuous +// at the class boundary. The class boundary location is based on human +// raters, but the slope is arbitrary. Particularly, it does not reflect +// the expectation value of probabilities of the human raters. It is just +// expected that a smoother class boundary will allow for higher-level +// optimization algorithms to work faster. +// +// Returns 2.0 for a perfect match, and 1.0 for 'ok', 0.0 for bad. Because the +// scoring is fuzzy, a butteraugli score of 0.96 would return a class of +// around 1.9. +double ButteraugliFuzzyClass(double score); + +// Input values should be in range 0 (bad) to 2 (good). Use +// kButteraugliNormalization as normalization. +double ButteraugliFuzzyInverse(double seek); + +// Implementation details, don't use anything below or your code will +// break in the future. + +#ifdef _MSC_VER +#define BUTTERAUGLI_INLINE __forceinline +#else +#define BUTTERAUGLI_INLINE inline +#endif + +#ifdef __clang__ +// Early versions of Clang did not support __builtin_assume_aligned. +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED __has_builtin(__builtin_assume_aligned) +#elif defined(__GNUC__) +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED 1 +#else +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED 0 +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* JXL_RESTRICT aligned = (float*)JXL_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if BUTTERAUGLI_HAS_ASSUME_ALIGNED +#define BUTTERAUGLI_ASSUME_ALIGNED(ptr, align) \ + __builtin_assume_aligned((ptr), (align)) +#else +#define BUTTERAUGLI_ASSUME_ALIGNED(ptr, align) (ptr) +#endif // BUTTERAUGLI_HAS_ASSUME_ALIGNED + +struct PsychoImage { + ImageF uhf[2]; // XY + ImageF hf[2]; // XY + Image3F mf; // XYB + Image3F lf; // XYB +}; + +// Blur needs a transposed image. +// Hold it here and only allocate on demand to reduce memory usage. +struct BlurTemp { + ImageF *GetTransposed(const ImageF &in) { + if (transposed_temp.xsize() == 0) { + transposed_temp = ImageF(in.ysize(), in.xsize()); + } + return &transposed_temp; + } + + ImageF transposed_temp; +}; + +class ButteraugliComparator { + public: + // Butteraugli is calibrated at xmul = 1.0. We add a multiplier here so that + // we can test the hypothesis that a higher weighing of the X channel would + // improve results at higher Butteraugli values. + ButteraugliComparator(const Image3F &rgb0, const ButteraugliParams ¶ms); + virtual ~ButteraugliComparator() = default; + + // Computes the butteraugli map between the original image given in the + // constructor and the distorted image give here. + void Diffmap(const Image3F &rgb1, ImageF &result) const; + + // Same as above, but OpsinDynamicsImage() was already applied. + void DiffmapOpsinDynamicsImage(const Image3F &xyb1, ImageF &result) const; + + // Same as above, but the frequency decomposition was already applied. + void DiffmapPsychoImage(const PsychoImage &pi1, ImageF &diffmap) const; + + void Mask(ImageF *BUTTERAUGLI_RESTRICT mask) const; + + private: + Image3F *Temp() const; + void ReleaseTemp() const; + + const size_t xsize_; + const size_t ysize_; + ButteraugliParams params_; + PsychoImage pi0_; + + // Shared temporary image storage to reduce the number of allocations; + // obtained via Temp(), must call ReleaseTemp when no longer needed. + mutable Image3F temp_; + mutable std::atomic_flag temp_in_use_ = ATOMIC_FLAG_INIT; + + mutable BlurTemp blur_temp_; + std::unique_ptr sub_; +}; + +// Deprecated. +bool ButteraugliDiffmap(const Image3F &rgb0, const Image3F &rgb1, + double hf_asymmetry, double xmul, ImageF &diffmap); + +bool ButteraugliDiffmap(const Image3F &rgb0, const Image3F &rgb1, + const ButteraugliParams ¶ms, ImageF &diffmap); + +double ButteraugliScoreFromDiffmap(const ImageF &diffmap, + const ButteraugliParams *params = nullptr); + +// Generate rgb-representation of the distance between two images. +Image3F CreateHeatMapImage(const ImageF &distmap, double good_threshold, + double bad_threshold); + +} // namespace jxl + +#endif // LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ diff --git a/media/libjxl/src/lib/jxl/butteraugli_test.cc b/media/libjxl/src/lib/jxl/butteraugli_test.cc new file mode 100644 index 000000000..98ec7888a --- /dev/null +++ b/media/libjxl/src/lib/jxl/butteraugli_test.cc @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/butteraugli.h" + +#include "gtest/gtest.h" +#include "jxl/butteraugli_cxx.h" +#include "lib/jxl/test_utils.h" + +TEST(ButteraugliTest, Lossless) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, pixels.data(), pixels.size(), + &pixel_format, pixels.data(), pixels.size())); + EXPECT_EQ(0.0, JxlButteraugliResultGetDistance(result.get(), 8.0)); +} + +TEST(ButteraugliTest, Distmap) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, pixels.data(), pixels.size(), + &pixel_format, pixels.data(), pixels.size())); + EXPECT_EQ(0.0, JxlButteraugliResultGetDistance(result.get(), 8.0)); + const float* distmap; + uint32_t row_stride; + JxlButteraugliResultGetDistmap(result.get(), &distmap, &row_stride); + for (uint32_t y = 0; y < ysize; y++) { + for (uint32_t x = 0; x < xsize; x++) { + EXPECT_EQ(0.0, distmap[y * row_stride + x]); + } + } +} + +TEST(ButteraugliTest, Distorted) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector orig_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + std::vector dist_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + dist_pixels[0] += 128; + + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, orig_pixels.data(), + orig_pixels.size(), &pixel_format, dist_pixels.data(), + dist_pixels.size())); + EXPECT_NE(0.0, JxlButteraugliResultGetDistance(result.get(), 8.0)); +} + +TEST(ButteraugliTest, Api) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector orig_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + std::vector dist_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + dist_pixels[0] += 128; + + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliApiSetHFAsymmetry(api.get(), 1.0f); + JxlButteraugliApiSetIntensityTarget(api.get(), 250.0f); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, orig_pixels.data(), + orig_pixels.size(), &pixel_format, dist_pixels.data(), + dist_pixels.size())); + double distance0 = JxlButteraugliResultGetDistance(result.get(), 8.0); + + JxlButteraugliApiSetHFAsymmetry(api.get(), 2.0f); + result.reset(JxlButteraugliCompute(api.get(), xsize, ysize, &pixel_format, + orig_pixels.data(), orig_pixels.size(), + &pixel_format, dist_pixels.data(), + dist_pixels.size())); + double distance1 = JxlButteraugliResultGetDistance(result.get(), 8.0); + + EXPECT_NE(distance0, distance1); + + JxlButteraugliApiSetIntensityTarget(api.get(), 80.0f); + result.reset(JxlButteraugliCompute(api.get(), xsize, ysize, &pixel_format, + orig_pixels.data(), orig_pixels.size(), + &pixel_format, dist_pixels.data(), + dist_pixels.size())); + double distance2 = JxlButteraugliResultGetDistance(result.get(), 8.0); + + EXPECT_NE(distance1, distance2); +} diff --git a/media/libjxl/src/lib/jxl/butteraugli_wrapper.cc b/media/libjxl/src/lib/jxl/butteraugli_wrapper.cc new file mode 100644 index 000000000..836b798d1 --- /dev/null +++ b/media/libjxl/src/lib/jxl/butteraugli_wrapper.cc @@ -0,0 +1,203 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include + +#include + +#include "jxl/butteraugli.h" +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/memory_manager_internal.h" + +namespace { + +void SetMetadataFromPixelFormat(const JxlPixelFormat* pixel_format, + jxl::ImageMetadata* metadata) { + uint32_t potential_alpha_bits = 0; + switch (pixel_format->data_type) { + case JXL_TYPE_FLOAT: + metadata->SetFloat32Samples(); + potential_alpha_bits = 16; + break; + case JXL_TYPE_FLOAT16: + metadata->SetFloat16Samples(); + potential_alpha_bits = 16; + break; + case JXL_TYPE_UINT16: + metadata->SetUintSamples(16); + potential_alpha_bits = 16; + break; + case JXL_TYPE_UINT8: + metadata->SetUintSamples(8); + potential_alpha_bits = 8; + break; + default: + JXL_ABORT("Unhandled JxlDataType"); + } + if (pixel_format->num_channels == 2 || pixel_format->num_channels == 4) { + metadata->SetAlphaBits(potential_alpha_bits); + } +} + +} // namespace + +struct JxlButteraugliResultStruct { + JxlMemoryManager memory_manager; + + jxl::ImageF distmap; + jxl::ButteraugliParams params; +}; + +struct JxlButteraugliApiStruct { + // Multiplier for penalizing new HF artifacts more than blurring away + // features. 1.0=neutral. + float hf_asymmetry = 1.0f; + + // Multiplier for the psychovisual difference in the X channel. + float xmul = 1.0f; + + // Number of nits that correspond to 1.0f input values. + float intensity_target = jxl::kDefaultIntensityTarget; + + JxlCmsInterface cms; + JxlMemoryManager memory_manager; + std::unique_ptr thread_pool{nullptr}; +}; + +JxlButteraugliApi* JxlButteraugliApiCreate( + const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlButteraugliApi)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlButteraugliApi* ret = new (alloc) JxlButteraugliApi(); + ret->cms = jxl::GetJxlCms(); + ret->memory_manager = local_memory_manager; + return ret; +} + +void JxlButteraugliApiSetParallelRunner(JxlButteraugliApi* api, + JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + api->thread_pool = jxl::make_unique(parallel_runner, + parallel_runner_opaque); +} + +void JxlButteraugliApiSetHFAsymmetry(JxlButteraugliApi* api, float v) { + api->hf_asymmetry = v; +} + +void JxlButteraugliApiSetIntensityTarget(JxlButteraugliApi* api, float v) { + api->intensity_target = v; +} + +void JxlButteraugliApiDestroy(JxlButteraugliApi* api) { + if (api) { + JxlMemoryManager local_memory_manager = api->memory_manager; + // Call destructor directly since custom free function is used. + api->~JxlButteraugliApi(); + jxl::MemoryManagerFree(&local_memory_manager, api); + } +} + +JxlButteraugliResult* JxlButteraugliCompute( + const JxlButteraugliApi* api, uint32_t xsize, uint32_t ysize, + const JxlPixelFormat* pixel_format_orig, const void* buffer_orig, + size_t size_orig, const JxlPixelFormat* pixel_format_dist, + const void* buffer_dist, size_t size_dist) { + jxl::ImageMetadata orig_metadata; + SetMetadataFromPixelFormat(pixel_format_orig, &orig_metadata); + jxl::ImageBundle orig_ib(&orig_metadata); + jxl::ColorEncoding c_current; + if (pixel_format_orig->data_type == JXL_TYPE_FLOAT) { + c_current = + jxl::ColorEncoding::LinearSRGB(pixel_format_orig->num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(pixel_format_orig->num_channels < 3); + } + if (!jxl::BufferToImageBundle(*pixel_format_orig, xsize, ysize, buffer_orig, + size_orig, api->thread_pool.get(), c_current, + &orig_ib)) { + return nullptr; + } + + jxl::ImageMetadata dist_metadata; + SetMetadataFromPixelFormat(pixel_format_dist, &dist_metadata); + jxl::ImageBundle dist_ib(&dist_metadata); + if (pixel_format_dist->data_type == JXL_TYPE_FLOAT) { + c_current = + jxl::ColorEncoding::LinearSRGB(pixel_format_dist->num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(pixel_format_dist->num_channels < 3); + } + if (!jxl::BufferToImageBundle(*pixel_format_dist, xsize, ysize, buffer_dist, + size_dist, api->thread_pool.get(), c_current, + &dist_ib)) { + return nullptr; + } + + void* alloc = jxl::MemoryManagerAlloc(&api->memory_manager, + sizeof(JxlButteraugliResult)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlButteraugliResult* result = new (alloc) JxlButteraugliResult(); + result->memory_manager = api->memory_manager; + result->params.hf_asymmetry = api->hf_asymmetry; + result->params.xmul = api->xmul; + result->params.intensity_target = api->intensity_target; + jxl::ButteraugliDistance(orig_ib, dist_ib, result->params, api->cms, + &result->distmap, api->thread_pool.get()); + + return result; +} + +float JxlButteraugliResultGetDistance(const JxlButteraugliResult* result, + float pnorm) { + return static_cast( + jxl::ComputeDistanceP(result->distmap, result->params, pnorm)); +} + +void JxlButteraugliResultGetDistmap(const JxlButteraugliResult* result, + const float** buffer, + uint32_t* row_stride) { + *buffer = result->distmap.Row(0); + *row_stride = result->distmap.PixelsPerRow(); +} + +float JxlButteraugliResultGetMaxDistance(const JxlButteraugliResult* result) { + float max_distance = 0.0; + for (uint32_t y = 0; y < result->distmap.ysize(); y++) { + for (uint32_t x = 0; x < result->distmap.xsize(); x++) { + if (result->distmap.ConstRow(y)[x] > max_distance) { + max_distance = result->distmap.ConstRow(y)[x]; + } + } + } + return max_distance; +} + +void JxlButteraugliResultDestroy(JxlButteraugliResult* result) { + if (result) { + JxlMemoryManager local_memory_manager = result->memory_manager; + // Call destructor directly since custom free function is used. + result->~JxlButteraugliResult(); + jxl::MemoryManagerFree(&local_memory_manager, result); + } +} diff --git a/media/libjxl/src/lib/jxl/byte_order_test.cc b/media/libjxl/src/lib/jxl/byte_order_test.cc new file mode 100644 index 000000000..c1ea19f31 --- /dev/null +++ b/media/libjxl/src/lib/jxl/byte_order_test.cc @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/byte_order.h" + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(ByteOrderTest, TestRoundTripBE16) { + const uint32_t in = 0x1234; + uint8_t buf[2]; + StoreBE16(in, buf); + EXPECT_EQ(in, LoadBE16(buf)); + EXPECT_NE(in, LoadLE16(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE16) { + const uint32_t in = 0x1234; + uint8_t buf[2]; + StoreLE16(in, buf); + EXPECT_EQ(in, LoadLE16(buf)); + EXPECT_NE(in, LoadBE16(buf)); +} + +TEST(ByteOrderTest, TestRoundTripBE32) { + const uint32_t in = 0xFEDCBA98u; + uint8_t buf[4]; + StoreBE32(in, buf); + EXPECT_EQ(in, LoadBE32(buf)); + EXPECT_NE(in, LoadLE32(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE32) { + const uint32_t in = 0xFEDCBA98u; + uint8_t buf[4]; + StoreLE32(in, buf); + EXPECT_EQ(in, LoadLE32(buf)); + EXPECT_NE(in, LoadBE32(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE64) { + const uint64_t in = 0xFEDCBA9876543210ull; + uint8_t buf[8]; + StoreLE64(in, buf); + EXPECT_EQ(in, LoadLE64(buf)); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/chroma_from_luma.cc b/media/libjxl/src/lib/jxl/chroma_from_luma.cc new file mode 100644 index 000000000..63d21cbb4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/chroma_from_luma.cc @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/chroma_from_luma.h" + +namespace jxl { + +ColorCorrelationMap::ColorCorrelationMap(size_t xsize, size_t ysize, bool XYB) + : ytox_map(DivCeil(xsize, kColorTileDim), DivCeil(ysize, kColorTileDim)), + ytob_map(DivCeil(xsize, kColorTileDim), DivCeil(ysize, kColorTileDim)) { + ZeroFillImage(&ytox_map); + ZeroFillImage(&ytob_map); + if (!XYB) { + base_correlation_b_ = 0; + } + RecomputeDCFactors(); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/chroma_from_luma.h b/media/libjxl/src/lib/jxl/chroma_from_luma.h new file mode 100644 index 000000000..cf2f90e43 --- /dev/null +++ b/media/libjxl/src/lib/jxl/chroma_from_luma.h @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CHROMA_FROM_LUMA_H_ +#define LIB_JXL_CHROMA_FROM_LUMA_H_ + +// Chroma-from-luma, computed using heuristics to determine the best linear +// model for the X and B channels from the Y channel. + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +// Tile is the rectangular grid of blocks that share color correlation +// parameters ("factor_x/b" such that residual_b = blue - Y * factor_b). +static constexpr size_t kColorTileDim = 64; + +static_assert(kColorTileDim % kBlockDim == 0, + "Color tile dim should be divisible by block dim"); +static constexpr size_t kColorTileDimInBlocks = kColorTileDim / kBlockDim; + +static_assert(kGroupDimInBlocks % kColorTileDimInBlocks == 0, + "Group dim should be divisible by color tile dim"); + +static constexpr uint8_t kDefaultColorFactor = 84; + +// JPEG DCT coefficients are at most 1024. CfL constants are at most 127, and +// the ratio of two entries in a JPEG quantization table is at most 255. Thus, +// since the CfL denominator is 84, this leaves 12 bits of mantissa to be used. +// For extra caution, we use 11. +static constexpr uint8_t kCFLFixedPointPrecision = 11; + +static constexpr U32Enc kColorFactorDist(Val(kDefaultColorFactor), Val(256), + BitsOffset(8, 2), BitsOffset(16, 258)); + +struct ColorCorrelationMap { + ColorCorrelationMap() = default; + // xsize/ysize are in pixels + // set XYB=false to do something close to no-op cmap (needed for now since + // cmap is mandatory) + ColorCorrelationMap(size_t xsize, size_t ysize, bool XYB = true); + + float YtoXRatio(int32_t x_factor) const { + return base_correlation_x_ + x_factor * color_scale_; + } + + float YtoBRatio(int32_t b_factor) const { + return base_correlation_b_ + b_factor * color_scale_; + } + + Status DecodeDC(BitReader* br) { + if (br->ReadFixedBits<1>() == 1) { + // All default. + return true; + } + SetColorFactor(U32Coder::Read(kColorFactorDist, br)); + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &base_correlation_x_)); + if (std::abs(base_correlation_x_) > 4.0f) { + return JXL_FAILURE("Base X correlation is out of range"); + } + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &base_correlation_b_)); + if (std::abs(base_correlation_b_) > 4.0f) { + return JXL_FAILURE("Base B correlation is out of range"); + } + ytox_dc_ = static_cast(br->ReadFixedBits()) + + std::numeric_limits::min(); + ytob_dc_ = static_cast(br->ReadFixedBits()) + + std::numeric_limits::min(); + RecomputeDCFactors(); + return true; + } + + // We consider a CfL map to be JPEG-reconstruction-compatible if base + // correlation is 0, no DC correlation is used, and we use the default color + // factor. + bool IsJPEGCompatible() const { + return base_correlation_x_ == 0 && base_correlation_b_ == 0 && + ytob_dc_ == 0 && ytox_dc_ == 0 && + color_factor_ == kDefaultColorFactor; + } + + int32_t RatioJPEG(int32_t factor) const { + return factor * (1 << kCFLFixedPointPrecision) / kDefaultColorFactor; + } + + void SetColorFactor(uint32_t factor) { + color_factor_ = factor; + color_scale_ = 1.0f / color_factor_; + RecomputeDCFactors(); + } + + void SetYToBDC(int32_t ytob_dc) { + ytob_dc_ = ytob_dc; + RecomputeDCFactors(); + } + void SetYToXDC(int32_t ytox_dc) { + ytox_dc_ = ytox_dc; + RecomputeDCFactors(); + } + + int32_t GetYToXDC() const { return ytox_dc_; } + int32_t GetYToBDC() const { return ytob_dc_; } + float GetColorFactor() const { return color_factor_; } + float GetBaseCorrelationX() const { return base_correlation_x_; } + float GetBaseCorrelationB() const { return base_correlation_b_; } + + const float* DCFactors() const { return dc_factors_; } + + void RecomputeDCFactors() { + dc_factors_[0] = YtoXRatio(ytox_dc_); + dc_factors_[2] = YtoBRatio(ytob_dc_); + } + + ImageSB ytox_map; + ImageSB ytob_map; + + private: + float dc_factors_[4] = {}; + // range of factor: -1.51 to +1.52 + uint32_t color_factor_ = kDefaultColorFactor; + float color_scale_ = 1.0f / color_factor_; + float base_correlation_x_ = 0.0f; + float base_correlation_b_ = kYToBRatio; + int32_t ytox_dc_ = 0; + int32_t ytob_dc_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_CHROMA_FROM_LUMA_H_ diff --git a/media/libjxl/src/lib/jxl/codec_in_out.h b/media/libjxl/src/lib/jxl/codec_in_out.h new file mode 100644 index 000000000..23f0a4afe --- /dev/null +++ b/media/libjxl/src/lib/jxl/codec_in_out.h @@ -0,0 +1,219 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CODEC_IN_OUT_H_ +#define LIB_JXL_CODEC_IN_OUT_H_ + +// Holds inputs/outputs for decoding/encoding images. + +#include + +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" +#include "lib/jxl/size_constraints.h" + +namespace jxl { + +// Per-channel interval, used to convert between (full-range) external and +// (bounded or unbounded) temp values. See external_image.cc for the definitions +// of temp/external. +struct CodecInterval { + CodecInterval() = default; + constexpr CodecInterval(float min, float max) : min(min), width(max - min) {} + // Defaults for temp. + float min = 0.0f; + float width = 1.0f; +}; + +template ::value>::type> +Status VerifyDimensions(const SizeConstraints* constraints, T xs, T ys) { + if (!constraints) return true; + + if (xs == 0 || ys == 0) return JXL_FAILURE("Empty image."); + if (xs > constraints->dec_max_xsize) return JXL_FAILURE("Image too wide."); + if (ys > constraints->dec_max_ysize) return JXL_FAILURE("Image too tall."); + + const uint64_t num_pixels = static_cast(xs) * ys; + if (num_pixels > constraints->dec_max_pixels) { + return JXL_FAILURE("Image too big."); + } + + return true; +} + +using CodecIntervals = std::array; // RGB[A] or Y[A] + +// Optional text/EXIF metadata. +struct Blobs { + std::vector exif; + std::vector iptc; + std::vector jumbf; + std::vector xmp; +}; + +// Holds a preview, a main image or one or more frames, plus the inputs/outputs +// to/from decoding/encoding. +class CodecInOut { + public: + CodecInOut() : preview_frame(&metadata.m) { + frames.reserve(1); + frames.emplace_back(&metadata.m); + } + + // Move-only. + CodecInOut(CodecInOut&&) = default; + CodecInOut& operator=(CodecInOut&&) = default; + + size_t LastStillFrame() const { + JXL_DASSERT(!frames.empty()); + size_t last = 0; + for (size_t i = 0; i < frames.size(); i++) { + last = i; + if (frames[i].duration > 0) break; + } + return last; + } + + ImageBundle& Main() { return frames[LastStillFrame()]; } + const ImageBundle& Main() const { return frames[LastStillFrame()]; } + + // If c_current.IsGray(), all planes must be identical. + void SetFromImage(Image3F&& color, const ColorEncoding& c_current) { + Main().SetFromImage(std::move(color), c_current); + SetIntensityTarget(this); + SetSize(Main().xsize(), Main().ysize()); + } + + void SetSize(size_t xsize, size_t ysize) { + JXL_CHECK(metadata.size.Set(xsize, ysize)); + } + + void CheckMetadata() const { + JXL_CHECK(metadata.m.bit_depth.bits_per_sample != 0); + JXL_CHECK(!metadata.m.color_encoding.ICC().empty()); + + if (preview_frame.xsize() != 0) preview_frame.VerifyMetadata(); + JXL_CHECK(preview_frame.metadata() == &metadata.m); + + for (const ImageBundle& ib : frames) { + ib.VerifyMetadata(); + JXL_CHECK(ib.metadata() == &metadata.m); + } + } + + size_t xsize() const { return metadata.size.xsize(); } + size_t ysize() const { return metadata.size.ysize(); } + void ShrinkTo(size_t xsize, size_t ysize) { + // preview is unaffected. + for (ImageBundle& ib : frames) { + ib.ShrinkTo(xsize, ysize); + } + SetSize(xsize, ysize); + } + + // Calls TransformTo for each ImageBundle (preview/frames). + Status TransformTo(const ColorEncoding& c_desired, const JxlCmsInterface& cms, + ThreadPool* pool = nullptr) { + if (metadata.m.have_preview) { + JXL_RETURN_IF_ERROR(preview_frame.TransformTo(c_desired, cms, pool)); + } + for (ImageBundle& ib : frames) { + JXL_RETURN_IF_ERROR(ib.TransformTo(c_desired, cms, pool)); + } + return true; + } + // Performs "PremultiplyAlpha" for each ImageBundle (preview/frames). + bool PremultiplyAlpha() { + const auto doPremultiplyAlpha = [](ImageBundle& bundle) { + if (!bundle.HasAlpha()) return; + if (!bundle.HasColor()) return; + auto* color = bundle.color(); + const auto* alpha = bundle.alpha(); + JXL_CHECK(color->ysize() == alpha->ysize()); + JXL_CHECK(color->xsize() == alpha->xsize()); + for (size_t y = 0; y < color->ysize(); y++) { + ::jxl::PremultiplyAlpha(color->PlaneRow(0, y), color->PlaneRow(1, y), + color->PlaneRow(2, y), alpha->Row(y), + color->xsize()); + } + }; + ExtraChannelInfo* eci = metadata.m.Find(ExtraChannel::kAlpha); + if (eci == nullptr || eci->alpha_associated) return false; + if (metadata.m.have_preview) { + doPremultiplyAlpha(preview_frame); + } + for (ImageBundle& ib : frames) { + doPremultiplyAlpha(ib); + } + eci->alpha_associated = true; + return true; + } + + bool UnpremultiplyAlpha() { + const auto doUnpremultiplyAlpha = [](ImageBundle& bundle) { + if (!bundle.HasAlpha()) return; + if (!bundle.HasColor()) return; + auto* color = bundle.color(); + const auto* alpha = bundle.alpha(); + JXL_CHECK(color->ysize() == alpha->ysize()); + JXL_CHECK(color->xsize() == alpha->xsize()); + for (size_t y = 0; y < color->ysize(); y++) { + ::jxl::UnpremultiplyAlpha(color->PlaneRow(0, y), color->PlaneRow(1, y), + color->PlaneRow(2, y), alpha->Row(y), + color->xsize()); + } + }; + ExtraChannelInfo* eci = metadata.m.Find(ExtraChannel::kAlpha); + if (eci == nullptr || !eci->alpha_associated) return false; + if (metadata.m.have_preview) { + doUnpremultiplyAlpha(preview_frame); + } + for (ImageBundle& ib : frames) { + doUnpremultiplyAlpha(ib); + } + eci->alpha_associated = false; + return true; + } + + // -- DECODER INPUT: + + SizeConstraints constraints; + + // -- DECODER OUTPUT: + + // Total number of pixels decoded (may differ from #frames * xsize * ysize + // if frames are cropped) + uint64_t dec_pixels = 0; + + // -- DECODER OUTPUT, ENCODER INPUT: + + // Metadata stored into / retrieved from bitstreams. + + Blobs blobs; + + CodecMetadata metadata; // applies to preview and all frames + + // If metadata.have_preview: + ImageBundle preview_frame; + + std::vector frames; // size=1 if !metadata.have_animation + + // If the image should be written to a JPEG, use this quality for encoding. + size_t jpeg_quality; +}; + +} // namespace jxl + +#endif // LIB_JXL_CODEC_IN_OUT_H_ diff --git a/media/libjxl/src/lib/jxl/codec_y4m_testonly.cc b/media/libjxl/src/lib/jxl/codec_y4m_testonly.cc new file mode 100644 index 000000000..dfcad9db5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/codec_y4m_testonly.cc @@ -0,0 +1,202 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/codec_y4m_testonly.h" + +#include + +namespace jxl { +namespace test { + +struct HeaderY4M { + size_t xsize; + size_t ysize; + size_t bits_per_sample; + int is_yuv; // Y4M: where 1 = 444, 2 = 422, 3 = 420 +}; + +// Decode Y4M images. +class Y4MParser { + public: + explicit Y4MParser(const Span bytes) + : pos_(bytes.data()), end_(pos_ + bytes.size()) {} + + // TODO(jon): support multi-frame y4m + Status ParseHeader(HeaderY4M* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(ExpectString("YUV4MPEG2", 9)); + header->is_yuv = 3; + // TODO(jon): check if 4:2:0 is indeed the default + header->bits_per_sample = 8; + // TODO(jon): check if there's a y4m convention for higher bit depths + while (pos_ < end_) { + char next = 0; + JXL_RETURN_IF_ERROR(ReadChar(&next)); + if (next == 0x0A) break; + if (next != ' ') continue; + char field = 0; + JXL_RETURN_IF_ERROR(ReadChar(&field)); + switch (field) { + case 'W': + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + break; + case 'H': + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + break; + case 'I': + JXL_RETURN_IF_ERROR(ReadChar(&next)); + if (next != 'p') { + return JXL_FAILURE( + "Y4M: only progressive (no frame interlacing) allowed"); + } + break; + case 'C': { + char c1 = 0; + JXL_RETURN_IF_ERROR(ReadChar(&c1)); + char c2 = 0; + JXL_RETURN_IF_ERROR(ReadChar(&c2)); + char c3 = 0; + JXL_RETURN_IF_ERROR(ReadChar(&c3)); + if (c1 != '4') return JXL_FAILURE("Y4M: invalid C param"); + if (c2 == '4') { + if (c3 != '4') return JXL_FAILURE("Y4M: invalid C param"); + header->is_yuv = 1; // 444 + } else if (c2 == '2') { + if (c3 == '2') { + header->is_yuv = 2; // 422 + } else if (c3 == '0') { + header->is_yuv = 3; // 420 + } else { + return JXL_FAILURE("Y4M: invalid C param"); + } + } else { + return JXL_FAILURE("Y4M: invalid C param"); + } + } + [[fallthrough]]; + // no break: fallthrough because this field can have values like + // "C420jpeg" (we are ignoring the chroma sample location and treat + // everything like C420jpeg) + case 'F': // Framerate in fps as numerator:denominator + // TODO(jon): actually read this and set corresponding jxl + // metadata + case 'A': // Pixel aspect ratio (ignoring it, could perhaps adjust + // intrinsic dimensions based on this?) + case 'X': // Comment, ignore + // ignore the field value and go to next one + while (pos_ < end_) { + if (pos_[0] == ' ' || pos_[0] == 0x0A) break; + pos_++; + } + break; + default: + return JXL_FAILURE("Y4M: parse error"); + } + } + JXL_RETURN_IF_ERROR(ExpectString("FRAME", 5)); + while (true) { + char next = 0; + JXL_RETURN_IF_ERROR(ReadChar(&next)); + if (next == 0x0A) { + *pos = pos_; + return true; + } + } + } + + private: + Status ExpectString(const char* str, size_t len) { + // Unlikely to happen. + if (pos_ + len < pos_) return JXL_FAILURE("Y4M: overflow"); + + if (pos_ + len > end_ || strncmp(str, (const char*)pos_, len) != 0) { + return JXL_FAILURE("Y4M: expected %s", str); + } + pos_ += len; + return true; + } + + Status ReadChar(char* out) { + // Unlikely to happen. + if (pos_ + 1 < pos_) return JXL_FAILURE("Y4M: overflow"); + + if (pos_ >= end_) { + return JXL_FAILURE("Y4M: unexpected end of input"); + } + *out = *pos_; + pos_++; + return true; + } + + static bool IsDigit(const uint8_t c) { return '0' <= c && c <= '9'; } + + Status ParseUnsigned(size_t* number) { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before number"); + if (!IsDigit(*pos_)) return JXL_FAILURE("PNM: expected unsigned number"); + + *number = 0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + return true; + } + + const uint8_t* pos_; + const uint8_t* const end_; +}; + +Status DecodeImageY4M(const Span bytes, CodecInOut* io) { + Y4MParser parser(bytes); + HeaderY4M header = {}; + const uint8_t* pos = nullptr; + JXL_RETURN_IF_ERROR(parser.ParseHeader(&header, &pos)); + + Image3F yuvdata(header.xsize, header.ysize); + ImageBundle bundle(&io->metadata.m); + const int hshift[3][3] = {{0, 0, 0}, {0, 1, 1}, {0, 1, 1}}; + const int vshift[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 1, 1}}; + + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < header.ysize >> vshift[header.is_yuv - 1][c]; ++y) { + float* const JXL_RESTRICT row = yuvdata.PlaneRow((c == 2 ? 2 : 1 - c), y); + if (pos + (header.xsize >> hshift[header.is_yuv - 1][c]) > + bytes.data() + bytes.size()) + return JXL_FAILURE("Not enough image data"); + for (size_t x = 0; x < header.xsize >> hshift[header.is_yuv - 1][c]; + ++x) { + row[x] = (1.f / 255.f) * ((*pos++) - 128.f); + } + } + } + bundle.SetFromImage(std::move(yuvdata), io->metadata.m.color_encoding); + bundle.color_transform = ColorTransform::kYCbCr; + + YCbCrChromaSubsampling subsampling; + uint8_t cssh[3] = { + 2, static_cast(hshift[header.is_yuv - 1][1] ? 1 : 2), + static_cast(hshift[header.is_yuv - 1][2] ? 1 : 2)}; + uint8_t cssv[3] = { + 2, static_cast(vshift[header.is_yuv - 1][1] ? 1 : 2), + static_cast(vshift[header.is_yuv - 1][2] ? 1 : 2)}; + + JXL_RETURN_IF_ERROR(subsampling.Set(cssh, cssv)); + bundle.chroma_subsampling = subsampling; + io->Main() = std::move(bundle); + + JXL_RETURN_IF_ERROR(io->metadata.m.color_encoding.SetSRGB(ColorSpace::kRGB)); + io->metadata.m.SetUintSamples(header.bits_per_sample); + io->metadata.m.SetAlphaBits(0); + io->dec_pixels = header.xsize * header.ysize; + + io->metadata.m.bit_depth.bits_per_sample = io->Main().DetectRealBitdepth(); + io->SetSize(header.xsize, header.ysize); + SetIntensityTarget(io); + return true; +} + +} // namespace test +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/codec_y4m_testonly.h b/media/libjxl/src/lib/jxl/codec_y4m_testonly.h new file mode 100644 index 000000000..f65759dfe --- /dev/null +++ b/media/libjxl/src/lib/jxl/codec_y4m_testonly.h @@ -0,0 +1,18 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace test { + +Status DecodeImageY4M(const Span bytes, CodecInOut* io); + +} // namespace test +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/coeff_order.cc b/media/libjxl/src/lib/jxl/coeff_order.cc new file mode 100644 index 000000000..399febb83 --- /dev/null +++ b/media/libjxl/src/lib/jxl/coeff_order.cc @@ -0,0 +1,155 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/coeff_order.h" + +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/lehmer_code.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +uint32_t CoeffOrderContext(uint32_t val) { + uint32_t token, nbits, bits; + HybridUintConfig(0, 0, 0).Encode(val, &token, &nbits, &bits); + return std::min(token, kPermutationContexts - 1); +} + +namespace { +Status ReadPermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br, ANSSymbolReader* reader, + const std::vector& context_map) { + std::vector lehmer(size); + // temp space needs to be as large as the next power of 2, so doubling the + // allocated size is enough. + std::vector temp(size * 2); + uint32_t end = + reader->ReadHybridUint(CoeffOrderContext(size), br, context_map) + skip; + if (end > size) { + return JXL_FAILURE("Invalid permutation size"); + } + uint32_t last = 0; + for (size_t i = skip; i < end; ++i) { + lehmer[i] = + reader->ReadHybridUint(CoeffOrderContext(last), br, context_map); + last = lehmer[i]; + if (lehmer[i] + i >= size) { + return JXL_FAILURE("Invalid lehmer code"); + } + } + if (order == nullptr) return true; + DecodeLehmerCode(lehmer.data(), temp.data(), size, order); + return true; +} + +} // namespace + +Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br) { + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kPermutationContexts, &code, &context_map)); + ANSSymbolReader reader(&code, br); + JXL_RETURN_IF_ERROR( + ReadPermutation(skip, size, order, br, &reader, context_map)); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("Invalid ANS stream"); + } + return true; +} + +namespace { + +Status DecodeCoeffOrder(AcStrategy acs, coeff_order_t* order, BitReader* br, + ANSSymbolReader* reader, + std::vector& natural_order, + const std::vector& context_map) { + PROFILER_FUNC; + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + + JXL_RETURN_IF_ERROR( + ReadPermutation(llf, size, order, br, reader, context_map)); + if (order == nullptr) return true; + for (size_t k = 0; k < size; ++k) { + order[k] = natural_order[order[k]]; + } + return true; +} + +} // namespace + +Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, + coeff_order_t* order, BitReader* br) { + uint16_t computed = 0; + std::vector context_map; + ANSCode code; + std::unique_ptr reader; + std::vector natural_order; + // Bitstream does not have histograms if no coefficient order is used. + if (used_orders != 0) { + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kPermutationContexts, &code, &context_map)); + reader = make_unique(&code, br); + } + uint32_t acs_mask = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + if ((used_acs & (1 << o)) == 0) continue; + acs_mask |= 1 << kStrategyOrder[o]; + } + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + bool used = (acs_mask & (1 << ord)) != 0; + + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + + if (used || (used_orders & (1 << ord))) { + if (natural_order.size() < size) natural_order.resize(size); + acs.ComputeNaturalCoeffOrder(natural_order.data()); + } + + if ((used_orders & (1 << ord)) == 0) { + // No need to set the default order if no ACS uses this order. + if (used) { + for (size_t c = 0; c < 3; c++) { + memcpy(&order[CoeffOrderOffset(ord, c)], natural_order.data(), + size * sizeof(*order)); + } + } + } else { + for (size_t c = 0; c < 3; c++) { + coeff_order_t* dest = used ? &order[CoeffOrderOffset(ord, c)] : nullptr; + JXL_RETURN_IF_ERROR(DecodeCoeffOrder(acs, dest, br, reader.get(), + natural_order, context_map)); + } + } + } + if (used_orders && !reader->CheckANSFinalState()) { + return JXL_FAILURE("Invalid ANS stream"); + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/coeff_order.h b/media/libjxl/src/lib/jxl/coeff_order.h new file mode 100644 index 000000000..50618514d --- /dev/null +++ b/media/libjxl/src/lib/jxl/coeff_order.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COEFF_ORDER_H_ +#define LIB_JXL_COEFF_ORDER_H_ + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +// Those offsets get multiplied by kDCTBlockSize. +static constexpr size_t kCoeffOrderOffset[] = { + 0, 1, 2, 3, 4, 5, 6, 10, 14, 18, + 34, 50, 66, 68, 70, 72, 76, 80, 84, 92, + 100, 108, 172, 236, 300, 332, 364, 396, 652, 908, + 1164, 1292, 1420, 1548, 2572, 3596, 4620, 5132, 5644, 6156, +}; +static_assert(3 * kNumOrders + 1 == + sizeof(kCoeffOrderOffset) / sizeof(*kCoeffOrderOffset), + "Update this array when adding or removing order types."); + +static constexpr size_t CoeffOrderOffset(size_t order, size_t c) { + return kCoeffOrderOffset[3 * order + c] * kDCTBlockSize; +} + +static constexpr size_t kCoeffOrderMaxSize = + kCoeffOrderOffset[3 * kNumOrders] * kDCTBlockSize; + +// Mapping from AC strategy to order bucket. Strategies with different natural +// orders must have different buckets. +constexpr uint8_t kStrategyOrder[] = { + 0, 1, 1, 1, 2, 3, 4, 4, 5, 5, 6, 6, 1, 1, + 1, 1, 1, 1, 7, 8, 8, 9, 10, 10, 11, 12, 12, +}; + +static_assert(AcStrategy::kNumValidStrategies == + sizeof(kStrategyOrder) / sizeof(*kStrategyOrder), + "Update this array when adding or removing AC strategies."); + +constexpr uint32_t kPermutationContexts = 8; + +uint32_t CoeffOrderContext(uint32_t val); + +Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, + coeff_order_t* order, BitReader* br); + +Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_COEFF_ORDER_H_ diff --git a/media/libjxl/src/lib/jxl/coeff_order_fwd.h b/media/libjxl/src/lib/jxl/coeff_order_fwd.h new file mode 100644 index 000000000..700e9a83d --- /dev/null +++ b/media/libjxl/src/lib/jxl/coeff_order_fwd.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COEFF_ORDER_FWD_H_ +#define LIB_JXL_COEFF_ORDER_FWD_H_ + +// Breaks circular dependency between ac_strategy and coeff_order. + +#include +#include + +#include "base/compiler_specific.h" + +namespace jxl { + +// Needs at least 16 bits. A 32-bit type speeds up DecodeAC by 2% at the cost of +// more memory. +using coeff_order_t = uint32_t; + +// Maximum number of orders to be used. Note that this needs to be multiplied by +// the number of channels. One per "size class" (plus one extra for DCT8), +// shared between transforms of size XxY and of size YxX. +constexpr uint8_t kNumOrders = 13; + +// DCT coefficients are laid out in such a way that the number of rows of +// coefficients is always the smaller coordinate. +JXL_INLINE constexpr size_t CoefficientRows(size_t rows, size_t columns) { + return rows < columns ? rows : columns; +} + +JXL_INLINE constexpr size_t CoefficientColumns(size_t rows, size_t columns) { + return rows < columns ? columns : rows; +} + +JXL_INLINE void CoefficientLayout(size_t* JXL_RESTRICT rows, + size_t* JXL_RESTRICT columns) { + size_t r = *rows; + size_t c = *columns; + *rows = CoefficientRows(r, c); + *columns = CoefficientColumns(r, c); +} + +} // namespace jxl + +#endif // LIB_JXL_COEFF_ORDER_FWD_H_ diff --git a/media/libjxl/src/lib/jxl/coeff_order_test.cc b/media/libjxl/src/lib/jxl/coeff_order_test.cc new file mode 100644 index 000000000..810c72559 --- /dev/null +++ b/media/libjxl/src/lib/jxl/coeff_order_test.cc @@ -0,0 +1,97 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/coeff_order.h" + +#include + +#include +#include // iota +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_coeff_order.h" + +namespace jxl { +namespace { + +void RoundtripPermutation(coeff_order_t* perm, coeff_order_t* out, size_t len, + size_t* size) { + BitWriter writer; + EncodePermutation(perm, 0, len, &writer, 0, nullptr); + writer.ZeroPadToByte(); + Status status = true; + { + BitReader reader(writer.GetSpan()); + BitReaderScopedCloser closer(&reader, &status); + ASSERT_TRUE(DecodePermutation(0, len, out, &reader)); + } + ASSERT_TRUE(status); + *size = writer.GetSpan().size(); +} + +enum Permutation { kIdentity, kFewSwaps, kFewSlides, kRandom }; + +constexpr size_t kSwaps = 32; + +void TestPermutation(Permutation kind, size_t len) { + std::vector perm(len); + std::iota(perm.begin(), perm.end(), 0); + Rng rng(0); + if (kind == kFewSwaps) { + for (size_t i = 0; i < kSwaps; i++) { + size_t a = rng.UniformU(0, len - 1); + size_t b = rng.UniformU(0, len - 1); + std::swap(perm[a], perm[b]); + } + } + if (kind == kFewSlides) { + for (size_t i = 0; i < kSwaps; i++) { + size_t a = rng.UniformU(0, len - 1); + size_t b = rng.UniformU(0, len - 1); + size_t from = std::min(a, b); + size_t to = std::max(a, b); + size_t start = perm[from]; + for (size_t j = from; j < to; j++) { + perm[j] = perm[j + 1]; + } + perm[to] = start; + } + } + if (kind == kRandom) { + rng.Shuffle(perm.data(), perm.size()); + } + std::vector out(len); + size_t size = 0; + RoundtripPermutation(perm.data(), out.data(), len, &size); + for (size_t idx = 0; idx < len; idx++) { + EXPECT_EQ(perm[idx], out[idx]); + } + printf("Encoded size: %" PRIuS "\n", size); +} + +TEST(CoeffOrderTest, IdentitySmall) { TestPermutation(kIdentity, 256); } +TEST(CoeffOrderTest, FewSlidesSmall) { TestPermutation(kFewSlides, 256); } +TEST(CoeffOrderTest, FewSwapsSmall) { TestPermutation(kFewSwaps, 256); } +TEST(CoeffOrderTest, RandomSmall) { TestPermutation(kRandom, 256); } + +TEST(CoeffOrderTest, IdentityMedium) { TestPermutation(kIdentity, 1 << 12); } +TEST(CoeffOrderTest, FewSlidesMedium) { TestPermutation(kFewSlides, 1 << 12); } +TEST(CoeffOrderTest, FewSwapsMedium) { TestPermutation(kFewSwaps, 1 << 12); } +TEST(CoeffOrderTest, RandomMedium) { TestPermutation(kRandom, 1 << 12); } + +TEST(CoeffOrderTest, IdentityBig) { TestPermutation(kIdentity, 1 << 16); } +TEST(CoeffOrderTest, FewSlidesBig) { TestPermutation(kFewSlides, 1 << 16); } +TEST(CoeffOrderTest, FewSwapsBig) { TestPermutation(kFewSwaps, 1 << 16); } +TEST(CoeffOrderTest, RandomBig) { TestPermutation(kRandom, 1 << 16); } + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/color_encoding_internal.cc b/media/libjxl/src/lib/jxl/color_encoding_internal.cc new file mode 100644 index 000000000..a2eca448c --- /dev/null +++ b/media/libjxl/src/lib/jxl/color_encoding_internal.cc @@ -0,0 +1,752 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/color_encoding_internal.h" + +#include + +#include +#include + +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/linalg.h" + +namespace jxl { +namespace { + +// Highest reasonable value for the gamma of a transfer curve. +constexpr uint32_t kMaxGamma = 8192; + +// These strings are baked into Description - do not change. + +std::string ToString(ColorSpace color_space) { + switch (color_space) { + case ColorSpace::kRGB: + return "RGB"; + case ColorSpace::kGray: + return "Gra"; + case ColorSpace::kXYB: + return "XYB"; + case ColorSpace::kUnknown: + return "CS?"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid ColorSpace %u", static_cast(color_space)); +} + +std::string ToString(WhitePoint white_point) { + switch (white_point) { + case WhitePoint::kD65: + return "D65"; + case WhitePoint::kCustom: + return "Cst"; + case WhitePoint::kE: + return "EER"; + case WhitePoint::kDCI: + return "DCI"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid WhitePoint %u", static_cast(white_point)); +} + +std::string ToString(Primaries primaries) { + switch (primaries) { + case Primaries::kSRGB: + return "SRG"; + case Primaries::k2100: + return "202"; + case Primaries::kP3: + return "DCI"; + case Primaries::kCustom: + return "Cst"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid Primaries %u", static_cast(primaries)); +} + +std::string ToString(TransferFunction transfer_function) { + switch (transfer_function) { + case TransferFunction::kSRGB: + return "SRG"; + case TransferFunction::kLinear: + return "Lin"; + case TransferFunction::k709: + return "709"; + case TransferFunction::kPQ: + return "PeQ"; + case TransferFunction::kHLG: + return "HLG"; + case TransferFunction::kDCI: + return "DCI"; + case TransferFunction::kUnknown: + return "TF?"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid TransferFunction %u", + static_cast(transfer_function)); +} + +std::string ToString(RenderingIntent rendering_intent) { + switch (rendering_intent) { + case RenderingIntent::kPerceptual: + return "Per"; + case RenderingIntent::kRelative: + return "Rel"; + case RenderingIntent::kSaturation: + return "Sat"; + case RenderingIntent::kAbsolute: + return "Abs"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid RenderingIntent %u", + static_cast(rendering_intent)); +} + +static double F64FromCustomxyI32(const int32_t i) { return i * 1E-6; } +static Status F64ToCustomxyI32(const double f, int32_t* JXL_RESTRICT i) { + if (!(-4 <= f && f <= 4)) { + return JXL_FAILURE("F64 out of bounds for CustomxyI32"); + } + *i = static_cast(roundf(f * 1E6)); + return true; +} + +Status ConvertExternalToInternalWhitePoint(const JxlWhitePoint external, + WhitePoint* internal) { + switch (external) { + case JXL_WHITE_POINT_D65: + *internal = WhitePoint::kD65; + return true; + case JXL_WHITE_POINT_CUSTOM: + *internal = WhitePoint::kCustom; + return true; + case JXL_WHITE_POINT_E: + *internal = WhitePoint::kE; + return true; + case JXL_WHITE_POINT_DCI: + *internal = WhitePoint::kDCI; + return true; + } + return JXL_FAILURE("Invalid WhitePoint enum value"); +} + +Status ConvertExternalToInternalPrimaries(const JxlPrimaries external, + Primaries* internal) { + switch (external) { + case JXL_PRIMARIES_SRGB: + *internal = Primaries::kSRGB; + return true; + case JXL_PRIMARIES_CUSTOM: + *internal = Primaries::kCustom; + return true; + case JXL_PRIMARIES_2100: + *internal = Primaries::k2100; + return true; + case JXL_PRIMARIES_P3: + *internal = Primaries::kP3; + return true; + } + return JXL_FAILURE("Invalid Primaries enum value"); +} + +Status ConvertExternalToInternalTransferFunction( + const JxlTransferFunction external, TransferFunction* internal) { + switch (external) { + case JXL_TRANSFER_FUNCTION_709: + *internal = TransferFunction::k709; + return true; + case JXL_TRANSFER_FUNCTION_UNKNOWN: + *internal = TransferFunction::kUnknown; + return true; + case JXL_TRANSFER_FUNCTION_LINEAR: + *internal = TransferFunction::kLinear; + return true; + case JXL_TRANSFER_FUNCTION_SRGB: + *internal = TransferFunction::kSRGB; + return true; + case JXL_TRANSFER_FUNCTION_PQ: + *internal = TransferFunction::kPQ; + return true; + case JXL_TRANSFER_FUNCTION_DCI: + *internal = TransferFunction::kDCI; + return true; + case JXL_TRANSFER_FUNCTION_HLG: + *internal = TransferFunction::kHLG; + return true; + case JXL_TRANSFER_FUNCTION_GAMMA: + return JXL_FAILURE("Gamma should be handled separately"); + } + return JXL_FAILURE("Invalid TransferFunction enum value"); +} + +Status ConvertExternalToInternalRenderingIntent( + const JxlRenderingIntent external, RenderingIntent* internal) { + switch (external) { + case JXL_RENDERING_INTENT_PERCEPTUAL: + *internal = RenderingIntent::kPerceptual; + return true; + case JXL_RENDERING_INTENT_RELATIVE: + *internal = RenderingIntent::kRelative; + return true; + case JXL_RENDERING_INTENT_SATURATION: + *internal = RenderingIntent::kSaturation; + return true; + case JXL_RENDERING_INTENT_ABSOLUTE: + *internal = RenderingIntent::kAbsolute; + return true; + } + return JXL_FAILURE("Invalid RenderingIntent enum value"); +} + +} // namespace + +CIExy Customxy::Get() const { + CIExy xy; + xy.x = F64FromCustomxyI32(x); + xy.y = F64FromCustomxyI32(y); + return xy; +} + +Status Customxy::Set(const CIExy& xy) { + JXL_RETURN_IF_ERROR(F64ToCustomxyI32(xy.x, &x)); + JXL_RETURN_IF_ERROR(F64ToCustomxyI32(xy.y, &y)); + size_t extension_bits, total_bits; + if (!Bundle::CanEncode(*this, &extension_bits, &total_bits)) { + return JXL_FAILURE("Unable to encode XY %f %f", xy.x, xy.y); + } + return true; +} + +bool CustomTransferFunction::SetImplicit() { + if (nonserialized_color_space == ColorSpace::kXYB) { + if (!SetGamma(1.0 / 3)) JXL_ASSERT(false); + return true; + } + return false; +} + +Status CustomTransferFunction::SetGamma(double gamma) { + if (gamma < (1.0f / kMaxGamma) || gamma > 1.0) { + return JXL_FAILURE("Invalid gamma %f", gamma); + } + + have_gamma_ = false; + if (ApproxEq(gamma, 1.0)) { + transfer_function_ = TransferFunction::kLinear; + return true; + } + if (ApproxEq(gamma, 1.0 / 2.6)) { + transfer_function_ = TransferFunction::kDCI; + return true; + } + // Don't translate 0.45.. to kSRGB nor k709 - that might change pixel + // values because those curves also have a linear part. + + have_gamma_ = true; + gamma_ = roundf(gamma * kGammaMul); + transfer_function_ = TransferFunction::kUnknown; + return true; +} + +namespace { + +std::array CreateC2(const Primaries pr, + const TransferFunction tf) { + std::array c2; + + { + ColorEncoding* c_rgb = c2.data() + 0; + c_rgb->SetColorSpace(ColorSpace::kRGB); + c_rgb->white_point = WhitePoint::kD65; + c_rgb->primaries = pr; + c_rgb->tf.SetTransferFunction(tf); + JXL_CHECK(c_rgb->CreateICC()); + } + + { + ColorEncoding* c_gray = c2.data() + 1; + c_gray->SetColorSpace(ColorSpace::kGray); + c_gray->white_point = WhitePoint::kD65; + c_gray->primaries = pr; + c_gray->tf.SetTransferFunction(tf); + JXL_CHECK(c_gray->CreateICC()); + } + + return c2; +} + +} // namespace + +const ColorEncoding& ColorEncoding::SRGB(bool is_gray) { + static std::array c2 = + CreateC2(Primaries::kSRGB, TransferFunction::kSRGB); + return c2[is_gray]; +} +const ColorEncoding& ColorEncoding::LinearSRGB(bool is_gray) { + static std::array c2 = + CreateC2(Primaries::kSRGB, TransferFunction::kLinear); + return c2[is_gray]; +} + +CIExy ColorEncoding::GetWhitePoint() const { + JXL_DASSERT(have_fields_); + CIExy xy; + switch (white_point) { + case WhitePoint::kCustom: + return white_.Get(); + + case WhitePoint::kD65: + xy.x = 0.3127; + xy.y = 0.3290; + return xy; + + case WhitePoint::kDCI: + // From https://ieeexplore.ieee.org/document/7290729 C.2 page 11 + xy.x = 0.314; + xy.y = 0.351; + return xy; + + case WhitePoint::kE: + xy.x = xy.y = 1.0 / 3; + return xy; + } + JXL_ABORT("Invalid WhitePoint %u", static_cast(white_point)); +} + +Status ColorEncoding::SetWhitePoint(const CIExy& xy) { + JXL_DASSERT(have_fields_); + if (xy.x == 0.0 || xy.y == 0.0) { + return JXL_FAILURE("Invalid white point %f %f", xy.x, xy.y); + } + if (ApproxEq(xy.x, 0.3127) && ApproxEq(xy.y, 0.3290)) { + white_point = WhitePoint::kD65; + return true; + } + if (ApproxEq(xy.x, 1.0 / 3) && ApproxEq(xy.y, 1.0 / 3)) { + white_point = WhitePoint::kE; + return true; + } + if (ApproxEq(xy.x, 0.314) && ApproxEq(xy.y, 0.351)) { + white_point = WhitePoint::kDCI; + return true; + } + white_point = WhitePoint::kCustom; + return white_.Set(xy); +} + +PrimariesCIExy ColorEncoding::GetPrimaries() const { + JXL_DASSERT(have_fields_); + JXL_ASSERT(HasPrimaries()); + PrimariesCIExy xy; + switch (primaries) { + case Primaries::kCustom: + xy.r = red_.Get(); + xy.g = green_.Get(); + xy.b = blue_.Get(); + return xy; + + case Primaries::kSRGB: + xy.r.x = 0.639998686; + xy.r.y = 0.330010138; + xy.g.x = 0.300003784; + xy.g.y = 0.600003357; + xy.b.x = 0.150002046; + xy.b.y = 0.059997204; + return xy; + + case Primaries::k2100: + xy.r.x = 0.708; + xy.r.y = 0.292; + xy.g.x = 0.170; + xy.g.y = 0.797; + xy.b.x = 0.131; + xy.b.y = 0.046; + return xy; + + case Primaries::kP3: + xy.r.x = 0.680; + xy.r.y = 0.320; + xy.g.x = 0.265; + xy.g.y = 0.690; + xy.b.x = 0.150; + xy.b.y = 0.060; + return xy; + } + JXL_ABORT("Invalid Primaries %u", static_cast(primaries)); +} + +Status ColorEncoding::SetPrimaries(const PrimariesCIExy& xy) { + JXL_DASSERT(have_fields_); + JXL_ASSERT(HasPrimaries()); + if (xy.r.x == 0.0 || xy.r.y == 0.0 || xy.g.x == 0.0 || xy.g.y == 0.0 || + xy.b.x == 0.0 || xy.b.y == 0.0) { + return JXL_FAILURE("Invalid primaries %f %f %f %f %f %f", xy.r.x, xy.r.y, + xy.g.x, xy.g.y, xy.b.x, xy.b.y); + } + + if (ApproxEq(xy.r.x, 0.64) && ApproxEq(xy.r.y, 0.33) && + ApproxEq(xy.g.x, 0.30) && ApproxEq(xy.g.y, 0.60) && + ApproxEq(xy.b.x, 0.15) && ApproxEq(xy.b.y, 0.06)) { + primaries = Primaries::kSRGB; + return true; + } + + if (ApproxEq(xy.r.x, 0.708) && ApproxEq(xy.r.y, 0.292) && + ApproxEq(xy.g.x, 0.170) && ApproxEq(xy.g.y, 0.797) && + ApproxEq(xy.b.x, 0.131) && ApproxEq(xy.b.y, 0.046)) { + primaries = Primaries::k2100; + return true; + } + if (ApproxEq(xy.r.x, 0.680) && ApproxEq(xy.r.y, 0.320) && + ApproxEq(xy.g.x, 0.265) && ApproxEq(xy.g.y, 0.690) && + ApproxEq(xy.b.x, 0.150) && ApproxEq(xy.b.y, 0.060)) { + primaries = Primaries::kP3; + return true; + } + + primaries = Primaries::kCustom; + JXL_RETURN_IF_ERROR(red_.Set(xy.r)); + JXL_RETURN_IF_ERROR(green_.Set(xy.g)); + JXL_RETURN_IF_ERROR(blue_.Set(xy.b)); + return true; +} + +Status ColorEncoding::CreateICC() { + InternalRemoveICC(); + if (!MaybeCreateProfile(*this, &icc_)) { + return JXL_FAILURE("Failed to create profile from fields"); + } + return true; +} + +std::string Description(const ColorEncoding& c_in) { + // Copy required for Implicit* + ColorEncoding c = c_in; + + std::string d = ToString(c.GetColorSpace()); + + if (!c.ImplicitWhitePoint()) { + d += '_'; + if (c.white_point == WhitePoint::kCustom) { + const CIExy wp = c.GetWhitePoint(); + d += ToString(wp.x) + ';'; + d += ToString(wp.y); + } else { + d += ToString(c.white_point); + } + } + + if (c.HasPrimaries()) { + d += '_'; + if (c.primaries == Primaries::kCustom) { + const PrimariesCIExy pr = c.GetPrimaries(); + d += ToString(pr.r.x) + ';'; + d += ToString(pr.r.y) + ';'; + d += ToString(pr.g.x) + ';'; + d += ToString(pr.g.y) + ';'; + d += ToString(pr.b.x) + ';'; + d += ToString(pr.b.y); + } else { + d += ToString(c.primaries); + } + } + + d += '_'; + d += ToString(c.rendering_intent); + + if (!c.tf.SetImplicit()) { + d += '_'; + if (c.tf.IsGamma()) { + d += 'g'; + d += ToString(c.tf.GetGamma()); + } else { + d += ToString(c.tf.GetTransferFunction()); + } + } + + return d; +} + +Customxy::Customxy() { Bundle::Init(this); } +Status Customxy::VisitFields(Visitor* JXL_RESTRICT visitor) { + uint32_t ux = PackSigned(x); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(19), BitsOffset(19, 524288), + BitsOffset(20, 1048576), + BitsOffset(21, 2097152), 0, &ux)); + x = UnpackSigned(ux); + uint32_t uy = PackSigned(y); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(19), BitsOffset(19, 524288), + BitsOffset(20, 1048576), + BitsOffset(21, 2097152), 0, &uy)); + y = UnpackSigned(uy); + return true; +} + +CustomTransferFunction::CustomTransferFunction() { Bundle::Init(this); } +Status CustomTransferFunction::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->Conditional(!SetImplicit())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_gamma_)); + + if (visitor->Conditional(have_gamma_)) { + // Gamma is represented as a 24-bit int, the exponent used is + // gamma_ / 1e7. Valid values are (0, 1]. On the low end side, we also + // limit it to kMaxGamma/1e7. + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(24, kGammaMul, &gamma_)); + if (gamma_ > kGammaMul || + static_cast(gamma_) * kMaxGamma < kGammaMul) { + return JXL_FAILURE("Invalid gamma %u", gamma_); + } + } + + if (visitor->Conditional(!have_gamma_)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(TransferFunction::kSRGB, &transfer_function_)); + } + } + + return true; +} + +ColorEncoding::ColorEncoding() { Bundle::Init(this); } +Status ColorEncoding::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &want_icc_)); + + // Always send even if want_icc_ because this affects decoding. + // We can skip the white point/primaries because they do not. + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(ColorSpace::kRGB, &color_space_)); + + if (visitor->Conditional(!WantICC())) { + // Serialize enums. NOTE: we set the defaults to the most common values so + // ImageMetadata.all_default is true in the common case. + + if (visitor->Conditional(!ImplicitWhitePoint())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(WhitePoint::kD65, &white_point)); + if (visitor->Conditional(white_point == WhitePoint::kCustom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&white_)); + } + } + + if (visitor->Conditional(HasPrimaries())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(Primaries::kSRGB, &primaries)); + if (visitor->Conditional(primaries == Primaries::kCustom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&red_)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&green_)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&blue_)); + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&tf)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(RenderingIntent::kRelative, &rendering_intent)); + + // We didn't have ICC, so all fields should be known. + if (color_space_ == ColorSpace::kUnknown || tf.IsUnknown()) { + return JXL_FAILURE( + "No ICC but cs %u and tf %u%s", + static_cast(color_space_), + tf.IsGamma() ? 0 + : static_cast(tf.GetTransferFunction()), + tf.IsGamma() ? "(gamma)" : ""); + } + + JXL_RETURN_IF_ERROR(CreateICC()); + } + + if (WantICC() && visitor->IsReading()) { + // Haven't called SetICC() yet, do nothing. + } else { + if (ICC().empty()) return JXL_FAILURE("Empty ICC"); + } + + return true; +} + +void ConvertInternalToExternalColorEncoding(const ColorEncoding& internal, + JxlColorEncoding* external) { + external->color_space = static_cast(internal.GetColorSpace()); + + external->white_point = static_cast(internal.white_point); + + jxl::CIExy whitepoint = internal.GetWhitePoint(); + external->white_point_xy[0] = whitepoint.x; + external->white_point_xy[1] = whitepoint.y; + + if (external->color_space == JXL_COLOR_SPACE_RGB || + external->color_space == JXL_COLOR_SPACE_UNKNOWN) { + external->primaries = static_cast(internal.primaries); + jxl::PrimariesCIExy primaries = internal.GetPrimaries(); + external->primaries_red_xy[0] = primaries.r.x; + external->primaries_red_xy[1] = primaries.r.y; + external->primaries_green_xy[0] = primaries.g.x; + external->primaries_green_xy[1] = primaries.g.y; + external->primaries_blue_xy[0] = primaries.b.x; + external->primaries_blue_xy[1] = primaries.b.y; + } + + if (internal.tf.IsGamma()) { + external->transfer_function = JXL_TRANSFER_FUNCTION_GAMMA; + external->gamma = internal.tf.GetGamma(); + } else { + external->transfer_function = + static_cast(internal.tf.GetTransferFunction()); + external->gamma = 0; + } + + external->rendering_intent = + static_cast(internal.rendering_intent); +} + +Status ConvertExternalToInternalColorEncoding(const JxlColorEncoding& external, + ColorEncoding* internal) { + internal->SetColorSpace(static_cast(external.color_space)); + + JXL_RETURN_IF_ERROR(ConvertExternalToInternalWhitePoint( + external.white_point, &internal->white_point)); + if (external.white_point == JXL_WHITE_POINT_CUSTOM) { + CIExy wp; + wp.x = external.white_point_xy[0]; + wp.y = external.white_point_xy[1]; + JXL_RETURN_IF_ERROR(internal->SetWhitePoint(wp)); + } + + if (external.color_space == JXL_COLOR_SPACE_RGB || + external.color_space == JXL_COLOR_SPACE_UNKNOWN) { + JXL_RETURN_IF_ERROR(ConvertExternalToInternalPrimaries( + external.primaries, &internal->primaries)); + if (external.primaries == JXL_PRIMARIES_CUSTOM) { + PrimariesCIExy primaries; + primaries.r.x = external.primaries_red_xy[0]; + primaries.r.y = external.primaries_red_xy[1]; + primaries.g.x = external.primaries_green_xy[0]; + primaries.g.y = external.primaries_green_xy[1]; + primaries.b.x = external.primaries_blue_xy[0]; + primaries.b.y = external.primaries_blue_xy[1]; + JXL_RETURN_IF_ERROR(internal->SetPrimaries(primaries)); + } + } + CustomTransferFunction tf; + if (external.transfer_function == JXL_TRANSFER_FUNCTION_GAMMA) { + JXL_RETURN_IF_ERROR(tf.SetGamma(external.gamma)); + } else { + TransferFunction tf_enum; + // JXL_TRANSFER_FUNCTION_GAMMA is not handled by this function since there's + // no internal enum value for it. + JXL_RETURN_IF_ERROR(ConvertExternalToInternalTransferFunction( + external.transfer_function, &tf_enum)); + tf.SetTransferFunction(tf_enum); + } + internal->tf = tf; + + JXL_RETURN_IF_ERROR(ConvertExternalToInternalRenderingIntent( + external.rendering_intent, &internal->rendering_intent)); + + // The ColorEncoding caches an ICC profile it created earlier that may no + // longer match the profile with the changed fields, so re-create it. + if (!(internal->CreateICC())) { + // This is not an error: for example, it doesn't have ICC profile creation + // implemented for XYB. This should not be returned as error, since + // ConvertExternalToInternalColorEncoding still worked correctly, and what + // matters is that internal->ICC() will not return the wrong profile. + } + + return true; +} + +/* Chromatic adaptation matrices*/ +static const float kBradford[9] = { + 0.8951f, 0.2664f, -0.1614f, -0.7502f, 1.7135f, + 0.0367f, 0.0389f, -0.0685f, 1.0296f, +}; + +static const float kBradfordInv[9] = { + 0.9869929f, -0.1470543f, 0.1599627f, 0.4323053f, 0.5183603f, + 0.0492912f, -0.0085287f, 0.0400428f, 0.9684867f, +}; + +// Adapts whitepoint x, y to D50 +Status AdaptToXYZD50(float wx, float wy, float matrix[9]) { + if (wx < 0 || wx > 1 || wy <= 0 || wy > 1) { + // Out of range values can cause division through zero + // further down with the bradford adaptation too. + return JXL_FAILURE("Invalid white point"); + } + float w[3] = {wx / wy, 1.0f, (1.0f - wx - wy) / wy}; + // 1 / tiny float can still overflow + JXL_RETURN_IF_ERROR(std::isfinite(w[0]) && std::isfinite(w[2])); + float w50[3] = {0.96422f, 1.0f, 0.82521f}; + + float lms[3]; + float lms50[3]; + + MatMul(kBradford, w, 3, 3, 1, lms); + MatMul(kBradford, w50, 3, 3, 1, lms50); + + if (lms[0] == 0 || lms[1] == 0 || lms[2] == 0) { + return JXL_FAILURE("Invalid white point"); + } + float a[9] = { + // /----> 0, 1, 2, 3, /----> 4, 5, 6, 7, /----> 8, + lms50[0] / lms[0], 0, 0, 0, lms50[1] / lms[1], 0, 0, 0, lms50[2] / lms[2], + }; + if (!std::isfinite(a[0]) || !std::isfinite(a[4]) || !std::isfinite(a[8])) { + return JXL_FAILURE("Invalid white point"); + } + + float b[9]; + MatMul(a, kBradford, 3, 3, 3, b); + MatMul(kBradfordInv, b, 3, 3, 3, matrix); + + return true; +} + +Status PrimariesToXYZ(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]) { + if (wx < 0 || wx > 1 || wy <= 0 || wy > 1) { + return JXL_FAILURE("Invalid white point"); + } + // TODO(lode): also require rx, ry, gx, gy, bx, to be in range 0-1? ICC + // profiles in theory forbid negative XYZ values, but in practice the ACES P0 + // color space uses a negative y for the blue primary. + float primaries[9] = { + rx, gx, bx, ry, gy, by, 1.0f - rx - ry, 1.0f - gx - gy, 1.0f - bx - by}; + float primaries_inv[9]; + memcpy(primaries_inv, primaries, sizeof(float) * 9); + JXL_RETURN_IF_ERROR(Inv3x3Matrix(primaries_inv)); + + float w[3] = {wx / wy, 1.0f, (1.0f - wx - wy) / wy}; + // 1 / tiny float can still overflow + JXL_RETURN_IF_ERROR(std::isfinite(w[0]) && std::isfinite(w[2])); + float xyz[3]; + MatMul(primaries_inv, w, 3, 3, 1, xyz); + + float a[9] = { + xyz[0], 0, 0, 0, xyz[1], 0, 0, 0, xyz[2], + }; + + MatMul(primaries, a, 3, 3, 3, matrix); + return true; +} + +Status PrimariesToXYZD50(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]) { + float toXYZ[9]; + JXL_RETURN_IF_ERROR(PrimariesToXYZ(rx, ry, gx, gy, bx, by, wx, wy, toXYZ)); + float d50[9]; + JXL_RETURN_IF_ERROR(AdaptToXYZD50(wx, wy, d50)); + + MatMul(d50, toXYZ, 3, 3, 3, matrix); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/color_encoding_internal.h b/media/libjxl/src/lib/jxl/color_encoding_internal.h new file mode 100644 index 000000000..d9e0448fe --- /dev/null +++ b/media/libjxl/src/lib/jxl/color_encoding_internal.h @@ -0,0 +1,463 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COLOR_ENCODING_INTERNAL_H_ +#define LIB_JXL_COLOR_ENCODING_INTERNAL_H_ + +// Metadata for color space conversions. + +#include +#include +#include + +#include // std::abs +#include +#include +#include + +#include "jxl/color_encoding.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// (All CIE units are for the standard 1931 2 degree observer) + +// Color space the color pixel data is encoded in. The color pixel data is +// 3-channel in all cases except in case of kGray, where it uses only 1 channel. +// This also determines the amount of channels used in modular encoding. +enum class ColorSpace : uint32_t { + // Trichromatic color data. This also includes CMYK if a kBlack + // ExtraChannelInfo is present. This implies, if there is an ICC profile, that + // the ICC profile uses a 3-channel color space if no kBlack extra channel is + // present, or uses color space 'CMYK' if a kBlack extra channel is present. + kRGB, + // Single-channel data. This implies, if there is an ICC profile, that the ICC + // profile also represents single-channel data and has the appropriate color + // space ('GRAY'). + kGray, + // Like kRGB, but implies fixed values for primaries etc. + kXYB, + // For non-RGB/gray data, e.g. from non-electro-optical sensors. Otherwise + // the same conditions as kRGB apply. + kUnknown +}; + +static inline const char* EnumName(ColorSpace /*unused*/) { + return "ColorSpace"; +} +static inline constexpr uint64_t EnumBits(ColorSpace /*unused*/) { + using CS = ColorSpace; + return MakeBit(CS::kRGB) | MakeBit(CS::kGray) | MakeBit(CS::kXYB) | + MakeBit(CS::kUnknown); +} + +// Values from CICP ColourPrimaries. +enum class WhitePoint : uint32_t { + kD65 = 1, // sRGB/BT.709/Display P3/BT.2020 + kCustom = 2, // Actual values encoded in separate fields + kE = 10, // XYZ + kDCI = 11, // DCI-P3 +}; + +static inline const char* EnumName(WhitePoint /*unused*/) { + return "WhitePoint"; +} +static inline constexpr uint64_t EnumBits(WhitePoint /*unused*/) { + return MakeBit(WhitePoint::kD65) | MakeBit(WhitePoint::kCustom) | + MakeBit(WhitePoint::kE) | MakeBit(WhitePoint::kDCI); +} + +// Values from CICP ColourPrimaries +enum class Primaries : uint32_t { + kSRGB = 1, // Same as BT.709 + kCustom = 2, // Actual values encoded in separate fields + k2100 = 9, // Same as BT.2020 + kP3 = 11, +}; + +static inline const char* EnumName(Primaries /*unused*/) { return "Primaries"; } +static inline constexpr uint64_t EnumBits(Primaries /*unused*/) { + using Pr = Primaries; + return MakeBit(Pr::kSRGB) | MakeBit(Pr::kCustom) | MakeBit(Pr::k2100) | + MakeBit(Pr::kP3); +} + +// Values from CICP TransferCharacteristics +enum class TransferFunction : uint32_t { + k709 = 1, + kUnknown = 2, + kLinear = 8, + kSRGB = 13, + kPQ = 16, // from BT.2100 + kDCI = 17, // from SMPTE RP 431-2 reference projector + kHLG = 18, // from BT.2100 +}; + +static inline const char* EnumName(TransferFunction /*unused*/) { + return "TransferFunction"; +} +static inline constexpr uint64_t EnumBits(TransferFunction /*unused*/) { + using TF = TransferFunction; + return MakeBit(TF::k709) | MakeBit(TF::kLinear) | MakeBit(TF::kSRGB) | + MakeBit(TF::kPQ) | MakeBit(TF::kDCI) | MakeBit(TF::kHLG) | + MakeBit(TF::kUnknown); +} + +enum class RenderingIntent : uint32_t { + // Values match ICC sRGB encodings. + kPerceptual = 0, // good for photos, requires a profile with LUT. + kRelative, // good for logos. + kSaturation, // perhaps useful for CG with fully saturated colors. + kAbsolute, // leaves white point unchanged; good for proofing. +}; + +static inline const char* EnumName(RenderingIntent /*unused*/) { + return "RenderingIntent"; +} +static inline constexpr uint64_t EnumBits(RenderingIntent /*unused*/) { + using RI = RenderingIntent; + return MakeBit(RI::kPerceptual) | MakeBit(RI::kRelative) | + MakeBit(RI::kSaturation) | MakeBit(RI::kAbsolute); +} + +// Chromaticity (Y is omitted because it is 1 for primaries/white points) +struct CIExy { + double x = 0.0; + double y = 0.0; +}; + +struct PrimariesCIExy { + CIExy r; + CIExy g; + CIExy b; +}; + +// Serializable form of CIExy. +struct Customxy : public Fields { + Customxy(); + JXL_FIELDS_NAME(Customxy) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + CIExy Get() const; + // Returns false if x or y do not fit in the encoding. + Status Set(const CIExy& xy); + + int32_t x; + int32_t y; +}; + +struct CustomTransferFunction : public Fields { + CustomTransferFunction(); + JXL_FIELDS_NAME(CustomTransferFunction) + + // Sets fields and returns true if nonserialized_color_space has an implicit + // transfer function, otherwise leaves fields unchanged and returns false. + bool SetImplicit(); + + // Gamma: only used for PNG inputs + bool IsGamma() const { return have_gamma_; } + double GetGamma() const { + JXL_ASSERT(IsGamma()); + return gamma_ * 1E-7; // (0, 1) + } + Status SetGamma(double gamma); + + TransferFunction GetTransferFunction() const { + JXL_ASSERT(!IsGamma()); + return transfer_function_; + } + void SetTransferFunction(const TransferFunction tf) { + have_gamma_ = false; + transfer_function_ = tf; + } + + bool IsUnknown() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kUnknown); + } + bool IsSRGB() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kSRGB); + } + bool IsLinear() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kLinear); + } + bool IsPQ() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kPQ); + } + bool IsHLG() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kHLG); + } + bool Is709() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::k709); + } + bool IsDCI() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kDCI); + } + bool IsSame(const CustomTransferFunction& other) const { + if (have_gamma_ != other.have_gamma_) return false; + if (have_gamma_) { + if (gamma_ != other.gamma_) return false; + } else { + if (transfer_function_ != other.transfer_function_) return false; + } + return true; + } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Must be set before calling VisitFields! + ColorSpace nonserialized_color_space = ColorSpace::kRGB; + + private: + static constexpr uint32_t kGammaMul = 10000000; + + bool have_gamma_; + + // OETF exponent to go from linear to gamma-compressed. + uint32_t gamma_; // Only used if have_gamma_. + + // Can be kUnknown. + TransferFunction transfer_function_; // Only used if !have_gamma_. +}; + +// Compact encoding of data required to interpret and translate pixels to a +// known color space. Stored in Metadata. Thread-compatible. +struct ColorEncoding : public Fields { + ColorEncoding(); + JXL_FIELDS_NAME(ColorEncoding) + + // Returns ready-to-use color encodings (initialized on-demand). + static const ColorEncoding& SRGB(bool is_gray = false); + static const ColorEncoding& LinearSRGB(bool is_gray = false); + + // Returns true if an ICC profile was successfully created from fields. + // Must be called after modifying fields. Defined in color_management.cc. + Status CreateICC(); + + // Returns non-empty and valid ICC profile, unless: + // - between calling InternalRemoveICC() and CreateICC() in tests; + // - WantICC() == true and SetICC() was not yet called; + // - after a failed call to SetSRGB(), SetICC(), or CreateICC(). + const PaddedBytes& ICC() const { return icc_; } + + // Internal only, do not call except from tests. + void InternalRemoveICC() { icc_.clear(); } + + // Returns true if `icc` is assigned and decoded successfully. If so, + // subsequent WantICC() will return true until DecideIfWantICC() changes it. + // Returning false indicates data has been lost. + Status SetICC(PaddedBytes&& icc) { + if (icc.empty()) return false; + icc_ = std::move(icc); + + if (!SetFieldsFromICC()) { + InternalRemoveICC(); + return false; + } + + want_icc_ = true; + return true; + } + + // Sets the raw ICC profile bytes, without parsing the ICC, and without + // updating the direct fields such as whitepoint, primaries and color + // space. Functions to get and set fields, such as SetWhitePoint, cannot be + // used anymore after this and functions such as IsSRGB return false no matter + // what the contents of the icc profile. + Status SetICCRaw(PaddedBytes&& icc) { + if (icc.empty()) return false; + icc_ = std::move(icc); + + want_icc_ = true; + have_fields_ = false; + return true; + } + + // Returns whether to send the ICC profile in the codestream. + bool WantICC() const { return want_icc_; } + + // Return whether the direct fields are set, if false but ICC is set, only + // raw ICC bytes are known. + bool HaveFields() const { return have_fields_; } + + // Causes WantICC() to return false if ICC() can be reconstructed from fields. + // Defined in color_management.cc. + void DecideIfWantICC(); + + bool IsGray() const { return color_space_ == ColorSpace::kGray; } + bool IsCMYK() const { return cmyk_; } + size_t Channels() const { return IsGray() ? 1 : 3; } + + // Returns false if the field is invalid and unusable. + bool HasPrimaries() const { + return !IsGray() && color_space_ != ColorSpace::kXYB; + } + + // Returns true after setting the field to a value defined by color_space, + // otherwise false and leaves the field unchanged. + bool ImplicitWhitePoint() { + if (color_space_ == ColorSpace::kXYB) { + white_point = WhitePoint::kD65; + return true; + } + return false; + } + + // Returns whether the color space is known to be sRGB. If a raw unparsed ICC + // profile is set without the fields being set, this returns false, even if + // the content of the ICC profile would match sRGB. + bool IsSRGB() const { + if (!have_fields_) return false; + if (!IsGray() && color_space_ != ColorSpace::kRGB) return false; + if (white_point != WhitePoint::kD65) return false; + if (primaries != Primaries::kSRGB) return false; + if (!tf.IsSRGB()) return false; + return true; + } + + // Returns whether the color space is known to be linear sRGB. If a raw + // unparsed ICC profile is set without the fields being set, this returns + // false, even if the content of the ICC profile would match linear sRGB. + bool IsLinearSRGB() const { + if (!have_fields_) return false; + if (!IsGray() && color_space_ != ColorSpace::kRGB) return false; + if (white_point != WhitePoint::kD65) return false; + if (primaries != Primaries::kSRGB) return false; + if (!tf.IsLinear()) return false; + return true; + } + + Status SetSRGB(const ColorSpace cs, + const RenderingIntent ri = RenderingIntent::kRelative) { + InternalRemoveICC(); + JXL_ASSERT(cs == ColorSpace::kGray || cs == ColorSpace::kRGB); + color_space_ = cs; + white_point = WhitePoint::kD65; + primaries = Primaries::kSRGB; + tf.SetTransferFunction(TransferFunction::kSRGB); + rendering_intent = ri; + return CreateICC(); + } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Accessors ensure tf.nonserialized_color_space is updated at the same time. + ColorSpace GetColorSpace() const { return color_space_; } + void SetColorSpace(const ColorSpace cs) { + color_space_ = cs; + tf.nonserialized_color_space = cs; + } + + CIExy GetWhitePoint() const; + Status SetWhitePoint(const CIExy& xy); + + PrimariesCIExy GetPrimaries() const; + Status SetPrimaries(const PrimariesCIExy& xy); + + // Checks if the color spaces (including white point / primaries) are the + // same, but ignores the transfer function, rendering intent and ICC bytes. + bool SameColorSpace(const ColorEncoding& other) const { + if (color_space_ != other.color_space_) return false; + + if (white_point != other.white_point) return false; + if (white_point == WhitePoint::kCustom) { + if (white_.x != other.white_.x || white_.y != other.white_.y) + return false; + } + + if (HasPrimaries() != other.HasPrimaries()) return false; + if (HasPrimaries()) { + if (primaries != other.primaries) return false; + if (primaries == Primaries::kCustom) { + if (red_.x != other.red_.x || red_.y != other.red_.y) return false; + if (green_.x != other.green_.x || green_.y != other.green_.y) + return false; + if (blue_.x != other.blue_.x || blue_.y != other.blue_.y) return false; + } + } + return true; + } + + // Checks if the color space and transfer function are the same, ignoring + // rendering intent and ICC bytes + bool SameColorEncoding(const ColorEncoding& other) const { + return SameColorSpace(other) && tf.IsSame(other.tf); + } + + mutable bool all_default; + + // Only valid if HaveFields() + WhitePoint white_point; + Primaries primaries; // Only valid if HasPrimaries() + CustomTransferFunction tf; + RenderingIntent rendering_intent; + + private: + // Returns true if all fields have been initialized (possibly to kUnknown). + // Returns false if the ICC profile is invalid or decoding it fails. + // Defined in enc_color_management.cc. + Status SetFieldsFromICC(); + + // If true, the codestream contains an ICC profile and we do not serialize + // fields. Otherwise, fields are serialized and we create an ICC profile. + bool want_icc_; + + // When false, fields such as white_point and tf are invalid and must not be + // used. This occurs after setting a raw bytes-only ICC profile, only the + // ICC bytes may be used. The color_space_ field is still valid. + bool have_fields_ = true; + + PaddedBytes icc_; // Valid ICC profile + + ColorSpace color_space_; // Can be kUnknown + bool cmyk_ = false; + + // Only used if white_point == kCustom. + Customxy white_; + + // Only used if primaries == kCustom. + Customxy red_; + Customxy green_; + Customxy blue_; +}; + +// Returns whether the two inputs are approximately equal. +static inline bool ApproxEq(const double a, const double b, +#if JPEGXL_ENABLE_SKCMS + double max_l1 = 1E-3) { +#else + double max_l1 = 8E-5) { +#endif + // Threshold should be sufficient for ICC's 15-bit fixed-point numbers. + // We have seen differences of 7.1E-5 with lcms2 and 1E-3 with skcms. + return std::abs(a - b) <= max_l1; +} + +// Returns a representation of the ColorEncoding fields (not icc). +// Example description: "RGB_D65_SRG_Rel_Lin" +std::string Description(const ColorEncoding& c); +static inline std::ostream& operator<<(std::ostream& os, + const ColorEncoding& c) { + return os << Description(c); +} + +void ConvertInternalToExternalColorEncoding(const jxl::ColorEncoding& internal, + JxlColorEncoding* external); + +Status ConvertExternalToInternalColorEncoding(const JxlColorEncoding& external, + jxl::ColorEncoding* internal); + +Status PrimariesToXYZ(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]); +Status PrimariesToXYZD50(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]); +Status AdaptToXYZD50(float wx, float wy, float matrix[9]); + +} // namespace jxl + +#endif // LIB_JXL_COLOR_ENCODING_INTERNAL_H_ diff --git a/media/libjxl/src/lib/jxl/color_encoding_internal_test.cc b/media/libjxl/src/lib/jxl/color_encoding_internal_test.cc new file mode 100644 index 000000000..32bd0cc16 --- /dev/null +++ b/media/libjxl/src/lib/jxl/color_encoding_internal_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/color_encoding_internal.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +TEST(ColorEncodingTest, RoundTripAll) { + for (const test::ColorEncodingDescriptor& cdesc : test::AllEncodings()) { + const ColorEncoding c_original = test::ColorEncodingFromDescriptor(cdesc); + // Verify Set(Get) yields the same white point/primaries/gamma. + { + ColorEncoding c; + EXPECT_TRUE(c.SetWhitePoint(c_original.GetWhitePoint())); + EXPECT_EQ(c_original.white_point, c.white_point); + } + { + ColorEncoding c; + EXPECT_TRUE(c.SetPrimaries(c_original.GetPrimaries())); + EXPECT_EQ(c_original.primaries, c.primaries); + } + if (c_original.tf.IsGamma()) { + ColorEncoding c; + EXPECT_TRUE(c.tf.SetGamma(c_original.tf.GetGamma())); + EXPECT_TRUE(c_original.tf.IsSame(c.tf)); + } + } +} + +TEST(ColorEncodingTest, CustomWhitePoint) { + ColorEncoding c; + // Nonsensical values + CIExy xy_in; + xy_in.x = 0.8; + xy_in.y = 0.01; + EXPECT_TRUE(c.SetWhitePoint(xy_in)); + const CIExy xy = c.GetWhitePoint(); + + ColorEncoding c2; + EXPECT_TRUE(c2.SetWhitePoint(xy)); + EXPECT_TRUE(c.SameColorSpace(c2)); +} + +TEST(ColorEncodingTest, CustomPrimaries) { + ColorEncoding c; + PrimariesCIExy xy_in; + // Nonsensical values + xy_in.r.x = -0.01; + xy_in.r.y = 0.2; + xy_in.g.x = 0.4; + xy_in.g.y = 0.401; + xy_in.b.x = 1.1; + xy_in.b.y = -1.2; + EXPECT_TRUE(c.SetPrimaries(xy_in)); + const PrimariesCIExy xy = c.GetPrimaries(); + + ColorEncoding c2; + EXPECT_TRUE(c2.SetPrimaries(xy)); + EXPECT_TRUE(c.SameColorSpace(c2)); +} + +TEST(ColorEncodingTest, CustomGamma) { + ColorEncoding c; +#ifndef JXL_CRASH_ON_ERROR + EXPECT_FALSE(c.tf.SetGamma(0.0)); + EXPECT_FALSE(c.tf.SetGamma(-1E-6)); + EXPECT_FALSE(c.tf.SetGamma(1.001)); +#endif + EXPECT_TRUE(c.tf.SetGamma(1.0)); + EXPECT_FALSE(c.tf.IsGamma()); + EXPECT_TRUE(c.tf.IsLinear()); + + EXPECT_TRUE(c.tf.SetGamma(0.123)); + EXPECT_TRUE(c.tf.IsGamma()); + const double gamma = c.tf.GetGamma(); + + ColorEncoding c2; + EXPECT_TRUE(c2.tf.SetGamma(gamma)); + EXPECT_TRUE(c.SameColorEncoding(c2)); + EXPECT_TRUE(c2.tf.IsGamma()); +} + +TEST(ColorEncodingTest, InternalExternalConversion) { + ColorEncoding source_internal; + JxlColorEncoding external; + ColorEncoding destination_internal; + + for (int i = 0; i < 100; i++) { + source_internal.SetColorSpace(static_cast(rand() % 4)); + CIExy wp; + wp.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + wp.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + EXPECT_TRUE(source_internal.SetWhitePoint(wp)); + if (source_internal.HasPrimaries()) { + PrimariesCIExy primaries; + primaries.r.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.r.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.g.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.g.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.b.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.b.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + EXPECT_TRUE(source_internal.SetPrimaries(primaries)); + } + CustomTransferFunction tf; + EXPECT_TRUE(tf.SetGamma((float(rand()) / float((RAND_MAX)) * 0.5) + 0.25)); + source_internal.tf = tf; + source_internal.rendering_intent = static_cast(rand() % 4); + + ConvertInternalToExternalColorEncoding(source_internal, &external); + EXPECT_TRUE(ConvertExternalToInternalColorEncoding(external, + &destination_internal)); + + EXPECT_EQ(source_internal.GetColorSpace(), + destination_internal.GetColorSpace()); + EXPECT_EQ(source_internal.white_point, destination_internal.white_point); + EXPECT_EQ(source_internal.GetWhitePoint().x, + destination_internal.GetWhitePoint().x); + EXPECT_EQ(source_internal.GetWhitePoint().y, + destination_internal.GetWhitePoint().y); + if (source_internal.HasPrimaries()) { + EXPECT_EQ(source_internal.GetPrimaries().r.x, + destination_internal.GetPrimaries().r.x); + EXPECT_EQ(source_internal.GetPrimaries().r.y, + destination_internal.GetPrimaries().r.y); + EXPECT_EQ(source_internal.GetPrimaries().g.x, + destination_internal.GetPrimaries().g.x); + EXPECT_EQ(source_internal.GetPrimaries().g.y, + destination_internal.GetPrimaries().g.y); + EXPECT_EQ(source_internal.GetPrimaries().b.x, + destination_internal.GetPrimaries().b.x); + EXPECT_EQ(source_internal.GetPrimaries().b.y, + destination_internal.GetPrimaries().b.y); + } + EXPECT_EQ(source_internal.tf.IsGamma(), destination_internal.tf.IsGamma()); + if (source_internal.tf.IsGamma()) { + EXPECT_EQ(source_internal.tf.GetGamma(), + destination_internal.tf.GetGamma()); + } else { + EXPECT_EQ(source_internal.tf.GetTransferFunction(), + destination_internal.tf.GetTransferFunction()); + } + EXPECT_EQ(source_internal.rendering_intent, + destination_internal.rendering_intent); + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/color_management.cc b/media/libjxl/src/lib/jxl/color_management.cc new file mode 100644 index 000000000..521a75adf --- /dev/null +++ b/media/libjxl/src/lib/jxl/color_management.cc @@ -0,0 +1,516 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/color_management.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/color_management.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/linalg.h" // MatMul, Inv3x3Matrix +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// NOTE: this is only used to provide a reasonable ICC profile that other +// software can read. Our own transforms use ExtraTF instead because that is +// more precise and supports unbounded mode. +std::vector CreateTableCurve(uint32_t N, const ExtraTF tf) { + JXL_ASSERT(N <= 4096); // ICC MFT2 only allows 4K entries + JXL_ASSERT(tf == ExtraTF::kPQ || tf == ExtraTF::kHLG); + // No point using float - LCMS converts to 16-bit for A2B/MFT. + std::vector table(N); + for (uint32_t i = 0; i < N; ++i) { + const float x = static_cast(i) / (N - 1); // 1.0 at index N - 1. + const double dx = static_cast(x); + // LCMS requires EOTF (e.g. 2.4 exponent). + double y = (tf == ExtraTF::kHLG) ? TF_HLG().DisplayFromEncoded(dx) + : TF_PQ().DisplayFromEncoded(dx); + JXL_ASSERT(y >= 0.0); + // Clamp to table range - necessary for HLG. + if (y > 1.0) y = 1.0; + // 1.0 corresponds to table value 0xFFFF. + table[i] = static_cast(roundf(y * 65535.0)); + } + return table; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(CreateTableCurve); // Local function. + +Status CIEXYZFromWhiteCIExy(const CIExy& xy, float XYZ[3]) { + // Target Y = 1. + if (std::abs(xy.y) < 1e-12) return JXL_FAILURE("Y value is too small"); + const float factor = 1 / xy.y; + XYZ[0] = xy.x * factor; + XYZ[1] = 1; + XYZ[2] = (1 - xy.x - xy.y) * factor; + return true; +} + +namespace { + +// NOTE: this is only used to provide a reasonable ICC profile that other +// software can read. Our own transforms use ExtraTF instead because that is +// more precise and supports unbounded mode. +template +std::vector CreateTableCurve(uint32_t N, const Func& func) { + JXL_ASSERT(N <= 4096); // ICC MFT2 only allows 4K entries + // No point using float - LCMS converts to 16-bit for A2B/MFT. + std::vector table(N); + for (uint32_t i = 0; i < N; ++i) { + const float x = static_cast(i) / (N - 1); // 1.0 at index N - 1. + // LCMS requires EOTF (e.g. 2.4 exponent). + double y = func.DisplayFromEncoded(static_cast(x)); + JXL_ASSERT(y >= 0.0); + // Clamp to table range - necessary for HLG. + if (y > 1.0) y = 1.0; + // 1.0 corresponds to table value 0xFFFF. + table[i] = static_cast(roundf(y * 65535.0)); + } + return table; +} + +void ICCComputeMD5(const PaddedBytes& data, uint8_t sum[16]) + JXL_NO_SANITIZE("unsigned-integer-overflow") { + PaddedBytes data64 = data; + data64.push_back(128); + // Add bytes such that ((size + 8) & 63) == 0. + size_t extra = ((64 - ((data64.size() + 8) & 63)) & 63); + data64.resize(data64.size() + extra, 0); + for (uint64_t i = 0; i < 64; i += 8) { + data64.push_back(static_cast(data.size() << 3u) >> i); + } + + static const uint32_t sineparts[64] = { + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, + 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, + 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, + 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, + 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, + 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, + 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, + 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, + 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, + }; + static const uint32_t shift[64] = { + 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, + 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, + 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, + 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, + }; + + uint32_t a0 = 0x67452301, b0 = 0xefcdab89, c0 = 0x98badcfe, d0 = 0x10325476; + + for (size_t i = 0; i < data64.size(); i += 64) { + uint32_t a = a0, b = b0, c = c0, d = d0, f, g; + for (size_t j = 0; j < 64; j++) { + if (j < 16) { + f = (b & c) | ((~b) & d); + g = j; + } else if (j < 32) { + f = (d & b) | ((~d) & c); + g = (5 * j + 1) & 0xf; + } else if (j < 48) { + f = b ^ c ^ d; + g = (3 * j + 5) & 0xf; + } else { + f = c ^ (b | (~d)); + g = (7 * j) & 0xf; + } + uint32_t dg0 = data64[i + g * 4 + 0], dg1 = data64[i + g * 4 + 1], + dg2 = data64[i + g * 4 + 2], dg3 = data64[i + g * 4 + 3]; + uint32_t u = dg0 | (dg1 << 8u) | (dg2 << 16u) | (dg3 << 24u); + f += a + sineparts[j] + u; + a = d; + d = c; + c = b; + b += (f << shift[j]) | (f >> (32u - shift[j])); + } + a0 += a; + b0 += b; + c0 += c; + d0 += d; + } + sum[0] = a0; + sum[1] = a0 >> 8u; + sum[2] = a0 >> 16u; + sum[3] = a0 >> 24u; + sum[4] = b0; + sum[5] = b0 >> 8u; + sum[6] = b0 >> 16u; + sum[7] = b0 >> 24u; + sum[8] = c0; + sum[9] = c0 >> 8u; + sum[10] = c0 >> 16u; + sum[11] = c0 >> 24u; + sum[12] = d0; + sum[13] = d0 >> 8u; + sum[14] = d0 >> 16u; + sum[15] = d0 >> 24u; +} + +Status CreateICCChadMatrix(CIExy w, float result[9]) { + float m[9]; + if (w.y == 0) { // WhitePoint can not be pitch-black. + return JXL_FAILURE("Invalid WhitePoint"); + } + JXL_RETURN_IF_ERROR(AdaptToXYZD50(w.x, w.y, m)); + memcpy(result, m, sizeof(float) * 9); + return true; +} + +// Creates RGB to XYZ matrix given RGB primaries and whitepoint in xy. +Status CreateICCRGBMatrix(CIExy r, CIExy g, CIExy b, CIExy w, float result[9]) { + float m[9]; + JXL_RETURN_IF_ERROR( + PrimariesToXYZD50(r.x, r.y, g.x, g.y, b.x, b.y, w.x, w.y, m)); + memcpy(result, m, sizeof(float) * 9); + return true; +} + +void WriteICCUint32(uint32_t value, size_t pos, PaddedBytes* JXL_RESTRICT icc) { + if (icc->size() < pos + 4) icc->resize(pos + 4); + (*icc)[pos + 0] = (value >> 24u) & 255; + (*icc)[pos + 1] = (value >> 16u) & 255; + (*icc)[pos + 2] = (value >> 8u) & 255; + (*icc)[pos + 3] = value & 255; +} + +void WriteICCUint16(uint16_t value, size_t pos, PaddedBytes* JXL_RESTRICT icc) { + if (icc->size() < pos + 2) icc->resize(pos + 2); + (*icc)[pos + 0] = (value >> 8u) & 255; + (*icc)[pos + 1] = value & 255; +} + +// Writes a 4-character tag +void WriteICCTag(const char* value, size_t pos, PaddedBytes* JXL_RESTRICT icc) { + if (icc->size() < pos + 4) icc->resize(pos + 4); + memcpy(icc->data() + pos, value, 4); +} + +Status WriteICCS15Fixed16(float value, size_t pos, + PaddedBytes* JXL_RESTRICT icc) { + // "nextafterf" for 32768.0f towards zero are: + // 32767.998046875, 32767.99609375, 32767.994140625 + // Even the first value works well,... + bool ok = (-32767.995f <= value) && (value <= 32767.995f); + if (!ok) return JXL_FAILURE("ICC value is out of range / NaN"); + int32_t i = value * 65536.0f + 0.5f; + // Use two's complement + uint32_t u = static_cast(i); + WriteICCUint32(u, pos, icc); + return true; +} + +Status CreateICCHeader(const ColorEncoding& c, + PaddedBytes* JXL_RESTRICT header) { + // TODO(lode): choose color management engine name, e.g. "skia" if + // integrated in skia. + static const char* kCmm = "jxl "; + + header->resize(128, 0); + + WriteICCUint32(0, 0, header); // size, correct value filled in at end + WriteICCTag(kCmm, 4, header); + WriteICCUint32(0x04300000u, 8, header); + WriteICCTag("mntr", 12, header); + WriteICCTag(c.IsGray() ? "GRAY" : "RGB ", 16, header); + WriteICCTag("XYZ ", 20, header); + + // Three uint32_t's date/time encoding. + // TODO(lode): encode actual date and time, this is a placeholder + uint32_t year = 2019, month = 12, day = 1; + uint32_t hour = 0, minute = 0, second = 0; + WriteICCUint16(year, 24, header); + WriteICCUint16(month, 26, header); + WriteICCUint16(day, 28, header); + WriteICCUint16(hour, 30, header); + WriteICCUint16(minute, 32, header); + WriteICCUint16(second, 34, header); + + WriteICCTag("acsp", 36, header); + WriteICCTag("APPL", 40, header); + WriteICCUint32(0, 44, header); // flags + WriteICCUint32(0, 48, header); // device manufacturer + WriteICCUint32(0, 52, header); // device model + WriteICCUint32(0, 56, header); // device attributes + WriteICCUint32(0, 60, header); // device attributes + WriteICCUint32(static_cast(c.rendering_intent), 64, header); + + // Mandatory D50 white point of profile connection space + WriteICCUint32(0x0000f6d6, 68, header); + WriteICCUint32(0x00010000, 72, header); + WriteICCUint32(0x0000d32d, 76, header); + + WriteICCTag(kCmm, 80, header); + + return true; +} + +void AddToICCTagTable(const char* tag, size_t offset, size_t size, + PaddedBytes* JXL_RESTRICT tagtable, + std::vector* offsets) { + WriteICCTag(tag, tagtable->size(), tagtable); + // writing true offset deferred to later + WriteICCUint32(0, tagtable->size(), tagtable); + offsets->push_back(offset); + WriteICCUint32(size, tagtable->size(), tagtable); +} + +void FinalizeICCTag(PaddedBytes* JXL_RESTRICT tags, size_t* offset, + size_t* size) { + while ((tags->size() & 3) != 0) { + tags->push_back(0); + } + *offset += *size; + *size = tags->size() - *offset; +} + +// The input text must be ASCII, writing other characters to UTF-16 is not +// implemented. +void CreateICCMlucTag(const std::string& text, PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("mluc", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint32(1, tags->size(), tags); + WriteICCUint32(12, tags->size(), tags); + WriteICCTag("enUS", tags->size(), tags); + WriteICCUint32(text.size() * 2, tags->size(), tags); + WriteICCUint32(28, tags->size(), tags); + for (size_t i = 0; i < text.size(); i++) { + tags->push_back(0); // prepend 0 for UTF-16 + tags->push_back(text[i]); + } +} + +Status CreateICCXYZTag(float xyz[3], PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("XYZ ", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + for (size_t i = 0; i < 3; ++i) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(xyz[i], tags->size(), tags)); + } + return true; +} + +Status CreateICCChadTag(float chad[9], PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("sf32", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(chad[i], tags->size(), tags)); + } + return true; +} + +void CreateICCCurvCurvTag(const std::vector& curve, + PaddedBytes* JXL_RESTRICT tags) { + size_t pos = tags->size(); + tags->resize(tags->size() + 12 + curve.size() * 2, 0); + WriteICCTag("curv", pos, tags); + WriteICCUint32(0, pos + 4, tags); + WriteICCUint32(curve.size(), pos + 8, tags); + for (size_t i = 0; i < curve.size(); i++) { + WriteICCUint16(curve[i], pos + 12 + i * 2, tags); + } +} + +Status CreateICCCurvParaTag(std::vector params, size_t curve_type, + PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("para", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint16(curve_type, tags->size(), tags); + WriteICCUint16(0, tags->size(), tags); + for (size_t i = 0; i < params.size(); i++) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(params[i], tags->size(), tags)); + } + return true; +} +} // namespace + +Status MaybeCreateProfile(const ColorEncoding& c, + PaddedBytes* JXL_RESTRICT icc) { + PaddedBytes header, tagtable, tags; + + if (c.GetColorSpace() == ColorSpace::kUnknown || c.tf.IsUnknown()) { + return false; // Not an error + } + + switch (c.GetColorSpace()) { + case ColorSpace::kRGB: + case ColorSpace::kGray: + break; // OK + case ColorSpace::kXYB: + return JXL_FAILURE("XYB ICC not yet implemented"); + default: + return JXL_FAILURE("Invalid CS %u", + static_cast(c.GetColorSpace())); + } + + JXL_RETURN_IF_ERROR(CreateICCHeader(c, &header)); + + std::vector offsets; + // tag count, deferred to later + WriteICCUint32(0, tagtable.size(), &tagtable); + + size_t tag_offset = 0, tag_size = 0; + + CreateICCMlucTag(Description(c), &tags); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("desc", tag_offset, tag_size, &tagtable, &offsets); + + const std::string copyright = + "Copyright 2019 Google LLC, CC-BY-SA 3.0 Unported " + "license(https://creativecommons.org/licenses/by-sa/3.0/legalcode)"; + CreateICCMlucTag(copyright, &tags); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("cprt", tag_offset, tag_size, &tagtable, &offsets); + + // TODO(eustas): isn't it the other way round: gray image has d50 WhitePoint? + if (c.IsGray()) { + float wtpt[3]; + JXL_RETURN_IF_ERROR(CIEXYZFromWhiteCIExy(c.GetWhitePoint(), wtpt)); + JXL_RETURN_IF_ERROR(CreateICCXYZTag(wtpt, &tags)); + } else { + float d50[3] = {0.964203, 1.0, 0.824905}; + JXL_RETURN_IF_ERROR(CreateICCXYZTag(d50, &tags)); + } + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("wtpt", tag_offset, tag_size, &tagtable, &offsets); + + if (!c.IsGray()) { + // Chromatic adaptation matrix + float chad[9]; + JXL_RETURN_IF_ERROR(CreateICCChadMatrix(c.GetWhitePoint(), chad)); + + const PrimariesCIExy primaries = c.GetPrimaries(); + float m[9]; + JXL_RETURN_IF_ERROR(CreateICCRGBMatrix(primaries.r, primaries.g, + primaries.b, c.GetWhitePoint(), m)); + float r[3] = {m[0], m[3], m[6]}; + float g[3] = {m[1], m[4], m[7]}; + float b[3] = {m[2], m[5], m[8]}; + + JXL_RETURN_IF_ERROR(CreateICCChadTag(chad, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("chad", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(r, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("rXYZ", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(g, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("gXYZ", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(b, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("bXYZ", tag_offset, tag_size, &tagtable, &offsets); + } + + if (c.tf.IsGamma()) { + float gamma = 1.0 / c.tf.GetGamma(); + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({gamma, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + } else { + switch (c.tf.GetTransferFunction()) { + case TransferFunction::kHLG: + CreateICCCurvCurvTag( + HWY_DYNAMIC_DISPATCH(CreateTableCurve)(4096, ExtraTF::kHLG), &tags); + break; + case TransferFunction::kPQ: + CreateICCCurvCurvTag( + HWY_DYNAMIC_DISPATCH(CreateTableCurve)(4096, ExtraTF::kPQ), &tags); + break; + case TransferFunction::kSRGB: + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag( + {2.4, 1.0 / 1.055, 0.055 / 1.055, 1.0 / 12.92, 0.04045}, 3, &tags)); + break; + case TransferFunction::k709: + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag( + {1.0 / 0.45, 1.0 / 1.099, 0.099 / 1.099, 1.0 / 4.5, 0.081}, 3, + &tags)); + break; + case TransferFunction::kLinear: + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({1.0, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + break; + case TransferFunction::kDCI: + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({2.6, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + break; + default: + JXL_ABORT("Unknown TF %u", + static_cast(c.tf.GetTransferFunction())); + } + } + FinalizeICCTag(&tags, &tag_offset, &tag_size); + if (c.IsGray()) { + AddToICCTagTable("kTRC", tag_offset, tag_size, &tagtable, &offsets); + } else { + AddToICCTagTable("rTRC", tag_offset, tag_size, &tagtable, &offsets); + AddToICCTagTable("gTRC", tag_offset, tag_size, &tagtable, &offsets); + AddToICCTagTable("bTRC", tag_offset, tag_size, &tagtable, &offsets); + } + + // Tag count + WriteICCUint32(offsets.size(), 0, &tagtable); + for (size_t i = 0; i < offsets.size(); i++) { + WriteICCUint32(offsets[i] + header.size() + tagtable.size(), 4 + 12 * i + 4, + &tagtable); + } + + // ICC profile size + WriteICCUint32(header.size() + tagtable.size() + tags.size(), 0, &header); + + *icc = header; + icc->append(tagtable); + icc->append(tags); + + // The MD5 checksum must be computed on the profile with profile flags, + // rendering intent, and region of the checksum itself, set to 0. + // TODO(lode): manually verify with a reliable tool that this creates correct + // signature (profile id) for ICC profiles. + PaddedBytes icc_sum = *icc; + if (icc_sum.size() >= 64 + 4) { + memset(icc_sum.data() + 44, 0, 4); + memset(icc_sum.data() + 64, 0, 4); + } + uint8_t checksum[16]; + ICCComputeMD5(icc_sum, checksum); + + memcpy(icc->data() + 84, checksum, sizeof(checksum)); + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/color_management.h b/media/libjxl/src/lib/jxl/color_management.h new file mode 100644 index 000000000..f728fe589 --- /dev/null +++ b/media/libjxl/src/lib/jxl/color_management.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COLOR_MANAGEMENT_H_ +#define LIB_JXL_COLOR_MANAGEMENT_H_ + +// ICC profiles and color space conversions. + +#include +#include + +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +enum class ExtraTF { + kNone, + kPQ, + kHLG, + kSRGB, +}; + +Status MaybeCreateProfile(const ColorEncoding& c, + PaddedBytes* JXL_RESTRICT icc); + +Status CIEXYZFromWhiteCIExy(const CIExy& xy, float XYZ[3]); + +} // namespace jxl + +#endif // LIB_JXL_COLOR_MANAGEMENT_H_ diff --git a/media/libjxl/src/lib/jxl/color_management_test.cc b/media/libjxl/src/lib/jxl/color_management_test.cc new file mode 100644 index 000000000..99382ca64 --- /dev/null +++ b/media/libjxl/src/lib/jxl/color_management_test.cc @@ -0,0 +1,319 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/color_management.h" + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { + +std::ostream& operator<<(std::ostream& os, const CIExy& xy) { + return os << "{x=" << xy.x << ", y=" << xy.y << "}"; +} + +std::ostream& operator<<(std::ostream& os, const PrimariesCIExy& primaries) { + return os << "{r=" << primaries.r << ", g=" << primaries.g + << ", b=" << primaries.b << "}"; +} + +namespace { + +using ::testing::ElementsAre; +using ::testing::FloatNear; + +// Small enough to be fast. If changed, must update Generate*. +static constexpr size_t kWidth = 16; + +struct Globals { + // TODO(deymo): Make this a const. + static Globals* GetInstance() { + static Globals ret; + return &ret; + } + + private: + static constexpr size_t kNumThreads = 0; // only have a single row. + + Globals() : pool(kNumThreads) { + in_gray = GenerateGray(); + in_color = GenerateColor(); + out_gray = ImageF(kWidth, 1); + out_color = ImageF(kWidth * 3, 1); + + c_native = ColorEncoding::LinearSRGB(/*is_gray=*/false); + c_gray = ColorEncoding::LinearSRGB(/*is_gray=*/true); + } + + static ImageF GenerateGray() { + ImageF gray(kWidth, 1); + float* JXL_RESTRICT row = gray.Row(0); + // Increasing left to right + for (uint32_t x = 0; x < kWidth; ++x) { + row[x] = x * 1.0f / (kWidth - 1); // [0, 1] + } + return gray; + } + + static ImageF GenerateColor() { + ImageF image(kWidth * 3, 1); + float* JXL_RESTRICT interleaved = image.Row(0); + std::fill(interleaved, interleaved + kWidth * 3, 0.0f); + + // [0, 4): neutral + for (int32_t x = 0; x < 4; ++x) { + interleaved[3 * x + 0] = x * 1.0f / 3; // [0, 1] + interleaved[3 * x + 2] = interleaved[3 * x + 1] = interleaved[3 * x + 0]; + } + + // [4, 13): pure RGB with low/medium/high saturation + for (int32_t c = 0; c < 3; ++c) { + interleaved[3 * (4 + c) + c] = 0.08f + c * 0.01f; + interleaved[3 * (7 + c) + c] = 0.75f + c * 0.01f; + interleaved[3 * (10 + c) + c] = 1.0f; + } + + // [13, 16): impure, not quite saturated RGB + interleaved[3 * 13 + 0] = 0.86f; + interleaved[3 * 13 + 2] = interleaved[3 * 13 + 1] = 0.16f; + interleaved[3 * 14 + 1] = 0.87f; + interleaved[3 * 14 + 2] = interleaved[3 * 14 + 0] = 0.16f; + interleaved[3 * 15 + 2] = 0.88f; + interleaved[3 * 15 + 1] = interleaved[3 * 15 + 0] = 0.16f; + + return image; + } + + public: + ThreadPoolInternal pool; + + // ImageF so we can use VerifyRelativeError; all are interleaved RGB. + ImageF in_gray; + ImageF in_color; + ImageF out_gray; + ImageF out_color; + ColorEncoding c_native; + ColorEncoding c_gray; +}; + +class ColorManagementTest + : public ::testing::TestWithParam { + public: + static void VerifySameFields(const ColorEncoding& c, + const ColorEncoding& c2) { + ASSERT_EQ(c.rendering_intent, c2.rendering_intent); + ASSERT_EQ(c.GetColorSpace(), c2.GetColorSpace()); + ASSERT_EQ(c.white_point, c2.white_point); + if (c.HasPrimaries()) { + ASSERT_EQ(c.primaries, c2.primaries); + } + ASSERT_TRUE(c.tf.IsSame(c2.tf)); + } + + // "Same" pixels after converting g->c_native -> c -> g->c_native. + static void VerifyPixelRoundTrip(const ColorEncoding& c) { + Globals* g = Globals::GetInstance(); + const ColorEncoding& c_native = c.IsGray() ? g->c_gray : g->c_native; + const JxlCmsInterface& cms = GetJxlCms(); + ColorSpaceTransform xform_fwd(cms); + ColorSpaceTransform xform_rev(cms); + const float intensity_target = + c.tf.IsHLG() ? 1000 : kDefaultIntensityTarget; + ASSERT_TRUE(xform_fwd.Init(c_native, c, intensity_target, kWidth, + g->pool.NumThreads())); + ASSERT_TRUE(xform_rev.Init(c, c_native, intensity_target, kWidth, + g->pool.NumThreads())); + + const size_t thread = 0; + const ImageF& in = c.IsGray() ? g->in_gray : g->in_color; + ImageF* JXL_RESTRICT out = c.IsGray() ? &g->out_gray : &g->out_color; + ASSERT_TRUE(xform_fwd.Run(thread, in.Row(0), xform_fwd.BufDst(thread))); + ASSERT_TRUE(xform_rev.Run(thread, xform_fwd.BufDst(thread), out->Row(0))); + +#if JPEGXL_ENABLE_SKCMS + double max_l1 = 7E-4; + double max_rel = 4E-7; +#else + double max_l1 = 5E-5; + // Most are lower; reached 3E-7 with D60 AP0. + double max_rel = 4E-7; +#endif + if (c.IsGray()) max_rel = 2E-5; + VerifyRelativeError(in, *out, max_l1, max_rel); + } +}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(ColorManagementTestInstantiation, + ColorManagementTest, + ::testing::ValuesIn(test::AllEncodings())); + +// Exercises the ColorManagement interface for ALL ColorEncoding synthesizable +// via enums. +TEST_P(ColorManagementTest, VerifyAllProfiles) { + ColorEncoding c = ColorEncodingFromDescriptor(GetParam()); + printf("%s\n", Description(c).c_str()); + + // Can create profile. + ASSERT_TRUE(c.CreateICC()); + + // Can set an equivalent ColorEncoding from the generated ICC profile. + ColorEncoding c3; + ASSERT_TRUE(c3.SetICC(PaddedBytes(c.ICC()))); + VerifySameFields(c, c3); + + VerifyPixelRoundTrip(c); +} + +testing::Matcher CIExyIs(const double x, const double y) { + static constexpr double kMaxError = 1e-4; + return testing::AllOf( + testing::Field(&CIExy::x, testing::DoubleNear(x, kMaxError)), + testing::Field(&CIExy::y, testing::DoubleNear(y, kMaxError))); +} + +testing::Matcher PrimariesAre( + const testing::Matcher& r, const testing::Matcher& g, + const testing::Matcher& b) { + return testing::AllOf(testing::Field(&PrimariesCIExy::r, r), + testing::Field(&PrimariesCIExy::g, g), + testing::Field(&PrimariesCIExy::b, b)); +} + +TEST_F(ColorManagementTest, sRGBChromaticity) { + const ColorEncoding sRGB = ColorEncoding::SRGB(); + EXPECT_THAT(sRGB.GetWhitePoint(), CIExyIs(0.3127, 0.3290)); + EXPECT_THAT(sRGB.GetPrimaries(), + PrimariesAre(CIExyIs(0.64, 0.33), CIExyIs(0.30, 0.60), + CIExyIs(0.15, 0.06))); +} + +TEST_F(ColorManagementTest, D2700Chromaticity) { + PaddedBytes icc = ReadTestData("jxl/color_management/sRGB-D2700.icc"); + ColorEncoding sRGB_D2700; + ASSERT_TRUE(sRGB_D2700.SetICC(std::move(icc))); + + EXPECT_THAT(sRGB_D2700.GetWhitePoint(), CIExyIs(0.45986, 0.41060)); + // The illuminant-relative chromaticities of this profile's primaries are the + // same as for sRGB. It is the PCS-relative chromaticities that would be + // different. + EXPECT_THAT(sRGB_D2700.GetPrimaries(), + PrimariesAre(CIExyIs(0.64, 0.33), CIExyIs(0.30, 0.60), + CIExyIs(0.15, 0.06))); +} + +TEST_F(ColorManagementTest, D2700ToSRGB) { + PaddedBytes icc = ReadTestData("jxl/color_management/sRGB-D2700.icc"); + ColorEncoding sRGB_D2700; + ASSERT_TRUE(sRGB_D2700.SetICC(std::move(icc))); + + ColorSpaceTransform transform(GetJxlCms()); + ASSERT_TRUE(transform.Init(sRGB_D2700, ColorEncoding::SRGB(), + kDefaultIntensityTarget, 1, 1)); + const float sRGB_D2700_values[3] = {0.863, 0.737, 0.490}; + float sRGB_values[3]; + ASSERT_TRUE(transform.Run(0, sRGB_D2700_values, sRGB_values)); + EXPECT_THAT(sRGB_values, + ElementsAre(FloatNear(0.914, 1e-3), FloatNear(0.745, 1e-3), + FloatNear(0.601, 1e-3))); +} + +TEST_F(ColorManagementTest, P3HlgTo2020Hlg) { + ColorEncoding p3_hlg; + p3_hlg.SetColorSpace(ColorSpace::kRGB); + p3_hlg.white_point = WhitePoint::kD65; + p3_hlg.primaries = Primaries::kP3; + p3_hlg.tf.SetTransferFunction(TransferFunction::kHLG); + ASSERT_TRUE(p3_hlg.CreateICC()); + + ColorEncoding rec2020_hlg = p3_hlg; + rec2020_hlg.primaries = Primaries::k2100; + ASSERT_TRUE(rec2020_hlg.CreateICC()); + + ColorSpaceTransform transform(GetJxlCms()); + ASSERT_TRUE(transform.Init(p3_hlg, rec2020_hlg, 1000, 1, 1)); + const float p3_hlg_values[3] = {0., 0.75, 0.}; + float rec2020_hlg_values[3]; + ASSERT_TRUE(transform.Run(0, p3_hlg_values, rec2020_hlg_values)); + EXPECT_THAT(rec2020_hlg_values, + ElementsAre(FloatNear(0.3973, 1e-4), FloatNear(0.7382, 1e-4), + FloatNear(0.1183, 1e-4))); +} + +TEST_F(ColorManagementTest, HlgOotf) { + ColorEncoding p3_hlg; + p3_hlg.SetColorSpace(ColorSpace::kRGB); + p3_hlg.white_point = WhitePoint::kD65; + p3_hlg.primaries = Primaries::kP3; + p3_hlg.tf.SetTransferFunction(TransferFunction::kHLG); + ASSERT_TRUE(p3_hlg.CreateICC()); + + ColorSpaceTransform transform_to_1000(GetJxlCms()); + ASSERT_TRUE( + transform_to_1000.Init(p3_hlg, ColorEncoding::LinearSRGB(), 1000, 1, 1)); + // HDR reference white: https://www.itu.int/pub/R-REP-BT.2408-4-2021 + float p3_hlg_values[3] = {0.75, 0.75, 0.75}; + float linear_srgb_values[3]; + ASSERT_TRUE(transform_to_1000.Run(0, p3_hlg_values, linear_srgb_values)); + // On a 1000-nit display, HDR reference white should be 203 cd/m² which is + // 0.203 times the maximum. + EXPECT_THAT(linear_srgb_values, + ElementsAre(FloatNear(0.203, 1e-3), FloatNear(0.203, 1e-3), + FloatNear(0.203, 1e-3))); + + ColorSpaceTransform transform_to_400(GetJxlCms()); + ASSERT_TRUE( + transform_to_400.Init(p3_hlg, ColorEncoding::LinearSRGB(), 400, 1, 1)); + ASSERT_TRUE(transform_to_400.Run(0, p3_hlg_values, linear_srgb_values)); + // On a 400-nit display, it should be 100 cd/m². + EXPECT_THAT(linear_srgb_values, + ElementsAre(FloatNear(0.250, 1e-3), FloatNear(0.250, 1e-3), + FloatNear(0.250, 1e-3))); + + p3_hlg_values[2] = 0.50; + ASSERT_TRUE(transform_to_1000.Run(0, p3_hlg_values, linear_srgb_values)); + EXPECT_THAT(linear_srgb_values, + ElementsAre(FloatNear(0.201, 1e-3), FloatNear(0.201, 1e-3), + FloatNear(0.050, 1e-3))); + + ColorSpaceTransform transform_from_400(GetJxlCms()); + ASSERT_TRUE( + transform_from_400.Init(ColorEncoding::LinearSRGB(), p3_hlg, 400, 1, 1)); + linear_srgb_values[0] = linear_srgb_values[1] = linear_srgb_values[2] = 0.250; + ASSERT_TRUE(transform_from_400.Run(0, linear_srgb_values, p3_hlg_values)); + EXPECT_THAT(p3_hlg_values, + ElementsAre(FloatNear(0.75, 1e-3), FloatNear(0.75, 1e-3), + FloatNear(0.75, 1e-3))); + + ColorEncoding grayscale_hlg; + grayscale_hlg.SetColorSpace(ColorSpace::kGray); + grayscale_hlg.white_point = WhitePoint::kD65; + grayscale_hlg.tf.SetTransferFunction(TransferFunction::kHLG); + ASSERT_TRUE(grayscale_hlg.CreateICC()); + + ColorSpaceTransform grayscale_transform(GetJxlCms()); + ASSERT_TRUE(grayscale_transform.Init( + grayscale_hlg, ColorEncoding::LinearSRGB(/*is_gray=*/true), 1000, 1, 1)); + const float grayscale_hlg_value = 0.75; + float linear_grayscale_value; + ASSERT_TRUE(grayscale_transform.Run(0, &grayscale_hlg_value, + &linear_grayscale_value)); + EXPECT_THAT(linear_grayscale_value, FloatNear(0.203, 1e-3)); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/common.h b/media/libjxl/src/lib/jxl/common.h new file mode 100644 index 000000000..b213c8da6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/common.h @@ -0,0 +1,238 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COMMON_H_ +#define LIB_JXL_COMMON_H_ + +// Shared constants and helper functions. + +#include +#include +#include + +#include // numeric_limits +#include // unique_ptr +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" + +#ifndef JXL_HIGH_PRECISION +#define JXL_HIGH_PRECISION 1 +#endif + +// Macro that defines whether support for decoding JXL files to JPEG is enabled. +#ifndef JPEGXL_ENABLE_TRANSCODE_JPEG +#define JPEGXL_ENABLE_TRANSCODE_JPEG 1 +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +namespace jxl { +// Some enums and typedefs used by more than one header file. + +constexpr size_t kBitsPerByte = 8; // more clear than CHAR_BIT + +constexpr inline size_t RoundUpBitsToByteMultiple(size_t bits) { + return (bits + 7) & ~size_t(7); +} + +constexpr inline size_t RoundUpToBlockDim(size_t dim) { + return (dim + 7) & ~size_t(7); +} + +static inline bool JXL_MAYBE_UNUSED SafeAdd(const uint64_t a, const uint64_t b, + uint64_t& sum) { + sum = a + b; + return sum >= a; // no need to check b - either sum >= both or < both. +} + +template +constexpr inline T1 DivCeil(T1 a, T2 b) { + return (a + b - 1) / b; +} + +// Works for any `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +constexpr double kPi = 3.14159265358979323846264338327950288; + +// Reasonable default for sRGB, matches common monitors. We map white to this +// many nits (cd/m^2) by default. Butteraugli was tuned for 250 nits, which is +// very close. +static constexpr float kDefaultIntensityTarget = 255; + +template +constexpr T Pi(T multiplier) { + return static_cast(multiplier * kPi); +} + +// Block is the square grid of pixels to which an "energy compaction" +// transformation (e.g. DCT) is applied. Each block has its own AC quantizer. +constexpr size_t kBlockDim = 8; + +constexpr size_t kDCTBlockSize = kBlockDim * kBlockDim; + +constexpr size_t kGroupDim = 256; +static_assert(kGroupDim % kBlockDim == 0, + "Group dim should be divisible by block dim"); +constexpr size_t kGroupDimInBlocks = kGroupDim / kBlockDim; + +// Maximum number of passes in an image. +constexpr size_t kMaxNumPasses = 11; + +// Maximum number of reference frames. +constexpr size_t kMaxNumReferenceFrames = 4; + +// Dimensions of a frame, in pixels, and other derived dimensions. +// Computed from FrameHeader. +// TODO(veluca): add extra channels. +struct FrameDimensions { + void Set(size_t xsize, size_t ysize, size_t group_size_shift, + size_t max_hshift, size_t max_vshift, bool modular_mode, + size_t upsampling) { + group_dim = (kGroupDim >> 1) << group_size_shift; + dc_group_dim = group_dim * kBlockDim; + xsize_upsampled = xsize; + ysize_upsampled = ysize; + this->xsize = DivCeil(xsize, upsampling); + this->ysize = DivCeil(ysize, upsampling); + xsize_blocks = DivCeil(this->xsize, kBlockDim << max_hshift) << max_hshift; + ysize_blocks = DivCeil(this->ysize, kBlockDim << max_vshift) << max_vshift; + xsize_padded = xsize_blocks * kBlockDim; + ysize_padded = ysize_blocks * kBlockDim; + if (modular_mode) { + // Modular mode doesn't have any padding. + xsize_padded = this->xsize; + ysize_padded = this->ysize; + } + xsize_upsampled_padded = xsize_padded * upsampling; + ysize_upsampled_padded = ysize_padded * upsampling; + xsize_groups = DivCeil(this->xsize, group_dim); + ysize_groups = DivCeil(this->ysize, group_dim); + xsize_dc_groups = DivCeil(xsize_blocks, group_dim); + ysize_dc_groups = DivCeil(ysize_blocks, group_dim); + num_groups = xsize_groups * ysize_groups; + num_dc_groups = xsize_dc_groups * ysize_dc_groups; + } + + // Image size without any upsampling, i.e. original_size / upsampling. + size_t xsize; + size_t ysize; + // Original image size. + size_t xsize_upsampled; + size_t ysize_upsampled; + // Image size after upsampling the padded image. + size_t xsize_upsampled_padded; + size_t ysize_upsampled_padded; + // Image size after padding to a multiple of kBlockDim (if VarDCT mode). + size_t xsize_padded; + size_t ysize_padded; + // Image size in kBlockDim blocks. + size_t xsize_blocks; + size_t ysize_blocks; + // Image size in number of groups. + size_t xsize_groups; + size_t ysize_groups; + // Image size in number of DC groups. + size_t xsize_dc_groups; + size_t ysize_dc_groups; + // Number of AC or DC groups. + size_t num_groups; + size_t num_dc_groups; + // Size of a group. + size_t group_dim; + size_t dc_group_dim; +}; + +// Prior to C++14 (i.e. C++11): provide our own make_unique +#if __cplusplus < 201402L +template +std::unique_ptr make_unique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} +#else +using std::make_unique; +#endif + +template +JXL_INLINE T Clamp1(T val, T low, T hi) { + return val < low ? low : val > hi ? hi : val; +} + +// Encodes non-negative (X) into (2 * X), negative (-X) into (2 * X - 1) +constexpr uint32_t PackSigned(int32_t value) + JXL_NO_SANITIZE("unsigned-integer-overflow") { + return (static_cast(value) << 1) ^ + ((static_cast(~value) >> 31) - 1); +} + +// Reverse to PackSigned, i.e. UnpackSigned(PackSigned(X)) == X. +// (((~value) & 1) - 1) is either 0 or 0xFF...FF and it will have an expected +// unsigned-integer-overflow. +constexpr intptr_t UnpackSigned(size_t value) + JXL_NO_SANITIZE("unsigned-integer-overflow") { + return static_cast((value >> 1) ^ (((~value) & 1) - 1)); +} + +// conversion from integer to string. +template +std::string ToString(T n) { + char data[32] = {}; + if (T(0.1) != T(0)) { + // float + snprintf(data, sizeof(data), "%g", static_cast(n)); + } else if (T(-1) > T(0)) { + // unsigned + snprintf(data, sizeof(data), "%llu", static_cast(n)); + } else { + // signed + snprintf(data, sizeof(data), "%lld", static_cast(n)); + } + return data; +} + +namespace { +static inline uint64_t DecodeVarInt(const uint8_t* input, size_t inputSize, + size_t* pos) { + size_t i; + uint64_t ret = 0; + for (i = 0; *pos + i < inputSize && i < 10; ++i) { + ret |= uint64_t(input[*pos + i] & 127) << uint64_t(7 * i); + // If the next-byte flag is not set, stop + if ((input[*pos + i] & 128) == 0) break; + } + // TODO: Return a decoding error if i == 10. + *pos += i + 1; + return ret; +} + +static inline bool EncodeVarInt(uint64_t value, size_t output_size, + size_t* output_pos, uint8_t* output) { + // While more than 7 bits of data are left, + // store 7 bits and set the next byte flag + while (value > 127) { + if (*output_pos > output_size) return false; + // |128: Set the next byte flag + output[(*output_pos)++] = ((uint8_t)(value & 127)) | 128; + // Remove the seven bits we just wrote + value >>= 7; + } + if (*output_pos > output_size) return false; + output[(*output_pos)++] = ((uint8_t)value) & 127; + return true; +} + +static inline void EncodeVarInt(uint64_t value, PaddedBytes* data) { + size_t pos = data->size(); + data->resize(data->size() + 9); + JXL_CHECK(EncodeVarInt(value, data->size(), &pos, data->data())); + data->resize(pos); +} +} // namespace + +} // namespace jxl + +#endif // LIB_JXL_COMMON_H_ diff --git a/media/libjxl/src/lib/jxl/compressed_dc.cc b/media/libjxl/src/lib/jxl/compressed_dc.cc new file mode 100644 index 000000000..3b2c32399 --- /dev/null +++ b/media/libjxl/src/lib/jxl/compressed_dc.cc @@ -0,0 +1,320 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/compressed_dc.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc" +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using D = HWY_FULL(float); +using DScalar = HWY_CAPPED(float, 1); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// TODO(veluca): optimize constants. +const float w1 = 0.20345139757231578f; +const float w2 = 0.0334829185968739f; +const float w0 = 1.0f - 4.0f * (w1 + w2); + +template +V MaxWorkaround(V a, V b) { +#if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800 + // Prevents "Do not know how to split the result of this operator" error + return IfThenElse(a > b, a, b); +#else + return Max(a, b); +#endif +} + +template +JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor, + const float* JXL_RESTRICT row_top, + const float* JXL_RESTRICT row, + const float* JXL_RESTRICT row_bottom, + Vec* JXL_RESTRICT mc, + Vec* JXL_RESTRICT sm, + Vec* JXL_RESTRICT gap, size_t x) { + const auto tl = LoadU(d, row_top + x - 1); + const auto tc = Load(d, row_top + x); + const auto tr = LoadU(d, row_top + x + 1); + + const auto ml = LoadU(d, row + x - 1); + *mc = Load(d, row + x); + const auto mr = LoadU(d, row + x + 1); + + const auto bl = LoadU(d, row_bottom + x - 1); + const auto bc = Load(d, row_bottom + x); + const auto br = LoadU(d, row_bottom + x + 1); + + const auto w_center = Set(d, w0); + const auto w_side = Set(d, w1); + const auto w_corner = Set(d, w2); + + const auto corner = Add(Add(tl, tr), Add(bl, br)); + const auto side = Add(Add(ml, mr), Add(tc, bc)); + *sm = MulAdd(corner, w_corner, MulAdd(side, w_side, Mul(*mc, w_center))); + + const auto dc_quant = Set(d, dc_factor); + *gap = MaxWorkaround(*gap, Abs(Div(Sub(*mc, *sm), dc_quant))); +} + +template +JXL_INLINE void ComputePixel( + const float* JXL_RESTRICT dc_factors, + const float* JXL_RESTRICT* JXL_RESTRICT rows_top, + const float* JXL_RESTRICT* JXL_RESTRICT rows, + const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom, + float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) { + const D d; + auto mc_x = Undefined(d); + auto mc_y = Undefined(d); + auto mc_b = Undefined(d); + auto sm_x = Undefined(d); + auto sm_y = Undefined(d); + auto sm_b = Undefined(d); + auto gap = Set(d, 0.5f); + ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0], + &mc_x, &sm_x, &gap, x); + ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1], + &mc_y, &sm_y, &gap, x); + ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2], + &mc_b, &sm_b, &gap, x); + auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f)); + factor = ZeroIfNegative(factor); + + auto out = MulAdd(Sub(sm_x, mc_x), factor, mc_x); + Store(out, d, out_rows[0] + x); + out = MulAdd(Sub(sm_y, mc_y), factor, mc_y); + Store(out, d, out_rows[1] + x); + out = MulAdd(Sub(sm_b, mc_b), factor, mc_b); + Store(out, d, out_rows[2] + x); +} + +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool) { + const size_t xsize = dc->xsize(); + const size_t ysize = dc->ysize(); + if (ysize <= 2 || xsize <= 2) return; + + // TODO(veluca): use tile-based processing? + // TODO(veluca): decide if changes to the y channel should be propagated to + // the x and b channels through color correlation. + JXL_ASSERT(w1 + w2 < 0.25f); + + PROFILER_FUNC; + + Image3F smoothed(xsize, ysize); + // Fill in borders that the loop below will not. First and last are unused. + for (size_t c = 0; c < 3; c++) { + for (size_t y : {size_t(0), ysize - 1}) { + memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y), + xsize * sizeof(float)); + } + } + auto process_row = [&](const uint32_t y, size_t /*thread*/) { + const float* JXL_RESTRICT rows_top[3]{ + dc->ConstPlaneRow(0, y - 1), + dc->ConstPlaneRow(1, y - 1), + dc->ConstPlaneRow(2, y - 1), + }; + const float* JXL_RESTRICT rows[3] = { + dc->ConstPlaneRow(0, y), + dc->ConstPlaneRow(1, y), + dc->ConstPlaneRow(2, y), + }; + const float* JXL_RESTRICT rows_bottom[3] = { + dc->ConstPlaneRow(0, y + 1), + dc->ConstPlaneRow(1, y + 1), + dc->ConstPlaneRow(2, y + 1), + }; + float* JXL_RESTRICT rows_out[3] = { + smoothed.PlaneRow(0, y), + smoothed.PlaneRow(1, y), + smoothed.PlaneRow(2, y), + }; + for (size_t x : {size_t(0), xsize - 1}) { + for (size_t c = 0; c < 3; c++) { + rows_out[c][x] = rows[c][x]; + } + } + + size_t x = 1; + // First pixels + const size_t N = Lanes(D()); + for (; x < std::min(N, xsize - 1); x++) { + ComputePixel(dc_factors, rows_top, rows, rows_bottom, rows_out, + x); + } + // Full vectors. + for (; x + N <= xsize - 1; x += N) { + ComputePixel(dc_factors, rows_top, rows, rows_bottom, rows_out, x); + } + // Last pixels. + for (; x < xsize - 1; x++) { + ComputePixel(dc_factors, rows_top, rows, rows_bottom, rows_out, + x); + } + }; + JXL_CHECK(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit, process_row, + "DCSmoothingRow")); + dc->Swap(smoothed); +} + +// DC dequantization. +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + if (chroma_subsampling.Is444()) { + const auto fac_x = Set(df, dc_factors[0] * mul); + const auto fac_y = Set(df, dc_factors[1] * mul); + const auto fac_b = Set(df, dc_factors[2] * mul); + const auto cfl_fac_x = Set(df, cfl_factors[0]); + const auto cfl_fac_b = Set(df, cfl_factors[2]); + for (size_t y = 0; y < r.ysize(); y++) { + float* dec_row_x = r.PlaneRow(dc, 0, y); + float* dec_row_y = r.PlaneRow(dc, 1, y); + float* dec_row_b = r.PlaneRow(dc, 2, y); + const int32_t* quant_row_x = in.channel[1].plane.Row(y); + const int32_t* quant_row_y = in.channel[0].plane.Row(y); + const int32_t* quant_row_b = in.channel[2].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x += Lanes(di)) { + const auto in_q_x = Load(di, quant_row_x + x); + const auto in_q_y = Load(di, quant_row_y + x); + const auto in_q_b = Load(di, quant_row_b + x); + const auto in_x = Mul(ConvertTo(df, in_q_x), fac_x); + const auto in_y = Mul(ConvertTo(df, in_q_y), fac_y); + const auto in_b = Mul(ConvertTo(df, in_q_b), fac_b); + Store(in_y, df, dec_row_y + x); + Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x); + Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x); + } + } + } else { + for (size_t c : {1, 0, 2}) { + Rect rect(r.x0() >> chroma_subsampling.HShift(c), + r.y0() >> chroma_subsampling.VShift(c), + r.xsize() >> chroma_subsampling.HShift(c), + r.ysize() >> chroma_subsampling.VShift(c)); + const auto fac = Set(df, dc_factors[c] * mul); + const Channel& ch = in.channel[c < 2 ? c ^ 1 : c]; + for (size_t y = 0; y < rect.ysize(); y++) { + const int32_t* quant_row = ch.plane.Row(y); + float* row = rect.PlaneRow(dc, c, y); + for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) { + const auto in_q = Load(di, quant_row + x); + const auto in = Mul(ConvertTo(df, in_q), fac); + Store(in, df, row + x); + } + } + } + } + if (bctx.num_dc_ctxs <= 1) { + for (size_t y = 0; y < r.ysize(); y++) { + uint8_t* qdc_row = r.Row(quant_dc, y); + memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize()); + } + } else { + for (size_t y = 0; y < r.ysize(); y++) { + uint8_t* qdc_row_val = r.Row(quant_dc, y); + const int32_t* quant_row_x = + in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0)); + const int32_t* quant_row_y = + in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1)); + const int32_t* quant_row_b = + in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2)); + for (size_t x = 0; x < r.xsize(); x++) { + int bucket_x = 0, bucket_y = 0, bucket_b = 0; + for (int t : bctx.dc_thresholds[0]) { + if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++; + } + for (int t : bctx.dc_thresholds[1]) { + if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++; + } + for (int t : bctx.dc_thresholds[2]) { + if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++; + } + int bucket = bucket_x; + bucket *= bctx.dc_thresholds[2].size() + 1; + bucket += bucket_b; + bucket *= bctx.dc_thresholds[1].size() + 1; + bucket += bucket_y; + qdc_row_val[x] = bucket; + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(DequantDC); +HWY_EXPORT(AdaptiveDCSmoothing); +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(dc_factors, dc, pool); +} + +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx) { + return HWY_DYNAMIC_DISPATCH(DequantDC)(r, dc, quant_dc, in, dc_factors, mul, + cfl_factors, chroma_subsampling, bctx); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/compressed_dc.h b/media/libjxl/src/lib/jxl/compressed_dc.h new file mode 100644 index 000000000..b06e5931f --- /dev/null +++ b/media/libjxl/src/lib/jxl/compressed_dc.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COMPRESSED_DC_H_ +#define LIB_JXL_COMPRESSED_DC_H_ + +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/modular_image.h" + +// DC handling functions: encoding and decoding of DC to and from bitstream, and +// related function to initialize the per-group decoder cache. + +namespace jxl { + +// Smooth DC in already-smooth areas, to counteract banding. +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool); + +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx); + +} // namespace jxl + +#endif // LIB_JXL_COMPRESSED_DC_H_ diff --git a/media/libjxl/src/lib/jxl/convolve-inl.h b/media/libjxl/src/lib/jxl/convolve-inl.h new file mode 100644 index 000000000..054c9c6f0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve-inl.h @@ -0,0 +1,297 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_CONVOLVE_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_CONVOLVE_INL_H_ +#undef LIB_JXL_CONVOLVE_INL_H_ +#else +#define LIB_JXL_CONVOLVE_INL_H_ +#endif + +#include + +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image_ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::CombineShiftRightBytes; +#endif +using hwy::HWY_NAMESPACE::TableLookupLanes; +using hwy::HWY_NAMESPACE::Vec; + +// Synthesizes left/right neighbors from a vector of center pixels. +class Neighbors { + public: + using D = HWY_CAPPED(float, 16); + using V = Vec; + + // Returns l[i] == c[Mirror(i - 1)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL1(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {0, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // ONML'KJII +#elif HWY_TARGET == HWY_SCALAR + return c; // Same (the first mirrored value is the last valid one) +#else // 128 bit + // c = LKJI +#if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(2, 1, 0, 0))}; // KJII +#else + const D d; + // TODO(deymo): Figure out if this can be optimized using a single vsri + // instruction to convert LKJI to KJII. + HWY_ALIGN constexpr int lanes[4] = {0, 0, 1, 2}; // KJII + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } + + // Returns l[i] == c[Mirror(i - 2)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL2(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {1, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // NMLK'JIIJ +#elif HWY_TARGET == HWY_SCALAR + const D d; + JXL_ASSERT(false); // unsupported, avoid calling this. + return Zero(d); +#else // 128 bit + // c = LKJI +#if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(1, 0, 0, 1))}; // JIIJ +#else + const D d; + HWY_ALIGN constexpr int lanes[4] = {1, 0, 0, 1}; // JIIJ + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } + + // Returns l[i] == c[Mirror(i - 3)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL3(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {2, 1, 0, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // MLKJ'IIJK +#elif HWY_TARGET == HWY_SCALAR + const D d; + JXL_ASSERT(false); // unsupported, avoid calling this. + return Zero(d); +#else // 128 bit + // c = LKJI +#if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(0, 0, 1, 2))}; // IIJK +#else + const D d; + HWY_ALIGN constexpr int lanes[4] = {2, 1, 0, 0}; // IIJK + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } +}; + +#if HWY_TARGET != HWY_SCALAR + +// Returns indices for SetTableIndices such that TableLookupLanes on the +// rightmost unaligned vector (rightmost sample in its most-significant lane) +// returns the mirrored values, with the mirror outside the last valid sample. +static inline const int32_t* MirrorLanes(const size_t mod) { + const HWY_CAPPED(float, 16) d; + constexpr size_t kN = MaxLanes(d); + + // For mod = `image width mod 16` 0..15: + // last full vec mirrored (mem order) loadedVec mirrorVec idxVec + // 0123456789abcdef| fedcba9876543210 fed..210 012..def 012..def + // 0123456789abcdef|0 0fedcba98765432 0fe..321 234..f00 123..eff + // 0123456789abcdef|01 10fedcba987654 10f..432 456..110 234..ffe + // 0123456789abcdef|012 210fedcba9876 210..543 67..2210 34..ffed + // 0123456789abcdef|0123 3210fedcba98 321..654 8..33210 4..ffedc + // 0123456789abcdef|01234 43210fedcba + // 0123456789abcdef|012345 543210fedc + // 0123456789abcdef|0123456 6543210fe + // 0123456789abcdef|01234567 76543210 + // 0123456789abcdef|012345678 8765432 + // 0123456789abcdef|0123456789 987654 + // 0123456789abcdef|0123456789A A9876 + // 0123456789abcdef|0123456789AB BA98 + // 0123456789abcdef|0123456789ABC CBA + // 0123456789abcdef|0123456789ABCD DC + // 0123456789abcdef|0123456789ABCDE E EDC..10f EED..210 ffe..321 +#if HWY_CAP_GE512 + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, // + 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; +#elif HWY_CAP_GE256 + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { + 1, 2, 3, 4, 5, 6, 7, 7, // + 6, 5, 4, 3, 2, 1, 0}; +#else // 128-bit + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {1, 2, 3, 3, // + 2, 1, 0}; +#endif + return idx_lanes + kN - 1 - mod; +} + +#endif // HWY_TARGET != HWY_SCALAR + +// Single entry point for convolution. +// "Strategy" (Direct*/Separable*) decides kernel size and how to evaluate it. +template +class ConvolveT { + static constexpr int64_t kRadius = Strategy::kRadius; + using Simd = HWY_CAPPED(float, 16); + + public: + static size_t MinWidth() { +#if HWY_TARGET == HWY_SCALAR + // First/Last use mirrored loads of up to +/- kRadius. + return 2 * kRadius; +#else + return Lanes(Simd()) + kRadius; +#endif + } + + // "Image" is ImageF or Image3F. + template + static void Run(const Image& in, const Rect& rect, const Weights& weights, + ThreadPool* pool, Image* out) { + PROFILER_ZONE("ConvolveT::Run"); + JXL_CHECK(SameSize(rect, *out)); + JXL_CHECK(rect.xsize() >= MinWidth()); + + static_assert(int64_t(kRadius) <= 3, + "Must handle [0, kRadius) and >= kRadius"); + switch (rect.xsize() % Lanes(Simd())) { + case 0: + return RunRows<0>(in, rect, weights, pool, out); + case 1: + return RunRows<1>(in, rect, weights, pool, out); + case 2: + return RunRows<2>(in, rect, weights, pool, out); + default: + return RunRows<3>(in, rect, weights, pool, out); + } + } + + private: + template + static JXL_INLINE void RunRow(const float* JXL_RESTRICT in, + const size_t xsize, const int64_t stride, + const WrapRow& wrap_row, const Weights& weights, + float* JXL_RESTRICT out) { + Strategy::template ConvolveRow(in, xsize, stride, wrap_row, + weights, out); + } + + template + static JXL_INLINE void RunBorderRows(const ImageF& in, const Rect& rect, + const int64_t ybegin, const int64_t yend, + const Weights& weights, ImageF* out) { + const int64_t stride = in.PixelsPerRow(); + const WrapRowMirror wrap_row(in, rect.ysize()); + for (int64_t y = ybegin; y < yend; ++y) { + RunRow(rect.ConstRow(in, y), rect.xsize(), stride, wrap_row, + weights, out->Row(y)); + } + } + + // Image3F. + template + static JXL_INLINE void RunBorderRows(const Image3F& in, const Rect& rect, + const int64_t ybegin, const int64_t yend, + const Weights& weights, Image3F* out) { + const int64_t stride = in.PixelsPerRow(); + for (int64_t y = ybegin; y < yend; ++y) { + for (size_t c = 0; c < 3; ++c) { + const WrapRowMirror wrap_row(in.Plane(c), rect.ysize()); + RunRow(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride, + wrap_row, weights, out->PlaneRow(c, y)); + } + } + } + + template + static JXL_INLINE void RunInteriorRows(const ImageF& in, const Rect& rect, + const int64_t ybegin, + const int64_t yend, + const Weights& weights, + ThreadPool* pool, ImageF* out) { + const int64_t stride = in.PixelsPerRow(); + JXL_CHECK(RunOnPool( + pool, ybegin, yend, ThreadPool::NoInit, + [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { + RunRow(rect.ConstRow(in, y), rect.xsize(), stride, + WrapRowUnchanged(), weights, out->Row(y)); + }, + "Convolve")); + } + + // Image3F. + template + static JXL_INLINE void RunInteriorRows(const Image3F& in, const Rect& rect, + const int64_t ybegin, + const int64_t yend, + const Weights& weights, + ThreadPool* pool, Image3F* out) { + const int64_t stride = in.PixelsPerRow(); + JXL_CHECK(RunOnPool( + pool, ybegin, yend, ThreadPool::NoInit, + [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { + for (size_t c = 0; c < 3; ++c) { + RunRow(rect.ConstPlaneRow(in, c, y), rect.xsize(), + stride, WrapRowUnchanged(), weights, + out->PlaneRow(c, y)); + } + }, + "Convolve3")); + } + + template + static JXL_INLINE void RunRows(const Image& in, const Rect& rect, + const Weights& weights, ThreadPool* pool, + Image* out) { + const int64_t ysize = rect.ysize(); + RunBorderRows(in, rect, 0, std::min(int64_t(kRadius), ysize), + weights, out); + if (ysize > 2 * int64_t(kRadius)) { + RunInteriorRows(in, rect, int64_t(kRadius), + ysize - int64_t(kRadius), weights, pool, out); + } + if (ysize > int64_t(kRadius)) { + RunBorderRows(in, rect, ysize - int64_t(kRadius), ysize, + weights, out); + } + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_CONVOLVE_INL_H_ diff --git a/media/libjxl/src/lib/jxl/convolve.h b/media/libjxl/src/lib/jxl/convolve.h new file mode 100644 index 000000000..2fcd2d098 --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve.h @@ -0,0 +1,105 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CONVOLVE_H_ +#define LIB_JXL_CONVOLVE_H_ + +// 2D convolution. + +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// No valid values outside [0, xsize), but the strategy may still safely load +// the preceding vector, and/or round xsize up to the vector lane count. This +// avoids needing PadImage. +// Requires xsize >= kConvolveLanes + kConvolveMaxRadius. +static constexpr size_t kConvolveMaxRadius = 3; + +// Weights must already be normalized. + +struct WeightsSymmetric3 { + // d r d (each replicated 4x) + // r c r + // d r d + float c[4]; + float r[4]; + float d[4]; +}; + +struct WeightsSymmetric5 { + // The lower-right quadrant is: c r R (each replicated 4x) + // r d L + // R L D + float c[4]; + float r[4]; + float R[4]; + float d[4]; + float D[4]; + float L[4]; +}; + +// Weights for separable 5x5 filters (typically but not necessarily the same +// values for horizontal and vertical directions). The kernel must already be +// normalized, but note that values for negative offsets are omitted, so the +// given values do not sum to 1. +struct WeightsSeparable5 { + // Horizontal 1D, distances 0..2 (each replicated 4x) + float horz[3 * 4]; + float vert[3 * 4]; +}; + +// Weights for separable 7x7 filters (typically but not necessarily the same +// values for horizontal and vertical directions). The kernel must already be +// normalized, but note that values for negative offsets are omitted, so the +// given values do not sum to 1. +// +// NOTE: for >= 7x7 Gaussian kernels, it is faster to use FastGaussian instead, +// at least when images exceed the L1 cache size. +struct WeightsSeparable7 { + // Horizontal 1D, distances 0..3 (each replicated 4x) + float horz[4 * 4]; + float vert[4 * 4]; +}; + +const WeightsSymmetric3& WeightsSymmetric3Lowpass(); +const WeightsSeparable5& WeightsSeparable5Lowpass(); +const WeightsSymmetric5& WeightsSymmetric5Lowpass(); + +void SlowSymmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out); + +void SlowSeparable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out); + +void SlowSeparable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out); + +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out); + +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out); + +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out); + +void Separable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out); + +} // namespace jxl + +#endif // LIB_JXL_CONVOLVE_H_ diff --git a/media/libjxl/src/lib/jxl/convolve_separable5.cc b/media/libjxl/src/lib/jxl/convolve_separable5.cc new file mode 100644 index 000000000..b26ff54bb --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve_separable5.cc @@ -0,0 +1,261 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_separable5.cc" +#include +#include + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; + +// 5x5 convolution by separable kernel with a single scan through the input. +// This is more cache-efficient than separate horizontal/vertical passes, and +// possibly faster (given enough registers) than tiling and/or transposing. +// +// Overview: imagine a 5x5 window around a central pixel. First convolve the +// rows by multiplying the pixels with the corresponding weights from +// WeightsSeparable5.horz[abs(x_offset) * 4]. Then multiply each of these +// intermediate results by the corresponding vertical weight, i.e. +// vert[abs(y_offset) * 4]. Finally, store the sum of these values as the +// convolution result at the position of the central pixel in the output. +// +// Each of these operations uses SIMD vectors. The central pixel and most +// importantly the output are aligned, so neighnoring pixels (e.g. x_offset=1) +// require unaligned loads. Because weights are supplied in identical groups of +// 4, we can use LoadDup128 to load them (slightly faster). +// +// Uses mirrored boundary handling. Until x >= kRadius, the horizontal +// convolution uses Neighbors class to shuffle vectors as if each of its lanes +// had been loaded from the mirrored offset. Similarly, the last full vector to +// write uses mirroring. In the case of scalar vectors, Neighbors is not usable +// and the value is loaded directly. Otherwise, the number of valid pixels +// modulo the vector size enables a small optimization: for smaller offsets, +// a non-mirrored load is sufficient. +class Separable5Strategy { + using D = HWY_CAPPED(float, 16); + using V = Vec; + + public: + static constexpr int64_t kRadius = 2; + + template + static JXL_MAYBE_INLINE void ConvolveRow( + const float* const JXL_RESTRICT row_m, const size_t xsize, + const int64_t stride, const WrapRow& wrap_row, + const WeightsSeparable5& weights, float* const JXL_RESTRICT row_out) { + const D d; + const int64_t neg_stride = -stride; // allows LEA addressing. + const float* const JXL_RESTRICT row_t2 = + wrap_row(row_m + 2 * neg_stride, stride); + const float* const JXL_RESTRICT row_t1 = + wrap_row(row_m + 1 * neg_stride, stride); + const float* const JXL_RESTRICT row_b1 = + wrap_row(row_m + 1 * stride, stride); + const float* const JXL_RESTRICT row_b2 = + wrap_row(row_m + 2 * stride, stride); + + const V wh0 = LoadDup128(d, weights.horz + 0 * 4); + const V wh1 = LoadDup128(d, weights.horz + 1 * 4); + const V wh2 = LoadDup128(d, weights.horz + 2 * 4); + const V wv0 = LoadDup128(d, weights.vert + 0 * 4); + const V wv1 = LoadDup128(d, weights.vert + 1 * 4); + const V wv2 = LoadDup128(d, weights.vert + 2 * 4); + + size_t x = 0; + + // More than one iteration for scalars. + for (; x < kRadius; x += Lanes(d)) { + const V conv0 = + Mul(HorzConvolveFirst(row_m, x, xsize, wh0, wh1, wh2), wv0); + + const V conv1t = HorzConvolveFirst(row_t1, x, xsize, wh0, wh1, wh2); + const V conv1b = HorzConvolveFirst(row_b1, x, xsize, wh0, wh1, wh2); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = HorzConvolveFirst(row_t2, x, xsize, wh0, wh1, wh2); + const V conv2b = HorzConvolveFirst(row_b2, x, xsize, wh0, wh1, wh2); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + Store(conv2, d, row_out + x); + } + + // Main loop: load inputs without padding + for (; x + Lanes(d) + kRadius <= xsize; x += Lanes(d)) { + const V conv0 = Mul(HorzConvolve(row_m + x, wh0, wh1, wh2), wv0); + + const V conv1t = HorzConvolve(row_t1 + x, wh0, wh1, wh2); + const V conv1b = HorzConvolve(row_b1 + x, wh0, wh1, wh2); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = HorzConvolve(row_t2 + x, wh0, wh1, wh2); + const V conv2b = HorzConvolve(row_b2 + x, wh0, wh1, wh2); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + Store(conv2, d, row_out + x); + } + + // Last full vector to write (the above loop handled mod >= kRadius) +#if HWY_TARGET == HWY_SCALAR + while (x < xsize) { +#else + if (kSizeModN < kRadius) { +#endif + const V conv0 = + Mul(HorzConvolveLast(row_m, x, xsize, wh0, wh1, wh2), wv0); + + const V conv1t = + HorzConvolveLast(row_t1, x, xsize, wh0, wh1, wh2); + const V conv1b = + HorzConvolveLast(row_b1, x, xsize, wh0, wh1, wh2); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = + HorzConvolveLast(row_t2, x, xsize, wh0, wh1, wh2); + const V conv2b = + HorzConvolveLast(row_b2, x, xsize, wh0, wh1, wh2); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + Store(conv2, d, row_out + x); + x += Lanes(d); + } + + // If mod = 0, the above vector was the last. + if (kSizeModN != 0) { + for (; x < xsize; ++x) { + float mul = 0.0f; + for (int64_t dy = -kRadius; dy <= kRadius; ++dy) { + const float wy = weights.vert[std::abs(dy) * 4]; + const float* clamped_row = wrap_row(row_m + dy * stride, stride); + for (int64_t dx = -kRadius; dx <= kRadius; ++dx) { + const float wx = weights.horz[std::abs(dx) * 4]; + const int64_t clamped_x = Mirror(x + dx, xsize); + mul += clamped_row[clamped_x] * wx * wy; + } + } + row_out[x] = mul; + } + } + } + + private: + // Same as HorzConvolve for the first/last vector in a row. + static JXL_MAYBE_INLINE V HorzConvolveFirst( + const float* const JXL_RESTRICT row, const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = Mul(c, wh0); + +#if HWY_TARGET == HWY_SCALAR + const V l1 = LoadU(d, row + Mirror(x - 1, xsize)); + const V l2 = LoadU(d, row + Mirror(x - 2, xsize)); +#else + (void)xsize; + const V l1 = Neighbors::FirstL1(c); + const V l2 = Neighbors::FirstL2(c); +#endif + + const V r1 = LoadU(d, row + x + 1); + const V r2 = LoadU(d, row + x + 2); + + const V mul1 = MulAdd(Add(l1, r1), wh1, mul0); + const V mul2 = MulAdd(Add(l2, r2), wh2, mul1); + return mul2; + } + + template + static JXL_MAYBE_INLINE V + HorzConvolveLast(const float* const JXL_RESTRICT row, const int64_t x, + const int64_t xsize, const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = Mul(c, wh0); + + const V l1 = LoadU(d, row + x - 1); + const V l2 = LoadU(d, row + x - 2); + + V r1, r2; +#if HWY_TARGET == HWY_SCALAR + r1 = LoadU(d, row + Mirror(x + 1, xsize)); + r2 = LoadU(d, row + Mirror(x + 2, xsize)); +#else + const size_t N = Lanes(d); + if (kSizeModN == 0) { + r2 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 2))); + r1 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 1))); + } else { // == 1 + const auto last = LoadU(d, row + xsize - N); + r2 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r1 = last; + } +#endif + + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = Add(l1, r1); + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = Add(l2, r2); + const V mul2 = MulAdd(sum2, wh2, mul1); + return mul2; + } + + // Requires kRadius valid pixels before/after pos. + static JXL_MAYBE_INLINE V HorzConvolve(const float* const JXL_RESTRICT pos, + const V wh0, const V wh1, + const V wh2) { + const D d; + const V c = LoadU(d, pos); + const V mul0 = Mul(c, wh0); + + // Loading anew is faster than combining vectors. + const V l1 = LoadU(d, pos - 1); + const V r1 = LoadU(d, pos + 1); + const V l2 = LoadU(d, pos - 2); + const V r2 = LoadU(d, pos + 2); + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = Add(l1, r1); + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = Add(l2, r2); + const V mul2 = MulAdd(sum2, wh2, mul1); + return mul2; + } +}; + +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable5(in, rect, weights, pool, out); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Separable5); +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Separable5)(in, rect, weights, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/convolve_separable7.cc b/media/libjxl/src/lib/jxl/convolve_separable7.cc new file mode 100644 index 000000000..086dfd22b --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve_separable7.cc @@ -0,0 +1,285 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_separable7.cc" +#include +#include + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; + +// 7x7 convolution by separable kernel with a single scan through the input. +// Extended version of Separable5, see documentation there. +class Separable7Strategy { + using D = HWY_CAPPED(float, 16); + using V = Vec; + + public: + static constexpr int64_t kRadius = 3; + + template + static JXL_MAYBE_INLINE void ConvolveRow( + const float* const JXL_RESTRICT row_m, const size_t xsize, + const int64_t stride, const WrapRow& wrap_row, + const WeightsSeparable7& weights, float* const JXL_RESTRICT row_out) { + const D d; + const int64_t neg_stride = -stride; // allows LEA addressing. + const float* const JXL_RESTRICT row_t3 = + wrap_row(row_m + 3 * neg_stride, stride); + const float* const JXL_RESTRICT row_t2 = + wrap_row(row_m + 2 * neg_stride, stride); + const float* const JXL_RESTRICT row_t1 = + wrap_row(row_m + 1 * neg_stride, stride); + const float* const JXL_RESTRICT row_b1 = + wrap_row(row_m + 1 * stride, stride); + const float* const JXL_RESTRICT row_b2 = + wrap_row(row_m + 2 * stride, stride); + const float* const JXL_RESTRICT row_b3 = + wrap_row(row_m + 3 * stride, stride); + + const V wh0 = LoadDup128(d, weights.horz + 0 * 4); + const V wh1 = LoadDup128(d, weights.horz + 1 * 4); + const V wh2 = LoadDup128(d, weights.horz + 2 * 4); + const V wh3 = LoadDup128(d, weights.horz + 3 * 4); + const V wv0 = LoadDup128(d, weights.vert + 0 * 4); + const V wv1 = LoadDup128(d, weights.vert + 1 * 4); + const V wv2 = LoadDup128(d, weights.vert + 2 * 4); + const V wv3 = LoadDup128(d, weights.vert + 3 * 4); + + size_t x = 0; + + // More than one iteration for scalars. + for (; x < kRadius; x += Lanes(d)) { + const V conv0 = + Mul(HorzConvolveFirst(row_m, x, xsize, wh0, wh1, wh2, wh3), wv0); + + const V conv1t = HorzConvolveFirst(row_t1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1b = HorzConvolveFirst(row_b1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = HorzConvolveFirst(row_t2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2b = HorzConvolveFirst(row_b2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + + const V conv3t = HorzConvolveFirst(row_t3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3b = HorzConvolveFirst(row_b3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3 = MulAdd(Add(conv3t, conv3b), wv3, conv2); + + Store(conv3, d, row_out + x); + } + + // Main loop: load inputs without padding + for (; x + Lanes(d) + kRadius <= xsize; x += Lanes(d)) { + const V conv0 = Mul(HorzConvolve(row_m + x, wh0, wh1, wh2, wh3), wv0); + + const V conv1t = HorzConvolve(row_t1 + x, wh0, wh1, wh2, wh3); + const V conv1b = HorzConvolve(row_b1 + x, wh0, wh1, wh2, wh3); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = HorzConvolve(row_t2 + x, wh0, wh1, wh2, wh3); + const V conv2b = HorzConvolve(row_b2 + x, wh0, wh1, wh2, wh3); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + + const V conv3t = HorzConvolve(row_t3 + x, wh0, wh1, wh2, wh3); + const V conv3b = HorzConvolve(row_b3 + x, wh0, wh1, wh2, wh3); + const V conv3 = MulAdd(Add(conv3t, conv3b), wv3, conv2); + + Store(conv3, d, row_out + x); + } + + // Last full vector to write (the above loop handled mod >= kRadius) +#if HWY_TARGET == HWY_SCALAR + while (x < xsize) { +#else + if (kSizeModN < kRadius) { +#endif + const V conv0 = + Mul(HorzConvolveLast(row_m, x, xsize, wh0, wh1, wh2, wh3), + wv0); + + const V conv1t = + HorzConvolveLast(row_t1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1b = + HorzConvolveLast(row_b1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = + HorzConvolveLast(row_t2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2b = + HorzConvolveLast(row_b2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + + const V conv3t = + HorzConvolveLast(row_t3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3b = + HorzConvolveLast(row_b3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3 = MulAdd(Add(conv3t, conv3b), wv3, conv2); + + Store(conv3, d, row_out + x); + x += Lanes(d); + } + + // If mod = 0, the above vector was the last. + if (kSizeModN != 0) { + for (; x < xsize; ++x) { + float mul = 0.0f; + for (int64_t dy = -kRadius; dy <= kRadius; ++dy) { + const float wy = weights.vert[std::abs(dy) * 4]; + const float* clamped_row = wrap_row(row_m + dy * stride, stride); + for (int64_t dx = -kRadius; dx <= kRadius; ++dx) { + const float wx = weights.horz[std::abs(dx) * 4]; + const int64_t clamped_x = Mirror(x + dx, xsize); + mul += clamped_row[clamped_x] * wx * wy; + } + } + row_out[x] = mul; + } + } + } + + private: + // Same as HorzConvolve for the first/last vector in a row. + static JXL_MAYBE_INLINE V HorzConvolveFirst( + const float* const JXL_RESTRICT row, const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2, const V wh3) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = Mul(c, wh0); + +#if HWY_TARGET == HWY_SCALAR + const V l1 = LoadU(d, row + Mirror(x - 1, xsize)); + const V l2 = LoadU(d, row + Mirror(x - 2, xsize)); + const V l3 = LoadU(d, row + Mirror(x - 3, xsize)); +#else + (void)xsize; + const V l1 = Neighbors::FirstL1(c); + const V l2 = Neighbors::FirstL2(c); + const V l3 = Neighbors::FirstL3(c); +#endif + + const V r1 = LoadU(d, row + x + 1); + const V r2 = LoadU(d, row + x + 2); + const V r3 = LoadU(d, row + x + 3); + + const V mul1 = MulAdd(Add(l1, r1), wh1, mul0); + const V mul2 = MulAdd(Add(l2, r2), wh2, mul1); + const V mul3 = MulAdd(Add(l3, r3), wh3, mul2); + return mul3; + } + + template + static JXL_MAYBE_INLINE V HorzConvolveLast( + const float* const JXL_RESTRICT row, const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2, const V wh3) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = Mul(c, wh0); + + const V l1 = LoadU(d, row + x - 1); + const V l2 = LoadU(d, row + x - 2); + const V l3 = LoadU(d, row + x - 3); + + V r1, r2, r3; +#if HWY_TARGET == HWY_SCALAR + r1 = LoadU(d, row + Mirror(x + 1, xsize)); + r2 = LoadU(d, row + Mirror(x + 2, xsize)); + r3 = LoadU(d, row + Mirror(x + 3, xsize)); +#else + const size_t N = Lanes(d); + if (kSizeModN == 0) { + r3 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 3))); + r2 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 2))); + r1 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 1))); + } else if (kSizeModN == 1) { + const auto last = LoadU(d, row + xsize - N); + r3 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 2))); + r2 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r1 = last; + } else /* kSizeModN >= 2 */ { + const auto last = LoadU(d, row + xsize - N); + r3 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r2 = last; + r1 = LoadU(d, row + x + 1); + } +#endif + + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = Add(l1, r1); + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = Add(l2, r2); + const V mul2 = MulAdd(sum2, wh2, mul1); + const V sum3 = Add(l3, r3); + const V mul3 = MulAdd(sum3, wh3, mul2); + return mul3; + } + + // Returns one vector of horizontal convolution results; lane i is the result + // for pixel pos + i. This is the fast path for interior pixels, i.e. kRadius + // valid pixels before/after pos. + static JXL_MAYBE_INLINE V HorzConvolve(const float* const JXL_RESTRICT pos, + const V wh0, const V wh1, const V wh2, + const V wh3) { + const D d; + const V c = LoadU(d, pos); + const V mul0 = Mul(c, wh0); + + // TODO(janwas): better to Combine + const V l1 = LoadU(d, pos - 1); + const V r1 = LoadU(d, pos + 1); + const V l2 = LoadU(d, pos - 2); + const V r2 = LoadU(d, pos + 2); + const V l3 = LoadU(d, pos - 3); + const V r3 = LoadU(d, pos + 3); + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = Add(l1, r1); + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = Add(l2, r2); + const V mul2 = MulAdd(sum2, wh2, mul1); + const V sum3 = Add(l3, r3); + const V mul3 = MulAdd(sum3, wh3, mul2); + return mul3; + } +}; + +void Separable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable7(in, rect, weights, pool, out); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Separable7); +void Separable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Separable7)(in, rect, weights, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/convolve_slow.cc b/media/libjxl/src/lib/jxl/convolve_slow.cc new file mode 100644 index 000000000..fffe5f74c --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve_slow.cc @@ -0,0 +1,212 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#include "lib/jxl/convolve-inl.h" + +namespace jxl { + +//------------------------------------------------------------------------------ +// Kernels + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define JXL_REP4(literal) literal, literal, literal, literal + +// Concentrates energy in low-frequency components (e.g. for antialiasing). +const WeightsSymmetric3& WeightsSymmetric3Lowpass() { + // Computed by research/convolve_weights.py's cubic spline approximations of + // prolate spheroidal wave functions. + constexpr float w0 = 0.36208932f; + constexpr float w1 = 0.12820096f; + constexpr float w2 = 0.03127668f; + static constexpr WeightsSymmetric3 weights = { + {JXL_REP4(w0)}, {JXL_REP4(w1)}, {JXL_REP4(w2)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Lowpass() { + constexpr float w0 = 0.41714928f; + constexpr float w1 = 0.25539268f; + constexpr float w2 = 0.03603267f; + static constexpr WeightsSeparable5 weights = { + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}, + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}}; + return weights; +} + +const WeightsSymmetric5& WeightsSymmetric5Lowpass() { + static constexpr WeightsSymmetric5 weights = { + {JXL_REP4(0.1740135f)}, {JXL_REP4(0.1065369f)}, {JXL_REP4(0.0150310f)}, + {JXL_REP4(0.0652254f)}, {JXL_REP4(0.0012984f)}, {JXL_REP4(0.0092025f)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Gaussian1() { + constexpr float w0 = 0.38774f; + constexpr float w1 = 0.24477f; + constexpr float w2 = 0.06136f; + static constexpr WeightsSeparable5 weights = { + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}, + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Gaussian2() { + constexpr float w0 = 0.250301f; + constexpr float w1 = 0.221461f; + constexpr float w2 = 0.153388f; + static constexpr WeightsSeparable5 weights = { + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}, + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}}; + return weights; +} + +#undef JXL_REP4 + +//------------------------------------------------------------------------------ +// Slow + +namespace { + +template +float SlowSymmetric3Pixel(const ImageF& in, const int64_t ix, const int64_t iy, + const int64_t xsize, const int64_t ysize, + const WeightsSymmetric3& weights) { + float sum = 0.0f; + + // ix: image; kx: kernel + for (int64_t ky = -1; ky <= 1; ky++) { + const int64_t y = WrapY()(iy + ky, ysize); + const float* JXL_RESTRICT row_in = in.ConstRow(static_cast(y)); + + const float wc = ky == 0 ? weights.c[0] : weights.r[0]; + const float wlr = ky == 0 ? weights.r[0] : weights.d[0]; + + const int64_t xm1 = WrapX()(ix - 1, xsize); + const int64_t xp1 = WrapX()(ix + 1, xsize); + sum += row_in[ix] * wc + (row_in[xm1] + row_in[xp1]) * wlr; + } + return sum; +} + +template +void SlowSymmetric3Row(const ImageF& in, const int64_t iy, const int64_t xsize, + const int64_t ysize, const WeightsSymmetric3& weights, + float* JXL_RESTRICT row_out) { + row_out[0] = + SlowSymmetric3Pixel(in, 0, iy, xsize, ysize, weights); + for (int64_t ix = 1; ix < xsize - 1; ix++) { + row_out[ix] = SlowSymmetric3Pixel(in, ix, iy, xsize, + ysize, weights); + } + { + const int64_t ix = xsize - 1; + row_out[ix] = SlowSymmetric3Pixel(in, ix, iy, xsize, + ysize, weights); + } +} + +} // namespace + +void SlowSymmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + + const int64_t xsize = static_cast(rect.xsize()); + const int64_t ysize = static_cast(rect.ysize()); + const int64_t kRadius = 1; + + JXL_CHECK(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t iy = task; + float* JXL_RESTRICT out_row = out->Row(static_cast(iy)); + + if (iy < kRadius || iy >= ysize - kRadius) { + SlowSymmetric3Row(in, iy, xsize, ysize, weights, out_row); + } else { + SlowSymmetric3Row(in, iy, xsize, ysize, weights, + out_row); + } + }, + "SlowSymmetric3")); +} + +namespace { + +// Separable kernels, any radius. +float SlowSeparablePixel(const ImageF& in, const Rect& rect, const int64_t x, + const int64_t y, const int64_t radius, + const float* JXL_RESTRICT horz_weights, + const float* JXL_RESTRICT vert_weights) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + const WrapMirror wrap; + + float mul = 0.0f; + for (int dy = -radius; dy <= radius; ++dy) { + const float wy = vert_weights[std::abs(dy) * 4]; + const size_t sy = wrap(y + dy, ysize); + JXL_CHECK(sy < ysize); + const float* const JXL_RESTRICT row = rect.ConstRow(in, sy); + for (int dx = -radius; dx <= radius; ++dx) { + const float wx = horz_weights[std::abs(dx) * 4]; + const size_t sx = wrap(x + dx, xsize); + JXL_CHECK(sx < xsize); + mul += row[sx] * wx * wy; + } + } + return mul; +} + +} // namespace + +void SlowSeparable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + PROFILER_FUNC; + const float* horz_weights = &weights.horz[0]; + const float* vert_weights = &weights.vert[0]; + + const size_t ysize = rect.ysize(); + JXL_CHECK(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + + float* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row_out[x] = SlowSeparablePixel(in, rect, x, y, /*radius=*/2, + horz_weights, vert_weights); + } + }, + "SlowSeparable5")); +} + +void SlowSeparable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out) { + PROFILER_FUNC; + const float* horz_weights = &weights.horz[0]; + const float* vert_weights = &weights.vert[0]; + + const size_t ysize = rect.ysize(); + JXL_CHECK(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + + float* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row_out[x] = SlowSeparablePixel(in, rect, x, y, /*radius=*/3, + horz_weights, vert_weights); + } + }, + "SlowSeparable7")); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/convolve_symmetric3.cc b/media/libjxl/src/lib/jxl/convolve_symmetric3.cc new file mode 100644 index 000000000..06b59dfb6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve_symmetric3.cc @@ -0,0 +1,194 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_symmetric3.cc" +#include +#include + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; + +template +static V WeightedSum(const ImageF& in, const WrapY wrap_y, const size_t ix, + const int64_t iy, const size_t ysize, const V wx0, + const V wx1, const V wx2) { + const HWY_FULL(float) d; + const float* JXL_RESTRICT center = in.ConstRow(wrap_y(iy, ysize)) + ix; + const auto in_m2 = LoadU(d, center - 2); + const auto in_p2 = LoadU(d, center + 2); + const auto in_m1 = LoadU(d, center - 1); + const auto in_p1 = LoadU(d, center + 1); + const auto in_00 = Load(d, center); + const auto sum_2 = Mul(wx2, Add(in_m2, in_p2)); + const auto sum_1 = Mul(wx1, Add(in_m1, in_p1)); + const auto sum_0 = Mul(wx0, in_00); + return Add(sum_2, Add(sum_1, sum_0)); +} + +// 3x3 convolution by symmetric kernel with a single scan through the input. +class Symmetric3Strategy { + using D = HWY_CAPPED(float, 16); + using V = Vec; + + public: + static constexpr int64_t kRadius = 1; + + // Only accesses pixels in [0, xsize). + template + static JXL_MAYBE_INLINE void ConvolveRow( + const float* const JXL_RESTRICT row_m, const size_t xsize, + const int64_t stride, const WrapRow& wrap_row, + const WeightsSymmetric3& weights, float* const JXL_RESTRICT row_out) { + const D d; + // t, m, b = top, middle, bottom row; + const float* const JXL_RESTRICT row_t = wrap_row(row_m - stride, stride); + const float* const JXL_RESTRICT row_b = wrap_row(row_m + stride, stride); + + // Must load in advance - compiler doesn't understand LoadDup128 and + // schedules them too late. + const V w0 = LoadDup128(d, weights.c); + const V w1 = LoadDup128(d, weights.r); + const V w2 = LoadDup128(d, weights.d); + + // l, c, r = left, center, right. Leftmost vector: need FirstL1. + { + const V tc = LoadU(d, row_t + 0); + const V mc = LoadU(d, row_m + 0); + const V bc = LoadU(d, row_b + 0); + const V tl = Neighbors::FirstL1(tc); + const V tr = LoadU(d, row_t + 0 + 1); + const V ml = Neighbors::FirstL1(mc); + const V mr = LoadU(d, row_m + 0 + 1); + const V bl = Neighbors::FirstL1(bc); + const V br = LoadU(d, row_b + 0 + 1); + const V conv = + WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + Store(conv, d, row_out + 0); + } + + // Loop as long as we can load enough new values: + const size_t N = Lanes(d); + size_t x = N; + for (; x + N + kRadius <= xsize; x += N) { + const auto conv = ConvolveValid(row_t, row_m, row_b, x, w0, w1, w2); + Store(conv, d, row_out + x); + } + + // For final (partial) vector: + const V tc = LoadU(d, row_t + x); + const V mc = LoadU(d, row_m + x); + const V bc = LoadU(d, row_b + x); + + V tr, mr, br; +#if HWY_TARGET == HWY_SCALAR + tr = tc; // Single-lane => mirrored right neighbor = center value. + mr = mc; + br = bc; +#else + if (kSizeModN == 0) { + // The above loop didn't handle the last vector because it needs an + // additional right neighbor (generated via mirroring). + auto mirror = SetTableIndices(d, MirrorLanes(N - 1)); + tr = TableLookupLanes(tc, mirror); + mr = TableLookupLanes(mc, mirror); + br = TableLookupLanes(bc, mirror); + } else { + auto mirror = SetTableIndices(d, MirrorLanes((xsize % N) - 1)); + // Loads last valid value into uppermost lane and mirrors. + tr = TableLookupLanes(LoadU(d, row_t + xsize - N), mirror); + mr = TableLookupLanes(LoadU(d, row_m + xsize - N), mirror); + br = TableLookupLanes(LoadU(d, row_b + xsize - N), mirror); + } +#endif + + const V tl = LoadU(d, row_t + x - 1); + const V ml = LoadU(d, row_m + x - 1); + const V bl = LoadU(d, row_b + x - 1); + const V conv = WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + Store(conv, d, row_out + x); + } + + private: + // Returns sum{x_i * w_i}. + template + static JXL_MAYBE_INLINE V WeightedSum(const V tl, const V tc, const V tr, + const V ml, const V mc, const V mr, + const V bl, const V bc, const V br, + const V w0, const V w1, const V w2) { + const V sum_tb = Add(tc, bc); + + // Faster than 5 mul + 4 FMA. + const V mul0 = Mul(mc, w0); + const V sum_lr = Add(ml, mr); + + const V x1 = Add(sum_tb, sum_lr); + const V mul1 = MulAdd(x1, w1, mul0); + + const V sum_t2 = Add(tl, tr); + const V sum_b2 = Add(bl, br); + const V x2 = Add(sum_t2, sum_b2); + const V mul2 = MulAdd(x2, w2, mul1); + return mul2; + } + + static JXL_MAYBE_INLINE V ConvolveValid(const float* JXL_RESTRICT row_t, + const float* JXL_RESTRICT row_m, + const float* JXL_RESTRICT row_b, + const int64_t x, const V w0, + const V w1, const V w2) { + const D d; + const V tc = LoadU(d, row_t + x); + const V mc = LoadU(d, row_m + x); + const V bc = LoadU(d, row_b + x); + const V tl = LoadU(d, row_t + x - 1); + const V tr = LoadU(d, row_t + x + 1); + const V ml = LoadU(d, row_m + x - 1); + const V mr = LoadU(d, row_m + x + 1); + const V bl = LoadU(d, row_b + x - 1); + const V br = LoadU(d, row_b + x + 1); + return WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + } +}; + +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSymmetric3(in, rect, weights, pool, out); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Symmetric3); +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Symmetric3)(in, rect, weights, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/convolve_symmetric5.cc b/media/libjxl/src/lib/jxl/convolve_symmetric5.cc new file mode 100644 index 000000000..55a16899c --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve_symmetric5.cc @@ -0,0 +1,185 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_symmetric5.cc" +#include +#include + +#include "lib/jxl/common.h" // RoundUpTo +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Vec; + +// Weighted sum of 1x5 pixels around ix, iy with [wx2 wx1 wx0 wx1 wx2]. +template +static float WeightedSumBorder(const ImageF& in, const WrapY wrap_y, + const int64_t ix, const int64_t iy, + const size_t xsize, const size_t ysize, + const float wx0, const float wx1, + const float wx2) { + const WrapMirror wrap_x; + const float* JXL_RESTRICT row = in.ConstRow(wrap_y(iy, ysize)); + const float in_m2 = row[wrap_x(ix - 2, xsize)]; + const float in_p2 = row[wrap_x(ix + 2, xsize)]; + const float in_m1 = row[wrap_x(ix - 1, xsize)]; + const float in_p1 = row[wrap_x(ix + 1, xsize)]; + const float in_00 = row[ix]; + const float sum_2 = wx2 * (in_m2 + in_p2); + const float sum_1 = wx1 * (in_m1 + in_p1); + const float sum_0 = wx0 * in_00; + return sum_2 + sum_1 + sum_0; +} + +template +static V WeightedSum(const ImageF& in, const WrapY wrap_y, const size_t ix, + const int64_t iy, const size_t ysize, const V wx0, + const V wx1, const V wx2) { + const HWY_FULL(float) d; + const float* JXL_RESTRICT center = in.ConstRow(wrap_y(iy, ysize)) + ix; + const auto in_m2 = LoadU(d, center - 2); + const auto in_p2 = LoadU(d, center + 2); + const auto in_m1 = LoadU(d, center - 1); + const auto in_p1 = LoadU(d, center + 1); + const auto in_00 = Load(d, center); + const auto sum_2 = Mul(wx2, Add(in_m2, in_p2)); + const auto sum_1 = Mul(wx1, Add(in_m1, in_p1)); + const auto sum_0 = Mul(wx0, in_00); + return Add(sum_2, Add(sum_1, sum_0)); +} + +// Produces result for one pixel +template +float Symmetric5Border(const ImageF& in, const Rect& rect, const int64_t ix, + const int64_t iy, const WeightsSymmetric5& weights) { + const float w0 = weights.c[0]; + const float w1 = weights.r[0]; + const float w2 = weights.R[0]; + const float w4 = weights.d[0]; + const float w5 = weights.L[0]; + const float w8 = weights.D[0]; + + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + const WrapY wrap_y; + // Unrolled loop over all 5 rows of the kernel. + float sum0 = WeightedSumBorder(in, wrap_y, ix, iy, xsize, ysize, w0, w1, w2); + + sum0 += WeightedSumBorder(in, wrap_y, ix, iy - 2, xsize, ysize, w2, w5, w8); + float sum1 = + WeightedSumBorder(in, wrap_y, ix, iy + 2, xsize, ysize, w2, w5, w8); + + sum0 += WeightedSumBorder(in, wrap_y, ix, iy - 1, xsize, ysize, w1, w4, w5); + sum1 += WeightedSumBorder(in, wrap_y, ix, iy + 1, xsize, ysize, w1, w4, w5); + + return sum0 + sum1; +} + +// Produces result for one vector's worth of pixels +template +static void Symmetric5Interior(const ImageF& in, const Rect& rect, + const int64_t ix, const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + const HWY_FULL(float) d; + + const auto w0 = LoadDup128(d, weights.c); + const auto w1 = LoadDup128(d, weights.r); + const auto w2 = LoadDup128(d, weights.R); + const auto w4 = LoadDup128(d, weights.d); + const auto w5 = LoadDup128(d, weights.L); + const auto w8 = LoadDup128(d, weights.D); + + const size_t ysize = rect.ysize(); + const WrapY wrap_y; + // Unrolled loop over all 5 rows of the kernel. + auto sum0 = WeightedSum(in, wrap_y, ix, iy, ysize, w0, w1, w2); + + sum0 = Add(sum0, WeightedSum(in, wrap_y, ix, iy - 2, ysize, w2, w5, w8)); + auto sum1 = WeightedSum(in, wrap_y, ix, iy + 2, ysize, w2, w5, w8); + + sum0 = Add(sum0, WeightedSum(in, wrap_y, ix, iy - 1, ysize, w1, w4, w5)); + sum1 = Add(sum1, WeightedSum(in, wrap_y, ix, iy + 1, ysize, w1, w4, w5)); + + Store(Add(sum0, sum1), d, row_out + ix); +} + +template +static void Symmetric5Row(const ImageF& in, const Rect& rect, const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + const int64_t kRadius = 2; + const size_t xsize = rect.xsize(); + + size_t ix = 0; + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const size_t aligned_x = RoundUpTo(kRadius, N); + for (; ix < std::min(aligned_x, xsize); ++ix) { + row_out[ix] = Symmetric5Border(in, rect, ix, iy, weights); + } + for (; ix + N + kRadius <= xsize; ix += N) { + Symmetric5Interior(in, rect, ix, iy, weights, row_out); + } + for (; ix < xsize; ++ix) { + row_out[ix] = Symmetric5Border(in, rect, ix, iy, weights); + } +} + +static JXL_NOINLINE void Symmetric5BorderRow(const ImageF& in, const Rect& rect, + const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + return Symmetric5Row(in, rect, iy, weights, row_out); +} + +// Semi-vectorized (interior pixels Fonly); called directly like slow::, unlike +// the fully vectorized strategies below. +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + + const size_t ysize = rect.ysize(); + JXL_CHECK(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t iy = task; + + if (iy < 2 || iy >= static_cast(ysize) - 2) { + Symmetric5BorderRow(in, rect, iy, weights, out->Row(iy)); + } else { + Symmetric5Row(in, rect, iy, weights, out->Row(iy)); + } + }, + "Symmetric5x5Convolution")); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Symmetric5); +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + return HWY_DYNAMIC_DISPATCH(Symmetric5)(in, rect, weights, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/convolve_test.cc b/media/libjxl/src/lib/jxl/convolve_test.cc new file mode 100644 index 000000000..2d75c31ea --- /dev/null +++ b/media/libjxl/src/lib/jxl/convolve_test.cc @@ -0,0 +1,250 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_test.cc" +#include +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +#ifndef JXL_DEBUG_CONVOLVE +#define JXL_DEBUG_CONVOLVE 0 +#endif + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +void TestNeighbors() { + const Neighbors::D d; + const Neighbors::V v = Iota(d, 0); + HWY_ALIGN float actual[hwy::kTestMaxVectorSize / sizeof(float)] = {0}; + + HWY_ALIGN float first_l1[hwy::kTestMaxVectorSize / sizeof(float)] = { + 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; + Store(Neighbors::FirstL1(v), d, actual); + const size_t N = Lanes(d); + EXPECT_EQ(std::vector(first_l1, first_l1 + N), + std::vector(actual, actual + N)); + +#if HWY_TARGET != HWY_SCALAR + HWY_ALIGN float first_l2[hwy::kTestMaxVectorSize / sizeof(float)] = { + 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}; + Store(Neighbors::FirstL2(v), d, actual); + EXPECT_EQ(std::vector(first_l2, first_l2 + N), + std::vector(actual, actual + N)); + + HWY_ALIGN float first_l3[hwy::kTestMaxVectorSize / sizeof(float)] = { + 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + Store(Neighbors::FirstL3(v), d, actual); + EXPECT_EQ(std::vector(first_l3, first_l3 + N), + std::vector(actual, actual + N)); +#endif // HWY_TARGET != HWY_SCALAR +} + +void VerifySymmetric3(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + const WeightsSymmetric3& weights = WeightsSymmetric3Lowpass(); + Symmetric3(in, rect, weights, pool, &out_expected); + SlowSymmetric3(in, rect, weights, pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +// Ensures Symmetric and Separable give the same result. +void VerifySymmetric5(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + Separable5(in, Rect(in), WeightsSeparable5Lowpass(), pool, &out_expected); + Symmetric5(in, rect, WeightsSymmetric5Lowpass(), pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +void VerifySeparable5(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + const WeightsSeparable5& weights = WeightsSeparable5Lowpass(); + Separable5(in, Rect(in), weights, pool, &out_expected); + SlowSeparable5(in, rect, weights, pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +void VerifySeparable7(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + // Gaussian sigma 1.0 + const WeightsSeparable7 weights = {{HWY_REP4(0.383103f), HWY_REP4(0.241843f), + HWY_REP4(0.060626f), HWY_REP4(0.00598f)}, + {HWY_REP4(0.383103f), HWY_REP4(0.241843f), + HWY_REP4(0.060626f), HWY_REP4(0.00598f)}}; + + SlowSeparable7(in, rect, weights, pool, &out_expected); + Separable7(in, Rect(in), weights, pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +// For all xsize/ysize and kernels: +void TestConvolve() { + TestNeighbors(); + + ThreadPoolInternal pool(4); + EXPECT_EQ(true, + RunOnPool( + &pool, kConvolveMaxRadius, 40, ThreadPool::NoInit, + [](const uint32_t task, size_t /*thread*/) { + const size_t xsize = task; + Rng rng(129 + 13 * xsize); + + ThreadPool* null_pool = nullptr; + ThreadPoolInternal pool3(3); + for (size_t ysize = kConvolveMaxRadius; ysize < 16; ++ysize) { + JXL_DEBUG(JXL_DEBUG_CONVOLVE, + "%" PRIuS " x %" PRIuS " (target %" PRIx64 + ")===============================", + xsize, ysize, static_cast(HWY_TARGET)); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym3------------------"); + VerifySymmetric3(xsize, ysize, null_pool, &rng); + VerifySymmetric3(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym5------------------"); + VerifySymmetric5(xsize, ysize, null_pool, &rng); + VerifySymmetric5(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sep5------------------"); + VerifySeparable5(xsize, ysize, null_pool, &rng); + VerifySeparable5(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sep7------------------"); + VerifySeparable7(xsize, ysize, null_pool, &rng); + VerifySeparable7(xsize, ysize, &pool3, &rng); + } + }, + "TestConvolve")); +} + +// Measures durations, verifies results, prints timings. `unpredictable1` +// must have value 1 (unknown to the compiler to prevent elision). +template +void BenchmarkConv(const char* caption, const Conv& conv, + const hwy::FuncInput unpredictable1) { + const size_t kNumInputs = 1; + const hwy::FuncInput inputs[kNumInputs] = {unpredictable1}; + hwy::Result results[kNumInputs]; + + const size_t kDim = 160; // in+out fit in L2 + ImageF in(kDim, kDim); + ZeroFillImage(&in); + in.Row(kDim / 2)[kDim / 2] = unpredictable1; + ImageF out(kDim, kDim); + + hwy::Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&in, &conv, &out](const hwy::FuncInput input) { + conv(in, &out); + return out.Row(input)[0]; + }, + inputs, kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + } + for (size_t i = 0; i < num_results; ++i) { + const double seconds = static_cast(results[i].ticks) / + hwy::platform::InvariantTicksPerSecond(); + printf("%12s: %7.2f MP/s (MAD=%4.2f%%)\n", caption, + kDim * kDim * 1E-6 / seconds, + static_cast(results[i].variability) * 100.0); + } +} + +struct ConvSymmetric3 { + void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const { + ThreadPool* null_pool = nullptr; + Symmetric3(in, Rect(in), WeightsSymmetric3Lowpass(), null_pool, out); + } +}; + +struct ConvSeparable5 { + void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const { + ThreadPool* null_pool = nullptr; + Separable5(in, Rect(in), WeightsSeparable5Lowpass(), null_pool, out); + } +}; + +void BenchmarkAll() { +#if 0 // disabled to avoid test timeouts, run manually on demand + const hwy::FuncInput unpredictable1 = time(nullptr) != 1234; + BenchmarkConv("Symmetric3", ConvSymmetric3(), unpredictable1); + BenchmarkConv("Separable5", ConvSeparable5(), unpredictable1); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class ConvolveTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(ConvolveTest); + +HWY_EXPORT_AND_TEST_P(ConvolveTest, TestConvolve); + +HWY_EXPORT_AND_TEST_P(ConvolveTest, BenchmarkAll); + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/data_parallel_test.cc b/media/libjxl/src/lib/jxl/data_parallel_test.cc new file mode 100644 index 000000000..dd6ea625f --- /dev/null +++ b/media/libjxl/src/lib/jxl/data_parallel_test.cc @@ -0,0 +1,89 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/data_parallel.h" + +#include "gtest/gtest.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +class DataParallelTest : public ::testing::Test { + protected: + // A fake class to verify that DataParallel is properly calling the + // client-provided runner functions. + static int FakeRunner(void* runner_opaque, void* jpegxl_opaque, + JxlParallelRunInit init, JxlParallelRunFunction func, + uint32_t start_range, uint32_t end_range) { + DataParallelTest* self = static_cast(runner_opaque); + self->runner_called_++; + self->jpegxl_opaque_ = jpegxl_opaque; + self->init_ = init; + self->func_ = func; + self->start_range_ = start_range; + self->end_range_ = end_range; + return self->runner_return_; + } + + ThreadPool pool_{&DataParallelTest::FakeRunner, this}; + + // Number of times FakeRunner() was called. + int runner_called_ = 0; + + // Parameters passed to FakeRunner. + void* jpegxl_opaque_ = nullptr; + JxlParallelRunInit init_ = nullptr; + JxlParallelRunFunction func_ = nullptr; + uint32_t start_range_ = -1; + uint32_t end_range_ = -1; + + // Return value that FakeRunner will return. + int runner_return_ = 0; +}; + +// JxlParallelRunInit interface. +typedef int (*JxlParallelRunInit)(); +int TestInit(void* jpegxl_opaque, size_t num_threads) { return 0; } + +} // namespace + +TEST_F(DataParallelTest, RunnerCalledParameters) { + EXPECT_TRUE(pool_.Run( + 1234, 5678, [](size_t /* num_threads */) { return true; }, + [](uint32_t /* task */, size_t /* thread */) { return; })); + EXPECT_EQ(1, runner_called_); + EXPECT_NE(nullptr, init_); + EXPECT_NE(nullptr, func_); + EXPECT_NE(nullptr, jpegxl_opaque_); + EXPECT_EQ(1234u, start_range_); + EXPECT_EQ(5678u, end_range_); +} + +TEST_F(DataParallelTest, RunnerFailurePropagates) { + runner_return_ = -1; // FakeRunner return value. + EXPECT_FALSE(pool_.Run( + 1234, 5678, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; })); + EXPECT_FALSE(RunOnPool( + nullptr, 1234, 5678, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; }, "Test")); +} + +TEST_F(DataParallelTest, RunnerNotCalledOnEmptyRange) { + runner_return_ = -1; // FakeRunner return value. + EXPECT_TRUE(pool_.Run( + 123, 123, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; })); + EXPECT_TRUE(RunOnPool( + nullptr, 123, 123, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; }, "Test")); + // We don't call the external runner when the range is empty. We don't even + // need to call the init function. + EXPECT_EQ(0, runner_called_); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dct-inl.h b/media/libjxl/src/lib/jxl/dct-inl.h new file mode 100644 index 000000000..532606075 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct-inl.h @@ -0,0 +1,334 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast SIMD floating-point (I)DCT, any power of two. + +#if defined(LIB_JXL_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DCT_INL_H_ +#undef LIB_JXL_DCT_INL_H_ +#else +#define LIB_JXL_DCT_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/dct_block-inl.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/transpose-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::NegMulAdd; +using hwy::HWY_NAMESPACE::Sub; + +template +struct FVImpl { + using type = HWY_CAPPED(float, SZ); +}; + +template <> +struct FVImpl<0> { + using type = HWY_FULL(float); +}; + +template +using FV = typename FVImpl::type; + +// Implementation of Lowest Complexity Self Recursive Radix-2 DCT II/III +// Algorithms, by Siriani M. Perera and Jianhua Liu. + +template +struct CoeffBundle { + static void AddReverse(const float* JXL_RESTRICT ain1, + const float* JXL_RESTRICT ain2, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N; i++) { + auto in1 = Load(FV(), ain1 + i * SZ); + auto in2 = Load(FV(), ain2 + (N - i - 1) * SZ); + Store(Add(in1, in2), FV(), aout + i * SZ); + } + } + static void SubReverse(const float* JXL_RESTRICT ain1, + const float* JXL_RESTRICT ain2, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N; i++) { + auto in1 = Load(FV(), ain1 + i * SZ); + auto in2 = Load(FV(), ain2 + (N - i - 1) * SZ); + Store(Sub(in1, in2), FV(), aout + i * SZ); + } + } + static void B(float* JXL_RESTRICT coeff) { + auto sqrt2 = Set(FV(), kSqrt2); + auto in1 = Load(FV(), coeff); + auto in2 = Load(FV(), coeff + SZ); + Store(MulAdd(in1, sqrt2, in2), FV(), coeff); + for (size_t i = 1; i + 1 < N; i++) { + auto in1 = Load(FV(), coeff + i * SZ); + auto in2 = Load(FV(), coeff + (i + 1) * SZ); + Store(Add(in1, in2), FV(), coeff + i * SZ); + } + } + static void BTranspose(float* JXL_RESTRICT coeff) { + for (size_t i = N - 1; i > 0; i--) { + auto in1 = Load(FV(), coeff + i * SZ); + auto in2 = Load(FV(), coeff + (i - 1) * SZ); + Store(Add(in1, in2), FV(), coeff + i * SZ); + } + auto sqrt2 = Set(FV(), kSqrt2); + auto in1 = Load(FV(), coeff); + Store(Mul(in1, sqrt2), FV(), coeff); + } + // Ideally optimized away by compiler (except the multiply). + static void InverseEvenOdd(const float* JXL_RESTRICT ain, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = Load(FV(), ain + i * SZ); + Store(in1, FV(), aout + 2 * i * SZ); + } + for (size_t i = N / 2; i < N; i++) { + auto in1 = Load(FV(), ain + i * SZ); + Store(in1, FV(), aout + (2 * (i - N / 2) + 1) * SZ); + } + } + // Ideally optimized away by compiler. + static void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = LoadU(FV(), ain + 2 * i * ain_stride); + Store(in1, FV(), aout + i * SZ); + } + for (size_t i = N / 2; i < N; i++) { + auto in1 = LoadU(FV(), ain + (2 * (i - N / 2) + 1) * ain_stride); + Store(in1, FV(), aout + i * SZ); + } + } + // Invoked on full vector. + static void Multiply(float* JXL_RESTRICT coeff) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = Load(FV(), coeff + (N / 2 + i) * SZ); + auto mul = Set(FV(), WcMultipliers::kMultipliers[i]); + Store(Mul(in1, mul), FV(), coeff + (N / 2 + i) * SZ); + } + } + static void MultiplyAndAdd(const float* JXL_RESTRICT coeff, + float* JXL_RESTRICT out, size_t out_stride) { + for (size_t i = 0; i < N / 2; i++) { + auto mul = Set(FV(), WcMultipliers::kMultipliers[i]); + auto in1 = Load(FV(), coeff + i * SZ); + auto in2 = Load(FV(), coeff + (N / 2 + i) * SZ); + auto out1 = MulAdd(mul, in2, in1); + auto out2 = NegMulAdd(mul, in2, in1); + StoreU(out1, FV(), out + i * out_stride); + StoreU(out2, FV(), out + (N - i - 1) * out_stride); + } + } + template + static void LoadFromBlock(const Block& in, size_t off, + float* JXL_RESTRICT coeff) { + for (size_t i = 0; i < N; i++) { + Store(in.LoadPart(FV(), i, off), FV(), coeff + i * SZ); + } + } + template + static void StoreToBlockAndScale(const float* JXL_RESTRICT coeff, + const Block& out, size_t off) { + auto mul = Set(FV(), 1.0f / N); + for (size_t i = 0; i < N; i++) { + out.StorePart(FV(), Mul(mul, Load(FV(), coeff + i * SZ)), i, off); + } + } +}; + +template +struct DCT1DImpl; + +template +struct DCT1DImpl<1, SZ> { + JXL_INLINE void operator()(float* JXL_RESTRICT mem) {} +}; + +template +struct DCT1DImpl<2, SZ> { + JXL_INLINE void operator()(float* JXL_RESTRICT mem) { + auto in1 = Load(FV(), mem); + auto in2 = Load(FV(), mem + SZ); + Store(Add(in1, in2), FV(), mem); + Store(Sub(in1, in2), FV(), mem + SZ); + } +}; + +template +struct DCT1DImpl { + void operator()(float* JXL_RESTRICT mem) { + // This is relatively small (4kB with 64-DCT and AVX-512) + HWY_ALIGN float tmp[N * SZ]; + CoeffBundle::AddReverse(mem, mem + N / 2 * SZ, tmp); + DCT1DImpl()(tmp); + CoeffBundle::SubReverse(mem, mem + N / 2 * SZ, tmp + N / 2 * SZ); + CoeffBundle::Multiply(tmp); + DCT1DImpl()(tmp + N / 2 * SZ); + CoeffBundle::B(tmp + N / 2 * SZ); + CoeffBundle::InverseEvenOdd(tmp, mem); + } +}; + +template +struct IDCT1DImpl; + +template +struct IDCT1DImpl<1, SZ> { + JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride) { + StoreU(LoadU(FV(), from), FV(), to); + } +}; + +template +struct IDCT1DImpl<2, SZ> { + JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride) { + JXL_DASSERT(from_stride >= SZ); + JXL_DASSERT(to_stride >= SZ); + auto in1 = LoadU(FV(), from); + auto in2 = LoadU(FV(), from + from_stride); + StoreU(Add(in1, in2), FV(), to); + StoreU(Sub(in1, in2), FV(), to + to_stride); + } +}; + +template +struct IDCT1DImpl { + void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride) { + JXL_DASSERT(from_stride >= SZ); + JXL_DASSERT(to_stride >= SZ); + // This is relatively small (4kB with 64-DCT and AVX-512) + HWY_ALIGN float tmp[N * SZ]; + CoeffBundle::ForwardEvenOdd(from, from_stride, tmp); + IDCT1DImpl()(tmp, SZ, tmp, SZ); + CoeffBundle::BTranspose(tmp + N / 2 * SZ); + IDCT1DImpl()(tmp + N / 2 * SZ, SZ, tmp + N / 2 * SZ, SZ); + CoeffBundle::MultiplyAndAdd(tmp, to, to_stride); + } +}; + +template +void DCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp) { + size_t M = M_or_0 != 0 ? M_or_0 : Mp; + constexpr size_t SZ = MaxLanes(FV()); + HWY_ALIGN float tmp[N * SZ]; + for (size_t i = 0; i < M; i += Lanes(FV())) { + // TODO(veluca): consider removing the temporary memory here (as is done in + // IDCT), if it turns out that some compilers don't optimize away the loads + // and this is performance-critical. + CoeffBundle::LoadFromBlock(from, i, tmp); + DCT1DImpl()(tmp); + CoeffBundle::StoreToBlockAndScale(tmp, to, i); + } +} + +template +void IDCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp) { + size_t M = M_or_0 != 0 ? M_or_0 : Mp; + constexpr size_t SZ = MaxLanes(FV()); + for (size_t i = 0; i < M; i += Lanes(FV())) { + IDCT1DImpl()(from.Address(0, i), from.Stride(), to.Address(0, i), + to.Stride()); + } +} + +template +struct DCT1D { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return DCT1DWrapper(from, to, M); + } +}; + +template +struct DCT1D MaxLanes(FV<0>()))>::type> { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return NoInlineWrapper(DCT1DWrapper, from, to, M); + } +}; + +template +struct IDCT1D { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return IDCT1DWrapper(from, to, M); + } +}; + +template +struct IDCT1D MaxLanes(FV<0>()))>::type> { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return NoInlineWrapper(IDCT1DWrapper, from, to, + M); + } +}; + +// Computes the maybe-transposed, scaled DCT of a block, that needs to be +// HWY_ALIGN'ed. +template +struct ComputeScaledDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // floats. + template + HWY_MAYBE_UNUSED void operator()(const From& from, float* to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + if (ROWS < COLS) { + DCT1D()(from, DCTTo(block, COLS)); + Transpose::Run(DCTFrom(block, COLS), DCTTo(to, ROWS)); + DCT1D()(DCTFrom(to, ROWS), DCTTo(block, ROWS)); + Transpose::Run(DCTFrom(block, ROWS), DCTTo(to, COLS)); + } else { + DCT1D()(from, DCTTo(to, COLS)); + Transpose::Run(DCTFrom(to, COLS), DCTTo(block, ROWS)); + DCT1D()(DCTFrom(block, ROWS), DCTTo(to, ROWS)); + } + } +}; +// Computes the maybe-transposed, scaled IDCT of a block, that needs to be +// HWY_ALIGN'ed. +template +struct ComputeScaledIDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // floats. + template + HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + // Reverse the steps done in ComputeScaledDCT. + if (ROWS < COLS) { + Transpose::Run(DCTFrom(from, COLS), DCTTo(block, ROWS)); + IDCT1D()(DCTFrom(block, ROWS), DCTTo(from, ROWS)); + Transpose::Run(DCTFrom(from, ROWS), DCTTo(block, COLS)); + IDCT1D()(DCTFrom(block, COLS), to); + } else { + IDCT1D()(DCTFrom(from, ROWS), DCTTo(block, ROWS)); + Transpose::Run(DCTFrom(block, ROWS), DCTTo(from, COLS)); + IDCT1D()(DCTFrom(from, COLS), to); + } + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); +#endif // LIB_JXL_DCT_INL_H_ diff --git a/media/libjxl/src/lib/jxl/dct_block-inl.h b/media/libjxl/src/lib/jxl/dct_block-inl.h new file mode 100644 index 000000000..50646a737 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct_block-inl.h @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Adapters for DCT input/output: from/to contiguous blocks or image rows. + +#if defined(LIB_JXL_DCT_BLOCK_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DCT_BLOCK_INL_H_ +#undef LIB_JXL_DCT_BLOCK_INL_H_ +#else +#define LIB_JXL_DCT_BLOCK_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/base/status.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Vec; + +// Block: (x, y) <-> (N * y + x) +// Lines: (x, y) <-> (stride * y + x) +// +// I.e. Block is a specialization of Lines with fixed stride. +// +// FromXXX should implement Read and Load (Read vector). +// ToXXX should implement Write and Store (Write vector). + +template +using BlockDesc = HWY_CAPPED(float, N); + +// Here and in the following, the SZ template parameter specifies the number of +// values to load/store. Needed because we want to handle 4x4 sub-blocks of +// 16x16 blocks. +class DCTFrom { + public: + DCTFrom(const float* data, size_t stride) : stride_(stride), data_(data) {} + + template + HWY_INLINE Vec LoadPart(D, const size_t row, size_t i) const { + JXL_DASSERT(Lanes(D()) <= stride_); + // Since these functions are used also for DC, no alignment at all is + // guaranteed in the case of floating blocks. + // TODO(veluca): consider using a different class for DC-to-LF and + // DC-from-LF, or copying DC values to/from a temporary aligned location. + return LoadU(D(), Address(row, i)); + } + + HWY_INLINE float Read(const size_t row, const size_t i) const { + return *Address(row, i); + } + + constexpr HWY_INLINE const float* Address(const size_t row, + const size_t i) const { + return data_ + row * stride_ + i; + } + + size_t Stride() const { return stride_; } + + private: + size_t stride_; + const float* JXL_RESTRICT data_; +}; + +class DCTTo { + public: + DCTTo(float* data, size_t stride) : stride_(stride), data_(data) {} + + template + HWY_INLINE void StorePart(D, const Vec& v, const size_t row, + size_t i) const { + JXL_DASSERT(Lanes(D()) <= stride_); + // Since these functions are used also for DC, no alignment at all is + // guaranteed in the case of floating blocks. + // TODO(veluca): consider using a different class for DC-to-LF and + // DC-from-LF, or copying DC values to/from a temporary aligned location. + StoreU(v, D(), Address(row, i)); + } + + HWY_INLINE void Write(float v, const size_t row, const size_t i) const { + *Address(row, i) = v; + } + + constexpr HWY_INLINE float* Address(const size_t row, const size_t i) const { + return data_ + row * stride_ + i; + } + + size_t Stride() const { return stride_; } + + private: + size_t stride_; + float* JXL_RESTRICT data_; +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DCT_BLOCK_INL_H_ diff --git a/media/libjxl/src/lib/jxl/dct_for_test.h b/media/libjxl/src/lib/jxl/dct_for_test.h new file mode 100644 index 000000000..8e32aa7ef --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct_for_test.h @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DCT_FOR_TEST_H_ +#define LIB_JXL_DCT_FOR_TEST_H_ + +// Unoptimized DCT only for use in tests. + +#include // memcpy + +#include +#include + +#include "lib/jxl/common.h" // Pi + +namespace jxl { + +namespace test { +static inline double alpha(int u) { return u == 0 ? 0.7071067811865475 : 1.0; } + +// N-DCT on M columns, divided by sqrt(N). Matches the definition in the spec. +template +void DCT1D(double block[N * M], double out[N * M]) { + std::vector matrix(N * N); + const double scale = std::sqrt(2.0) / N; + for (size_t y = 0; y < N; y++) { + for (size_t u = 0; u < N; u++) { + matrix[N * u + y] = alpha(u) * cos((y + 0.5) * u * Pi(1.0 / N)) * scale; + } + } + for (size_t x = 0; x < M; x++) { + for (size_t u = 0; u < N; u++) { + out[M * u + x] = 0; + for (size_t y = 0; y < N; y++) { + out[M * u + x] += matrix[N * u + y] * block[M * y + x]; + } + } + } +} + +// N-IDCT on M columns, multiplied by sqrt(N). Matches the definition in the +// spec. +template +void IDCT1D(double block[N * M], double out[N * M]) { + std::vector matrix(N * N); + const double scale = std::sqrt(2.0); + for (size_t y = 0; y < N; y++) { + for (size_t u = 0; u < N; u++) { + // Transpose of DCT matrix. + matrix[N * y + u] = alpha(u) * cos((y + 0.5) * u * Pi(1.0 / N)) * scale; + } + } + for (size_t x = 0; x < M; x++) { + for (size_t u = 0; u < N; u++) { + out[M * u + x] = 0; + for (size_t y = 0; y < N; y++) { + out[M * u + x] += matrix[N * u + y] * block[M * y + x]; + } + } + } +} + +template +void TransposeBlock(double in[N * M], double out[M * N]) { + for (size_t x = 0; x < N; x++) { + for (size_t y = 0; y < M; y++) { + out[y * N + x] = in[x * M + y]; + } + } +} +} // namespace test + +// Untransposed DCT. +template +void DCTSlow(double block[N * N]) { + constexpr size_t kBlockSize = N * N; + std::vector g(kBlockSize); + test::DCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); + test::DCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); +} + +// Untransposed IDCT. +template +void IDCTSlow(double block[N * N]) { + constexpr size_t kBlockSize = N * N; + std::vector g(kBlockSize); + test::IDCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); + test::IDCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); +} + +} // namespace jxl + +#endif // LIB_JXL_DCT_FOR_TEST_H_ diff --git a/media/libjxl/src/lib/jxl/dct_scales.cc b/media/libjxl/src/lib/jxl/dct_scales.cc new file mode 100644 index 000000000..f9e89a601 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct_scales.cc @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dct_scales.h" + +namespace jxl { + +// Definition of constexpr arrays. +constexpr float DCTResampleScales<1, 8>::kScales[]; +constexpr float DCTResampleScales<2, 16>::kScales[]; +constexpr float DCTResampleScales<4, 32>::kScales[]; +constexpr float DCTResampleScales<8, 64>::kScales[]; +constexpr float DCTResampleScales<16, 128>::kScales[]; +constexpr float DCTResampleScales<32, 256>::kScales[]; +constexpr float DCTResampleScales<8, 1>::kScales[]; +constexpr float DCTResampleScales<16, 2>::kScales[]; +constexpr float DCTResampleScales<32, 4>::kScales[]; +constexpr float DCTResampleScales<64, 8>::kScales[]; +constexpr float DCTResampleScales<128, 16>::kScales[]; +constexpr float DCTResampleScales<256, 32>::kScales[]; +constexpr float WcMultipliers<4>::kMultipliers[]; +constexpr float WcMultipliers<8>::kMultipliers[]; +constexpr float WcMultipliers<16>::kMultipliers[]; +constexpr float WcMultipliers<32>::kMultipliers[]; +constexpr float WcMultipliers<64>::kMultipliers[]; +constexpr float WcMultipliers<128>::kMultipliers[]; +constexpr float WcMultipliers<256>::kMultipliers[]; + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dct_scales.h b/media/libjxl/src/lib/jxl/dct_scales.h new file mode 100644 index 000000000..23af03d60 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct_scales.h @@ -0,0 +1,379 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DCT_SCALES_H_ +#define LIB_JXL_DCT_SCALES_H_ + +// Scaling factors. + +#include + +namespace jxl { + +static constexpr float kSqrt2 = 1.41421356237f; +static constexpr float kSqrt0_5 = 0.70710678118f; + +// For n != 0, the n-th basis function of a N-DCT, evaluated in pixel k, has a +// value of cos((k+1/2) n/(2N) pi). When downsampling by 2x, we average +// the values for pixel k and k+1 to get the value for pixel (k/2), thus we get +// +// [cos((k+1/2) n/N pi) + cos((k+3/2) n/N pi)]/2 = +// cos(n/(2N) pi) cos((k+1) n/N pi) = +// cos(n/(2N) pi) cos(((k/2)+1/2) n/(N/2) pi) +// +// which is exactly the same as the value of pixel k/2 of a N/2-sized DCT, +// except for the cos(n/(2N) pi) scaling factor (which does *not* +// depend on the pixel). Thus, when using the lower-frequency coefficients of a +// DCT-N to compute a DCT-(N/2), they should be scaled by this constant. Scaling +// factors for a DCT-(N/4) etc can then be obtained by successive +// multiplications. The structs below contain the above-mentioned scaling +// factors. +// +// Python code for the tables below: +// +// for i in range(N // 8): +// v = math.cos(i / (2 * N) * math.pi) +// v *= math.cos(i / (N) * math.pi) +// v *= math.cos(i / (N / 2) * math.pi) +// print(v, end=", ") + +template +struct DCTResampleScales; + +template <> +struct DCTResampleScales<8, 1> { + static constexpr float kScales[] = { + 1.000000000000000000, + }; +}; + +template <> +struct DCTResampleScales<16, 2> { + static constexpr float kScales[] = { + 1.000000000000000000, + 0.901764195028874394, + }; +}; + +template <> +struct DCTResampleScales<32, 4> { + static constexpr float kScales[] = { + 1.000000000000000000, + 0.974886821136879522, + 0.901764195028874394, + 0.787054918159101335, + }; +}; + +template <> +struct DCTResampleScales<64, 8> { + static constexpr float kScales[] = { + 1.0000000000000000, 0.9936866130906366, 0.9748868211368796, + 0.9440180941651672, 0.9017641950288744, 0.8490574973847023, + 0.7870549181591013, 0.7171081282466044, + }; +}; + +template <> +struct DCTResampleScales<128, 16> { + static constexpr float kScales[] = { + 1.0, + 0.9984194528776054, + 0.9936866130906366, + 0.9858278282666936, + 0.9748868211368796, + 0.9609244059440204, + 0.9440180941651672, + 0.9242615922757944, + 0.9017641950288744, + 0.8766500784429904, + 0.8490574973847023, + 0.8191378932865928, + 0.7870549181591013, + 0.7529833816270532, + 0.7171081282466044, + 0.6796228528314651, + }; +}; + +template <> +struct DCTResampleScales<256, 32> { + static constexpr float kScales[] = { + 1.0, + 0.9996047255830407, + 0.9984194528776054, + 0.9964458326264695, + 0.9936866130906366, + 0.9901456355893141, + 0.9858278282666936, + 0.9807391980963174, + 0.9748868211368796, + 0.9682788310563117, + 0.9609244059440204, + 0.9528337534340876, + 0.9440180941651672, + 0.9344896436056892, + 0.9242615922757944, + 0.913348084400198, + 0.9017641950288744, + 0.8895259056651056, + 0.8766500784429904, + 0.8631544288990163, + 0.8490574973847023, + 0.8343786191696513, + 0.8191378932865928, + 0.8033561501721485, + 0.7870549181591013, + 0.7702563888779096, + 0.7529833816270532, + 0.7352593067735488, + 0.7171081282466044, + 0.6985543251889097, + 0.6796228528314651, + 0.6603391026591464, + }; +}; + +// Inverses of the above. +template <> +struct DCTResampleScales<1, 8> { + static constexpr float kScales[] = { + 1.000000000000000000, + }; +}; + +template <> +struct DCTResampleScales<2, 16> { + static constexpr float kScales[] = { + 1.000000000000000000, + 1.108937353592731823, + }; +}; + +template <> +struct DCTResampleScales<4, 32> { + static constexpr float kScales[] = { + 1.000000000000000000, + 1.025760096781116015, + 1.108937353592731823, + 1.270559368765487251, + }; +}; + +template <> +struct DCTResampleScales<8, 64> { + static constexpr float kScales[] = { + 1.0000000000000000, 1.0063534990068217, 1.0257600967811158, + 1.0593017296817173, 1.1089373535927318, 1.1777765381970435, + 1.2705593687654873, 1.3944898413647777, + }; +}; + +template <> +struct DCTResampleScales<16, 128> { + static constexpr float kScales[] = { + 1.0, + 1.0015830492062623, + 1.0063534990068217, + 1.0143759095928793, + 1.0257600967811158, + 1.0406645869480142, + 1.0593017296817173, + 1.0819447744633812, + 1.1089373535927318, + 1.1407059950032632, + 1.1777765381970435, + 1.2207956782315876, + 1.2705593687654873, + 1.3280505578213306, + 1.3944898413647777, + 1.4714043176061107, + }; +}; + +template <> +struct DCTResampleScales<32, 256> { + static constexpr float kScales[] = { + 1.0, + 1.0003954307206069, + 1.0015830492062623, + 1.0035668445360069, + 1.0063534990068217, + 1.009952439375063, + 1.0143759095928793, + 1.0196390660647288, + 1.0257600967811158, + 1.0327603660498115, + 1.0406645869480142, + 1.049501024072585, + 1.0593017296817173, + 1.0701028169146336, + 1.0819447744633812, + 1.0948728278734026, + 1.1089373535927318, + 1.124194353004584, + 1.1407059950032632, + 1.158541237256391, + 1.1777765381970435, + 1.1984966740820495, + 1.2207956782315876, + 1.244777922949508, + 1.2705593687654873, + 1.2982690107339132, + 1.3280505578213306, + 1.3600643892400104, + 1.3944898413647777, + 1.4315278911623237, + 1.4714043176061107, + 1.5143734423314616, + }; +}; + +// Constants for DCT implementation. Generated by the following snippet: +// for i in range(N // 2): +// print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ") +template +struct WcMultipliers; + +template <> +struct WcMultipliers<4> { + static constexpr float kMultipliers[] = { + 0.541196100146197, + 1.3065629648763764, + }; +}; + +template <> +struct WcMultipliers<8> { + static constexpr float kMultipliers[] = { + 0.5097955791041592, + 0.6013448869350453, + 0.8999762231364156, + 2.5629154477415055, + }; +}; + +template <> +struct WcMultipliers<16> { + static constexpr float kMultipliers[] = { + 0.5024192861881557, 0.5224986149396889, 0.5669440348163577, + 0.6468217833599901, 0.7881546234512502, 1.060677685990347, + 1.7224470982383342, 5.101148618689155, + }; +}; + +template <> +struct WcMultipliers<32> { + static constexpr float kMultipliers[] = { + 0.5006029982351963, 0.5054709598975436, 0.5154473099226246, + 0.5310425910897841, 0.5531038960344445, 0.5829349682061339, + 0.6225041230356648, 0.6748083414550057, 0.7445362710022986, + 0.8393496454155268, 0.9725682378619608, 1.1694399334328847, + 1.4841646163141662, 2.057781009953411, 3.407608418468719, + 10.190008123548033, + }; +}; +template <> +struct WcMultipliers<64> { + static constexpr float kMultipliers[] = { + 0.500150636020651, 0.5013584524464084, 0.5037887256810443, + 0.5074711720725553, 0.5124514794082247, 0.5187927131053328, + 0.52657731515427, 0.535909816907992, 0.5469204379855088, + 0.5597698129470802, 0.57465518403266, 0.5918185358574165, + 0.6115573478825099, 0.6342389366884031, 0.6603198078137061, + 0.6903721282002123, 0.7251205223771985, 0.7654941649730891, + 0.8127020908144905, 0.8683447152233481, 0.9345835970364075, + 1.0144082649970547, 1.1120716205797176, 1.233832737976571, + 1.3892939586328277, 1.5939722833856311, 1.8746759800084078, + 2.282050068005162, 2.924628428158216, 4.084611078129248, + 6.796750711673633, 20.373878167231453, + }; +}; +template <> +struct WcMultipliers<128> { + static constexpr float kMultipliers[] = { + 0.5000376519155477, 0.5003390374428216, 0.5009427176380873, + 0.5018505174842379, 0.5030651913013697, 0.5045904432216454, + 0.5064309549285542, 0.5085924210498143, 0.5110815927066812, + 0.5139063298475396, 0.5170756631334912, 0.5205998663018917, + 0.524490540114724, 0.5287607092074876, 0.5334249333971333, + 0.538499435291984, 0.5440022463817783, 0.549953374183236, + 0.5563749934898856, 0.5632916653417023, 0.5707305880121454, + 0.5787218851348208, 0.5872989370937893, 0.5964987630244563, + 0.606362462272146, 0.6169357260050706, 0.6282694319707711, + 0.6404203382416639, 0.6534518953751283, 0.6674352009263413, + 0.6824501259764195, 0.6985866506472291, 0.7159464549705746, + 0.7346448236478627, 0.7548129391165311, 0.776600658233963, + 0.8001798956216941, 0.8257487738627852, 0.8535367510066064, + 0.8838110045596234, 0.9168844461846523, 0.9531258743921193, + 0.9929729612675466, 1.036949040910389, 1.0856850642580145, + 1.1399486751015042, 1.2006832557294167, 1.2690611716991191, + 1.346557628206286, 1.4350550884414341, 1.5369941008524954, + 1.6555965242641195, 1.7952052190778898, 1.961817848571166, + 2.163957818751979, 2.4141600002500763, 2.7316450287739396, + 3.147462191781909, 3.7152427383269746, 4.5362909369693565, + 5.827688377844654, 8.153848602466814, 13.58429025728446, + 40.744688103351834, + }; +}; + +template <> +struct WcMultipliers<256> { + static constexpr float kMultipliers[128] = { + 0.5000094125358878, 0.500084723455784, 0.5002354020255269, + 0.5004615618093246, 0.5007633734146156, 0.5011410648064231, + 0.5015949217281668, 0.502125288230386, 0.5027325673091954, + 0.5034172216566842, 0.5041797745258774, 0.5050208107132756, + 0.5059409776624396, 0.5069409866925212, 0.5080216143561264, + 0.509183703931388, 0.5104281670536573, 0.5117559854927805, + 0.5131682130825206, 0.5146659778093218, 0.516250484068288, + 0.5179230150949777, 0.5196849355823947, 0.5215376944933958, + 0.5234828280796439, 0.52552196311921, 0.5276568203859896, + 0.5298892183652453, 0.5322210772308335, 0.5346544231010253, + 0.537191392591309, 0.5398342376841637, 0.5425853309375497, + 0.545447171055775, 0.5484223888484947, 0.551513753605893, + 0.554724179920619, 0.5580567349898085, 0.5615146464335654, + 0.5651013106696203, 0.5688203018875696, 0.5726753816701664, + 0.5766705093136241, 0.5808098529038624, 0.5850978012111273, + 0.58953897647151, 0.5941382481306648, 0.5989007476325463, + 0.6038318843443582, 0.6089373627182432, 0.614223200800649, + 0.6196957502119484, 0.6253617177319102, 0.6312281886412079, + 0.6373026519855411, 0.6435930279473415, 0.6501076975307724, + 0.6568555347890955, 0.6638459418498757, 0.6710888870233562, + 0.6785949463131795, 0.6863753486870501, 0.6944420255086364, + 0.7028076645818034, 0.7114857693151208, 0.7204907235796304, + 0.7298378629074134, 0.7395435527641373, 0.749625274727372, + 0.7601017215162176, 0.7709929019493761, 0.7823202570613161, + 0.7941067887834509, 0.8063772028037925, 0.8191580674598145, + 0.83247799080191, 0.8463678182968619, 0.860860854031955, + 0.8759931087426972, 0.8918035785352535, 0.9083345588266809, + 0.9256319988042384, 0.9437459026371479, 0.962730784794803, + 0.9826461881778968, 1.0035572754078206, 1.0255355056139732, + 1.048659411496106, 1.0730154944316674, 1.0986992590905857, + 1.1258164135986009, 1.1544842669978943, 1.184833362908442, + 1.217009397314603, 1.2511754798461228, 1.287514812536712, + 1.326233878832723, 1.3675662599582539, 1.411777227500661, + 1.459169302866857, 1.5100890297227016, 1.5649352798258847, + 1.6241695131835794, 1.6883285509131505, 1.7580406092704062, + 1.8340456094306077, 1.9172211551275689, 2.0086161135167564, + 2.1094945286246385, 2.22139377701127, 2.346202662531156, + 2.486267909203593, 2.644541877144861, 2.824791402350551, + 3.0318994541759925, 3.2723115884254845, 3.5547153325075804, + 3.891107790700307, 4.298537526449054, 4.802076008665048, + 5.440166215091329, 6.274908408039339, 7.413566756422303, + 9.058751453879703, 11.644627325175037, 16.300023088031555, + 27.163977662448232, 81.48784219222516, + }; +}; + +// Apply the DCT algorithm-intrinsic constants to DCTResampleScale. +template +constexpr float DCTTotalResampleScale(size_t x) { + return DCTResampleScales::kScales[x]; +} + +} // namespace jxl + +#endif // LIB_JXL_DCT_SCALES_H_ diff --git a/media/libjxl/src/lib/jxl/dct_test.cc b/media/libjxl/src/lib/jxl/dct_test.cc new file mode 100644 index 000000000..8c9bc27fc --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct_test.cc @@ -0,0 +1,390 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dct_test.cc" +#include +#include +#include + +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_for_test.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/image.h" +#include "lib/jxl/test_utils.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// Computes the in-place NxN DCT of block. +// Requires that block is HWY_ALIGN'ed. +// +// Performs ComputeTransposedScaledDCT and then transposes and scales it to +// obtain "vanilla" DCT. +template +void ComputeDCT(float block[N * N]) { + HWY_ALIGN float tmp_block[N * N]; + HWY_ALIGN float scratch_space[N * N]; + ComputeScaledDCT()(DCTFrom(block, N), tmp_block, scratch_space); + + // Untranspose. + Transpose::Run(DCTFrom(tmp_block, N), DCTTo(block, N)); +} + +// Computes the in-place 8x8 iDCT of block. +// Requires that block is HWY_ALIGN'ed. +template +void ComputeIDCT(float block[N * N]) { + HWY_ALIGN float tmp_block[N * N]; + HWY_ALIGN float scratch_space[N * N]; + // Untranspose. + Transpose::Run(DCTFrom(block, N), DCTTo(tmp_block, N)); + + ComputeScaledIDCT()(tmp_block, DCTTo(block, N), scratch_space); +} + +template +void TransposeTestT(float accuracy) { + constexpr size_t kBlockSize = N * N; + HWY_ALIGN float src[kBlockSize]; + DCTTo to_src(src, N); + for (size_t y = 0; y < N; ++y) { + for (size_t x = 0; x < N; ++x) { + to_src.Write(y * N + x, y, x); + } + } + HWY_ALIGN float dst[kBlockSize]; + Transpose::Run(DCTFrom(src, N), DCTTo(dst, N)); + DCTFrom from_dst(dst, N); + for (size_t y = 0; y < N; ++y) { + for (size_t x = 0; x < N; ++x) { + float expected = x * N + y; + float actual = from_dst.Read(y, x); + EXPECT_NEAR(expected, actual, accuracy) << "x = " << x << ", y = " << y; + } + } +} + +void TransposeTest() { + TransposeTestT<8>(1e-7f); + TransposeTestT<16>(1e-7f); + TransposeTestT<32>(1e-7f); +} + +template +void ColumnDctRoundtripT(float accuracy) { + constexpr size_t kBlockSize = N * N; + // Though we are only interested in single column result, dct.h has built-in + // limit on minimal number of columns processed. So, to be safe, we do + // regular 8x8 block transformation. On the bright side - we could check all + // 8 basis vectors at once. + HWY_ALIGN float block[kBlockSize]; + DCTTo to(block, N); + DCTFrom from(block, N); + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + to.Write((i == j) ? 1.0f : 0.0f, i, j); + } + } + + // Running (I)DCT on the same memory block seems to trigger a compiler bug on + // ARMv7 with clang6. + HWY_ALIGN float tmp[kBlockSize]; + DCTTo to_tmp(tmp, N); + DCTFrom from_tmp(tmp, N); + + DCT1D()(from, to_tmp); + IDCT1D()(from_tmp, to); + + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + float expected = (i == j) ? 1.0f : 0.0f; + float actual = from.Read(i, j); + EXPECT_NEAR(expected, actual, accuracy) << " i=" << i << ", j=" << j; + } + } +} + +void ColumnDctRoundtrip() { + ColumnDctRoundtripT<8>(1e-6f); + ColumnDctRoundtripT<16>(1e-6f); + ColumnDctRoundtripT<32>(1e-6f); +} + +template +void TestDctAccuracy(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + HWY_ALIGN float fast[kBlockSize] = {0.0f}; + double slow[kBlockSize] = {0.0}; + fast[i] = 1.0; + slow[i] = 1.0; + DCTSlow(slow); + ComputeDCT(fast); + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(fast[k], slow[k], accuracy / N) + << "i = " << i << ", k = " << k << ", N = " << N; + } + } +} + +template +void TestIdctAccuracy(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + HWY_ALIGN float fast[kBlockSize] = {0.0f}; + double slow[kBlockSize] = {0.0}; + fast[i] = 1.0; + slow[i] = 1.0; + IDCTSlow(slow); + ComputeIDCT(fast); + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(fast[k], slow[k], accuracy * N) + << "i = " << i << ", k = " << k << ", N = " << N; + } + } +} + +template +void TestInverseT(float accuracy) { + ThreadPoolInternal pool(N < 32 ? 0 : 8); + enum { kBlockSize = N * N }; + EXPECT_TRUE(RunOnPool( + &pool, 0, kBlockSize, ThreadPool::NoInit, + [accuracy](const uint32_t task, size_t /*thread*/) { + const size_t i = static_cast(task); + HWY_ALIGN float x[kBlockSize] = {0.0f}; + x[i] = 1.0; + + ComputeIDCT(x); + ComputeDCT(x); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(x[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k; + } + }, + "TestInverse")); +} + +void InverseTest() { + TestInverseT<8>(1e-6f); + TestInverseT<16>(1e-6f); + TestInverseT<32>(3e-6f); +} + +template +void TestDctTranspose(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + for (size_t j = 0; j < kBlockSize; ++j) { + // We check that = . + // That means (Me_j)_i = (M^\dagger{}e_i)_j + + // x := Me_j + HWY_ALIGN float x[kBlockSize] = {0.0f}; + x[j] = 1.0; + ComputeIDCT(x); + // y := M^\dagger{}e_i + HWY_ALIGN float y[kBlockSize] = {0.0f}; + y[i] = 1.0; + ComputeDCT(y); + + EXPECT_NEAR(x[i] / N, y[j] * N, accuracy) << "i = " << i << ", j = " << j; + } + } +} + +template +void TestSlowInverse(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + double x[kBlockSize] = {0.0f}; + x[i] = 1.0; + + DCTSlow(x); + IDCTSlow(x); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(x[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k; + } + } +} + +template +void TestRectInverseT(float accuracy) { + constexpr size_t kBlockSize = ROWS * COLS; + for (size_t i = 0; i < kBlockSize; ++i) { + HWY_ALIGN float x[kBlockSize] = {0.0f}; + HWY_ALIGN float out[kBlockSize] = {0.0f}; + x[i] = 1.0; + HWY_ALIGN float coeffs[kBlockSize] = {0.0f}; + HWY_ALIGN float scratch_space[kBlockSize * 2]; + + ComputeScaledDCT()(DCTFrom(x, COLS), coeffs, scratch_space); + ComputeScaledIDCT()(coeffs, DCTTo(out, COLS), scratch_space); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(out[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k << " ROWS = " << ROWS + << " COLS = " << COLS; + } + } +} + +void TestRectInverse() { + TestRectInverseT<16, 32>(1e-6f); + TestRectInverseT<8, 32>(1e-6f); + TestRectInverseT<8, 16>(1e-6f); + TestRectInverseT<4, 8>(1e-6f); + TestRectInverseT<2, 4>(1e-6f); + TestRectInverseT<1, 4>(1e-6f); + TestRectInverseT<1, 2>(1e-6f); + + TestRectInverseT<32, 16>(1e-6f); + TestRectInverseT<32, 8>(1e-6f); + TestRectInverseT<16, 8>(1e-6f); + TestRectInverseT<8, 4>(1e-6f); + TestRectInverseT<4, 2>(1e-6f); + TestRectInverseT<4, 1>(1e-6f); + TestRectInverseT<2, 1>(1e-6f); +} + +template +void TestRectTransposeT(float accuracy) { + constexpr size_t kBlockSize = ROWS * COLS; + HWY_ALIGN float scratch_space[kBlockSize * 2]; + for (size_t px = 0; px < COLS; ++px) { + for (size_t py = 0; py < ROWS; ++py) { + HWY_ALIGN float x1[kBlockSize] = {0.0f}; + HWY_ALIGN float x2[kBlockSize] = {0.0f}; + HWY_ALIGN float coeffs1[kBlockSize] = {0.0f}; + HWY_ALIGN float coeffs2[kBlockSize] = {0.0f}; + x1[py * COLS + px] = 1; + x2[px * ROWS + py] = 1; + + constexpr size_t OUT_ROWS = ROWS < COLS ? ROWS : COLS; + constexpr size_t OUT_COLS = ROWS < COLS ? COLS : ROWS; + + ComputeScaledDCT()(DCTFrom(x1, COLS), coeffs1, scratch_space); + ComputeScaledDCT()(DCTFrom(x2, ROWS), coeffs2, scratch_space); + + for (size_t x = 0; x < OUT_COLS; ++x) { + for (size_t y = 0; y < OUT_ROWS; ++y) { + EXPECT_NEAR(coeffs1[y * OUT_COLS + x], coeffs2[y * OUT_COLS + x], + accuracy) + << " px = " << px << ", py = " << py << ", x = " << x + << ", y = " << y; + } + } + } + } +} + +void TestRectTranspose() { + TestRectTransposeT<16, 32>(1e-6f); + TestRectTransposeT<8, 32>(1e-6f); + TestRectTransposeT<8, 16>(1e-6f); + TestRectTransposeT<4, 8>(1e-6f); + TestRectTransposeT<2, 4>(1e-6f); + TestRectTransposeT<1, 4>(1e-6f); + TestRectTransposeT<1, 2>(1e-6f); + + // Identical to 8, 16 + // TestRectTranspose<16, 8>(1e-6f); +} + +void TestDctAccuracyShard(size_t shard) { + if (shard == 0) { + TestDctAccuracy<1>(1.1E-7f); + TestDctAccuracy<2>(1.1E-7f); + TestDctAccuracy<4>(1.1E-7f); + TestDctAccuracy<8>(1.1E-7f); + TestDctAccuracy<16>(1.3E-7f); + } + TestDctAccuracy<32>(1.1E-7f, 32 * shard, 32 * (shard + 1)); +} + +void TestIdctAccuracyShard(size_t shard) { + if (shard == 0) { + TestIdctAccuracy<1>(1E-7f); + TestIdctAccuracy<2>(1E-7f); + TestIdctAccuracy<4>(1E-7f); + TestIdctAccuracy<8>(1E-7f); + TestIdctAccuracy<16>(1E-7f); + } + TestIdctAccuracy<32>(1E-7f, 32 * shard, 32 * (shard + 1)); +} + +void TestDctTransposeShard(size_t shard) { + if (shard == 0) { + TestDctTranspose<8>(1E-6f); + TestDctTranspose<16>(1E-6f); + } + TestDctTranspose<32>(3E-6f, 32 * shard, 32 * (shard + 1)); +} + +void TestSlowInverseShard(size_t shard) { + if (shard == 0) { + TestSlowInverse<1>(1E-5f); + TestSlowInverse<2>(1E-5f); + TestSlowInverse<4>(1E-5f); + TestSlowInverse<8>(1E-5f); + TestSlowInverse<16>(1E-5f); + } + TestSlowInverse<32>(1E-5f, 32 * shard, 32 * (shard + 1)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class TransposeTest : public hwy::TestWithParamTarget {}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(TransposeTest); + +HWY_EXPORT_AND_TEST_P(TransposeTest, TransposeTest); +HWY_EXPORT_AND_TEST_P(TransposeTest, InverseTest); +HWY_EXPORT_AND_TEST_P(TransposeTest, ColumnDctRoundtrip); +HWY_EXPORT_AND_TEST_P(TransposeTest, TestRectInverse); +HWY_EXPORT_AND_TEST_P(TransposeTest, TestRectTranspose); + +// Tests in the DctShardedTest class are sharded for N=32. +class DctShardedTest : public ::hwy::TestWithParamTargetAndT {}; + +std::vector ShardRange(uint32_t n) { +#ifdef JXL_DISABLE_SLOW_TESTS + JXL_ASSERT(n > 6); + std::vector ret = {0, 1, 3, 5, n - 1}; +#else + std::vector ret(n); + std::iota(ret.begin(), ret.end(), 0); +#endif // JXL_DISABLE_SLOW_TESTS + return ret; +} + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(DctShardedTest, + ::testing::ValuesIn(ShardRange(32))); + +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestDctAccuracyShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestIdctAccuracyShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestDctTransposeShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestSlowInverseShard); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/dct_util.h b/media/libjxl/src/lib/jxl/dct_util.h new file mode 100644 index 000000000..fb6ce3b97 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dct_util.h @@ -0,0 +1,86 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DCT_UTIL_H_ +#define LIB_JXL_DCT_UTIL_H_ + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +union ACPtr { + int32_t* ptr32; + int16_t* ptr16; + ACPtr() = default; + explicit ACPtr(int16_t* p) : ptr16(p) {} + explicit ACPtr(int32_t* p) : ptr32(p) {} +}; + +union ConstACPtr { + const int32_t* ptr32; + const int16_t* ptr16; + ConstACPtr() = default; + explicit ConstACPtr(const int16_t* p) : ptr16(p) {} + explicit ConstACPtr(const int32_t* p) : ptr32(p) {} +}; + +enum class ACType { k16 = 0, k32 = 1 }; + +class ACImage { + public: + virtual ~ACImage() = default; + virtual ACType Type() const = 0; + virtual ACPtr PlaneRow(size_t c, size_t y, size_t xbase) = 0; + virtual ConstACPtr PlaneRow(size_t c, size_t y, size_t xbase) const = 0; + virtual size_t PixelsPerRow() const = 0; + virtual void ZeroFill() = 0; + virtual void ZeroFillPlane(size_t c) = 0; + virtual bool IsEmpty() const = 0; +}; + +template +class ACImageT final : public ACImage { + public: + ACImageT() = default; + ACImageT(size_t xsize, size_t ysize) { + static_assert( + std::is_same::value || std::is_same::value, + "ACImage must be either 32- or 16- bit"); + img_ = Image3(xsize, ysize); + } + ACType Type() const override { + return sizeof(T) == 2 ? ACType::k16 : ACType::k32; + } + ACPtr PlaneRow(size_t c, size_t y, size_t xbase) override { + return ACPtr(img_.PlaneRow(c, y) + xbase); + } + ConstACPtr PlaneRow(size_t c, size_t y, size_t xbase) const override { + return ConstACPtr(img_.PlaneRow(c, y) + xbase); + } + + size_t PixelsPerRow() const override { return img_.PixelsPerRow(); } + + void ZeroFill() override { ZeroFillImage(&img_); } + + void ZeroFillPlane(size_t c) override { ZeroFillImage(&img_.Plane(c)); } + + bool IsEmpty() const override { + return img_.xsize() == 0 || img_.ysize() == 0; + } + + private: + Image3 img_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DCT_UTIL_H_ diff --git a/media/libjxl/src/lib/jxl/dec_ans.cc b/media/libjxl/src/lib/jxl/dec_ans.cc new file mode 100644 index 000000000..c9145472e --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_ans.cc @@ -0,0 +1,374 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_ans.h" + +#include + +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { + +// Decodes a number in the range [0..255], by reading 1 - 11 bits. +inline int DecodeVarLenUint8(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast(input->ReadFixedBits<3>()); + if (nbits == 0) { + return 1; + } else { + return static_cast(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +// Decodes a number in the range [0..65535], by reading 1 - 21 bits. +inline int DecodeVarLenUint16(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast(input->ReadFixedBits<4>()); + if (nbits == 0) { + return 1; + } else { + return static_cast(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +Status ReadHistogram(int precision_bits, std::vector* counts, + BitReader* input) { + int simple_code = input->ReadBits(1); + if (simple_code == 1) { + int i; + int symbols[2] = {0}; + int max_symbol = 0; + const int num_symbols = input->ReadBits(1) + 1; + for (i = 0; i < num_symbols; ++i) { + symbols[i] = DecodeVarLenUint8(input); + if (symbols[i] > max_symbol) max_symbol = symbols[i]; + } + counts->resize(max_symbol + 1); + if (num_symbols == 1) { + (*counts)[symbols[0]] = 1 << precision_bits; + } else { + if (symbols[0] == symbols[1]) { // corrupt data + return false; + } + (*counts)[symbols[0]] = input->ReadBits(precision_bits); + (*counts)[symbols[1]] = (1 << precision_bits) - (*counts)[symbols[0]]; + } + } else { + int is_flat = input->ReadBits(1); + if (is_flat == 1) { + int alphabet_size = DecodeVarLenUint8(input) + 1; + *counts = CreateFlatHistogram(alphabet_size, 1 << precision_bits); + return true; + } + + uint32_t shift; + { + // TODO(veluca): speed up reading with table lookups. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = 0; + for (; log < upper_bound_log; log++) { + if (input->ReadFixedBits<1>() == 0) break; + } + shift = (input->ReadBits(log) | (1 << log)) - 1; + if (shift > ANS_LOG_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid shift value"); + } + } + + int length = DecodeVarLenUint8(input) + 3; + counts->resize(length); + int total_count = 0; + + static const uint8_t huff[128][2] = { + {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + }; + + std::vector logcounts(counts->size()); + int omit_log = -1; + int omit_pos = -1; + // This array remembers which symbols have an RLE length. + std::vector same(counts->size(), 0); + for (size_t i = 0; i < logcounts.size(); ++i) { + input->Refill(); // for PeekFixedBits + Advance + int idx = input->PeekFixedBits<7>(); + input->Consume(huff[idx][0]); + logcounts[i] = huff[idx][1]; + // The RLE symbol. + if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) { + int rle_length = DecodeVarLenUint8(input); + same[i] = rle_length + 5; + i += rle_length + 3; + continue; + } + if (logcounts[i] > omit_log) { + omit_log = logcounts[i]; + omit_pos = i; + } + } + // Invalid input, e.g. due to invalid usage of RLE. + if (omit_pos < 0) return JXL_FAILURE("Invalid histogram."); + if (static_cast(omit_pos) + 1 < logcounts.size() && + logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid histogram."); + } + int prev = 0; + int numsame = 0; + for (size_t i = 0; i < logcounts.size(); ++i) { + if (same[i]) { + // RLE sequence, let this loop output the same count for the next + // iterations. + numsame = same[i] - 1; + prev = i > 0 ? (*counts)[i - 1] : 0; + } + if (numsame > 0) { + (*counts)[i] = prev; + numsame--; + } else { + int code = logcounts[i]; + // omit_pos may not be negative at this point (checked before). + if (i == static_cast(omit_pos)) { + continue; + } else if (code == 0) { + continue; + } else if (code == 1) { + (*counts)[i] = 1; + } else { + int bitcount = GetPopulationCountPrecision(code - 1, shift); + (*counts)[i] = (1 << (code - 1)) + + (input->ReadBits(bitcount) << (code - 1 - bitcount)); + } + } + total_count += (*counts)[i]; + } + (*counts)[omit_pos] = (1 << precision_bits) - total_count; + if ((*counts)[omit_pos] <= 0) { + // The histogram we've read sums to more than total_count (including at + // least 1 for the omitted value). + return JXL_FAILURE("Invalid histogram count."); + } + } + return true; +} + +} // namespace + +Status DecodeANSCodes(const size_t num_histograms, + const size_t max_alphabet_size, BitReader* in, + ANSCode* result) { + result->degenerate_symbols.resize(num_histograms, -1); + if (result->use_prefix_code) { + JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS); + result->huffman_data.resize(num_histograms); + std::vector alphabet_sizes(num_histograms); + for (size_t c = 0; c < num_histograms; c++) { + alphabet_sizes[c] = DecodeVarLenUint16(in) + 1; + if (alphabet_sizes[c] > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]); + } + } + for (size_t c = 0; c < num_histograms; c++) { + if (alphabet_sizes[c] > 1) { + if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) { + if (!in->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for huffman code"); + } + return JXL_FAILURE("Invalid huffman tree number %" PRIuS + ", alphabet size %u", + c, alphabet_sizes[c]); + } + } else { + // 0-bit codes does not require extension tables. + result->huffman_data[c].table_.clear(); + result->huffman_data[c].table_.resize(1u << kHuffmanTableBits); + } + for (const auto& h : result->huffman_data[c].table_) { + if (h.bits <= kHuffmanTableBits) { + result->UpdateMaxNumBits(c, h.value); + } + } + } + } else { + JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); + result->alias_tables = + AllocateArray(num_histograms * (1 << result->log_alpha_size) * + sizeof(AliasTable::Entry)); + AliasTable::Entry* alias_tables = + reinterpret_cast(result->alias_tables.get()); + for (size_t c = 0; c < num_histograms; ++c) { + std::vector counts; + if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) { + return JXL_FAILURE("Invalid histogram bitstream."); + } + if (counts.size() > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size()); + } + while (!counts.empty() && counts.back() == 0) { + counts.pop_back(); + } + for (size_t s = 0; s < counts.size(); s++) { + if (counts[s] != 0) { + result->UpdateMaxNumBits(c, s); + } + } + // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol. + int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1); + for (int s = 0; s < degenerate_symbol; ++s) { + if (counts[s] != 0) { + degenerate_symbol = -1; + break; + } + } + result->degenerate_symbols[c] = degenerate_symbol; + InitAliasTable(counts, ANS_TAB_SIZE, result->log_alpha_size, + alias_tables + c * (1 << result->log_alpha_size)); + } + } + return true; +} +Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config, + BitReader* br) { + br->Refill(); + size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1)); + size_t msb_in_token = 0, lsb_in_token = 0; + if (split_exponent != log_alpha_size) { + // otherwise, msb/lsb don't matter. + size_t nbits = CeilLog2Nonzero(split_exponent + 1); + msb_in_token = br->ReadBits(nbits); + if (msb_in_token > split_exponent) { + // This could be invalid here already and we need to check this before + // we use its value to read more bits. + return JXL_FAILURE("Invalid HybridUintConfig"); + } + nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1); + lsb_in_token = br->ReadBits(nbits); + } + if (lsb_in_token + msb_in_token > split_exponent) { + return JXL_FAILURE("Invalid HybridUintConfig"); + } + *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token); + return true; +} + +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector* uint_config, + BitReader* br) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config->size(); i++) { + JXL_RETURN_IF_ERROR( + DecodeUintConfig(log_alpha_size, &(*uint_config)[i], br)); + } + return true; +} + +LZ77Params::LZ77Params() { Bundle::Init(this); } +Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled)); + if (!visitor->Conditional(enabled)) return true; + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096), + BitsOffset(15, 8), 224, &min_symbol)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5), + BitsOffset(8, 9), 3, &min_length)); + return true; +} + +void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { + HybridUintConfig* cfg = &uint_config[ctx]; + // LZ77 symbols use a different uint config. + if (lz77.enabled && lz77.nonserialized_distance_context != ctx && + symbol >= lz77.min_symbol) { + symbol -= lz77.min_symbol; + cfg = &lz77.length_uint_config; + } + size_t split_token = cfg->split_token; + size_t msb_in_token = cfg->msb_in_token; + size_t lsb_in_token = cfg->lsb_in_token; + size_t split_exponent = cfg->split_exponent; + if (symbol < split_token) { + max_num_bits = std::max(max_num_bits, split_exponent); + return; + } + uint32_t n_extra_bits = + split_exponent - (msb_in_token + lsb_in_token) + + ((symbol - split_token) >> (msb_in_token + lsb_in_token)); + size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1; + max_num_bits = std::max(max_num_bits, total_bits); +} + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector* context_map, bool disallow_lz77) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); + if (code->lz77.enabled) { + num_contexts++; + JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8, + &code->lz77.length_uint_config, br)); + } + if (code->lz77.enabled && disallow_lz77) { + return JXL_FAILURE("Using LZ77 when explicitly disallowed"); + } + size_t num_histograms = 1; + context_map->resize(num_contexts); + if (num_contexts > 1) { + JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br)); + } + code->lz77.nonserialized_distance_context = context_map->back(); + code->use_prefix_code = br->ReadFixedBits<1>(); + if (code->use_prefix_code) { + code->log_alpha_size = PREFIX_MAX_BITS; + } else { + code->log_alpha_size = br->ReadFixedBits<2>() + 5; + } + code->uint_config.resize(num_histograms); + JXL_RETURN_IF_ERROR( + DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); + const size_t max_alphabet_size = 1 << code->log_alpha_size; + JXL_RETURN_IF_ERROR( + DecodeANSCodes(num_histograms, max_alphabet_size, br, code)); + // When using LZ77, flat codes might result in valid codestreams with + // histograms that potentially allow very large bit counts. + // TODO(veluca): in principle, a valid codestream might contain a histogram + // that could allow very large numbers of bits that is never used during ANS + // decoding. There's no benefit to doing that, though. + if (!code->lz77.enabled && code->max_num_bits > 32) { + // Just emit a warning as there are many opportunities for false positives. + JXL_WARNING("Histogram can represent numbers that are too large: %" PRIuS + "\n", + code->max_num_bits); + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_ans.h b/media/libjxl/src/lib/jxl/dec_ans.h new file mode 100644 index 000000000..0f4406745 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_ans.h @@ -0,0 +1,462 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_ANS_H_ +#define LIB_JXL_DEC_ANS_H_ + +// Library to decode the ANS population counts from the bit-stream and build a +// decoding table from them. + +#include +#include + +#include +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_huffman.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +class ANSSymbolReader; + +// Experiments show that best performance is typically achieved for a +// split-exponent of 3 or 4. Trend seems to be that '4' is better +// for large-ish pictures, and '3' better for rather small-ish pictures. +// This is plausible - the more special symbols we have, the better +// statistics we need to get a benefit out of them. + +// Our hybrid-encoding scheme has dedicated tokens for the smallest +// (1 << split_exponents) numbers, and for the rest +// encodes (number of bits) + (msb_in_token sub-leading binary digits) + +// (lsb_in_token lowest binary digits) in the token, with the remaining bits +// then being encoded as data. +// +// Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. +// +// Numbers N in [0 .. 15]: +// These get represented as (token=N, bits=''). +// Numbers N >= 16: +// If n is such that 2**n <= N < 2**(n+1), +// and m = N - 2**n is the 'mantissa', +// these get represented as: +// (token=split_token + +// ((n - split_exponent) * 4) + +// (m >> (n - msb_in_token)), +// bits=m & (1 << (n - msb_in_token)) - 1) +// Specifically, we would get: +// N = 0 - 15: (token=N, nbits=0, bits='') +// N = 16 (10000): (token=16, nbits=2, bits='00') +// N = 17 (10001): (token=16, nbits=2, bits='01') +// N = 20 (10100): (token=17, nbits=2, bits='00') +// N = 24 (11000): (token=18, nbits=2, bits='00') +// N = 28 (11100): (token=19, nbits=2, bits='00') +// N = 32 (100000): (token=20, nbits=3, bits='000') +// N = 65535: (token=63, nbits=13, bits='1111111111111') +struct HybridUintConfig { + uint32_t split_exponent; + uint32_t split_token; + uint32_t msb_in_token; + uint32_t lsb_in_token; + JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, + uint32_t* JXL_RESTRICT nbits, + uint32_t* JXL_RESTRICT bits) const { + if (value < split_token) { + *token = value; + *nbits = 0; + *bits = 0; + } else { + uint32_t n = FloorLog2Nonzero(value); + uint32_t m = value - (1 << n); + *token = split_token + + ((n - split_exponent) << (msb_in_token + lsb_in_token)) + + ((m >> (n - msb_in_token)) << lsb_in_token) + + (m & ((1 << lsb_in_token) - 1)); + *nbits = n - msb_in_token - lsb_in_token; + *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); + } + } + + explicit HybridUintConfig(uint32_t split_exponent = 4, + uint32_t msb_in_token = 2, + uint32_t lsb_in_token = 0) + : split_exponent(split_exponent), + split_token(1 << split_exponent), + msb_in_token(msb_in_token), + lsb_in_token(lsb_in_token) { + JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); + } +}; + +struct LZ77Params : public Fields { + LZ77Params(); + JXL_FIELDS_NAME(LZ77Params) + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + bool enabled; + + // Symbols above min_symbol use a special hybrid uint encoding and + // represent a length, to be added to min_length. + uint32_t min_symbol; + uint32_t min_length; + + // Not serialized by VisitFields. + HybridUintConfig length_uint_config{0, 0, 0}; + + size_t nonserialized_distance_context; +}; + +static constexpr size_t kWindowSize = 1 << 20; +static constexpr size_t kNumSpecialDistances = 120; +// Table of special distance codes from WebP lossless. +static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { + {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, + {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, + {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, + {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, + {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, + {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, + {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, + {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, + {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, + {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, + {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, + {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, + {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, + {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, + {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; + +struct ANSCode { + CacheAlignedUniquePtr alias_tables; + std::vector huffman_data; + std::vector uint_config; + std::vector degenerate_symbols; + bool use_prefix_code; + uint8_t log_alpha_size; // for ANS. + LZ77Params lz77; + // Maximum number of bits necessary to represent the result of a + // ReadHybridUint call done with this ANSCode. + size_t max_num_bits = 0; + void UpdateMaxNumBits(size_t ctx, size_t symbol); +}; + +class ANSSymbolReader { + public: + // Invalid symbol reader, to be overwritten. + ANSSymbolReader() = default; + ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, + size_t distance_multiplier = 0) + : alias_tables_( + reinterpret_cast(code->alias_tables.get())), + huffman_data_(code->huffman_data.data()), + use_prefix_code_(code->use_prefix_code), + configs(code->uint_config.data()) { + if (!use_prefix_code_) { + state_ = static_cast(br->ReadFixedBits<32>()); + log_alpha_size_ = code->log_alpha_size; + log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; + entry_size_minus_1_ = (1 << log_entry_size_) - 1; + } else { + state_ = (ANS_SIGNATURE << 16u); + } + if (!code->lz77.enabled) return; + // a std::vector incurs unacceptable decoding speed loss because of + // initialization. + lz77_window_storage_ = AllocateArray(kWindowSize * sizeof(uint32_t)); + lz77_window_ = reinterpret_cast(lz77_window_storage_.get()); + lz77_ctx_ = code->lz77.nonserialized_distance_context; + lz77_length_uint_ = code->lz77.length_uint_config; + lz77_threshold_ = code->lz77.min_symbol; + lz77_min_length_ = code->lz77.min_length; + num_special_distances_ = + distance_multiplier == 0 ? 0 : kNumSpecialDistances; + for (size_t i = 0; i < num_special_distances_; i++) { + int dist = kSpecialDistances[i][0]; + dist += static_cast(distance_multiplier) * kSpecialDistances[i][1]; + if (dist < 1) dist = 1; + special_distances_[i] = dist; + } + } + + JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + + const AliasTable::Entry* table = + &alias_tables_[histo_idx << log_alpha_size_]; + const AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; + +#if 1 + // Branchless version is about equally fast on SKX. + const uint32_t new_state = + (state_ << 16u) | static_cast(br->PeekFixedBits<16>()); + const bool normalize = state_ < (1u << 16u); + state_ = normalize ? new_state : state_; + br->Consume(normalize ? 16 : 0); +#else + if (JXL_UNLIKELY(state_ < (1u << 16u))) { + state_ = (state_ << 16u) | br->PeekFixedBits<16>(); + br->Consume(16); + } +#endif + const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); + AliasTable::Prefetch(table, next_res, log_entry_size_); + + return symbol.value; + } + + JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + return huffman_data_[histo_idx].ReadSymbol(br); + } + + JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + // TODO(veluca): hoist if in hotter loops. + if (JXL_UNLIKELY(use_prefix_code_)) { + return ReadSymbolHuffWithoutRefill(histo_idx, br); + } + return ReadSymbolANSWithoutRefill(histo_idx, br); + } + + JXL_INLINE size_t ReadSymbol(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + br->Refill(); + return ReadSymbolWithoutRefill(histo_idx, br); + } + +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + bool CheckANSFinalState() const { return true; } +#else + bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); } +#endif + + template + static JXL_INLINE uint32_t ReadHybridUintConfig( + const HybridUintConfig& config, size_t token, BitReader* br) { + size_t split_token = config.split_token; + size_t msb_in_token = config.msb_in_token; + size_t lsb_in_token = config.lsb_in_token; + size_t split_exponent = config.split_exponent; + // Fast-track version of hybrid integer decoding. + if (token < split_token) return token; + uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + + ((token - split_token) >> (msb_in_token + lsb_in_token)); + // Max amount of bits for ReadBits is 32 and max valid left shift is 29 + // bits. However, for speed no error is propagated here, instead limit the + // nbits size. If nbits > 29, the code stream is invalid, but no error is + // returned. + // Note that in most cases we will emit an error if the histogram allows + // representing numbers that would cause invalid shifts, but we need to + // keep this check as when LZ77 is enabled it might make sense to have an + // histogram that could in principle cause invalid shifts. + nbits &= 31u; + uint32_t low = token & ((1 << lsb_in_token) - 1); + token >>= lsb_in_token; + const size_t bits = br->PeekBits(nbits); + br->Consume(nbits); + size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) + << nbits) | + bits) + << lsb_in_token) | + low; + // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not + // fit uint32_t + return static_cast(ret); + } + + // Takes a *clustered* idx. Can only use if HuffRleOnly() is true. + void ReadHybridUintClusteredHuffRleOnly(size_t ctx, + BitReader* JXL_RESTRICT br, + uint32_t* value, uint32_t* run) { + JXL_DASSERT(HuffRleOnly()); + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolHuffWithoutRefill(ctx, br); + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + *run = + ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + + lz77_min_length_ - 1; + return; + } + *value = ReadHybridUintConfig(configs[ctx], token, br); + } + bool HuffRleOnly() { + if (lz77_window_ == nullptr) return false; + if (!use_prefix_code_) return false; + for (size_t i = 0; i < kHuffmanTableBits; i++) { + if (huffman_data_[lz77_ctx_].table_[i].bits) return false; + if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false; + } + if (configs[lz77_ctx_].split_token > 1) return false; + return true; + } + + // Takes a *clustered* idx. + size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { + if (JXL_UNLIKELY(num_to_copy_ > 0)) { + size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; + num_to_copy_--; + lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolWithoutRefill(ctx, br); + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + num_to_copy_ = + ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + + lz77_min_length_; + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + // Distance code. + size_t token = ReadSymbolWithoutRefill(lz77_ctx_, br); + size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], token, br); + if (JXL_LIKELY(distance < num_special_distances_)) { + distance = special_distances_[distance]; + } else { + distance = distance + 1 - num_special_distances_; + } + if (JXL_UNLIKELY(distance > num_decoded_)) { + distance = num_decoded_; + } + if (JXL_UNLIKELY(distance > kWindowSize)) { + distance = kWindowSize; + } + copy_pos_ = num_decoded_ - distance; + if (JXL_UNLIKELY(distance == 0)) { + JXL_DASSERT(lz77_window_ != nullptr); + // distance 0 -> num_decoded_ == copy_pos_ == 0 + size_t to_fill = std::min(num_to_copy_, kWindowSize); + memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); + } + // TODO(eustas): overflow; mark BitReader as unhealthy + if (num_to_copy_ < lz77_min_length_) return 0; + return ReadHybridUintClustered(ctx, br); // will trigger a copy. + } + size_t ret = ReadHybridUintConfig(configs[ctx], token, br); + if (lz77_window_) lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + + JXL_INLINE size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, + const std::vector& context_map) { + return ReadHybridUintClustered(context_map[ctx], br); + } + + // ctx is a *clustered* context! + // This function will modify the ANS state as if `count` symbols have been + // decoded. + bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { + // TODO(veluca): No optimization for Huffman mode yet. + if (use_prefix_code_) return false; + // TODO(eustas): propagate "degenerate_symbol" to simplify this method. + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; + AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + if (symbol.freq != ANS_TAB_SIZE) return false; + if (configs[ctx].split_token <= symbol.value) return false; + if (symbol.value >= lz77_threshold_) return false; + *value = symbol.value; + if (lz77_window_) { + for (size_t i = 0; i < count; i++) { + lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; + } + } + return true; + } + + static constexpr size_t kMaxCheckpointInterval = 512; + struct Checkpoint { + uint32_t state; + uint32_t num_to_copy; + uint32_t copy_pos; + uint32_t num_decoded; + uint32_t lz77_window[kMaxCheckpointInterval]; + }; + void Save(Checkpoint* checkpoint) { + checkpoint->state = state_; + checkpoint->num_decoded = num_decoded_; + checkpoint->num_to_copy = num_to_copy_; + checkpoint->copy_pos = copy_pos_; + if (lz77_window_) { + size_t win_start = num_decoded_ & kWindowMask; + size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; + if (win_end > win_start) { + memcpy(checkpoint->lz77_window, lz77_window_ + win_start, + (win_end - win_start) * sizeof(*lz77_window_)); + } else { + memcpy(checkpoint->lz77_window, lz77_window_ + win_start, + (kWindowSize - win_start) * sizeof(*lz77_window_)); + memcpy(checkpoint->lz77_window + (kWindowSize - win_start), + lz77_window_, win_end * sizeof(*lz77_window_)); + } + } + } + void Restore(const Checkpoint& checkpoint) { + state_ = checkpoint.state; + JXL_DASSERT(num_decoded_ <= + checkpoint.num_decoded + kMaxCheckpointInterval); + num_decoded_ = checkpoint.num_decoded; + num_to_copy_ = checkpoint.num_to_copy; + copy_pos_ = checkpoint.copy_pos; + if (lz77_window_) { + size_t win_start = num_decoded_ & kWindowMask; + size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; + if (win_end > win_start) { + memcpy(lz77_window_ + win_start, checkpoint.lz77_window, + (win_end - win_start) * sizeof(*lz77_window_)); + } else { + memcpy(lz77_window_ + win_start, checkpoint.lz77_window, + (kWindowSize - win_start) * sizeof(*lz77_window_)); + memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start), + win_end * sizeof(*lz77_window_)); + } + } + } + + private: + const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned + const HuffmanDecodingData* huffman_data_; + bool use_prefix_code_; + uint32_t state_ = ANS_SIGNATURE << 16u; + const HybridUintConfig* JXL_RESTRICT configs; + uint32_t log_alpha_size_{}; + uint32_t log_entry_size_{}; + uint32_t entry_size_minus_1_{}; + + // LZ77 structures and constants. + static constexpr size_t kWindowMask = kWindowSize - 1; + CacheAlignedUniquePtr lz77_window_storage_; + uint32_t* lz77_window_ = nullptr; + uint32_t num_decoded_ = 0; + uint32_t num_to_copy_ = 0; + uint32_t copy_pos_ = 0; + uint32_t lz77_ctx_ = 0; + uint32_t lz77_min_length_ = 0; + uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. + HybridUintConfig lz77_length_uint_; + uint32_t special_distances_[kNumSpecialDistances]{}; + uint32_t num_special_distances_{}; +}; + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector* context_map, + bool disallow_lz77 = false); + +// Exposed for tests. +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector* uint_config, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_DEC_ANS_H_ diff --git a/media/libjxl/src/lib/jxl/dec_bit_reader.h b/media/libjxl/src/lib/jxl/dec_bit_reader.h new file mode 100644 index 000000000..df70284e3 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_bit_reader.h @@ -0,0 +1,354 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_BIT_READER_H_ +#define LIB_JXL_DEC_BIT_READER_H_ + +// Bounds-checked bit reader; 64-bit buffer with support for deferred refills +// and switching to reading byte-aligned words. + +#include +#include +#include // memcpy + +#ifdef __BMI2__ +#include +#endif + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" + +namespace jxl { + +// Reads bits previously written to memory by BitWriter. Uses unaligned 8-byte +// little-endian loads. +class BitReader { + public: + static constexpr size_t kMaxBitsPerCall = 56; + + // Constructs an invalid BitReader, to be overwritten before usage. + BitReader() + : buf_(0), + bits_in_buf_(0), + next_byte_{nullptr}, + end_minus_8_{nullptr}, + first_byte_(nullptr) {} + BitReader(const BitReader&) = delete; + + // bytes need not be aligned nor padded! + template + explicit BitReader(const ArrayLike& bytes) + : buf_(0), + bits_in_buf_(0), + next_byte_(bytes.data()), + // Assumes first_byte_ >= 8. + end_minus_8_(bytes.data() - 8 + bytes.size()), + first_byte_(bytes.data()) { + Refill(); + } + ~BitReader() { + // Close() must be called before destroying an initialized bit reader. + // Invalid bit readers will have a nullptr in first_byte_. + JXL_ASSERT(close_called_ || !first_byte_); + } + + // Move operator needs to invalidate the other BitReader such that it is + // irrelevant if we call Close() on it or not. + BitReader& operator=(BitReader&& other) noexcept { + // Ensure the current instance was already closed, before we overwrite it + // with other. + JXL_ASSERT(close_called_ || !first_byte_); + + JXL_DASSERT(!other.close_called_); + buf_ = other.buf_; + bits_in_buf_ = other.bits_in_buf_; + next_byte_ = other.next_byte_; + end_minus_8_ = other.end_minus_8_; + first_byte_ = other.first_byte_; + overread_bytes_ = other.overread_bytes_; + close_called_ = other.close_called_; + + other.first_byte_ = nullptr; + other.next_byte_ = nullptr; + return *this; + } + BitReader& operator=(const BitReader& other) = delete; + + // For time-critical reads, refills can be shared by multiple reads. + // Based on variant 4 (plus bounds-checking), see + // fgiesen.wordpress.com/2018/02/20/reading-bits-in-far-too-many-ways-part-2/ + JXL_INLINE void Refill() { + if (JXL_UNLIKELY(next_byte_ > end_minus_8_)) { + BoundsCheckedRefill(); + } else { + // It's safe to load 64 bits; insert valid (possibly nonzero) bits above + // bits_in_buf_. The shift requires bits_in_buf_ < 64. + buf_ |= LoadLE64(next_byte_) << bits_in_buf_; + + // Advance by bytes fully absorbed into the buffer. + next_byte_ += (63 - bits_in_buf_) >> 3; + + // We absorbed a multiple of 8 bits, so the lower 3 bits of bits_in_buf_ + // must remain unchanged, otherwise the next refill's shifted bits will + // not align with buf_. Set the three upper bits so the result >= 56. + bits_in_buf_ |= 56; + JXL_DASSERT(56 <= bits_in_buf_ && bits_in_buf_ < 64); + } + } + + // Returns the bits that would be returned by Read without calling Advance(). + // It is legal to PEEK at more bits than present in the bitstream (required + // by Huffman), and those bits will be zero. + template + JXL_INLINE uint64_t PeekFixedBits() const { + static_assert(N <= kMaxBitsPerCall, "Reading too many bits in one call."); + JXL_DASSERT(!close_called_); + return buf_ & ((1ULL << N) - 1); + } + + JXL_INLINE uint64_t PeekBits(size_t nbits) const { + JXL_DASSERT(nbits <= kMaxBitsPerCall); + JXL_DASSERT(!close_called_); + + // Slightly faster but requires BMI2. It is infeasible to make the many + // callers reside between begin/end_target, especially because only the + // callers in dec_ans are time-critical. Therefore only enabled if the + // entire binary is compiled for (and thus requires) BMI2. +#if defined(__BMI2__) && defined(__x86_64__) + return _bzhi_u64(buf_, nbits); +#else + const uint64_t mask = (1ULL << nbits) - 1; + return buf_ & mask; +#endif + } + + // Removes bits from the buffer. Need not match the previous Peek size, but + // the buffer must contain at least num_bits (this prevents consuming more + // than the total number of bits). + JXL_INLINE void Consume(size_t num_bits) { + JXL_DASSERT(!close_called_); + JXL_DASSERT(bits_in_buf_ >= num_bits); +#ifdef JXL_CRASH_ON_ERROR + // When JXL_CRASH_ON_ERROR is defined, it is a fatal error to read more bits + // than available in the stream. A non-zero overread_bytes_ implies that + // next_byte_ is already at the end of the stream, so we don't need to + // check that. + JXL_ASSERT(bits_in_buf_ >= num_bits + overread_bytes_ * kBitsPerByte); +#endif + bits_in_buf_ -= num_bits; + buf_ >>= num_bits; + } + + JXL_INLINE uint64_t ReadBits(size_t nbits) { + JXL_DASSERT(!close_called_); + Refill(); + const uint64_t bits = PeekBits(nbits); + Consume(nbits); + return bits; + } + + template + JXL_INLINE uint64_t ReadFixedBits() { + JXL_DASSERT(!close_called_); + Refill(); + const uint64_t bits = PeekFixedBits(); + Consume(N); + return bits; + } + + // Equivalent to calling ReadFixedBits(1) `skip` times, but much faster. + // `skip` is typically large. + void SkipBits(size_t skip) { + JXL_DASSERT(!close_called_); + // Buffer is large enough - don't zero buf_ below. + if (JXL_UNLIKELY(skip <= bits_in_buf_)) { + Consume(skip); + return; + } + + // First deduct what we can satisfy from the buffer + skip -= bits_in_buf_; + bits_in_buf_ = 0; + // Not enough to call Advance - that may leave some bits in the buffer + // which were previously ABOVE bits_in_buf. + buf_ = 0; + + // Skip whole bytes + const size_t whole_bytes = skip / kBitsPerByte; + skip %= kBitsPerByte; + if (JXL_UNLIKELY(whole_bytes > + static_cast(end_minus_8_ + 8 - next_byte_))) { + // This is already an overflow condition (skipping past the end of the bit + // stream). However if we increase next_byte_ too much we risk overflowing + // that value and potentially making it valid again (next_byte_ < end). + // This will set next_byte_ to the end of the stream and still consume + // some bits in overread_bytes_, however the TotalBitsConsumed() will be + // incorrect (still larger than the TotalBytes()). + next_byte_ = end_minus_8_ + 8; + skip += kBitsPerByte; + } else { + next_byte_ += whole_bytes; + } + + Refill(); + Consume(skip); + } + + size_t TotalBitsConsumed() const { + const size_t bytes_read = static_cast(next_byte_ - first_byte_); + return (bytes_read + overread_bytes_) * kBitsPerByte - bits_in_buf_; + } + + Status JumpToByteBoundary() { + const size_t remainder = TotalBitsConsumed() % kBitsPerByte; + if (remainder == 0) return true; + if (JXL_UNLIKELY(ReadBits(kBitsPerByte - remainder) != 0)) { + return JXL_FAILURE("Non-zero padding bits"); + } + return true; + } + + // For interoperability with other bitreaders (for resuming at + // non-byte-aligned positions). + const uint8_t* FirstByte() const { return first_byte_; } + size_t TotalBytes() const { + return static_cast(end_minus_8_ + 8 - first_byte_); + } + + // Returns span of the remaining (unconsumed) bytes, e.g. for passing to + // external decoders such as Brotli. + Span GetSpan() const { + JXL_DASSERT(first_byte_ != nullptr); + JXL_ASSERT(TotalBitsConsumed() % kBitsPerByte == 0); + const size_t offset = TotalBitsConsumed() / kBitsPerByte; // no remainder + JXL_ASSERT(offset <= TotalBytes()); + return Span(first_byte_ + offset, TotalBytes() - offset); + } + + // Returns whether all the bits read so far have been within the input bounds. + // When reading past the EOF, the Read*() and Consume() functions return zeros + // but flag a failure when calling Close() without checking this function. + Status AllReadsWithinBounds() { + // Mark up to which point the user checked the out of bounds condition. If + // the user handles the condition at higher level (e.g. fetch more bytes + // from network, return a custom JXL_FAILURE, ...), Close() should not + // output a debug error (which would break tests with JXL_CRASH_ON_ERROR + // even when legitimately handling the situation at higher level). This is + // used by Bundle::CanRead. + checked_out_of_bounds_bits_ = TotalBitsConsumed(); + if (TotalBitsConsumed() > TotalBytes() * kBitsPerByte) { + return false; + } + return true; + } + + // Close the bit reader and return whether all the previous reads were + // successful. Close must be called once. + Status Close() { + JXL_DASSERT(!close_called_); + close_called_ = true; + if (!first_byte_) return true; + if (TotalBitsConsumed() > checked_out_of_bounds_bits_ && + TotalBitsConsumed() > TotalBytes() * kBitsPerByte) { + return JXL_FAILURE("Read more bits than available in the bit_reader"); + } + return true; + } + + private: + // Separate function avoids inlining this relatively cold code into callers. + JXL_NOINLINE void BoundsCheckedRefill() { + PROFILER_FUNC; + const uint8_t* end = end_minus_8_ + 8; + + // Read whole bytes until we have [56, 64) bits (same as LoadLE64) + for (; bits_in_buf_ < 64 - kBitsPerByte; bits_in_buf_ += kBitsPerByte) { + if (next_byte_ >= end) break; + buf_ |= static_cast(*next_byte_++) << bits_in_buf_; + } + JXL_DASSERT(bits_in_buf_ < 64); + + // Add extra bytes as 0 at the end of the stream in the bit_buffer_. If + // these bits are read, Close() will return a failure. + size_t extra_bytes = (63 - bits_in_buf_) / kBitsPerByte; + overread_bytes_ += extra_bytes; + bits_in_buf_ += extra_bytes * kBitsPerByte; + + JXL_DASSERT(bits_in_buf_ < 64); + JXL_DASSERT(bits_in_buf_ >= 56); + } + + JXL_NOINLINE uint32_t BoundsCheckedReadByteAlignedWord() { + if (next_byte_ + 1 < end_minus_8_ + 8) { + uint32_t ret = LoadLE16(next_byte_); + next_byte_ += 2; + return ret; + } + overread_bytes_ += 2; + return 0; + } + + uint64_t buf_; + size_t bits_in_buf_; // [0, 64) + const uint8_t* JXL_RESTRICT next_byte_; + const uint8_t* end_minus_8_; // for refill bounds check + const uint8_t* first_byte_; // for GetSpan + + // Number of bytes past the end that were loaded into the buf_. These bytes + // are not read from memory, but instead assumed 0. It is an error (likely due + // to an invalid stream) to Consume() more bits than specified in the range + // passed to the constructor. + uint64_t overread_bytes_{0}; + bool close_called_{false}; + + uint64_t checked_out_of_bounds_bits_{0}; +}; + +// Closes a BitReader when the BitReaderScopedCloser goes out of scope. When +// closing the bit reader, if the status result was failure it sets this failure +// to the passed variable pointer. Typical usage. +// +// Status ret = true; +// { +// BitReader reader(...); +// BitReaderScopedCloser reader_closer(&reader, &ret); +// +// // ... code that can return errors here ... +// } +// // ... more code that doesn't use the BitReader. +// return ret; + +class BitReaderScopedCloser { + public: + BitReaderScopedCloser(BitReader* reader, Status* status) + : reader_(reader), status_(status) { + JXL_DASSERT(reader_ != nullptr); + JXL_DASSERT(status_ != nullptr); + } + ~BitReaderScopedCloser() { + if (reader_ != nullptr) { + Status close_ret = reader_->Close(); + if (!close_ret) *status_ = close_ret; + } + } + void CloseAndSuppressError() { + JXL_ASSERT(reader_ != nullptr); + (void)reader_->Close(); + reader_ = nullptr; + } + BitReaderScopedCloser(const BitReaderScopedCloser&) = delete; + + private: + BitReader* reader_; + Status* status_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_BIT_READER_H_ diff --git a/media/libjxl/src/lib/jxl/dec_cache.cc b/media/libjxl/src/lib/jxl/dec_cache.cc new file mode 100644 index 000000000..b819b5104 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_cache.cc @@ -0,0 +1,235 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_cache.h" + +#include "lib/jxl/blending.h" +#include "lib/jxl/render_pipeline/stage_blending.h" +#include "lib/jxl/render_pipeline/stage_chroma_upsampling.h" +#include "lib/jxl/render_pipeline/stage_epf.h" +#include "lib/jxl/render_pipeline/stage_from_linear.h" +#include "lib/jxl/render_pipeline/stage_gaborish.h" +#include "lib/jxl/render_pipeline/stage_noise.h" +#include "lib/jxl/render_pipeline/stage_patches.h" +#include "lib/jxl/render_pipeline/stage_splines.h" +#include "lib/jxl/render_pipeline/stage_spot.h" +#include "lib/jxl/render_pipeline/stage_to_linear.h" +#include "lib/jxl/render_pipeline/stage_tone_mapping.h" +#include "lib/jxl/render_pipeline/stage_upsampling.h" +#include "lib/jxl/render_pipeline/stage_write.h" +#include "lib/jxl/render_pipeline/stage_xyb.h" +#include "lib/jxl/render_pipeline/stage_ycbcr.h" + +namespace jxl { + +Status PassesDecoderState::PreparePipeline(ImageBundle* decoded, + PipelineOptions options) { + const FrameHeader& frame_header = shared->frame_header; + size_t num_c = 3 + frame_header.nonserialized_metadata->m.num_extra_channels; + if ((frame_header.flags & FrameHeader::kNoise) != 0) { + num_c += 3; + } + + if (frame_header.CanBeReferenced()) { + // Necessary so that SetInputSizes() can allocate output buffers as needed. + frame_storage_for_referencing = ImageBundle(decoded->metadata()); + } + + RenderPipeline::Builder builder(num_c); + + if (options.use_slow_render_pipeline) { + builder.UseSimpleImplementation(); + } + + if (!frame_header.chroma_subsampling.Is444()) { + for (size_t c = 0; c < 3; c++) { + if (frame_header.chroma_subsampling.HShift(c) != 0) { + builder.AddStage(GetChromaUpsamplingStage(c, /*horizontal=*/true)); + } + if (frame_header.chroma_subsampling.VShift(c) != 0) { + builder.AddStage(GetChromaUpsamplingStage(c, /*horizontal=*/false)); + } + } + } + + if (frame_header.loop_filter.gab) { + builder.AddStage(GetGaborishStage(frame_header.loop_filter)); + } + + { + const LoopFilter& lf = frame_header.loop_filter; + if (lf.epf_iters >= 3) { + builder.AddStage(GetEPFStage(lf, sigma, 0)); + } + if (lf.epf_iters >= 1) { + builder.AddStage(GetEPFStage(lf, sigma, 1)); + } + if (lf.epf_iters >= 2) { + builder.AddStage(GetEPFStage(lf, sigma, 2)); + } + } + + bool late_ec_upsample = frame_header.upsampling != 1; + for (auto ecups : frame_header.extra_channel_upsampling) { + if (ecups != frame_header.upsampling) { + // If patches are applied, either frame_header.upsampling == 1 or + // late_ec_upsample is true. + late_ec_upsample = false; + } + } + + if (!late_ec_upsample) { + for (size_t ec = 0; ec < frame_header.extra_channel_upsampling.size(); + ec++) { + if (frame_header.extra_channel_upsampling[ec] != 1) { + builder.AddStage(GetUpsamplingStage( + frame_header.nonserialized_metadata->transform_data, 3 + ec, + CeilLog2Nonzero(frame_header.extra_channel_upsampling[ec]))); + } + } + } + + if ((frame_header.flags & FrameHeader::kPatches) != 0) { + builder.AddStage( + GetPatchesStage(&shared->image_features.patches, + 3 + shared->metadata->m.num_extra_channels)); + } + if ((frame_header.flags & FrameHeader::kSplines) != 0) { + builder.AddStage(GetSplineStage(&shared->image_features.splines)); + } + + if (frame_header.upsampling != 1) { + size_t nb_channels = + 3 + + (late_ec_upsample ? frame_header.extra_channel_upsampling.size() : 0); + for (size_t c = 0; c < nb_channels; c++) { + builder.AddStage(GetUpsamplingStage( + frame_header.nonserialized_metadata->transform_data, c, + CeilLog2Nonzero(frame_header.upsampling))); + } + } + + if ((frame_header.flags & FrameHeader::kNoise) != 0) { + builder.AddStage(GetConvolveNoiseStage(num_c - 3)); + builder.AddStage(GetAddNoiseStage(shared->image_features.noise_params, + shared->cmap, num_c - 3)); + } + if (frame_header.dc_level != 0) { + builder.AddStage(GetWriteToImage3FStage( + &shared_storage.dc_frames[frame_header.dc_level - 1])); + } + + if (frame_header.CanBeReferenced() && + frame_header.save_before_color_transform) { + builder.AddStage(GetWriteToImageBundleStage( + &frame_storage_for_referencing, output_encoding_info.color_encoding)); + } + + bool has_alpha = false; + size_t alpha_c = 0; + for (size_t i = 0; i < decoded->metadata()->extra_channel_info.size(); i++) { + if (decoded->metadata()->extra_channel_info[i].type == + ExtraChannel::kAlpha) { + has_alpha = true; + alpha_c = 3 + i; + break; + } + } + + size_t width = options.coalescing + ? frame_header.nonserialized_metadata->xsize() + : shared->frame_dim.xsize_upsampled; + size_t height = options.coalescing + ? frame_header.nonserialized_metadata->ysize() + : shared->frame_dim.ysize_upsampled; + + if (fast_xyb_srgb8_conversion) { + JXL_ASSERT(!NeedsBlending(this)); + JXL_ASSERT(!frame_header.CanBeReferenced() || + frame_header.save_before_color_transform); + JXL_ASSERT(!options.render_spotcolors || + !decoded->metadata()->Find(ExtraChannel::kSpotColor)); + builder.AddStage(GetFastXYBTosRGB8Stage(rgb_output, rgb_stride, width, + height, rgb_output_is_rgba, + has_alpha, alpha_c)); + } else { + bool linear = false; + if (frame_header.color_transform == ColorTransform::kYCbCr) { + builder.AddStage(GetYCbCrStage()); + } else if (frame_header.color_transform == ColorTransform::kXYB) { + builder.AddStage(GetXYBStage(output_encoding_info.opsin_params)); + linear = true; + } // Nothing to do for kNone. + + if (options.coalescing && NeedsBlending(this)) { + if (linear) { + builder.AddStage(GetFromLinearStage(output_encoding_info)); + linear = false; + } + builder.AddStage( + GetBlendingStage(this, output_encoding_info.color_encoding)); + } + + if (options.coalescing && frame_header.CanBeReferenced() && + !frame_header.save_before_color_transform) { + if (linear) { + builder.AddStage(GetFromLinearStage(output_encoding_info)); + linear = false; + } + builder.AddStage(GetWriteToImageBundleStage( + &frame_storage_for_referencing, output_encoding_info.color_encoding)); + } + + if (options.render_spotcolors && + frame_header.nonserialized_metadata->m.Find(ExtraChannel::kSpotColor)) { + for (size_t i = 0; i < decoded->metadata()->extra_channel_info.size(); + i++) { + // Don't use Find() because there may be multiple spot color channels. + const ExtraChannelInfo& eci = + decoded->metadata()->extra_channel_info[i]; + if (eci.type == ExtraChannel::kSpotColor) { + builder.AddStage(GetSpotColorStage(3 + i, eci.spot_color)); + } + } + } + + auto tone_mapping_stage = GetToneMappingStage(output_encoding_info); + if (tone_mapping_stage) { + if (!linear) { + auto to_linear_stage = GetToLinearStage(output_encoding_info); + if (!to_linear_stage) { + return JXL_FAILURE( + "attempting to perform tone mapping on colorspace not " + "convertible to linear"); + } + builder.AddStage(std::move(to_linear_stage)); + linear = true; + } + builder.AddStage(std::move(tone_mapping_stage)); + } + + if (linear) { + builder.AddStage(GetFromLinearStage(output_encoding_info)); + linear = false; + } + + if (pixel_callback.IsPresent()) { + builder.AddStage(GetWriteToPixelCallbackStage( + pixel_callback, width, height, rgb_output_is_rgba, has_alpha, + unpremul_alpha, alpha_c)); + } else if (rgb_output) { + builder.AddStage(GetWriteToU8Stage(rgb_output, rgb_stride, height, + rgb_output_is_rgba, has_alpha, + alpha_c)); + } else { + builder.AddStage(GetWriteToImageBundleStage( + decoded, output_encoding_info.color_encoding)); + } + } + render_pipeline = std::move(builder).Finalize(shared->frame_dim); + return render_pipeline->IsInitialized(); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_cache.h b/media/libjxl/src/lib/jxl/dec_cache.h new file mode 100644 index 000000000..7105ba8ba --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_cache.h @@ -0,0 +1,252 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_CACHE_H_ +#define LIB_JXL_DEC_CACHE_H_ + +#include + +#include +#include // HWY_ALIGN_MAX + +#include "jxl/decode.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/dec_noise.h" +#include "lib/jxl/image.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/render_pipeline/render_pipeline.h" +#include "lib/jxl/render_pipeline/stage_upsampling.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +constexpr size_t kSigmaBorder = 1; +constexpr size_t kSigmaPadding = 2; + +struct PixelCallback { + PixelCallback() = default; + PixelCallback(JxlImageOutInitCallback init, JxlImageOutRunCallback run, + JxlImageOutDestroyCallback destroy, void* init_opaque) + : init(init), run(run), destroy(destroy), init_opaque(init_opaque) { +#if JXL_ENABLE_ASSERT + const bool has_init = init != nullptr; + const bool has_run = run != nullptr; + const bool has_destroy = destroy != nullptr; + JXL_ASSERT(has_init == has_run && has_run == has_destroy); +#endif + } + + bool IsPresent() const { return run != nullptr; } + + void* Init(size_t num_threads, size_t num_pixels) const { + return init(init_opaque, num_threads, num_pixels); + } + + JxlImageOutInitCallback init = nullptr; + JxlImageOutRunCallback run = nullptr; + JxlImageOutDestroyCallback destroy = nullptr; + void* init_opaque = nullptr; +}; + +// Per-frame decoder state. All the images here should be accessed through a +// group rect (either with block units or pixel units). +struct PassesDecoderState { + PassesSharedState shared_storage; + // Allows avoiding copies for encoder loop. + const PassesSharedState* JXL_RESTRICT shared = &shared_storage; + + // 8x upsampling stage for DC. + std::unique_ptr upsampler8x; + + // For ANS decoding. + std::vector code; + std::vector> context_map; + + // Multiplier to be applied to the quant matrices of the x channel. + float x_dm_multiplier; + float b_dm_multiplier; + + // Sigma values for EPF. + ImageF sigma; + + // RGB8 output buffer. If not nullptr, image data will be written to this + // buffer instead of being written to the output ImageBundle. The image data + // is assumed to have the stride given by `rgb_stride`, hence row `i` starts + // at position `i * rgb_stride`. + uint8_t* rgb_output; + size_t rgb_stride = 0; + + // Whether to use int16 float-XYB-to-uint8-srgb conversion. + bool fast_xyb_srgb8_conversion; + + // If true, rgb_output or callback output is RGBA using 4 instead of 3 bytes + // per pixel. + bool rgb_output_is_rgba; + // If true, the RGBA output will be unpremultiplied before writing to the + // output callback (the output buffer case is handled in ConvertToExternal). + bool unpremul_alpha; + + // Callback for line-by-line output. + PixelCallback pixel_callback; + + // Buffer of upsampling * kApplyImageFeaturesTileDim ones. + std::vector opaque_alpha; + // One row per thread + std::vector> pixel_callback_rows; + + // Used for seeding noise. + size_t visible_frame_index = 0; + size_t nonvisible_frame_index = 0; + + // Keep track of the transform types used. + std::atomic used_acs{0}; + + // Storage for coefficients if in "accumulate" mode. + std::unique_ptr coefficients = make_unique>(0, 0); + + // Rendering pipeline. + std::unique_ptr render_pipeline; + + // Storage for the current frame if it can be referenced by future frames. + ImageBundle frame_storage_for_referencing; + + struct PipelineOptions { + bool use_slow_render_pipeline; + bool coalescing; + bool render_spotcolors; + }; + + Status PreparePipeline(ImageBundle* decoded, PipelineOptions options); + + // Information for colour conversions. + OutputEncodingInfo output_encoding_info; + + // Initializes decoder-specific structures using information from *shared. + Status Init() { + x_dm_multiplier = + std::pow(1 / (1.25f), shared->frame_header.x_qm_scale - 2.0f); + b_dm_multiplier = + std::pow(1 / (1.25f), shared->frame_header.b_qm_scale - 2.0f); + + rgb_output = nullptr; + rgb_output_is_rgba = false; + unpremul_alpha = false; + fast_xyb_srgb8_conversion = false; + pixel_callback = PixelCallback(); + used_acs = 0; + + upsampler8x = GetUpsamplingStage(shared->metadata->transform_data, 0, 3); + if (shared->frame_header.loop_filter.epf_iters > 0) { + sigma = ImageF(shared->frame_dim.xsize_blocks + 2 * kSigmaPadding, + shared->frame_dim.ysize_blocks + 2 * kSigmaPadding); + } + return true; + } + + // Initialize the decoder state after all of DC is decoded. + Status InitForAC(ThreadPool* pool) { + shared_storage.coeff_order_size = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + if (((1 << o) & used_acs) == 0) continue; + uint8_t ord = kStrategyOrder[o]; + shared_storage.coeff_order_size = + std::max(kCoeffOrderOffset[3 * (ord + 1)] * kDCTBlockSize, + shared_storage.coeff_order_size); + } + size_t sz = shared_storage.frame_header.passes.num_passes * + shared_storage.coeff_order_size; + if (sz > shared_storage.coeff_orders.size()) { + shared_storage.coeff_orders.resize(sz); + } + return true; + } + + // Fills the `state->filter_weights.sigma` image with the precomputed sigma + // values in the area inside `block_rect`. Accesses the AC strategy, quant + // field and epf_sharpness fields in the corresponding positions. + void ComputeSigma(const Rect& block_rect, PassesDecoderState* state); +}; + +// Temp images required for decoding a single group. Reduces memory allocations +// for large images because we only initialize min(#threads, #groups) instances. +struct GroupDecCache { + void InitOnce(size_t num_passes, size_t used_acs) { + PROFILER_FUNC; + + for (size_t i = 0; i < num_passes; i++) { + if (num_nzeroes[i].xsize() == 0) { + // Allocate enough for a whole group - partial groups on the + // right/bottom border just use a subset. The valid size is passed via + // Rect. + + num_nzeroes[i] = Image3I(kGroupDimInBlocks, kGroupDimInBlocks); + } + } + size_t max_block_area = 0; + + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + AcStrategy acs = AcStrategy::FromRawStrategy(o); + if ((used_acs & (1 << o)) == 0) continue; + size_t area = + acs.covered_blocks_x() * acs.covered_blocks_y() * kDCTBlockSize; + max_block_area = std::max(area, max_block_area); + } + + if (max_block_area > max_block_area_) { + max_block_area_ = max_block_area; + // We need 3x float blocks for dequantized coefficients and 1x for scratch + // space for transforms. + float_memory_ = hwy::AllocateAligned(max_block_area_ * 4); + // We need 3x int32 or int16 blocks for quantized coefficients. + int32_memory_ = hwy::AllocateAligned(max_block_area_ * 3); + int16_memory_ = hwy::AllocateAligned(max_block_area_ * 3); + } + + dec_group_block = float_memory_.get(); + scratch_space = dec_group_block + max_block_area_ * 3; + dec_group_qblock = int32_memory_.get(); + dec_group_qblock16 = int16_memory_.get(); + } + + void InitDCBufferOnce() { + if (dc_buffer.xsize() == 0) { + dc_buffer = ImageF(kGroupDimInBlocks + kRenderPipelineXOffset * 2, + kGroupDimInBlocks + 4); + } + } + + // Scratch space used by DecGroupImpl(). + float* dec_group_block; + int32_t* dec_group_qblock; + int16_t* dec_group_qblock16; + + // For TransformToPixels. + float* scratch_space; + // Note that scratch_space is never used at the same time as dec_group_qblock. + // Moreover, only one of dec_group_qblock16 is ever used. + // TODO(veluca): figure out if we can save allocations. + + // AC decoding + Image3I num_nzeroes[kMaxNumPasses]; + + // Buffer for DC upsampling. + ImageF dc_buffer; + + private: + hwy::AlignedFreeUniquePtr float_memory_; + hwy::AlignedFreeUniquePtr int32_memory_; + hwy::AlignedFreeUniquePtr int16_memory_; + size_t max_block_area_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_CACHE_H_ diff --git a/media/libjxl/src/lib/jxl/dec_context_map.cc b/media/libjxl/src/lib/jxl/dec_context_map.cc new file mode 100644 index 000000000..93c59f773 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_context_map.cc @@ -0,0 +1,105 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_context_map.h" + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/entropy_coder.h" + +namespace jxl { + +namespace { + +void MoveToFront(uint8_t* v, uint8_t index) { + uint8_t value = v[index]; + uint8_t i = index; + for (; i; --i) v[i] = v[i - 1]; + v[0] = value; +} + +void InverseMoveToFrontTransform(uint8_t* v, int v_len) { + uint8_t mtf[256]; + int i; + for (i = 0; i < 256; ++i) { + mtf[i] = static_cast(i); + } + for (i = 0; i < v_len; ++i) { + uint8_t index = v[i]; + v[i] = mtf[index]; + if (index) MoveToFront(mtf, index); + } +} + +Status VerifyContextMap(const std::vector& context_map, + const size_t num_htrees) { + std::vector have_htree(num_htrees); + size_t num_found = 0; + for (const uint8_t htree : context_map) { + if (htree >= num_htrees) { + return JXL_FAILURE("Invalid histogram index in context map."); + } + if (!have_htree[htree]) { + have_htree[htree] = true; + ++num_found; + } + } + if (num_found != num_htrees) { + return JXL_FAILURE("Incomplete context map."); + } + return true; +} + +} // namespace + +Status DecodeContextMap(std::vector* context_map, size_t* num_htrees, + BitReader* input) { + bool is_simple = input->ReadFixedBits<1>(); + if (is_simple) { + int bits_per_entry = input->ReadFixedBits<2>(); + if (bits_per_entry != 0) { + for (size_t i = 0; i < context_map->size(); i++) { + (*context_map)[i] = input->ReadBits(bits_per_entry); + } + } else { + std::fill(context_map->begin(), context_map->end(), 0); + } + } else { + bool use_mtf = input->ReadFixedBits<1>(); + ANSCode code; + std::vector dummy_ctx_map; + // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't + // make sense in non-malicious bitstreams, and could cause a stack overflow + // in malicious bitstreams by making every context map require its own + // context map. + JXL_RETURN_IF_ERROR( + DecodeHistograms(input, 1, &code, &dummy_ctx_map, + /*disallow_lz77=*/context_map->size() <= 2)); + ANSSymbolReader reader(&code, input); + size_t i = 0; + while (i < context_map->size()) { + uint32_t sym = reader.ReadHybridUint(0, input, dummy_ctx_map); + if (sym >= kMaxClusters) { + return JXL_FAILURE("Invalid cluster ID"); + } + (*context_map)[i] = sym; + i++; + } + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("Invalid context map"); + } + if (use_mtf) { + InverseMoveToFrontTransform(context_map->data(), context_map->size()); + } + } + *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1; + return VerifyContextMap(*context_map, *num_htrees); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_context_map.h b/media/libjxl/src/lib/jxl/dec_context_map.h new file mode 100644 index 000000000..95b8a0ca9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_context_map.h @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_CONTEXT_MAP_H_ +#define LIB_JXL_DEC_CONTEXT_MAP_H_ + +#include +#include + +#include + +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +// Context map uses uint8_t. +constexpr size_t kMaxClusters = 256; + +// Reads the context map from the bit stream. On calling this function, +// context_map->size() must be the number of possible context ids. +// Sets *num_htrees to the number of different histogram ids in +// *context_map. +Status DecodeContextMap(std::vector* context_map, size_t* num_htrees, + BitReader* input); + +} // namespace jxl + +#endif // LIB_JXL_DEC_CONTEXT_MAP_H_ diff --git a/media/libjxl/src/lib/jxl/dec_external_image.cc b/media/libjxl/src/lib/jxl/dec_external_image.cc new file mode 100644 index 000000000..abf3ed433 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_external_image.cc @@ -0,0 +1,505 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_external_image.h" + +#include + +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_external_image.cc" +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::NearestInt; + +// TODO(jon): check if this can be replaced by a FloatToU16 function +void FloatToU32(const float* in, uint32_t* out, size_t num, float mul, + size_t bits_per_sample) { + const HWY_FULL(float) d; + const hwy::HWY_NAMESPACE::Rebind du; + + // Unpoison accessing partially-uninitialized vectors with memory sanitizer. + // This is because we run NearestInt() on the vector, which triggers msan even + // it it safe to do so since the values are not mixed between lanes. + const size_t num_round_up = RoundUpTo(num, Lanes(d)); + msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num)); + + const auto one = Set(d, 1.0f); + const auto scale = Set(d, mul); + for (size_t x = 0; x < num; x += Lanes(d)) { + auto v = Load(d, in + x); + // Clamp turns NaN to 'min'. + v = Clamp(v, Zero(d), one); + auto i = NearestInt(Mul(v, scale)); + Store(BitCast(du, i), du, out + x); + } + + // Poison back the output. + msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num)); +} + +void FloatToF16(const float* in, hwy::float16_t* out, size_t num) { + const HWY_FULL(float) d; + const hwy::HWY_NAMESPACE::Rebind du; + + // Unpoison accessing partially-uninitialized vectors with memory sanitizer. + // This is because we run DemoteTo() on the vector which triggers msan. + const size_t num_round_up = RoundUpTo(num, Lanes(d)); + msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num)); + + for (size_t x = 0; x < num; x += Lanes(d)) { + auto v = Load(d, in + x); + auto v16 = DemoteTo(du, v); + Store(v16, du, out + x); + } + + // Poison back the output. + msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +// Stores a float in big endian +void StoreBEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreBE32(u, p); +} + +// Stores a float in little endian +void StoreLEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreLE32(u, p); +} + +// The orientation may not be identity. +// TODO(lode): SIMDify where possible +template +Status UndoOrientation(jxl::Orientation undo_orientation, const Plane& image, + Plane& out, jxl::ThreadPool* pool) { + const size_t xsize = image.xsize(); + const size_t ysize = image.ysize(); + + if (undo_orientation == Orientation::kFlipHorizontal) { + out = Plane(xsize, ysize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[xsize - x - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kRotate180) { + out = Plane(xsize, ysize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(ysize - y - 1); + for (size_t x = 0; x < xsize; ++x) { + row_out[xsize - x - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kFlipVertical) { + out = Plane(xsize, ysize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(ysize - y - 1); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kTranspose) { + out = Plane(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(x)[y] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kRotate90) { + out = Plane(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(x)[ysize - y - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kAntiTranspose) { + out = Plane(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(xsize - x - 1)[ysize - y - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kRotate270) { + out = Plane(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(xsize - x - 1)[y] = row_in[x]; + } + }, + "UndoOrientation")); + } + return true; +} +} // namespace + +HWY_EXPORT(FloatToU32); +HWY_EXPORT(FloatToF16); + +namespace { + +using StoreFuncType = void(uint32_t value, uint8_t* dest); +template +void StoreUintRow(uint32_t* JXL_RESTRICT* rows_u32, size_t num_channels, + size_t xsize, size_t bytes_per_sample, + uint8_t* JXL_RESTRICT out) { + for (size_t x = 0; x < xsize; ++x) { + for (size_t c = 0; c < num_channels; c++) { + StoreFunc(rows_u32[c][x], + out + (num_channels * x + c) * bytes_per_sample); + } + } +} + +template +void StoreFloatRow(const float* JXL_RESTRICT* rows_in, size_t num_channels, + size_t xsize, uint8_t* JXL_RESTRICT out) { + for (size_t x = 0; x < xsize; ++x) { + for (size_t c = 0; c < num_channels; c++) { + StoreFunc(rows_in[c][x], out + (num_channels * x + c) * sizeof(float)); + } + } +} + +void JXL_INLINE Store8(uint32_t value, uint8_t* dest) { *dest = value & 0xff; } + +// Maximum number of channels for the ConvertChannelsToExternal function. +const size_t kConvertMaxChannels = 4; + +// Converts a list of channels to an interleaved image, applying transformations +// when needed. +// The input channels are given as a (non-const!) array of channel pointers and +// interleaved in that order. +// +// Note: if a pointer in channels[] is nullptr, a 1.0 value will be used +// instead. This is useful for handling when a user requests an alpha channel +// from an image that doesn't have one. The first channel in the list may not +// be nullptr, since it is used to determine the image size. +Status ConvertChannelsToExternal(const ImageF* channels[], size_t num_channels, + size_t bits_per_sample, bool float_out, + JxlEndianness endianness, size_t stride, + jxl::ThreadPool* pool, void* out_image, + size_t out_size, + const PixelCallback& out_callback, + jxl::Orientation undo_orientation) { + JXL_DASSERT(num_channels != 0 && num_channels <= kConvertMaxChannels); + JXL_DASSERT(channels[0] != nullptr); + JXL_CHECK(float_out ? bits_per_sample == 16 || bits_per_sample == 32 + : bits_per_sample > 0 && bits_per_sample <= 16); + if (!!out_image == out_callback.IsPresent()) { + return JXL_FAILURE( + "Must provide either an out_image or an out_callback, but not both."); + } + + const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte); + const size_t bytes_per_pixel = num_channels * bytes_per_channel; + + std::vector> row_out_callback; + const auto FreeCallbackOpaque = [&out_callback](void* p) { + out_callback.destroy(p); + }; + std::unique_ptr out_run_opaque( + nullptr, FreeCallbackOpaque); + auto InitOutCallback = [&](size_t num_threads) -> Status { + if (out_callback.IsPresent()) { + out_run_opaque.reset(out_callback.Init(num_threads, stride)); + JXL_RETURN_IF_ERROR(out_run_opaque != nullptr); + row_out_callback.resize(num_threads); + for (size_t i = 0; i < num_threads; ++i) { + row_out_callback[i].resize(stride); + } + } + return true; + }; + + // Channels used to store the transformed original channels if needed. + ImageF temp_channels[kConvertMaxChannels]; + if (undo_orientation != Orientation::kIdentity) { + for (size_t c = 0; c < num_channels; ++c) { + if (channels[c]) { + JXL_RETURN_IF_ERROR(UndoOrientation(undo_orientation, *channels[c], + temp_channels[c], pool)); + channels[c] = &(temp_channels[c]); + } + } + } + + // First channel may not be nullptr. + size_t xsize = channels[0]->xsize(); + size_t ysize = channels[0]->ysize(); + if (stride < bytes_per_pixel * xsize) { + return JXL_FAILURE("stride is smaller than scanline width in bytes: %" PRIuS + " vs %" PRIuS, + stride, bytes_per_pixel * xsize); + } + if (!out_callback.IsPresent() && + out_size < (ysize - 1) * stride + bytes_per_pixel * xsize) { + return JXL_FAILURE("out_size is too small to store image"); + } + + const bool little_endian = + endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + // Handle the case where a channel is nullptr by creating a single row with + // ones to use instead. + ImageF ones; + for (size_t c = 0; c < num_channels; ++c) { + if (!channels[c]) { + ones = ImageF(xsize, 1); + FillImage(1.0f, &ones); + break; + } + } + + if (float_out) { + if (bits_per_sample == 16) { + bool swap_endianness = little_endian != IsLittleEndian(); + Plane f16_cache; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), + [&](size_t num_threads) { + f16_cache = + Plane(xsize, num_channels * num_threads); + return InitOutCallback(num_threads); + }, + [&](const uint32_t task, const size_t thread) { + const int64_t y = task; + const float* JXL_RESTRICT row_in[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0); + } + hwy::float16_t* JXL_RESTRICT row_f16[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_f16[c] = f16_cache.Row(c + thread * num_channels); + HWY_DYNAMIC_DISPATCH(FloatToF16) + (row_in[c], row_f16[c], xsize); + } + uint8_t* row_out = + out_callback.IsPresent() + ? row_out_callback[thread].data() + : &(reinterpret_cast(out_image))[stride * y]; + // interleave the one scanline + hwy::float16_t* row_f16_out = + reinterpret_cast(row_out); + for (size_t x = 0; x < xsize; x++) { + for (size_t c = 0; c < num_channels; c++) { + row_f16_out[x * num_channels + c] = row_f16[c][x]; + } + } + if (swap_endianness) { + size_t size = xsize * num_channels * 2; + for (size_t i = 0; i < size; i += 2) { + std::swap(row_out[i + 0], row_out[i + 1]); + } + } + if (out_callback.IsPresent()) { + out_callback.run(out_run_opaque.get(), thread, 0, y, xsize, + row_out); + } + }, + "ConvertF16")); + } else if (bits_per_sample == 32) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), + [&](size_t num_threads) { return InitOutCallback(num_threads); }, + [&](const uint32_t task, const size_t thread) { + const int64_t y = task; + uint8_t* row_out = + out_callback.IsPresent() + ? row_out_callback[thread].data() + : &(reinterpret_cast(out_image))[stride * y]; + const float* JXL_RESTRICT row_in[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0); + } + if (little_endian) { + StoreFloatRow(row_in, num_channels, xsize, row_out); + } else { + StoreFloatRow(row_in, num_channels, xsize, row_out); + } + if (out_callback.IsPresent()) { + out_callback.run(out_run_opaque.get(), thread, 0, y, xsize, + row_out); + } + }, + "ConvertFloat")); + } else { + return JXL_FAILURE("float other than 16-bit and 32-bit not supported"); + } + } else { + // Multiplier to convert from floating point 0-1 range to the integer + // range. + float mul = (1ull << bits_per_sample) - 1; + Plane u32_cache; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), + [&](size_t num_threads) { + u32_cache = Plane(xsize, num_channels * num_threads); + return InitOutCallback(num_threads); + }, + [&](const uint32_t task, const size_t thread) { + const int64_t y = task; + uint8_t* row_out = + out_callback.IsPresent() + ? row_out_callback[thread].data() + : &(reinterpret_cast(out_image))[stride * y]; + const float* JXL_RESTRICT row_in[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0); + } + uint32_t* JXL_RESTRICT row_u32[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_u32[c] = u32_cache.Row(c + thread * num_channels); + // row_u32[] is a per-thread temporary row storage, this isn't + // intended to be initialized on a previous run. + msan::PoisonMemory(row_u32[c], xsize * sizeof(row_u32[c][0])); + HWY_DYNAMIC_DISPATCH(FloatToU32) + (row_in[c], row_u32[c], xsize, mul, bits_per_sample); + } + if (bits_per_sample <= 8) { + StoreUintRow(row_u32, num_channels, xsize, 1, row_out); + } else { + if (little_endian) { + StoreUintRow(row_u32, num_channels, xsize, 2, row_out); + } else { + StoreUintRow(row_u32, num_channels, xsize, 2, row_out); + } + } + if (out_callback.IsPresent()) { + out_callback.run(out_run_opaque.get(), thread, 0, y, xsize, + row_out); + } + }, + "ConvertUint")); + } + return true; +} + +} // namespace + +Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample, + bool float_out, size_t num_channels, + JxlEndianness endianness, size_t stride, + jxl::ThreadPool* pool, void* out_image, + size_t out_size, const PixelCallback& out_callback, + jxl::Orientation undo_orientation, + bool unpremul_alpha) { + bool want_alpha = num_channels == 2 || num_channels == 4; + size_t color_channels = num_channels <= 2 ? 1 : 3; + + const Image3F* color = &ib.color(); + // Undo premultiplied alpha. + Image3F unpremul; + if (ib.AlphaIsPremultiplied() && ib.HasAlpha() && unpremul_alpha) { + unpremul = Image3F(color->xsize(), color->ysize()); + CopyImageTo(*color, &unpremul); + for (size_t y = 0; y < unpremul.ysize(); y++) { + UnpremultiplyAlpha(unpremul.PlaneRow(0, y), unpremul.PlaneRow(1, y), + unpremul.PlaneRow(2, y), ib.alpha().Row(y), + unpremul.xsize()); + } + color = &unpremul; + } + + const ImageF* channels[kConvertMaxChannels]; + size_t c = 0; + for (; c < color_channels; c++) { + channels[c] = &color->Plane(c); + } + if (want_alpha) { + channels[c++] = ib.HasAlpha() ? &ib.alpha() : nullptr; + } + JXL_ASSERT(num_channels == c); + + return ConvertChannelsToExternal( + channels, num_channels, bits_per_sample, float_out, endianness, stride, + pool, out_image, out_size, out_callback, undo_orientation); +} + +Status ConvertToExternal(const jxl::ImageF& channel, size_t bits_per_sample, + bool float_out, JxlEndianness endianness, + size_t stride, jxl::ThreadPool* pool, void* out_image, + size_t out_size, const PixelCallback& out_callback, + jxl::Orientation undo_orientation) { + const ImageF* channels[1]; + channels[0] = &channel; + return ConvertChannelsToExternal(channels, 1, bits_per_sample, float_out, + endianness, stride, pool, out_image, + out_size, out_callback, undo_orientation); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/dec_external_image.h b/media/libjxl/src/lib/jxl/dec_external_image.h new file mode 100644 index 000000000..9b3b8bf66 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_external_image.h @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_EXTERNAL_IMAGE_H_ +#define LIB_JXL_DEC_EXTERNAL_IMAGE_H_ + +// Interleaved image for color transforms and Codec. + +#include +#include + +#include "jxl/decode.h" +#include "jxl/types.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Converts ib to interleaved void* pixel buffer with the given format. +// bits_per_sample: must be 16 or 32 if float_out is true, and at most 16 +// if it is false. No bit packing is done. +// num_channels: must be 1, 2, 3 or 4 for gray, gray+alpha, RGB, RGB+alpha. +// This supports the features needed for the C API and does not perform +// color space conversion. +// TODO(lode): support rectangle crop. +// stride_out is output scanline size in bytes, must be >= +// output_xsize * output_bytes_per_pixel. +// undo_orientation is an EXIF orientation to undo. Depending on the +// orientation, the output xsize and ysize are swapped compared to input +// xsize and ysize. +Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample, + bool float_out, size_t num_channels, + JxlEndianness endianness, size_t stride_out, + jxl::ThreadPool* thread_pool, void* out_image, + size_t out_size, const PixelCallback& out_callback, + jxl::Orientation undo_orientation, + bool unpremul_alpha = false); + +// Converts single-channel image to interleaved void* pixel buffer with the +// given format, with a single channel. +// This supports the features needed for the C API to get extra channels. +// Arguments are similar to the multi-channel function above. +Status ConvertToExternal(const jxl::ImageF& channel, size_t bits_per_sample, + bool float_out, JxlEndianness endianness, + size_t stride_out, jxl::ThreadPool* thread_pool, + void* out_image, size_t out_size, + const PixelCallback& out_callback, + jxl::Orientation undo_orientation); +} // namespace jxl + +#endif // LIB_JXL_DEC_EXTERNAL_IMAGE_H_ diff --git a/media/libjxl/src/lib/jxl/dec_external_image_gbench.cc b/media/libjxl/src/lib/jxl/dec_external_image_gbench.cc new file mode 100644 index 000000000..0011792a5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_external_image_gbench.cc @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Decoder case, interleaves an internal float image. +void BM_DecExternalImage_ConvertImageRGBA(benchmark::State& state) { + const size_t kNumIter = 5; + size_t xsize = state.range(); + size_t ysize = state.range(); + size_t num_channels = 4; + + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + Image3F color(xsize, ysize); + ZeroFillImage(&color); + ib.SetFromImage(std::move(color), ColorEncoding::SRGB()); + ImageF alpha(xsize, ysize); + ZeroFillImage(&alpha); + ib.SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + const size_t bytes_per_row = xsize * num_channels; + std::vector interleaved(bytes_per_row * ysize); + + for (auto _ : state) { + for (size_t i = 0; i < kNumIter; ++i) { + JXL_CHECK(ConvertToExternal( + ib, + /*bits_per_sample=*/8, + /*float_out=*/false, num_channels, JXL_NATIVE_ENDIAN, + /*stride*/ bytes_per_row, + /*thread_pool=*/nullptr, interleaved.data(), interleaved.size(), + /*out_callback=*/{}, + /*undo_orientation=*/jxl::Orientation::kIdentity)); + } + } + + // Pixels per second. + state.SetItemsProcessed(kNumIter * state.iterations() * xsize * ysize); + state.SetBytesProcessed(kNumIter * state.iterations() * interleaved.size()); +} + +BENCHMARK(BM_DecExternalImage_ConvertImageRGBA) + ->RangeMultiplier(2) + ->Range(256, 2048); + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_frame.cc b/media/libjxl/src/lib/jxl/dec_frame.cc new file mode 100644 index 000000000..c90bb4720 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_frame.cc @@ -0,0 +1,878 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_frame.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "jxl/types.h" +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/luminance.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +Status DecodeGlobalDCInfo(BitReader* reader, bool is_jpeg, + PassesDecoderState* state, ThreadPool* pool) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(state->shared_storage.quantizer.Decode(reader)); + + JXL_RETURN_IF_ERROR( + DecodeBlockCtxMap(reader, &state->shared_storage.block_ctx_map)); + + JXL_RETURN_IF_ERROR(state->shared_storage.cmap.DecodeDC(reader)); + + // Pre-compute info for decoding a group. + if (is_jpeg) { + state->shared_storage.quantizer.ClearDCMul(); // Don't dequant DC + } + + state->shared_storage.ac_strategy.FillInvalid(); + return true; +} +} // namespace + +Status DecodeFrame(PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool, + const uint8_t* next_in, size_t avail_in, + ImageBundle* decoded, const CodecMetadata& metadata, + bool use_slow_rendering_pipeline) { + FrameDecoder frame_decoder(dec_state, metadata, pool, + use_slow_rendering_pipeline); + + BitReader reader(Span(next_in, avail_in)); + JXL_RETURN_IF_ERROR(frame_decoder.InitFrame(&reader, decoded, + /*is_preview=*/false, + /*output_needed=*/true)); + JXL_RETURN_IF_ERROR(reader.AllReadsWithinBounds()); + size_t header_bytes = reader.TotalBitsConsumed() / kBitsPerByte; + JXL_RETURN_IF_ERROR(reader.Close()); + + size_t processed_bytes = header_bytes; + Status close_ok = true; + std::vector> section_readers; + { + std::vector> section_closers; + std::vector section_info; + std::vector section_status; + size_t pos = header_bytes; + for (auto toc_entry : frame_decoder.Toc()) { + JXL_RETURN_IF_ERROR(pos + toc_entry.size <= avail_in); + auto br = make_unique( + Span(next_in + pos, toc_entry.size)); + section_info.emplace_back( + FrameDecoder::SectionInfo{br.get(), toc_entry.id}); + section_closers.emplace_back( + make_unique(br.get(), &close_ok)); + section_readers.emplace_back(std::move(br)); + pos += toc_entry.size; + } + section_status.resize(section_info.size()); + JXL_RETURN_IF_ERROR(frame_decoder.ProcessSections( + section_info.data(), section_info.size(), section_status.data())); + for (size_t i = 0; i < section_status.size(); i++) { + JXL_RETURN_IF_ERROR(section_status[i] == FrameDecoder::kDone); + processed_bytes += frame_decoder.Toc()[i].size; + } + } + JXL_RETURN_IF_ERROR(close_ok); + JXL_RETURN_IF_ERROR(frame_decoder.FinalizeFrame()); + decoded->SetDecodedBytes(processed_bytes); + return true; +} + +Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, + bool is_preview, bool output_needed) { + PROFILER_FUNC; + decoded_ = decoded; + JXL_ASSERT(is_finalized_); + + // Reset the dequantization matrices to their default values. + dec_state_->shared_storage.matrices = DequantMatrices(); + + frame_header_.nonserialized_is_preview = is_preview; + JXL_ASSERT(frame_header_.nonserialized_metadata != nullptr); + JXL_RETURN_IF_ERROR(ReadFrameHeader(br, &frame_header_)); + frame_dim_ = frame_header_.ToFrameDimensions(); + JXL_DEBUG_V(2, "FrameHeader: %s", frame_header_.DebugString().c_str()); + + const size_t num_passes = frame_header_.passes.num_passes; + const size_t num_groups = frame_dim_.num_groups; + + // If the previous frame was not a kRegularFrame, `decoded` may have different + // dimensions; must reset to avoid errors. + decoded->RemoveColor(); + decoded->ClearExtraChannels(); + + decoded->duration = frame_header_.animation_frame.duration; + + if (!frame_header_.nonserialized_is_preview && + (frame_header_.is_last || frame_header_.animation_frame.duration > 0) && + (frame_header_.frame_type == kRegularFrame || + frame_header_.frame_type == kSkipProgressive)) { + ++dec_state_->visible_frame_index; + dec_state_->nonvisible_frame_index = 0; + } else { + ++dec_state_->nonvisible_frame_index; + } + + // Read TOC. + const bool has_ac_global = true; + const size_t toc_entries = NumTocEntries(num_groups, frame_dim_.num_dc_groups, + num_passes, has_ac_global); + std::vector sizes; + std::vector permutation; + JXL_RETURN_IF_ERROR(ReadToc(toc_entries, br, &sizes, &permutation)); + bool have_permutation = !permutation.empty(); + toc_.resize(toc_entries); + section_sizes_sum_ = 0; + for (size_t i = 0; i < toc_entries; ++i) { + toc_[i].size = sizes[i]; + size_t index = have_permutation ? permutation[i] : i; + toc_[index].id = i; + if (section_sizes_sum_ + toc_[i].size < section_sizes_sum_) { + return JXL_FAILURE("group offset overflow"); + } + section_sizes_sum_ += toc_[i].size; + } + + JXL_DASSERT((br->TotalBitsConsumed() % kBitsPerByte) == 0); + const size_t group_codes_begin = br->TotalBitsConsumed() / kBitsPerByte; + JXL_DASSERT(!toc_.empty()); + + // Overflow check. + if (group_codes_begin + section_sizes_sum_ < group_codes_begin) { + return JXL_FAILURE("Invalid group codes"); + } + + if (!frame_header_.chroma_subsampling.Is444() && + !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) && + frame_header_.encoding == FrameEncoding::kVarDCT) { + return JXL_FAILURE( + "Non-444 chroma subsampling is not allowed when adaptive DC " + "smoothing is enabled"); + } + + if (!output_needed) return true; + JXL_RETURN_IF_ERROR( + InitializePassesSharedState(frame_header_, &dec_state_->shared_storage)); + JXL_RETURN_IF_ERROR(dec_state_->Init()); + modular_frame_decoder_.Init(frame_dim_); + + if (decoded->IsJPEG()) { + if (frame_header_.encoding == FrameEncoding::kModular) { + return JXL_FAILURE("Cannot output JPEG from Modular"); + } + jpeg::JPEGData* jpeg_data = decoded->jpeg_data.get(); + size_t num_components = jpeg_data->components.size(); + if (num_components != 1 && num_components != 3) { + return JXL_FAILURE("Invalid number of components"); + } + if (frame_header_.nonserialized_metadata->m.xyb_encoded) { + return JXL_FAILURE("Cannot decode to JPEG an XYB image"); + } + auto jpeg_c_map = JpegOrder(ColorTransform::kYCbCr, num_components == 1); + decoded->jpeg_data->width = frame_dim_.xsize; + decoded->jpeg_data->height = frame_dim_.ysize; + for (size_t c = 0; c < num_components; c++) { + auto& component = jpeg_data->components[jpeg_c_map[c]]; + component.width_in_blocks = + frame_dim_.xsize_blocks >> frame_header_.chroma_subsampling.HShift(c); + component.height_in_blocks = + frame_dim_.ysize_blocks >> frame_header_.chroma_subsampling.VShift(c); + component.h_samp_factor = + 1 << frame_header_.chroma_subsampling.RawHShift(c); + component.v_samp_factor = + 1 << frame_header_.chroma_subsampling.RawVShift(c); + component.coeffs.resize(component.width_in_blocks * + component.height_in_blocks * jxl::kDCTBlockSize); + } + } + + // Clear the state. + decoded_dc_global_ = false; + decoded_ac_global_ = false; + is_finalized_ = false; + finalized_dc_ = false; + num_sections_done_ = 0; + decoded_dc_groups_.clear(); + decoded_dc_groups_.resize(frame_dim_.num_dc_groups); + decoded_passes_per_ac_group_.clear(); + decoded_passes_per_ac_group_.resize(frame_dim_.num_groups, 0); + processed_section_.clear(); + processed_section_.resize(toc_.size()); + allocated_ = false; + return true; +} + +Status FrameDecoder::ProcessDCGlobal(BitReader* br) { + PROFILER_FUNC; + PassesSharedState& shared = dec_state_->shared_storage; + if (shared.frame_header.flags & FrameHeader::kPatches) { + bool uses_extra_channels = false; + JXL_RETURN_IF_ERROR(shared.image_features.patches.Decode( + br, frame_dim_.xsize_padded, frame_dim_.ysize_padded, + &uses_extra_channels)); + if (uses_extra_channels && frame_header_.upsampling != 1) { + for (size_t ecups : frame_header_.extra_channel_upsampling) { + if (ecups != frame_header_.upsampling) { + return JXL_FAILURE( + "Cannot use extra channels in patches if color channels are " + "subsampled differently from extra channels"); + } + } + } + } else { + shared.image_features.patches.Clear(); + } + shared.image_features.splines.Clear(); + if (shared.frame_header.flags & FrameHeader::kSplines) { + JXL_RETURN_IF_ERROR(shared.image_features.splines.Decode( + br, frame_dim_.xsize * frame_dim_.ysize)); + } + if (shared.frame_header.flags & FrameHeader::kNoise) { + JXL_RETURN_IF_ERROR(DecodeNoise(br, &shared.image_features.noise_params)); + } + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.DecodeDC(br)); + + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + jxl::DecodeGlobalDCInfo(br, decoded_->IsJPEG(), dec_state_, pool_)); + } + // Splines' draw cache uses the color correlation map. + if (shared.frame_header.flags & FrameHeader::kSplines) { + JXL_RETURN_IF_ERROR(shared.image_features.splines.InitializeDrawCache( + frame_dim_.xsize_upsampled, frame_dim_.ysize_upsampled, + dec_state_->shared->cmap)); + } + Status dec_status = modular_frame_decoder_.DecodeGlobalInfo( + br, frame_header_, /*allow_truncated_group=*/false); + if (dec_status.IsFatalError()) return dec_status; + if (dec_status) { + decoded_dc_global_ = true; + } + return dec_status; +} + +Status FrameDecoder::ProcessDCGroup(size_t dc_group_id, BitReader* br) { + PROFILER_FUNC; + const size_t gx = dc_group_id % frame_dim_.xsize_dc_groups; + const size_t gy = dc_group_id / frame_dim_.xsize_dc_groups; + const LoopFilter& lf = dec_state_->shared->frame_header.loop_filter; + if (frame_header_.encoding == FrameEncoding::kVarDCT && + !(frame_header_.flags & FrameHeader::kUseDcFrame)) { + JXL_RETURN_IF_ERROR( + modular_frame_decoder_.DecodeVarDCTDC(dc_group_id, br, dec_state_)); + } + const Rect mrect(gx * frame_dim_.dc_group_dim, gy * frame_dim_.dc_group_dim, + frame_dim_.dc_group_dim, frame_dim_.dc_group_dim); + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + mrect, br, 3, 1000, ModularStreamId::ModularDC(dc_group_id), + /*zerofill=*/false, nullptr, nullptr, + /*allow_truncated=*/false)); + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + modular_frame_decoder_.DecodeAcMetadata(dc_group_id, br, dec_state_)); + } else if (lf.epf_iters > 0) { + FillImage(kInvSigmaNum / lf.epf_sigma_for_modular, &dec_state_->sigma); + } + decoded_dc_groups_[dc_group_id] = uint8_t{true}; + return true; +} + +void FrameDecoder::FinalizeDC() { + // Do Adaptive DC smoothing if enabled. This *must* happen between all the + // ProcessDCGroup and ProcessACGroup. + if (frame_header_.encoding == FrameEncoding::kVarDCT && + !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) && + !(frame_header_.flags & FrameHeader::kUseDcFrame)) { + AdaptiveDCSmoothing(dec_state_->shared->quantizer.MulDC(), + &dec_state_->shared_storage.dc_storage, pool_); + } + + finalized_dc_ = true; +} + +Status FrameDecoder::AllocateOutput() { + if (allocated_) return true; + modular_frame_decoder_.MaybeDropFullImage(); + decoded_->origin = dec_state_->shared->frame_header.frame_origin; + JXL_RETURN_IF_ERROR(dec_state_->InitForAC(nullptr)); + allocated_ = true; + return true; +} + +Status FrameDecoder::ProcessACGlobal(BitReader* br) { + JXL_CHECK(finalized_dc_); + + // Decode AC group. + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode( + br, &modular_frame_decoder_)); + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.EnsureComputed( + dec_state_->used_acs)); + + size_t num_histo_bits = + CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups); + dec_state_->shared_storage.num_histograms = + 1 + br->ReadBits(num_histo_bits); + + dec_state_->code.resize(kMaxNumPasses); + dec_state_->context_map.resize(kMaxNumPasses); + // Read coefficient orders and histograms. + size_t max_num_bits_ac = 0; + for (size_t i = 0; + i < dec_state_->shared_storage.frame_header.passes.num_passes; i++) { + uint16_t used_orders = U32Coder::Read(kOrderEnc, br); + JXL_RETURN_IF_ERROR(DecodeCoeffOrders( + used_orders, dec_state_->used_acs, + &dec_state_->shared_storage + .coeff_orders[i * dec_state_->shared_storage.coeff_order_size], + br)); + size_t num_contexts = + dec_state_->shared->num_histograms * + dec_state_->shared_storage.block_ctx_map.NumACContexts(); + JXL_RETURN_IF_ERROR(DecodeHistograms( + br, num_contexts, &dec_state_->code[i], &dec_state_->context_map[i])); + // Add extra values to enable the cheat in hot loop of DecodeACVarBlock. + dec_state_->context_map[i].resize( + num_contexts + kZeroDensityContextLimit - kZeroDensityContextCount); + max_num_bits_ac = + std::max(max_num_bits_ac, dec_state_->code[i].max_num_bits); + } + max_num_bits_ac += CeilLog2Nonzero( + dec_state_->shared_storage.frame_header.passes.num_passes); + // 16-bit buffer for decoding to JPEG are not implemented. + // TODO(veluca): figure out the exact limit - 16 should still work with + // 16-bit buffers, but we are excluding it for safety. + bool use_16_bit = max_num_bits_ac < 16 && !decoded_->IsJPEG(); + bool store = frame_header_.passes.num_passes > 1; + size_t xs = store ? kGroupDim * kGroupDim : 0; + size_t ys = store ? frame_dim_.num_groups : 0; + if (use_16_bit) { + dec_state_->coefficients = make_unique>(xs, ys); + } else { + dec_state_->coefficients = make_unique>(xs, ys); + } + if (store) { + dec_state_->coefficients->ZeroFill(); + } + } + + // Set JPEG decoding data. + if (decoded_->IsJPEG()) { + decoded_->color_transform = frame_header_.color_transform; + decoded_->chroma_subsampling = frame_header_.chroma_subsampling; + const std::vector& qe = + dec_state_->shared_storage.matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + jpeg::JPEGData* jpeg_data = decoded_->jpeg_data.get(); + size_t num_components = jpeg_data->components.size(); + bool is_gray = (num_components == 1); + auto jpeg_c_map = JpegOrder(frame_header_.color_transform, is_gray); + size_t qt_set = 0; + for (size_t c = 0; c < num_components; c++) { + // TODO(eustas): why 1-st quant table for gray? + size_t quant_c = is_gray ? 1 : c; + size_t qpos = jpeg_data->components[jpeg_c_map[c]].quant_idx; + JXL_CHECK(qpos != jpeg_data->quant.size()); + qt_set |= 1 << qpos; + for (size_t x = 0; x < 8; x++) { + for (size_t y = 0; y < 8; y++) { + jpeg_data->quant[qpos].values[x * 8 + y] = + (*qe[0].qraw.qtable)[quant_c * 64 + y * 8 + x]; + } + } + } + for (size_t i = 0; i < jpeg_data->quant.size(); i++) { + if (qt_set & (1 << i)) continue; + if (i == 0) return JXL_FAILURE("First quant table unused."); + // Unused quant table is set to copy of previous quant table + for (size_t j = 0; j < 64; j++) { + jpeg_data->quant[i].values[j] = jpeg_data->quant[i - 1].values[j]; + } + } + } + decoded_ac_global_ = true; + return true; +} + +Status FrameDecoder::ProcessACGroup(size_t ac_group_id, + BitReader* JXL_RESTRICT* br, + size_t num_passes, size_t thread, + bool force_draw, bool dc_only) { + PROFILER_ZONE("process_group"); + size_t group_dim = frame_dim_.group_dim; + const size_t gx = ac_group_id % frame_dim_.xsize_groups; + const size_t gy = ac_group_id / frame_dim_.xsize_groups; + const size_t x = gx * group_dim; + const size_t y = gy * group_dim; + JXL_DEBUG_V(3, + "Processing AC group %" PRIuS "(%" PRIuS ",%" PRIuS + ") group_dim: %" PRIuS " decoded passes: %u new passes: %" PRIuS, + ac_group_id, gx, gy, group_dim, + decoded_passes_per_ac_group_[ac_group_id], num_passes); + + RenderPipelineInput render_pipeline_input = + dec_state_->render_pipeline->GetInputBuffers(ac_group_id, thread); + + bool should_run_pipeline = true; + + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + group_dec_caches_[thread].InitOnce(frame_header_.passes.num_passes, + dec_state_->used_acs); + JXL_RETURN_IF_ERROR(DecodeGroup(br, num_passes, ac_group_id, dec_state_, + &group_dec_caches_[thread], thread, + render_pipeline_input, decoded_, + decoded_passes_per_ac_group_[ac_group_id], + force_draw, dc_only, &should_run_pipeline)); + } + + // don't limit to image dimensions here (is done in DecodeGroup) + const Rect mrect(x, y, group_dim, group_dim); + bool modular_ready = false; + size_t pass0 = decoded_passes_per_ac_group_[ac_group_id]; + size_t pass1 = + force_draw ? frame_header_.passes.num_passes : pass0 + num_passes; + for (size_t i = pass0; i < pass1; ++i) { + int minShift, maxShift; + frame_header_.passes.GetDownsamplingBracket(i, minShift, maxShift); + bool modular_pass_ready = true; + if (i < pass0 + num_passes) { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + mrect, br[i - pass0], minShift, maxShift, + ModularStreamId::ModularAC(ac_group_id, i), + /*zerofill=*/false, dec_state_, &render_pipeline_input, + /*allow_truncated=*/false, &modular_pass_ready)); + } else { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + mrect, nullptr, minShift, maxShift, + ModularStreamId::ModularAC(ac_group_id, i), /*zerofill=*/true, + dec_state_, &render_pipeline_input, + /*allow_truncated=*/false, &modular_pass_ready)); + } + if (modular_pass_ready) modular_ready = true; + } + decoded_passes_per_ac_group_[ac_group_id] += num_passes; + + if ((frame_header_.flags & FrameHeader::kNoise) != 0) { + PROFILER_ZONE("GenerateNoise"); + size_t noise_c_start = + 3 + frame_header_.nonserialized_metadata->m.num_extra_channels; + // When the color channels are downsampled, we need to generate more noise + // input for the current group than just the group dimensions. + std::pair rects[3]; + for (size_t iy = 0; iy < frame_header_.upsampling; iy++) { + for (size_t ix = 0; ix < frame_header_.upsampling; ix++) { + for (size_t c = 0; c < 3; c++) { + auto r = render_pipeline_input.GetBuffer(noise_c_start + c); + rects[c].first = r.first; + size_t x1 = r.second.x0() + r.second.xsize(); + size_t y1 = r.second.y0() + r.second.ysize(); + rects[c].second = Rect(r.second.x0() + ix * group_dim, + r.second.y0() + iy * group_dim, group_dim, + group_dim, x1, y1); + } + Random3Planes(dec_state_->visible_frame_index, + dec_state_->nonvisible_frame_index, + (gx * frame_header_.upsampling + ix) * group_dim, + (gy * frame_header_.upsampling + iy) * group_dim, + rects[0], rects[1], rects[2]); + } + } + } + + if (!modular_frame_decoder_.UsesFullImage() && !decoded_->IsJPEG()) { + if (should_run_pipeline && modular_ready) { + render_pipeline_input.Done(); + } else if (force_draw) { + return JXL_FAILURE("Modular group decoding failed."); + } + } + return true; +} + +void FrameDecoder::MarkSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status) { + num_sections_done_ += num; + for (size_t i = 0; i < num; i++) { + if (section_status[i] != SectionStatus::kDone) { + processed_section_[sections[i].id] = false; + num_sections_done_--; + } + } +} + +Status FrameDecoder::ProcessSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status) { + if (num == 0) return true; // Nothing to process + std::fill(section_status, section_status + num, SectionStatus::kSkipped); + size_t dc_global_sec = num; + size_t ac_global_sec = num; + std::vector dc_group_sec(frame_dim_.num_dc_groups, num); + std::vector> ac_group_sec( + frame_dim_.num_groups, + std::vector(frame_header_.passes.num_passes, num)); + // This keeps track of the number of ac passes we want to process during this + // call of ProcessSections. + std::vector desired_num_ac_passes(frame_dim_.num_groups); + bool single_section = + frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1; + if (single_section) { + JXL_ASSERT(num == 1); + JXL_ASSERT(sections[0].id == 0); + if (processed_section_[0] == false) { + processed_section_[0] = true; + ac_group_sec[0].resize(1); + dc_global_sec = ac_global_sec = dc_group_sec[0] = ac_group_sec[0][0] = 0; + desired_num_ac_passes[0] = 1; + } else { + section_status[0] = SectionStatus::kDuplicate; + } + } else { + size_t ac_global_index = frame_dim_.num_dc_groups + 1; + for (size_t i = 0; i < num; i++) { + JXL_ASSERT(sections[i].id < processed_section_.size()); + if (processed_section_[sections[i].id]) { + section_status[i] = SectionStatus::kDuplicate; + continue; + } + if (sections[i].id == 0) { + dc_global_sec = i; + } else if (sections[i].id < ac_global_index) { + dc_group_sec[sections[i].id - 1] = i; + } else if (sections[i].id == ac_global_index) { + ac_global_sec = i; + } else { + size_t ac_idx = sections[i].id - ac_global_index - 1; + size_t acg = ac_idx % frame_dim_.num_groups; + size_t acp = ac_idx / frame_dim_.num_groups; + if (acp >= frame_header_.passes.num_passes) { + return JXL_FAILURE("Invalid section ID"); + } + ac_group_sec[acg][acp] = i; + } + processed_section_[sections[i].id] = true; + } + // Count number of new passes per group. + for (size_t g = 0; g < ac_group_sec.size(); g++) { + size_t j = 0; + for (; j + decoded_passes_per_ac_group_[g] < + frame_header_.passes.num_passes; + j++) { + if (ac_group_sec[g][j + decoded_passes_per_ac_group_[g]] == num) { + break; + } + } + desired_num_ac_passes[g] = j; + } + } + if (dc_global_sec != num) { + Status dc_global_status = ProcessDCGlobal(sections[dc_global_sec].br); + if (dc_global_status.IsFatalError()) return dc_global_status; + if (dc_global_status) { + section_status[dc_global_sec] = SectionStatus::kDone; + } else { + section_status[dc_global_sec] = SectionStatus::kPartial; + } + } + + std::atomic has_error{false}; + if (decoded_dc_global_) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool_, 0, dc_group_sec.size(), ThreadPool::NoInit, + [this, &dc_group_sec, &num, §ions, §ion_status, &has_error]( + size_t i, size_t thread) { + if (dc_group_sec[i] != num) { + if (!ProcessDCGroup(i, sections[dc_group_sec[i]].br)) { + has_error = true; + } else { + section_status[dc_group_sec[i]] = SectionStatus::kDone; + } + } + }, + "DecodeDCGroup")); + } + if (has_error) return JXL_FAILURE("Error in DC group"); + + if (*std::min_element(decoded_dc_groups_.begin(), decoded_dc_groups_.end()) && + !finalized_dc_) { + PassesDecoderState::PipelineOptions pipeline_options; + pipeline_options.use_slow_render_pipeline = use_slow_rendering_pipeline_; + pipeline_options.coalescing = coalescing_; + pipeline_options.render_spotcolors = render_spotcolors_; + JXL_RETURN_IF_ERROR( + dec_state_->PreparePipeline(decoded_, pipeline_options)); + FinalizeDC(); + JXL_RETURN_IF_ERROR(AllocateOutput()); + if (progressive_detail_ >= JxlProgressiveDetail::kDC) { + MarkSections(sections, num, section_status); + return true; + } + } + + if (finalized_dc_ && ac_global_sec != num && !decoded_ac_global_) { + JXL_RETURN_IF_ERROR(ProcessACGlobal(sections[ac_global_sec].br)); + section_status[ac_global_sec] = SectionStatus::kDone; + } + + if (progressive_detail_ >= JxlProgressiveDetail::kLastPasses) { + // Mark that we only want the next progression pass. + size_t target_complete_passes = NextNumPassesToPause(); + for (size_t i = 0; i < ac_group_sec.size(); i++) { + desired_num_ac_passes[i] = + std::min(desired_num_ac_passes[i], + target_complete_passes - decoded_passes_per_ac_group_[i]); + } + } + + if (decoded_ac_global_) { + // Mark all the AC groups that we received as not complete yet. + for (size_t i = 0; i < ac_group_sec.size(); i++) { + if (desired_num_ac_passes[i] != 0) { + dec_state_->render_pipeline->ClearDone(i); + } + } + + JXL_RETURN_IF_ERROR(RunOnPool( + pool_, 0, ac_group_sec.size(), + [this](size_t num_threads) { + return PrepareStorage(num_threads, + decoded_passes_per_ac_group_.size()); + }, + [this, &ac_group_sec, &desired_num_ac_passes, &num, §ions, + §ion_status, &has_error](size_t g, size_t thread) { + if (desired_num_ac_passes[g] == 0) { + // no new AC pass, nothing to do + return; + } + (void)num; + size_t first_pass = decoded_passes_per_ac_group_[g]; + BitReader* JXL_RESTRICT readers[kMaxNumPasses]; + for (size_t i = 0; i < desired_num_ac_passes[g]; i++) { + JXL_ASSERT(ac_group_sec[g][first_pass + i] != num); + readers[i] = sections[ac_group_sec[g][first_pass + i]].br; + } + if (!ProcessACGroup(g, readers, desired_num_ac_passes[g], + GetStorageLocation(thread, g), + /*force_draw=*/false, /*dc_only=*/false)) { + has_error = true; + } else { + for (size_t i = 0; i < desired_num_ac_passes[g]; i++) { + section_status[ac_group_sec[g][first_pass + i]] = + SectionStatus::kDone; + } + } + }, + "DecodeGroup")); + } + if (has_error) return JXL_FAILURE("Error in AC group"); + + MarkSections(sections, num, section_status); + return true; +} + +Status FrameDecoder::Flush() { + bool has_blending = frame_header_.blending_info.mode != BlendMode::kReplace || + frame_header_.custom_size_or_origin; + for (const auto& blending_info_ec : + frame_header_.extra_channel_blending_info) { + if (blending_info_ec.mode != BlendMode::kReplace) has_blending = true; + } + // No early Flush() if blending is enabled. + if (has_blending && !is_finalized_) { + return false; + } + // No early Flush() - nothing to do - if the frame is a kSkipProgressive + // frame. + if (frame_header_.frame_type == FrameType::kSkipProgressive && + !is_finalized_) { + return true; + } + if (decoded_->IsJPEG()) { + // Nothing to do. + return true; + } + JXL_RETURN_IF_ERROR(AllocateOutput()); + + uint32_t completely_decoded_ac_pass = *std::min_element( + decoded_passes_per_ac_group_.begin(), decoded_passes_per_ac_group_.end()); + if (completely_decoded_ac_pass < frame_header_.passes.num_passes) { + // We don't have all AC yet: force a draw of all the missing areas. + // Mark all sections as not complete. + for (size_t i = 0; i < decoded_passes_per_ac_group_.size(); i++) { + if (decoded_passes_per_ac_group_[i] < frame_header_.passes.num_passes) { + dec_state_->render_pipeline->ClearDone(i); + } + } + std::atomic has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool_, 0, decoded_passes_per_ac_group_.size(), + [this](const size_t num_threads) { + return PrepareStorage(num_threads, + decoded_passes_per_ac_group_.size()); + }, + [this, &has_error](const uint32_t g, size_t thread) { + if (decoded_passes_per_ac_group_[g] == + frame_header_.passes.num_passes) { + // This group was drawn already, nothing to do. + return; + } + BitReader* JXL_RESTRICT readers[kMaxNumPasses] = {}; + bool ok = ProcessACGroup( + g, readers, /*num_passes=*/0, GetStorageLocation(thread, g), + /*force_draw=*/true, /*dc_only=*/!decoded_ac_global_); + if (!ok) has_error = true; + }, + "ForceDrawGroup")); + if (has_error) { + return JXL_FAILURE("Drawing groups failed"); + } + } + + // undo global modular transforms and copy int pixel buffers to float ones + JXL_RETURN_IF_ERROR(modular_frame_decoder_.FinalizeDecoding(dec_state_, pool_, + is_finalized_)); + + return true; +} + +int FrameDecoder::SavedAs(const FrameHeader& header) { + if (header.frame_type == FrameType::kDCFrame) { + // bits 16, 32, 64, 128 for DC level + return 16 << (header.dc_level - 1); + } else if (header.CanBeReferenced()) { + // bits 1, 2, 4 and 8 for the references + return 1 << header.save_as_reference; + } + + return 0; +} + +bool FrameDecoder::HasEverything() const { + if (!decoded_dc_global_) return false; + if (!decoded_ac_global_) return false; + for (auto& have_dc_group : decoded_dc_groups_) { + if (!have_dc_group) return false; + } + for (auto& nb_passes : decoded_passes_per_ac_group_) { + if (nb_passes < frame_header_.passes.num_passes) return false; + } + return true; +} + +int FrameDecoder::References() const { + if (is_finalized_) { + return 0; + } + if (!HasEverything()) return 0; + + int result = 0; + + // Blending + if (frame_header_.frame_type == FrameType::kRegularFrame || + frame_header_.frame_type == FrameType::kSkipProgressive) { + bool cropped = frame_header_.custom_size_or_origin; + if (cropped || frame_header_.blending_info.mode != BlendMode::kReplace) { + result |= (1 << frame_header_.blending_info.source); + } + const auto& extra = frame_header_.extra_channel_blending_info; + for (size_t i = 0; i < extra.size(); ++i) { + if (cropped || extra[i].mode != BlendMode::kReplace) { + result |= (1 << extra[i].source); + } + } + } + + // Patches + if (frame_header_.flags & FrameHeader::kPatches) { + result |= dec_state_->shared->image_features.patches.GetReferences(); + } + + // DC Level + if (frame_header_.flags & FrameHeader::kUseDcFrame) { + // Reads from the next dc level + int dc_level = frame_header_.dc_level + 1; + // bits 16, 32, 64, 128 for DC level + result |= (16 << (dc_level - 1)); + } + + return result; +} + +Status FrameDecoder::FinalizeFrame() { + if (is_finalized_) { + return JXL_FAILURE("FinalizeFrame called multiple times"); + } + is_finalized_ = true; + if (decoded_->IsJPEG()) { + // Nothing to do. + return true; + } + if (!finalized_dc_) { + // We don't have all of DC, and render pipeline is not created yet, so we + // can not call Flush() yet. + return JXL_FAILURE("FinalizeFrame called before DC was fully decoded"); + } + + JXL_RETURN_IF_ERROR(Flush()); + + if (frame_header_.CanBeReferenced()) { + auto& info = dec_state_->shared_storage + .reference_frames[frame_header_.save_as_reference]; + info.storage = std::move(dec_state_->frame_storage_for_referencing); + info.ib_is_in_xyb = frame_header_.save_before_color_transform; + info.frame = &info.storage; + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_frame.h b/media/libjxl/src/lib/jxl/dec_frame.h new file mode 100644 index 000000000..62c61c0fa --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_frame.h @@ -0,0 +1,331 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_FRAME_H_ +#define LIB_JXL_DEC_FRAME_H_ + +#include + +#include "jxl/decode.h" +#include "jxl/types.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/blending.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Decodes a frame. Groups may be processed in parallel by `pool`. +// `metadata` is the metadata that applies to all frames of the codestream +// `decoded->metadata` must already be set and must match metadata.m. +// Used in the encoder to model decoder behaviour, and in tests. +Status DecodeFrame(PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool, + const uint8_t* next_in, size_t avail_in, + ImageBundle* decoded, const CodecMetadata& metadata, + bool use_slow_rendering_pipeline = false); + +// TODO(veluca): implement "forced drawing". +class FrameDecoder { + public: + // All parameters must outlive the FrameDecoder. + FrameDecoder(PassesDecoderState* dec_state, const CodecMetadata& metadata, + ThreadPool* pool, bool use_slow_rendering_pipeline) + : dec_state_(dec_state), + pool_(pool), + frame_header_(&metadata), + use_slow_rendering_pipeline_(use_slow_rendering_pipeline) {} + + void SetRenderSpotcolors(bool rsc) { render_spotcolors_ = rsc; } + void SetCoalescing(bool c) { coalescing_ = c; } + + // Read FrameHeader and table of contents from the given BitReader. + // Also checks frame dimensions for their limits, and sets the output + // image buffer. + Status InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, + bool is_preview, bool output_needed); + + struct SectionInfo { + BitReader* JXL_RESTRICT br; + size_t id; + }; + + struct TocEntry { + size_t size; + size_t id; + }; + + enum SectionStatus { + // Processed correctly. + kDone = 0, + // Skipped because other required sections were not yet processed. + kSkipped = 1, + // Skipped because the section was already processed. + kDuplicate = 2, + // Only partially decoded: the section will need to be processed again. + kPartial = 3, + }; + + // Processes `num` sections; each SectionInfo contains the index + // of the section and a BitReader that only contains the data of the section. + // `section_status` should point to `num` elements, and will be filled with + // information about whether each section was processed or not. + // A section is a part of the encoded file that is indexed by the TOC. + Status ProcessSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status); + + // Flushes all the data decoded so far to pixels. + Status Flush(); + + // Runs final operations once a frame data is decoded. + // Must be called exactly once per frame, after all calls to ProcessSections. + Status FinalizeFrame(); + + // Returns dependencies of this frame on reference ids as a bit mask: bits 0-3 + // indicate reference frame 0-3 for patches and blending, bits 4-7 indicate DC + // frames this frame depends on. Only returns a valid result after all calls + // to ProcessSections are finished and before FinalizeFrame. + int References() const; + + // Returns reference id of storage location where this frame is stored as a + // bit flag, or 0 if not stored. + // Matches the bit mask used for GetReferences: bits 0-3 indicate it is stored + // for patching or blending, bits 4-7 indicate DC frame. + // Unlike References, can be ran at any time as + // soon as the frame header is known. + static int SavedAs(const FrameHeader& header); + + uint64_t SumSectionSizes() const { return section_sizes_sum_; } + const std::vector& Toc() const { return toc_; } + + const FrameHeader& GetFrameHeader() const { return frame_header_; } + + // Returns whether a DC image has been decoded, accessible at low resolution + // at passes.shared_storage.dc_storage + bool HasDecodedDC() const { return finalized_dc_; } + bool HasDecodedAll() const { return toc_.size() == num_sections_done_; } + + size_t NumCompletePasses() const { + return *std::min_element(decoded_passes_per_ac_group_.begin(), + decoded_passes_per_ac_group_.end()); + }; + + // If enabled, ProcessSections will stop and return true when the DC + // sections have been processed, instead of starting the AC sections. This + // will only occur if supported (that is, flushing will produce a valid + // 1/8th*1/8th resolution image). The return value of true then does not mean + // all sections have been processed, use HasDecodedDC and HasDecodedAll + // to check the true finished state. + // Returns the progressive detail that will be effective for the frame. + JxlProgressiveDetail SetPauseAtProgressive(JxlProgressiveDetail prog_detail) { + bool single_section = + frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1; + if (frame_header_.frame_type != kSkipProgressive && + // If there's only one group and one pass, there is no separate section + // for DC and the entire full resolution image is available at once. + !single_section && + // If extra channels are encoded with modular without squeeze, they + // don't support DC. If the are encoded with squeeze, DC works in theory + // but the implementation may not yet correctly support this for Flush. + // Therefore, can't correctly pause for a progressive step if there is + // an extra channel (including alpha channel) + // TOOD(firsching): Check if this is still the case. + decoded_->metadata()->extra_channel_info.empty() && + // DC is not guaranteed to be available in modular mode and may be a + // black image. If squeeze is used, it may be available depending on the + // current implementation. + // TODO(lode): do return DC if it's known that flushing at this point + // will produce a valid 1/8th downscaled image with modular encoding. + frame_header_.encoding == FrameEncoding::kVarDCT) { + progressive_detail_ = prog_detail; + } else { + progressive_detail_ = JxlProgressiveDetail::kFrames; + } + if (progressive_detail_ >= JxlProgressiveDetail::kPasses) { + for (size_t i = 1; i < frame_header_.passes.num_passes; ++i) { + passes_to_pause_.push_back(i); + } + } else if (progressive_detail_ >= JxlProgressiveDetail::kLastPasses) { + for (size_t i = 0; i < frame_header_.passes.num_downsample; ++i) { + passes_to_pause_.push_back(frame_header_.passes.last_pass[i] + 1); + } + // The format does not guarantee that these values are sorted. + std::sort(passes_to_pause_.begin(), passes_to_pause_.end()); + } + return progressive_detail_; + } + + size_t NextNumPassesToPause() const { + auto it = std::upper_bound(passes_to_pause_.begin(), passes_to_pause_.end(), + NumCompletePasses()); + return (it != passes_to_pause_.end() ? *it + : std::numeric_limits::max()); + } + + void MaybeSetUnpremultiplyAlpha(bool unpremul_alpha) { + const jxl::ExtraChannelInfo* alpha = + decoded_->metadata()->Find(jxl::ExtraChannel::kAlpha); + if (alpha && alpha->alpha_associated && unpremul_alpha) { + dec_state_->unpremul_alpha = true; + } + } + + // Sets the buffer to which uint8 sRGB pixels will be decoded. This is not + // supported for all images. If it succeeds, HasRGBBuffer() will return true. + // If it does not succeed, the image is decoded to the ImageBundle passed to + // InitFrame instead. + // If an output callback is set, this function *may not* be called. + // + // @param undo_orientation: if true, indicates the frame decoder should apply + // the exif orientation to bring the image to the intended display + // orientation. Performing this operation is not yet supported, so this + // results in not setting the buffer if the image has a non-identity EXIF + // orientation. When outputting to the ImageBundle, no orientation is undone. + void MaybeSetRGB8OutputBuffer(uint8_t* rgb_output, size_t stride, + bool is_rgba, bool undo_orientation) const { + if (!CanDoLowMemoryPath(undo_orientation) || dec_state_->unpremul_alpha) { + return; + } + dec_state_->rgb_output = rgb_output; + dec_state_->rgb_output_is_rgba = is_rgba; + dec_state_->rgb_stride = stride; + JXL_ASSERT(!dec_state_->pixel_callback.IsPresent()); +#if !JXL_HIGH_PRECISION + if (decoded_->metadata()->xyb_encoded && + dec_state_->output_encoding_info.color_encoding.IsSRGB() && + dec_state_->output_encoding_info.all_default_opsin && + dec_state_->output_encoding_info.desired_intensity_target == + dec_state_->output_encoding_info.orig_intensity_target && + HasFastXYBTosRGB8() && frame_header_.needs_color_transform()) { + dec_state_->fast_xyb_srgb8_conversion = true; + } +#endif + } + + // Same as MaybeSetRGB8OutputBuffer, but with a float callback. This is not + // supported for all images. If it succeeds, HasRGBBuffer() will return true. + // If it does not succeed, the image is decoded to the ImageBundle passed to + // InitFrame instead. + // If a RGB8 output buffer is set, this function *may not* be called. + // + // @param undo_orientation: if true, indicates the frame decoder should apply + // the exif orientation to bring the image to the intended display + // orientation. Performing this operation is not yet supported, so this + // results in not setting the buffer if the image has a non-identity EXIF + // orientation. When outputting to the ImageBundle, no orientation is undone. + void MaybeSetFloatCallback(const PixelCallback& pixel_callback, bool is_rgba, + bool unpremul_alpha, bool undo_orientation) const { + if (!CanDoLowMemoryPath(undo_orientation)) return; + dec_state_->pixel_callback = pixel_callback; + dec_state_->rgb_output_is_rgba = is_rgba; + JXL_ASSERT(dec_state_->rgb_output == nullptr); + } + + // Returns true if the rgb output buffer passed by MaybeSetRGB8OutputBuffer + // has been/will be populated by Flush() / FinalizeFrame(), or if a pixel + // callback has been used. + bool HasRGBBuffer() const { + return dec_state_->rgb_output != nullptr || + dec_state_->pixel_callback.IsPresent(); + } + + private: + Status ProcessDCGlobal(BitReader* br); + Status ProcessDCGroup(size_t dc_group_id, BitReader* br); + void FinalizeDC(); + Status AllocateOutput(); + Status ProcessACGlobal(BitReader* br); + Status ProcessACGroup(size_t ac_group_id, BitReader* JXL_RESTRICT* br, + size_t num_passes, size_t thread, bool force_draw, + bool dc_only); + void MarkSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status); + + // Allocates storage for parallel decoding using up to `num_threads` threads + // of up to `num_tasks` tasks. The value of `thread` passed to + // `GetStorageLocation` must be smaller than the `num_threads` value passed + // here. The value of `task` passed to `GetStorageLocation` must be smaller + // than the value of `num_tasks` passed here. + Status PrepareStorage(size_t num_threads, size_t num_tasks) { + size_t storage_size = std::min(num_threads, num_tasks); + if (storage_size > group_dec_caches_.size()) { + group_dec_caches_.resize(storage_size); + } + use_task_id_ = num_threads > num_tasks; + bool use_group_ids = (modular_frame_decoder_.UsesFullImage() && + (frame_header_.encoding == FrameEncoding::kVarDCT || + (frame_header_.flags & FrameHeader::kNoise))); + if (dec_state_->render_pipeline) { + JXL_RETURN_IF_ERROR(dec_state_->render_pipeline->PrepareForThreads( + storage_size, use_group_ids)); + } + return true; + } + + size_t GetStorageLocation(size_t thread, size_t task) { + if (use_task_id_) return task; + return thread; + } + + // If the image has default exif orientation (or has an orientation but should + // not be undone) and no blending, the current frame cannot be referenced by + // future frames, there are no spot colors to be rendered, and alpha is not + // premultiplied, then low memory options can be used + // (uint8 output buffer or float pixel callback). + // TODO(veluca): reduce this set of restrictions. + bool CanDoLowMemoryPath(bool undo_orientation) const { + return !(undo_orientation && + decoded_->metadata()->GetOrientation() != Orientation::kIdentity); + } + + PassesDecoderState* dec_state_; + ThreadPool* pool_; + std::vector toc_; + uint64_t section_sizes_sum_; + // TODO(veluca): figure out the duplication between these and dec_state_. + FrameHeader frame_header_; + FrameDimensions frame_dim_; + ImageBundle* decoded_; + ModularFrameDecoder modular_frame_decoder_; + bool render_spotcolors_ = true; + bool coalescing_ = true; + + std::vector processed_section_; + std::vector decoded_passes_per_ac_group_; + std::vector decoded_dc_groups_; + bool decoded_dc_global_; + bool decoded_ac_global_; + bool HasEverything() const; + bool finalized_dc_ = true; + size_t num_sections_done_ = 0; + bool is_finalized_ = true; + bool allocated_ = false; + + std::vector group_dec_caches_; + + // Whether or not the task id should be used for storage indexing, instead of + // the thread id. + bool use_task_id_ = false; + + // Testing setting: whether or not to use the slow rendering pipeline. + bool use_slow_rendering_pipeline_; + + JxlProgressiveDetail progressive_detail_ = kFrames; + // Number of completed passes where section decoding should pause. + // Used for progressive details at least kLastPasses. + std::vector passes_to_pause_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_FRAME_H_ diff --git a/media/libjxl/src/lib/jxl/dec_group.cc b/media/libjxl/src/lib/jxl/dec_group.cc new file mode 100644 index 000000000..689bc8177 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_group.cc @@ -0,0 +1,800 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_group.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc" +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" + +#ifndef LIB_JXL_DEC_GROUP_CC +#define LIB_JXL_DEC_GROUP_CC +namespace jxl { + +// Interface for reading groups for DecodeGroupImpl. +class GetBlock { + public: + virtual void StartRow(size_t by) = 0; + virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, + size_t size, size_t log2_covered_blocks, + ACPtr block[3], ACType ac_type) = 0; + virtual ~GetBlock() {} +}; + +// Controls whether DecodeGroupImpl renders to pixels or not. +enum DrawMode { + // Render to pixels. + kDraw = 0, + // Don't render to pixels. + kDontDraw = 1, +}; + +} // namespace jxl +#endif // LIB_JXL_DEC_GROUP_CC + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftRight; + +using D = HWY_FULL(float); +using DU = HWY_FULL(uint32_t); +using DI = HWY_FULL(int32_t); +using DI16 = Rebind; +constexpr D d; +constexpr DI di; +constexpr DI16 di16; + +// TODO(veluca): consider SIMDfying. +void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { + for (size_t x = 0; x < 8; x++) { + for (size_t y = x + 1; y < 8; y++) { + std::swap(block[y * 8 + x], block[x * 8 + y]); + } + } +} + +template +void DequantLane(Vec scaled_dequant_x, Vec scaled_dequant_y, + Vec scaled_dequant_b, + const float* JXL_RESTRICT dequant_matrices, size_t size, + size_t k, Vec x_cc_mul, Vec b_cc_mul, + const float* JXL_RESTRICT biases, ACPtr qblock[3], + float* JXL_RESTRICT block) { + const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x); + const auto y_mul = + Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y); + const auto b_mul = + Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b); + + Vec quantized_x_int; + Vec quantized_y_int; + Vec quantized_b_int; + if (ac_type == ACType::k16) { + Rebind di16; + quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k)); + quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k)); + quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k)); + } else { + quantized_x_int = Load(di, qblock[0].ptr32 + k); + quantized_y_int = Load(di, qblock[1].ptr32 + k); + quantized_b_int = Load(di, qblock[2].ptr32 + k); + } + + const auto dequant_x_cc = + Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul); + const auto dequant_y = + Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul); + const auto dequant_b_cc = + Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul); + + const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc); + const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc); + Store(dequant_x, d, block + k); + Store(dequant_y, d, block + size + k); + Store(dequant_b, d, block + 2 * size + k); +} + +template +void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, + float x_dm_multiplier, float b_dm_multiplier, Vec x_cc_mul, + Vec b_cc_mul, size_t kind, size_t size, + const Quantizer& quantizer, size_t covered_blocks, + const size_t* sbx, + const float* JXL_RESTRICT* JXL_RESTRICT dc_row, + size_t dc_stride, const float* JXL_RESTRICT biases, + ACPtr qblock[3], float* JXL_RESTRICT block) { + PROFILER_FUNC; + + const auto scaled_dequant_s = inv_global_scale / quant; + + const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier); + const auto scaled_dequant_y = Set(d, scaled_dequant_s); + const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); + + const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); + + for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { + DequantLane(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, + dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, + qblock, block); + } + for (size_t c = 0; c < 3; c++) { + LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, + block + c * size); + } +} + +Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* JXL_RESTRICT dec_state, + size_t thread, size_t group_idx, + RenderPipelineInput& render_pipeline_input, + ImageBundle* decoded, DrawMode draw) { + // TODO(veluca): investigate cache usage in this function. + PROFILER_FUNC; + const Rect block_rect = dec_state->shared->BlockGroupRect(group_idx); + const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy; + + const size_t xsize_blocks = block_rect.xsize(); + const size_t ysize_blocks = block_rect.ysize(); + + const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); + + const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); + + const YCbCrChromaSubsampling& cs = + dec_state->shared->frame_header.chroma_subsampling; + + size_t idct_stride[3]; + for (size_t c = 0; c < 3; c++) { + idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow(); + } + + HWY_ALIGN int32_t scaled_qtable[64 * 3]; + + ACType ac_type = dec_state->coefficients->Type(); + auto dequant_block = ac_type == ACType::k16 ? DequantBlock + : DequantBlock; + // Whether or not coefficients should be stored for future usage, and/or read + // from past usage. + bool accumulate = !dec_state->coefficients->IsEmpty(); + // Offset of the current block in the group. + size_t offset = 0; + + std::array jpeg_c_map; + bool jpeg_is_gray = false; + std::array dcoff = {}; + + // TODO(veluca): all of this should be done only once per image. + if (decoded->IsJPEG()) { + if (!dec_state->shared->cmap.IsJPEGCompatible()) { + return JXL_FAILURE("The CfL map is not JPEG-compatible"); + } + jpeg_is_gray = (decoded->jpeg_data->components.size() == 1); + jpeg_c_map = JpegOrder(dec_state->shared->frame_header.color_transform, + jpeg_is_gray); + const std::vector& qe = + dec_state->shared->matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + for (size_t c = 0; c < 3; c++) { + if (dec_state->shared->frame_header.color_transform == + ColorTransform::kNone) { + dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c]; + } + for (size_t i = 0; i < 64; i++) { + // Transpose the matrix, as it will be used on the transposed block. + int n = qe[0].qraw.qtable->at(64 + i); + int d = qe[0].qraw.qtable->at(64 * c + i); + if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) { + return JXL_FAILURE("Invalid JPEG quantization table"); + } + scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] = + (1 << kCFLFixedPointPrecision) * n / d; + } + } + } + + size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)}; + size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)}; + Rect r[3]; + for (size_t i = 0; i < 3; i++) { + r[i] = + Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i], + block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]); + if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(), + dec_state->shared->dc->Plane(i).ysize()})) { + return JXL_FAILURE("Frame dimensions are too big for the image."); + } + } + + for (size_t by = 0; by < ysize_blocks; ++by) { + get_block->StartRow(by); + size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]}; + + const int32_t* JXL_RESTRICT row_quant = + block_rect.ConstRow(dec_state->shared->raw_quant_field, by); + + const float* JXL_RESTRICT dc_rows[3] = { + r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]), + r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]), + r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]), + }; + + const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks; + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + + const int8_t* JXL_RESTRICT row_cmap[3] = { + dec_state->shared->cmap.ytox_map.ConstRow(ty), + nullptr, + dec_state->shared->cmap.ytob_map.ConstRow(ty), + }; + + float* JXL_RESTRICT idct_row[3]; + int16_t* JXL_RESTRICT jpeg_row[3]; + for (size_t c = 0; c < 3; c++) { + idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row( + render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim); + if (decoded->IsJPEG()) { + auto& component = decoded->jpeg_data->components[jpeg_c_map[c]]; + jpeg_row[c] = + component.coeffs.data() + + (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) * + kDCTBlockSize; + } + } + + size_t bx = 0; + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks; + auto x_cc_mul = + Set(d, dec_state->shared->cmap.YtoXRatio(row_cmap[0][abs_tx])); + auto b_cc_mul = + Set(d, dec_state->shared->cmap.YtoBRatio(row_cmap[2][abs_tx])); + // Increment bx by llf_x because those iterations would otherwise + // immediately continue (!IsFirstBlock). Reduces mispredictions. + for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) { + size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]}; + AcStrategy acs = acs_row[bx]; + const size_t llf_x = acs.covered_blocks_x(); + + // Can only happen in the second or lower rows of a varblock. + if (JXL_UNLIKELY(!acs.IsFirstBlock())) { + bx += llf_x; + continue; + } + PROFILER_ZONE("DecodeGroupImpl inner"); + const size_t log2_covered_blocks = acs.log2_covered_blocks(); + + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + + ACPtr qblock[3]; + if (accumulate) { + for (size_t c = 0; c < 3; c++) { + qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset); + } + } else { + // No point in reading from bitstream without accumulating and not + // drawing. + JXL_ASSERT(draw == kDraw); + if (ac_type == ACType::k16) { + memset(group_dec_cache->dec_group_qblock16, 0, + size * 3 * sizeof(int16_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size; + } + } else { + memset(group_dec_cache->dec_group_qblock, 0, + size * 3 * sizeof(int32_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size; + } + } + } + JXL_RETURN_IF_ERROR(get_block->LoadBlock( + bx, by, acs, size, log2_covered_blocks, qblock, ac_type)); + offset += size; + if (draw == kDontDraw) { + bx += llf_x; + continue; + } + + if (JXL_UNLIKELY(decoded->IsJPEG())) { + if (acs.Strategy() != AcStrategy::Type::DCT) { + return JXL_FAILURE( + "Can only decode to JPEG if only DCT-8 is used."); + } + + HWY_ALIGN int32_t transposed_dct_y[64]; + for (size_t c : {1, 0, 2}) { + // Propagate only Y for grayscale. + if (jpeg_is_gray && c != 1) { + continue; + } + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + int16_t* JXL_RESTRICT jpeg_pos = + jpeg_row[c] + sbx[c] * kDCTBlockSize; + // JPEG XL is transposed, JPEG is not. + auto transposed_dct = qblock[c].ptr32; + Transpose8x8InPlace(transposed_dct); + // No CfL - no need to store the y block converted to integers. + if (!cs.Is444() || + (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) { + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else if (c == 1) { + // Y channel: save for restoring X/B, but nothing else to do. + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + Store(ini, di, transposed_dct_y + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else { + // transposed_dct_y contains the y channel block, transposed. + const auto scale = Set( + di, dec_state->shared->cmap.RatioJPEG(row_cmap[c][abs_tx])); + const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1)); + for (int i = 0; i < 64; i += Lanes(d)) { + auto in = Load(di, transposed_dct + i); + auto in_y = Load(di, transposed_dct_y + i); + auto qt = Load(di, scaled_qtable + c * size + i); + auto coeff_scale = ShiftRight( + Add(Mul(qt, scale), round)); + auto cfl_factor = ShiftRight( + Add(Mul(in_y, coeff_scale), round)); + StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i); + } + } + jpeg_pos[0] = + Clamp1(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047); + } + } else { + HWY_ALIGN float* const block = group_dec_cache->dec_group_block; + // Dequantize and add predictions. + dequant_block( + acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, + dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), + size, dec_state->shared->quantizer, + acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, + dc_stride, + dec_state->output_encoding_info.opsin_params.quant_biases, qblock, + block); + + for (size_t c : {1, 0, 2}) { + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + // IDCT + float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim; + TransformToPixels(acs.Strategy(), block + c * size, idct_pos, + idct_stride[c], group_dec_cache->scratch_space); + } + } + bx += llf_x; + } + } + } + if (draw == kDontDraw) { + return true; + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { +// Decode quantized AC coefficients of DCT blocks. +// LLF components in the output block will not be modified. +template +Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks, + int32_t* JXL_RESTRICT row_nzeros, + const int32_t* JXL_RESTRICT row_nzeros_top, + size_t nzeros_stride, size_t c, size_t bx, size_t by, + size_t lbx, AcStrategy acs, + const coeff_order_t* JXL_RESTRICT coeff_order, + BitReader* JXL_RESTRICT br, + ANSSymbolReader* JXL_RESTRICT decoder, + const std::vector& context_map, + const uint8_t* qdc_row, const int32_t* qf_row, + const BlockCtxMap& block_ctx_map, ACPtr block, + size_t shift = 0) { + PROFILER_FUNC; + // Equal to number of LLF coefficients. + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32); + + size_t ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &coeff_order[CoeffOrderOffset(ord, c)]; + + size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset; + + size_t nzeros = decoder->ReadHybridUint(nzero_ctx, br, context_map); + if (nzeros + covered_blocks > size) { + return JXL_FAILURE("Invalid AC: nzeros too large"); + } + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + row_nzeros[bx + x + y * nzeros_stride] = + (nzeros + covered_blocks - 1) >> log2_covered_blocks; + } + } + + const size_t histo_offset = + ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + + // Skip LLF + { + PROFILER_ZONE("AcDecSkipLLF, reader"); + size_t prev = (nzeros > size / 16 ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + const size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + const size_t u_coeff = decoder->ReadHybridUint(ctx, br, context_map); + // Hand-rolled version of UnpackSigned, shifting before the conversion to + // signed integer to avoid undefined behavior of shifting negative + // numbers. + const size_t magnitude = u_coeff >> 1; + const size_t neg_sign = (~u_coeff) & 1; + const intptr_t coeff = + static_cast((magnitude ^ (neg_sign - 1)) << shift); + if (ac_type == ACType::k16) { + block.ptr16[order[k]] += coeff; + } else { + block.ptr32[order[k]] += coeff; + } + prev = static_cast(u_coeff != 0); + nzeros -= prev; + } + if (JXL_UNLIKELY(nzeros != 0)) { + return JXL_FAILURE("Invalid AC: nzeros not 0. Block (%" PRIuS ", %" PRIuS + "), channel %" PRIuS, + bx, by, c); + } + } + return true; +} + +// Structs used by DecodeGroupImpl to get a quantized block. +// GetBlockFromBitstream uses ANS decoding (and thus keeps track of row +// pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient +// image provided by the encoder. + +struct GetBlockFromBitstream : public GetBlock { + void StartRow(size_t by) override { + qf_row = rect.ConstRow(*qf, by); + for (size_t c = 0; c < 3; c++) { + size_t sby = by >> vshift[c]; + quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0(); + for (size_t i = 0; i < num_passes; i++) { + row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby); + row_nzeros_top[i][c] = + sby == 0 + ? nullptr + : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1); + } + } + } + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + auto decode_ac_varblock = ac_type == ACType::k16 + ? DecodeACVarBlock + : DecodeACVarBlock; + for (size_t c : {1, 0, 2}) { + size_t sbx = bx >> hshift[c]; + size_t sby = by >> vshift[c]; + if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) { + continue; + } + + for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) { + JXL_RETURN_IF_ERROR(decode_ac_varblock( + ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c], + row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs, + &coeff_orders[pass * coeff_order_size], readers[pass], + &decoders[pass], context_map[pass], quant_dc_row, qf_row, + *block_ctx_map, block[c], shift_for_pass[pass])); + } + } + return true; + } + + Status Init(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes, + size_t group_idx, size_t histo_selector_bits, const Rect& rect, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* dec_state, size_t first_pass) { + for (size_t i = 0; i < 3; i++) { + hshift[i] = dec_state->shared->frame_header.chroma_subsampling.HShift(i); + vshift[i] = dec_state->shared->frame_header.chroma_subsampling.VShift(i); + } + this->coeff_order_size = dec_state->shared->coeff_order_size; + this->coeff_orders = + dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size; + this->context_map = dec_state->context_map.data() + first_pass; + this->readers = readers; + this->num_passes = num_passes; + this->shift_for_pass = + dec_state->shared->frame_header.passes.shift + first_pass; + this->group_dec_cache = group_dec_cache; + this->rect = rect; + block_ctx_map = &dec_state->shared->block_ctx_map; + qf = &dec_state->shared->raw_quant_field; + quant_dc = &dec_state->shared->quant_dc; + + for (size_t pass = 0; pass < num_passes; pass++) { + // Select which histogram set to use among those of the current pass. + size_t cur_histogram = 0; + if (histo_selector_bits != 0) { + cur_histogram = readers[pass]->ReadBits(histo_selector_bits); + } + if (cur_histogram >= dec_state->shared->num_histograms) { + return JXL_FAILURE("Invalid histogram selector"); + } + ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); + + decoders[pass] = + ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]); + } + nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); + for (size_t i = 0; i < num_passes; i++) { + JXL_ASSERT( + nzeros_stride == + static_cast(group_dec_cache->num_nzeroes[i].PixelsPerRow())); + } + return true; + } + + const uint32_t* shift_for_pass = nullptr; // not owned + const coeff_order_t* JXL_RESTRICT coeff_orders; + size_t coeff_order_size; + const std::vector* JXL_RESTRICT context_map; + ANSSymbolReader decoders[kMaxNumPasses]; + BitReader* JXL_RESTRICT* JXL_RESTRICT readers; + size_t num_passes; + size_t ctx_offset[kMaxNumPasses]; + size_t nzeros_stride; + int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3]; + const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3]; + GroupDecCache* JXL_RESTRICT group_dec_cache; + const BlockCtxMap* block_ctx_map; + const ImageI* qf; + const ImageB* quant_dc; + const int32_t* qf_row; + const uint8_t* quant_dc_row; + Rect rect; + size_t hshift[3], vshift[3]; +}; + +struct GetBlockFromEncoder : public GetBlock { + void StartRow(size_t by) override {} + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + JXL_DASSERT(ac_type == ACType::k32); + for (size_t c = 0; c < 3; c++) { + // for each pass + for (size_t i = 0; i < quantized_ac->size(); i++) { + for (size_t k = 0; k < size; k++) { + // TODO(veluca): SIMD. + block[c].ptr32[k] += + rows[i][c][offset + k] * (1 << shift_for_pass[i]); + } + } + } + offset += size; + return true; + } + + GetBlockFromEncoder(const std::vector>& ac, + size_t group_idx, const uint32_t* shift_for_pass) + : quantized_ac(&ac), shift_for_pass(shift_for_pass) { + // TODO(veluca): not supported with chroma subsampling. + for (size_t i = 0; i < quantized_ac->size(); i++) { + JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + } + + const std::vector>* JXL_RESTRICT quantized_ac; + size_t offset = 0; + const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3]; + const uint32_t* shift_for_pass = nullptr; // not owned +}; + +HWY_EXPORT(DecodeGroupImpl); + +} // namespace + +Status DecodeGroup(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only, bool* should_run_pipeline) { + PROFILER_FUNC; + + DrawMode draw = (num_passes + first_pass == + dec_state->shared->frame_header.passes.num_passes) || + force_draw + ? kDraw + : kDontDraw; + + if (should_run_pipeline) { + *should_run_pipeline = draw != kDontDraw; + } + + if (draw == kDraw && num_passes == 0 && first_pass == 0) { + group_dec_cache->InitDCBufferOnce(); + const YCbCrChromaSubsampling& cs = + dec_state->shared->frame_header.chroma_subsampling; + for (size_t c : {0, 1, 2}) { + size_t hs = cs.HShift(c); + size_t vs = cs.VShift(c); + // We reuse filter_input_storage here as it is not currently in use. + const Rect src_rect_precs = dec_state->shared->BlockGroupRect(group_idx); + const Rect src_rect = + Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs, + src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs); + const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(), + src_rect.ysize()); + CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2, + copy_rect, &group_dec_cache->dc_buffer); + // Mirrorpad. Interleaving left and right padding ensures that padding + // works out correctly even for images with DC size of 1. + for (size_t y = 0; y < src_rect.ysize() + 4; y++) { + size_t xend = kRenderPipelineXOffset + + (dec_state->shared->dc->Plane(c).xsize() >> hs) - + src_rect.x0(); + for (size_t ix = 0; ix < 2; ix++) { + if (src_rect.x0() == 0) { + group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] = + group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix]; + } + if (src_rect.x0() + src_rect.xsize() + 2 >= + (dec_state->shared->dc->xsize() >> hs)) { + group_dec_cache->dc_buffer.Row(y)[xend + ix] = + group_dec_cache->dc_buffer.Row(y)[xend - ix - 1]; + } + } + } + Rect dst_rect = render_pipeline_input.GetBuffer(c).second; + ImageF* upsampling_dst = render_pipeline_input.GetBuffer(c).first; + JXL_ASSERT(dst_rect.IsInside(*upsampling_dst)); + + RenderPipelineStage::RowInfo input_rows(1, std::vector(5)); + RenderPipelineStage::RowInfo output_rows(1, std::vector(8)); + for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize(); + y++) { + for (ssize_t iy = 0; iy < 5; iy++) { + input_rows[0][iy] = group_dec_cache->dc_buffer.Row( + Mirror(ssize_t(y) + iy - 2, + dec_state->shared->dc->Plane(c).ysize() >> vs) + + 2 - src_rect.y0()); + } + for (size_t iy = 0; iy < 8; iy++) { + output_rows[0][iy] = + dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) - + kRenderPipelineXOffset; + } + // Arguments set to 0/nullptr are not used. + dec_state->upsampler8x->ProcessRow(input_rows, output_rows, + /*xextra=*/0, src_rect.xsize(), 0, 0, + thread); + } + } + return true; + } + + size_t histo_selector_bits = 0; + if (dc_only) { + JXL_ASSERT(num_passes == 0); + } else { + JXL_ASSERT(dec_state->shared->num_histograms > 0); + histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms); + } + + GetBlockFromBitstream get_block; + JXL_RETURN_IF_ERROR( + get_block.Init(readers, num_passes, group_idx, histo_selector_bits, + dec_state->shared->BlockGroupRect(group_idx), + group_dec_cache, dec_state, first_pass)); + + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + &get_block, group_dec_cache, dec_state, thread, group_idx, + render_pipeline_input, decoded, draw)); + + for (size_t pass = 0; pass < num_passes; pass++) { + if (!get_block.decoders[pass].CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + } + return true; +} + +Status DecodeGroupForRoundtrip(const std::vector>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out) { + PROFILER_FUNC; + + GetBlockFromEncoder get_block(ac, group_idx, + dec_state->shared->frame_header.passes.shift); + group_dec_cache->InitOnce( + /*num_passes=*/0, + /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1); + + return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + &get_block, group_dec_cache, dec_state, thread, group_idx, + render_pipeline_input, decoded, kDraw); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/dec_group.h b/media/libjxl/src/lib/jxl/dec_group.h new file mode 100644 index 000000000..e50d22d2f --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_group.h @@ -0,0 +1,49 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_GROUP_H_ +#define LIB_JXL_DEC_GROUP_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +Status DecodeGroup(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only, bool* should_run_pipeline); + +Status DecodeGroupForRoundtrip(const std::vector>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_DEC_GROUP_H_ diff --git a/media/libjxl/src/lib/jxl/dec_group_border.cc b/media/libjxl/src/lib/jxl/dec_group_border.cc new file mode 100644 index 000000000..4bee3ae6e --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_group_border.cc @@ -0,0 +1,184 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_group_border.h" + +#include + +namespace jxl { + +void GroupBorderAssigner::Init(const FrameDimensions& frame_dim) { + frame_dim_ = frame_dim; + size_t num_corners = + (frame_dim_.xsize_groups + 1) * (frame_dim_.ysize_groups + 1); + counters_.reset(new std::atomic[num_corners]); + // Initialize counters. + for (size_t y = 0; y < frame_dim_.ysize_groups + 1; y++) { + for (size_t x = 0; x < frame_dim_.xsize_groups + 1; x++) { + // Counters at image borders don't have anything on the other side, we + // pre-fill their value to have more uniform handling afterwards. + uint8_t init_value = 0; + if (x == 0) { + init_value |= kTopLeft | kBottomLeft; + } + if (x == frame_dim_.xsize_groups) { + init_value |= kTopRight | kBottomRight; + } + if (y == 0) { + init_value |= kTopLeft | kTopRight; + } + if (y == frame_dim_.ysize_groups) { + init_value |= kBottomLeft | kBottomRight; + } + counters_[y * (frame_dim_.xsize_groups + 1) + x] = init_value; + } + } +} + +void GroupBorderAssigner::ClearDone(size_t group_id) { + size_t x = group_id % frame_dim_.xsize_groups; + size_t y = group_id / frame_dim_.xsize_groups; + size_t top_left_idx = y * (frame_dim_.xsize_groups + 1) + x; + size_t top_right_idx = y * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_right_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_left_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x; + counters_[top_left_idx].fetch_and(~kBottomRight); + counters_[top_right_idx].fetch_and(~kBottomLeft); + counters_[bottom_left_idx].fetch_and(~kTopRight); + counters_[bottom_right_idx].fetch_and(~kTopLeft); +} + +// Looking at each corner between groups, we can guarantee that the four +// involved groups will agree between each other regarding the order in which +// each of the four groups terminated. Thus, the last of the four groups +// gets the responsibility of handling the corner. For borders, every border +// is assigned to its top corner (for vertical borders) or to its left corner +// (for horizontal borders): the order as seen on those corners will decide who +// handles that border. + +void GroupBorderAssigner::GroupDone(size_t group_id, size_t padx, size_t pady, + Rect* rects_to_finalize, + size_t* num_to_finalize) { + size_t x = group_id % frame_dim_.xsize_groups; + size_t y = group_id / frame_dim_.xsize_groups; + Rect block_rect(x * frame_dim_.group_dim / kBlockDim, + y * frame_dim_.group_dim / kBlockDim, + frame_dim_.group_dim / kBlockDim, + frame_dim_.group_dim / kBlockDim, frame_dim_.xsize_blocks, + frame_dim_.ysize_blocks); + + size_t top_left_idx = y * (frame_dim_.xsize_groups + 1) + x; + size_t top_right_idx = y * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_right_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_left_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x; + + auto fetch_status = [this](size_t idx, uint8_t bit) { + // Note that the acq-rel semantics of this fetch are actually needed to + // ensure that the pixel data of the group is already written to memory. + size_t status = counters_[idx].fetch_or(bit); + JXL_DASSERT((bit & status) == 0); + return bit | status; + }; + + size_t top_left_status = fetch_status(top_left_idx, kBottomRight); + size_t top_right_status = fetch_status(top_right_idx, kBottomLeft); + size_t bottom_right_status = fetch_status(bottom_right_idx, kTopLeft); + size_t bottom_left_status = fetch_status(bottom_left_idx, kTopRight); + + size_t x1 = block_rect.x0() + block_rect.xsize(); + size_t y1 = block_rect.y0() + block_rect.ysize(); + + bool is_last_group_x = frame_dim_.xsize_groups == x + 1; + bool is_last_group_y = frame_dim_.ysize_groups == y + 1; + + // Start of border of neighbouring group, end of border of this group, start + // of border of this group (on the other side), end of border of next group. + size_t xpos[4] = { + block_rect.x0() == 0 ? 0 : block_rect.x0() * kBlockDim - padx, + block_rect.x0() == 0 + ? 0 + : std::min(frame_dim_.xsize, block_rect.x0() * kBlockDim + padx), + is_last_group_x ? frame_dim_.xsize : x1 * kBlockDim - padx, + std::min(frame_dim_.xsize, x1 * kBlockDim + padx)}; + size_t ypos[4] = { + block_rect.y0() == 0 ? 0 : block_rect.y0() * kBlockDim - pady, + block_rect.y0() == 0 + ? 0 + : std::min(frame_dim_.ysize, block_rect.y0() * kBlockDim + pady), + is_last_group_y ? frame_dim_.ysize : y1 * kBlockDim - pady, + std::min(frame_dim_.ysize, y1 * kBlockDim + pady)}; + + *num_to_finalize = 0; + auto append_rect = [&](size_t x0, size_t x1, size_t y0, size_t y1) { + Rect rect(xpos[x0], ypos[y0], xpos[x1] - xpos[x0], ypos[y1] - ypos[y0]); + if (rect.xsize() == 0 || rect.ysize() == 0) return; + JXL_DASSERT(*num_to_finalize < kMaxToFinalize); + rects_to_finalize[(*num_to_finalize)++] = rect; + }; + + // Because of how group borders are assigned, it is impossible that we need to + // process the left and right side of some area but not the center area. Thus, + // we compute the first/last part to process in every horizontal strip and + // merge them together. We first collect a mask of what parts should be + // processed. + // We do this horizontally rather than vertically because horizontal borders + // are larger. + bool available_parts_mask[3][3] = {}; // [x][y] + // Center + available_parts_mask[1][1] = true; + // Corners + if (top_left_status == 0xF) available_parts_mask[0][0] = true; + if (top_right_status == 0xF) available_parts_mask[2][0] = true; + if (bottom_right_status == 0xF) available_parts_mask[2][2] = true; + if (bottom_left_status == 0xF) available_parts_mask[0][2] = true; + // Other borders + if (top_left_status & kTopRight) available_parts_mask[1][0] = true; + if (top_left_status & kBottomLeft) available_parts_mask[0][1] = true; + if (top_right_status & kBottomRight) available_parts_mask[2][1] = true; + if (bottom_left_status & kBottomRight) available_parts_mask[1][2] = true; + + // Collect horizontal ranges. + constexpr size_t kNoSegment = 3; + std::pair horizontal_segments[3] = {{kNoSegment, kNoSegment}, + {kNoSegment, kNoSegment}, + {kNoSegment, kNoSegment}}; + for (size_t y = 0; y < 3; y++) { + for (size_t x = 0; x < 3; x++) { + if (!available_parts_mask[x][y]) continue; + JXL_DASSERT(horizontal_segments[y].second == kNoSegment || + horizontal_segments[y].second == x); + JXL_DASSERT((horizontal_segments[y].first == kNoSegment) == + (horizontal_segments[y].second == kNoSegment)); + if (horizontal_segments[y].first == kNoSegment) { + horizontal_segments[y].first = x; + } + horizontal_segments[y].second = x + 1; + } + } + if (horizontal_segments[0] == horizontal_segments[1] && + horizontal_segments[0] == horizontal_segments[2]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 3); + } else if (horizontal_segments[0] == horizontal_segments[1]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 2); + append_rect(horizontal_segments[2].first, horizontal_segments[2].second, 2, + 3); + } else if (horizontal_segments[1] == horizontal_segments[2]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 1); + append_rect(horizontal_segments[1].first, horizontal_segments[1].second, 1, + 3); + } else { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 1); + append_rect(horizontal_segments[1].first, horizontal_segments[1].second, 1, + 2); + append_rect(horizontal_segments[2].first, horizontal_segments[2].second, 2, + 3); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_group_border.h b/media/libjxl/src/lib/jxl/dec_group_border.h new file mode 100644 index 000000000..2d974c998 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_group_border.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_GROUP_BORDER_H_ +#define LIB_JXL_DEC_GROUP_BORDER_H_ + +#include + +#include + +#include "lib/jxl/base/arch_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +class GroupBorderAssigner { + public: + // Prepare the GroupBorderAssigner to handle a given frame. + void Init(const FrameDimensions& frame_dim); + // Marks a group as done, and returns the (at most 3) rects to run + // FinalizeImageRect on. `block_rect` must be the rect corresponding + // to the given `group_id`, measured in blocks. + void GroupDone(size_t group_id, size_t padx, size_t pady, + Rect* rects_to_finalize, size_t* num_to_finalize); + // Marks a group as not-done, for running re-paints. + void ClearDone(size_t group_id); + + static constexpr size_t kMaxToFinalize = 3; + + private: + FrameDimensions frame_dim_; + std::unique_ptr[]> counters_; + + // Constants to identify group positions relative to the corners. + static constexpr uint8_t kTopLeft = 0x01; + static constexpr uint8_t kTopRight = 0x02; + static constexpr uint8_t kBottomRight = 0x04; + static constexpr uint8_t kBottomLeft = 0x08; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_GROUP_BORDER_H_ diff --git a/media/libjxl/src/lib/jxl/dec_huffman.cc b/media/libjxl/src/lib/jxl/dec_huffman.cc new file mode 100644 index 000000000..05b275773 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_huffman.cc @@ -0,0 +1,255 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_huffman.h" + +#include /* for memset */ + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +static const int kCodeLengthCodes = 18; +static const uint8_t kCodeLengthCodeOrder[kCodeLengthCodes] = { + 1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15, +}; +static const uint8_t kDefaultCodeLength = 8; +static const uint8_t kCodeLengthRepeatCode = 16; + +int ReadHuffmanCodeLengths(const uint8_t* code_length_code_lengths, + int num_symbols, uint8_t* code_lengths, + BitReader* br) { + int symbol = 0; + uint8_t prev_code_len = kDefaultCodeLength; + int repeat = 0; + uint8_t repeat_code_len = 0; + int space = 32768; + HuffmanCode table[32]; + + uint16_t counts[16] = {0}; + for (int i = 0; i < kCodeLengthCodes; ++i) { + ++counts[code_length_code_lengths[i]]; + } + if (!BuildHuffmanTable(table, 5, code_length_code_lengths, kCodeLengthCodes, + &counts[0])) { + return 0; + } + + while (symbol < num_symbols && space > 0) { + const HuffmanCode* p = table; + uint8_t code_len; + br->Refill(); + p += br->PeekFixedBits<5>(); + br->Consume(p->bits); + code_len = (uint8_t)p->value; + if (code_len < kCodeLengthRepeatCode) { + repeat = 0; + code_lengths[symbol++] = code_len; + if (code_len != 0) { + prev_code_len = code_len; + space -= 32768u >> code_len; + } + } else { + const int extra_bits = code_len - 14; + int old_repeat; + int repeat_delta; + uint8_t new_len = 0; + if (code_len == kCodeLengthRepeatCode) { + new_len = prev_code_len; + } + if (repeat_code_len != new_len) { + repeat = 0; + repeat_code_len = new_len; + } + old_repeat = repeat; + if (repeat > 0) { + repeat -= 2; + repeat <<= extra_bits; + } + repeat += (int)br->ReadBits(extra_bits) + 3; + repeat_delta = repeat - old_repeat; + if (symbol + repeat_delta > num_symbols) { + return 0; + } + memset(&code_lengths[symbol], repeat_code_len, (size_t)repeat_delta); + symbol += repeat_delta; + if (repeat_code_len != 0) { + space -= repeat_delta << (15 - repeat_code_len); + } + } + } + if (space != 0) { + return 0; + } + memset(&code_lengths[symbol], 0, (size_t)(num_symbols - symbol)); + return true; +} + +static JXL_INLINE bool ReadSimpleCode(size_t alphabet_size, BitReader* br, + HuffmanCode* table) { + size_t max_bits = + (alphabet_size > 1u) ? FloorLog2Nonzero(alphabet_size - 1u) + 1 : 0; + + size_t num_symbols = br->ReadFixedBits<2>() + 1; + + uint16_t symbols[4] = {0}; + for (size_t i = 0; i < num_symbols; ++i) { + uint16_t symbol = br->ReadBits(max_bits); + if (symbol >= alphabet_size) { + return false; + } + symbols[i] = symbol; + } + + for (size_t i = 0; i < num_symbols - 1; ++i) { + for (size_t j = i + 1; j < num_symbols; ++j) { + if (symbols[i] == symbols[j]) return false; + } + } + + // 4 symbols have to option to encode. + if (num_symbols == 4) num_symbols += br->ReadFixedBits<1>(); + + const auto swap_symbols = [&symbols](size_t i, size_t j) { + uint16_t t = symbols[j]; + symbols[j] = symbols[i]; + symbols[i] = t; + }; + + size_t table_size = 1; + switch (num_symbols) { + case 1: + table[0] = {0, symbols[0]}; + break; + case 2: + if (symbols[0] > symbols[1]) swap_symbols(0, 1); + table[0] = {1, symbols[0]}; + table[1] = {1, symbols[1]}; + table_size = 2; + break; + case 3: + if (symbols[1] > symbols[2]) swap_symbols(1, 2); + table[0] = {1, symbols[0]}; + table[2] = {1, symbols[0]}; + table[1] = {2, symbols[1]}; + table[3] = {2, symbols[2]}; + table_size = 4; + break; + case 4: { + for (size_t i = 0; i < 3; ++i) { + for (size_t j = i + 1; j < 4; ++j) { + if (symbols[i] > symbols[j]) swap_symbols(i, j); + } + } + table[0] = {2, symbols[0]}; + table[2] = {2, symbols[1]}; + table[1] = {2, symbols[2]}; + table[3] = {2, symbols[3]}; + table_size = 4; + break; + } + case 5: { + if (symbols[2] > symbols[3]) swap_symbols(2, 3); + table[0] = {1, symbols[0]}; + table[1] = {2, symbols[1]}; + table[2] = {1, symbols[0]}; + table[3] = {3, symbols[2]}; + table[4] = {1, symbols[0]}; + table[5] = {2, symbols[1]}; + table[6] = {1, symbols[0]}; + table[7] = {3, symbols[3]}; + table_size = 8; + break; + } + default: { + // Unreachable. + return false; + } + } + + const uint32_t goal_size = 1u << kHuffmanTableBits; + while (table_size != goal_size) { + memcpy(&table[table_size], &table[0], + (size_t)table_size * sizeof(table[0])); + table_size <<= 1; + } + + return true; +} + +bool HuffmanDecodingData::ReadFromBitStream(size_t alphabet_size, + BitReader* br) { + if (alphabet_size > (1 << PREFIX_MAX_BITS)) return false; + + /* simple_code_or_skip is used as follows: + 1 for simple code; + 0 for no skipping, 2 skips 2 code lengths, 3 skips 3 code lengths */ + uint32_t simple_code_or_skip = br->ReadFixedBits<2>(); + if (simple_code_or_skip == 1u) { + table_.resize(1u << kHuffmanTableBits); + return ReadSimpleCode(alphabet_size, br, table_.data()); + } + + std::vector code_lengths(alphabet_size, 0); + uint8_t code_length_code_lengths[kCodeLengthCodes] = {0}; + int space = 32; + int num_codes = 0; + /* Static Huffman code for the code length code lengths */ + static const HuffmanCode huff[16] = { + {2, 0}, {2, 4}, {2, 3}, {3, 2}, {2, 0}, {2, 4}, {2, 3}, {4, 1}, + {2, 0}, {2, 4}, {2, 3}, {3, 2}, {2, 0}, {2, 4}, {2, 3}, {4, 5}, + }; + for (size_t i = simple_code_or_skip; i < kCodeLengthCodes && space > 0; ++i) { + const int code_len_idx = kCodeLengthCodeOrder[i]; + const HuffmanCode* p = huff; + uint8_t v; + br->Refill(); + p += br->PeekFixedBits<4>(); + br->Consume(p->bits); + v = (uint8_t)p->value; + code_length_code_lengths[code_len_idx] = v; + if (v != 0) { + space -= (32u >> v); + ++num_codes; + } + } + bool ok = (num_codes == 1 || space == 0) && + ReadHuffmanCodeLengths(code_length_code_lengths, alphabet_size, + &code_lengths[0], br); + + if (!ok) return false; + uint16_t counts[16] = {0}; + for (size_t i = 0; i < alphabet_size; ++i) { + ++counts[code_lengths[i]]; + } + table_.resize(alphabet_size + 376); + uint32_t table_size = + BuildHuffmanTable(table_.data(), kHuffmanTableBits, &code_lengths[0], + alphabet_size, &counts[0]); + table_.resize(table_size); + return (table_size > 0); +} + +// Decodes the next Huffman coded symbol from the bit-stream. +uint16_t HuffmanDecodingData::ReadSymbol(BitReader* br) const { + size_t n_bits; + const HuffmanCode* table = table_.data(); + table += br->PeekBits(kHuffmanTableBits); + n_bits = table->bits; + if (n_bits > kHuffmanTableBits) { + br->Consume(kHuffmanTableBits); + n_bits -= kHuffmanTableBits; + table += table->value; + table += br->PeekBits(n_bits); + } + br->Consume(table->bits); + return table->value; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_huffman.h b/media/libjxl/src/lib/jxl/dec_huffman.h new file mode 100644 index 000000000..162c3e309 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_huffman.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_HUFFMAN_H_ +#define LIB_JXL_DEC_HUFFMAN_H_ + +#include +#include + +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +static constexpr size_t kHuffmanTableBits = 8u; + +struct HuffmanDecodingData { + // Decodes the Huffman code lengths from the bit-stream and fills in the + // pre-allocated table with the corresponding 2-level Huffman decoding table. + // Returns false if the Huffman code lengths can not de decoded. + bool ReadFromBitStream(size_t alphabet_size, BitReader* br); + + uint16_t ReadSymbol(BitReader* br) const; + + std::vector table_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_HUFFMAN_H_ diff --git a/media/libjxl/src/lib/jxl/dec_modular.cc b/media/libjxl/src/lib/jxl/dec_modular.cc new file mode 100644 index 000000000..bf85eaa05 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_modular.cc @@ -0,0 +1,774 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_modular.h" + +#include + +#include +#include +#include + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc" +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Rebind; + +void MultiplySum(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const pixel_type* const JXL_RESTRICT row_in_Y, + const float factor, float* const JXL_RESTRICT row_out) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x)); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, row_out + x); + } +} + +void RgbFromSingle(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const float factor, float* out_r, float* out_g, + float* out_b) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, out_r + x); + Store(out, df, out_g + x); + Store(out, df, out_b + x); + } +} + +void SingleFromSingle(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const float factor, float* row_out) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, row_out + x); + } +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(MultiplySum); // Local function +HWY_EXPORT(RgbFromSingle); // Local function +HWY_EXPORT(SingleFromSingle); // Local function + +// Slow conversion using double precision multiplication, only +// needed when the bit depth is too high for single precision +void SingleFromSingleAccurate(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const double factor, float* row_out) { + for (size_t x = 0; x < xsize; x++) { + row_out[x] = row_in[x] * factor; + } +} + +// convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int +// back to binary32 float +void int_to_float(const pixel_type* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const size_t xsize, + const int bits, const int exp_bits) { + if (bits == 32) { + JXL_ASSERT(sizeof(pixel_type) == sizeof(float)); + JXL_ASSERT(exp_bits == 8); + memcpy(row_out, row_in, xsize * sizeof(float)); + return; + } + int exp_bias = (1 << (exp_bits - 1)) - 1; + int sign_shift = bits - 1; + int mant_bits = bits - exp_bits - 1; + int mant_shift = 23 - mant_bits; + for (size_t x = 0; x < xsize; ++x) { + uint32_t f; + memcpy(&f, &row_in[x], 4); + int signbit = (f >> sign_shift); + f &= (1 << sign_shift) - 1; + if (f == 0) { + row_out[x] = (signbit ? -0.f : 0.f); + continue; + } + int exp = (f >> mant_bits); + int mantissa = (f & ((1 << mant_bits) - 1)); + mantissa <<= mant_shift; + // Try to normalize only if there is space for maneuver. + if (exp == 0 && exp_bits < 8) { + // subnormal number + while ((mantissa & 0x800000) == 0) { + mantissa <<= 1; + exp--; + } + exp++; + // remove leading 1 because it is implicit now + mantissa &= 0x7fffff; + } + exp -= exp_bias; + // broke up the arbitrary float into its parts, now reassemble into + // binary32 + exp += 127; + JXL_ASSERT(exp >= 0); + f = (signbit ? 0x80000000 : 0); + f |= (exp << 23); + f |= mantissa; + memcpy(&row_out[x], &f, 4); + } +} + +std::string ModularStreamId::DebugString() const { + std::ostringstream os; + os << (kind == kGlobalData ? "ModularGlobal" + : kind == kVarDCTDC ? "VarDCTDC" + : kind == kModularDC ? "ModularDC" + : kind == kACMetadata ? "ACMeta" + : kind == kQuantTable ? "QuantTable" + : kind == kModularAC ? "ModularAC" + : ""); + if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata || + kind == kModularAC) { + os << " group " << group_id; + } + if (kind == kModularAC) { + os << " pass " << pass_id; + } + if (kind == kQuantTable) { + os << " " << quant_table_id; + } + return os.str(); +} + +Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, + const FrameHeader& frame_header, + bool allow_truncated_group) { + bool decode_color = frame_header.encoding == FrameEncoding::kModular; + const auto& metadata = frame_header.nonserialized_metadata->m; + bool is_gray = metadata.color_encoding.IsGray(); + size_t nb_chans = 3; + if (is_gray && frame_header.color_transform == ColorTransform::kNone) { + nb_chans = 1; + } + do_color = decode_color; + size_t nb_extra = metadata.extra_channel_info.size(); + bool has_tree = reader->ReadBits(1); + if (!allow_truncated_group || + reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) { + if (has_tree) { + size_t tree_size_limit = + std::min(static_cast(1 << 22), + 1024 + frame_dim.xsize * frame_dim.ysize * + (nb_chans + nb_extra) / 16); + JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); + } + } + if (!do_color) nb_chans = 0; + + bool fp = metadata.bit_depth.floating_point_sample; + + // bits_per_sample is just metadata for XYB images. + if (metadata.bit_depth.bits_per_sample >= 32 && do_color && + frame_header.color_transform != ColorTransform::kXYB) { + if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { + return JXL_FAILURE("uint32_t not supported in dec_modular"); + } else if (metadata.bit_depth.bits_per_sample > 32) { + return JXL_FAILURE("bits_per_sample > 32 not supported"); + } + } + + Image gi(frame_dim.xsize, frame_dim.ysize, metadata.bit_depth.bits_per_sample, + nb_chans + nb_extra); + + all_same_shift = true; + if (frame_header.color_transform == ColorTransform::kYCbCr) { + for (size_t c = 0; c < nb_chans; c++) { + gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c); + gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c); + size_t xsize_shifted = + DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift); + size_t ysize_shifted = + DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift); + gi.channel[c].shrink(xsize_shifted, ysize_shifted); + if (gi.channel[c].hshift != gi.channel[0].hshift || + gi.channel[c].vshift != gi.channel[0].vshift) + all_same_shift = false; + } + } + + for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) { + size_t ecups = frame_header.extra_channel_upsampling[ec]; + gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups), + DivCeil(frame_dim.ysize_upsampled, ecups)); + gi.channel[c].hshift = gi.channel[c].vshift = + CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling); + if (gi.channel[c].hshift != gi.channel[0].hshift || + gi.channel[c].vshift != gi.channel[0].vshift) + all_same_shift = false; + } + + JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s", + gi.DebugString().c_str()); + ModularOptions options; + options.max_chan_size = frame_dim.group_dim; + options.group_dim = frame_dim.group_dim; + Status dec_status = ModularGenericDecompress( + reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim), + &options, + /*undo_transforms=*/false, &tree, &code, &context_map, + allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) { + return JXL_FAILURE("Failed to decode global modular info"); + } + + // TODO(eustas): are we sure this can be done after partial decode? + have_something = false; + for (size_t c = 0; c < gi.channel.size(); c++) { + Channel& gic = gi.channel[c]; + if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim && + gic.h <= frame_dim.group_dim) + have_something = true; + } + // move global transforms to groups if possible + if (!have_something && all_same_shift) { + if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) { + global_transform = gi.transform; + gi.transform.clear(); + // TODO(jon): also move no-delta-palette out (trickier though) + } + } + full_image = std::move(gi); + JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s", + full_image.DebugString().c_str()); + return dec_status; +} + +void ModularFrameDecoder::MaybeDropFullImage() { + if (full_image.transform.empty() && !have_something && all_same_shift) { + use_full_image = false; + JXL_DEBUG_V(6, "Dropping full image"); + for (auto& ch : full_image.channel) { + // keep metadata on channels around, but dealloc their planes + ch.plane = Plane(); + } + } +} + +Status ModularFrameDecoder::DecodeGroup( + const Rect& rect, BitReader* reader, int minShift, int maxShift, + const ModularStreamId& stream, bool zerofill, PassesDecoderState* dec_state, + RenderPipelineInput* render_pipeline_input, bool allow_truncated, + bool* should_run_pipeline) { + JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s", + stream.DebugString().c_str(), Description(rect).c_str(), minShift, + maxShift, zerofill ? "using zerofill" : ""); + JXL_DASSERT(stream.kind == ModularStreamId::kModularDC || + stream.kind == ModularStreamId::kModularAC); + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + Image gi(xsize, ysize, full_image.bitdepth, 0); + // start at the first bigger-than-groupsize non-metachannel + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; + } + size_t beginc = c; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + if (zerofill && use_full_image) { + for (size_t y = 0; y < r.ysize(); ++y) { + pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y); + memset(row_out, 0, r.xsize() * sizeof(*row_out)); + } + } else { + Channel gc(r.xsize(), r.ysize()); + if (zerofill) ZeroFillImage(&gc.plane); + gc.hshift = fc.hshift; + gc.vshift = fc.vshift; + gi.channel.emplace_back(std::move(gc)); + } + } + if (zerofill && use_full_image) return true; + // Return early if there's nothing to decode. Otherwise there might be + // problems later (in ModularImageToDecodedRect). + if (gi.channel.empty()) { + if (dec_state && should_run_pipeline) { + const auto& frame_header = dec_state->shared->frame_header; + const auto* metadata = frame_header.nonserialized_metadata; + if (do_color || metadata->m.num_extra_channels > 0) { + // Signal to FrameDecoder that we do not have some of the required input + // for the render pipeline. + *should_run_pipeline = false; + } + } + JXL_DEBUG_V(6, "Nothing to decode, returning early."); + return true; + } + ModularOptions options; + if (!zerofill) { + auto status = ModularGenericDecompress( + reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, + /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated); + if (!allow_truncated) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; + } + // Undo global transforms that have been pushed to the group level + if (!use_full_image) { + JXL_ASSERT(render_pipeline_input); + for (auto t : global_transform) { + JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header)); + } + JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(gi, dec_state, nullptr, + *render_pipeline_input, + Rect(0, 0, gi.w, gi.h))); + return true; + } + int gic = 0; + for (c = beginc; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + JXL_ASSERT(use_full_image); + CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()), + /*from=*/gi.channel[gic].plane, + /*rect_to=*/r, /*to=*/&fc.plane); + gic++; + } + return true; +} + +Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state) { + const Rect r = dec_state->shared->DCGroupRect(group_id); + // TODO(eustas): investigate if we could reduce the impact of + // EvalRationalPolynomial; generally speaking, the limit is + // 2**(128/(3*magic)), where 128 comes from IEEE 754 exponent, + // 3 comes from XybToRgb that cubes the values, and "magic" is + // the sum of all other contributions. 2**18 is known to lead + // to NaN on input found by fuzzing (see commit message). + Image image(r.xsize(), r.ysize(), full_image.bitdepth, 3); + size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim); + reader->Refill(); + size_t extra_precision = reader->ReadFixedBits<2>(); + float mul = 1.0f / (1 << extra_precision); + ModularOptions options; + for (size_t c = 0; c < 3; c++) { + Channel& ch = image.channel[c < 2 ? c ^ 1 : c]; + ch.w >>= dec_state->shared->frame_header.chroma_subsampling.HShift(c); + ch.h >>= dec_state->shared->frame_header.chroma_subsampling.VShift(c); + ch.shrink(); + } + if (!ModularGenericDecompress( + reader, image, /*header=*/nullptr, stream_id, &options, + /*undo_transforms=*/true, &tree, &code, &context_map)) { + return JXL_FAILURE("Failed to decode modular DC group"); + } + DequantDC(r, &dec_state->shared_storage.dc_storage, + &dec_state->shared_storage.quant_dc, image, + dec_state->shared->quantizer.MulDC(), mul, + dec_state->shared->cmap.DCFactors(), + dec_state->shared->frame_header.chroma_subsampling, + dec_state->shared->block_ctx_map); + return true; +} + +Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state) { + const Rect r = dec_state->shared->DCGroupRect(group_id); + size_t upper_bound = r.xsize() * r.ysize(); + reader->Refill(); + size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1; + size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim); + // YToX, YToB, ACS + QF, EPF + Image image(r.xsize(), r.ysize(), full_image.bitdepth, 4); + static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); + Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); + image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[2] = Channel(count, 2, 0, 0); + ModularOptions options; + if (!ModularGenericDecompress( + reader, image, /*header=*/nullptr, stream_id, &options, + /*undo_transforms=*/true, &tree, &code, &context_map)) { + return JXL_FAILURE("Failed to decode AC metadata"); + } + ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr, + &dec_state->shared_storage.cmap.ytox_map); + ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr, + &dec_state->shared_storage.cmap.ytob_map); + size_t num = 0; + bool is444 = dec_state->shared->frame_header.chroma_subsampling.Is444(); + auto& ac_strategy = dec_state->shared_storage.ac_strategy; + size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize()); + size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize()); + uint32_t local_used_acs = 0; + for (size_t iy = 0; iy < r.ysize(); iy++) { + size_t y = r.y0() + iy; + int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); + uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy); + int32_t* row_in_1 = image.channel[2].plane.Row(0); + int32_t* row_in_2 = image.channel[2].plane.Row(1); + int32_t* row_in_3 = image.channel[3].plane.Row(iy); + for (size_t ix = 0; ix < r.xsize(); ix++) { + size_t x = r.x0() + ix; + int sharpness = row_in_3[ix]; + if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) { + return JXL_FAILURE("Corrupted sharpness field"); + } + row_epf[ix] = sharpness; + if (ac_strategy.IsValid(x, y)) { + continue; + } + + if (num >= count) return JXL_FAILURE("Corrupted stream"); + + if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) { + return JXL_FAILURE("Invalid AC strategy"); + } + local_used_acs |= 1u << row_in_1[num]; + AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]); + if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) && + !is444) { + return JXL_FAILURE( + "AC strategy not compatible with chroma subsampling"); + } + // Ensure that blocks do not overflow *AC* groups. + size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks; + size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks; + size_t next_x_dct_block = x + acs.covered_blocks_x(); + size_t next_y_dct_block = y + acs.covered_blocks_y(); + if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) { + return JXL_FAILURE("Invalid AC strategy, x overflow"); + } + if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) { + return JXL_FAILURE("Invalid AC strategy, y overflow"); + } + JXL_RETURN_IF_ERROR( + ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num]))); + row_qf[ix] = 1 + std::max(0, std::min(Quantizer::kQuantMax - 1, + row_in_2[num])); + num++; + } + } + dec_state->used_acs |= local_used_acs; + if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(r, dec_state); + } + return true; +} + +Status ModularFrameDecoder::ModularImageToDecodedRect( + Image& gi, PassesDecoderState* dec_state, jxl::ThreadPool* pool, + RenderPipelineInput& render_pipeline_input, Rect modular_rect) { + const auto& frame_header = dec_state->shared->frame_header; + const auto* metadata = frame_header.nonserialized_metadata; + JXL_CHECK(gi.transform.empty()); + + auto get_row = [&](size_t c, size_t y) { + const auto& buffer = render_pipeline_input.GetBuffer(c); + return buffer.second.Row(buffer.first, y); + }; + + size_t c = 0; + if (do_color) { + const bool rgb_from_gray = + metadata->m.color_encoding.IsGray() && + frame_header.color_transform == ColorTransform::kNone; + const bool fp = metadata->m.bit_depth.floating_point_sample && + frame_header.color_transform != ColorTransform::kXYB; + for (; c < 3; c++) { + double factor = full_image.bitdepth < 32 + ? 1.0 / ((1u << full_image.bitdepth) - 1) + : 0; + size_t c_in = c; + if (frame_header.color_transform == ColorTransform::kXYB) { + factor = dec_state->shared->matrices.DCQuants()[c]; + // XYB is encoded as YX(B-Y) + if (c < 2) c_in = 1 - c; + } else if (rgb_from_gray) { + c_in = 0; + } + JXL_ASSERT(c_in < gi.channel.size()); + Channel& ch_in = gi.channel[c_in]; + // TODO(eustas): could we detect it on earlier stage? + if (ch_in.w == 0 || ch_in.h == 0) { + return JXL_FAILURE("Empty image"); + } + JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3); + Rect r = render_pipeline_input.GetBuffer(c).second; + Rect mr(modular_rect.x0() >> ch_in.hshift, + modular_rect.y0() >> ch_in.vshift, + DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), + DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); + mr = mr.Crop(ch_in.plane); + size_t xsize_shifted = r.xsize(); + size_t ysize_shifted = r.ysize(); + if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { + return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS + "x%" PRIuS + " modular channel into " + "a %" PRIuS "x%" PRIuS " rect", + mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); + } + if (frame_header.color_transform == ColorTransform::kXYB && c == 2) { + JXL_ASSERT(!fp); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + const pixel_type* const JXL_RESTRICT row_in_Y = + mr.Row(&gi.channel[0].plane, y); + float* const JXL_RESTRICT row_out = get_row(c, y); + HWY_DYNAMIC_DISPATCH(MultiplySum) + (xsize_shifted, row_in, row_in_Y, factor, row_out); + }, + "ModularIntToFloat")); + } else if (fp) { + int bits = metadata->m.bit_depth.bits_per_sample; + int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + if (rgb_from_gray) { + for (size_t cc = 0; cc < 3; cc++) { + float* const JXL_RESTRICT row_out = get_row(cc, y); + int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + } + } else { + float* const JXL_RESTRICT row_out = get_row(c, y); + int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + } + }, + "ModularIntToFloat_losslessfloat")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + if (rgb_from_gray) { + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(RgbFromSingle) + (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y), + get_row(2, y)); + } else { + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(0, y)); + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(1, y)); + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(2, y)); + } + } else { + float* const JXL_RESTRICT row_out = get_row(c, y); + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (xsize_shifted, row_in, factor, row_out); + } else { + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + row_out); + } + } + }, + "ModularIntToFloat")); + } + if (rgb_from_gray) { + break; + } + } + if (rgb_from_gray) { + c = 1; + } + } + size_t num_extra_channels = metadata->m.num_extra_channels; + for (size_t ec = 0; ec < num_extra_channels; ec++, c++) { + const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec]; + int bits = eci.bit_depth.bits_per_sample; + int exp_bits = eci.bit_depth.exponent_bits_per_sample; + bool fp = eci.bit_depth.floating_point_sample; + JXL_ASSERT(fp || bits < 32); + const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1)); + JXL_ASSERT(c < gi.channel.size()); + Channel& ch_in = gi.channel[c]; + Rect r = render_pipeline_input.GetBuffer(3 + ec).second; + Rect mr(modular_rect.x0() >> ch_in.hshift, + modular_rect.y0() >> ch_in.vshift, + DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), + DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); + mr = mr.Crop(ch_in.plane); + if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { + return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS + "x%" PRIuS + " modular channel into " + "a %" PRIuS "x%" PRIuS " rect", + mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); + } + for (size_t y = 0; y < r.ysize(); ++y) { + float* const JXL_RESTRICT row_out = + r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y); + const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); + if (fp) { + int_to_float(row_in, row_out, r.xsize(), bits, exp_bits); + } else { + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (r.xsize(), row_in, factor, row_out); + } else { + SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out); + } + } + } + } + return true; +} + +Status ModularFrameDecoder::FinalizeDecoding(PassesDecoderState* dec_state, + jxl::ThreadPool* pool, + bool inplace) { + if (!use_full_image) return true; + Image gi = (inplace ? std::move(full_image) : full_image.clone()); + size_t xsize = gi.w; + size_t ysize = gi.h; + + JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s", + gi.DebugString().c_str()); + + // Don't use threads if total image size is smaller than a group + if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr; + + // Undo the global transforms + gi.undo_transforms(global_header.wp_header, pool); + JXL_DASSERT(global_transform.empty()); + if (gi.error) return JXL_FAILURE("Undoing transforms failed"); + + for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) { + dec_state->render_pipeline->ClearDone(i); + } + std::atomic has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, dec_state->shared->frame_dim.num_groups, + [&](size_t num_threads) { + const auto& frame_header = dec_state->shared->frame_header; + bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT || + (frame_header.flags & FrameHeader::kNoise)); + return dec_state->render_pipeline->PrepareForThreads(num_threads, + use_group_ids); + }, + [&](const uint32_t group, size_t thread_id) { + RenderPipelineInput input = + dec_state->render_pipeline->GetInputBuffers(group, thread_id); + if (!ModularImageToDecodedRect(gi, dec_state, nullptr, input, + dec_state->shared->GroupRect(group))) { + has_error = true; + return; + } + input.Done(); + }, + "ModularToRect")); + if (has_error) { + return JXL_FAILURE("Error producing input to render pipeline"); + } + return true; +} + +static constexpr const float kAlmostZero = 1e-8f; + +Status ModularFrameDecoder::DecodeQuantTable( + size_t required_size_x, size_t required_size_y, BitReader* br, + QuantEncoding* encoding, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den)); + if (encoding->qraw.qtable_den < kAlmostZero) { + // qtable[] values are already checked for <= 0 so the denominator may not + // be negative. + return JXL_FAILURE("Invalid qtable_den: value too small"); + } + Image image(required_size_x, required_size_y, 8, 3); + ModularOptions options; + if (modular_frame_decoder) { + JXL_RETURN_IF_ERROR(ModularGenericDecompress( + br, image, /*header=*/nullptr, + ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim), + &options, /*undo_transforms=*/true, &modular_frame_decoder->tree, + &modular_frame_decoder->code, &modular_frame_decoder->context_map)); + } else { + JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr, + 0, &options, + /*undo_transforms=*/true)); + } + if (!encoding->qraw.qtable) { + encoding->qraw.qtable = new std::vector(); + } + encoding->qraw.qtable->resize(required_size_x * required_size_y * 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < required_size_y; y++) { + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < required_size_x; x++) { + (*encoding->qraw.qtable)[c * required_size_x * required_size_y + + y * required_size_x + x] = row[x]; + if (row[x] <= 0) { + return JXL_FAILURE("Invalid raw quantization table"); + } + } + } + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/dec_modular.h b/media/libjxl/src/lib/jxl/dec_modular.h new file mode 100644 index 000000000..ec94b4648 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_modular.h @@ -0,0 +1,141 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_MODULAR_H_ +#define LIB_JXL_DEC_MODULAR_H_ + +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +struct ModularStreamId { + enum Kind { + kGlobalData, + kVarDCTDC, + kModularDC, + kACMetadata, + kQuantTable, + kModularAC + }; + Kind kind; + size_t quant_table_id; + size_t group_id; // DC or AC group id. + size_t pass_id; // Only for kModularAC. + size_t ID(const FrameDimensions& frame_dim) const { + size_t id = 0; + switch (kind) { + case kGlobalData: + id = 0; + break; + case kVarDCTDC: + id = 1 + group_id; + break; + case kModularDC: + id = 1 + frame_dim.num_dc_groups + group_id; + break; + case kACMetadata: + id = 1 + 2 * frame_dim.num_dc_groups + group_id; + break; + case kQuantTable: + id = 1 + 3 * frame_dim.num_dc_groups + quant_table_id; + break; + case kModularAC: + id = 1 + 3 * frame_dim.num_dc_groups + DequantMatrices::kNum + + frame_dim.num_groups * pass_id + group_id; + break; + }; + return id; + } + static ModularStreamId Global() { + return ModularStreamId{kGlobalData, 0, 0, 0}; + } + static ModularStreamId VarDCTDC(size_t group_id) { + return ModularStreamId{kVarDCTDC, 0, group_id, 0}; + } + static ModularStreamId ModularDC(size_t group_id) { + return ModularStreamId{kModularDC, 0, group_id, 0}; + } + static ModularStreamId ACMetadata(size_t group_id) { + return ModularStreamId{kACMetadata, 0, group_id, 0}; + } + static ModularStreamId QuantTable(size_t quant_table_id) { + JXL_ASSERT(quant_table_id < DequantMatrices::kNum); + return ModularStreamId{kQuantTable, quant_table_id, 0, 0}; + } + static ModularStreamId ModularAC(size_t group_id, size_t pass_id) { + return ModularStreamId{kModularAC, 0, group_id, pass_id}; + } + static size_t Num(const FrameDimensions& frame_dim, size_t passes) { + return ModularAC(0, passes).ID(frame_dim); + } + std::string DebugString() const; +}; + +class ModularFrameDecoder { + public: + void Init(const FrameDimensions& frame_dim) { this->frame_dim = frame_dim; } + Status DecodeGlobalInfo(BitReader* reader, const FrameHeader& frame_header, + bool allow_truncated_group); + Status DecodeGroup(const Rect& rect, BitReader* reader, int minShift, + int maxShift, const ModularStreamId& stream, bool zerofill, + PassesDecoderState* dec_state, + RenderPipelineInput* render_pipeline_input, + bool allow_truncated, bool* should_run_pipeline = nullptr); + // Decodes a VarDCT DC group (`group_id`) from the given `reader`. + Status DecodeVarDCTDC(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state); + // Decodes a VarDCT AC Metadata group (`group_id`) from the given `reader`. + Status DecodeAcMetadata(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state); + // Decodes a RAW quant table from `br` into the given `encoding`, of size + // `required_size_x x required_size_y`. If `modular_frame_decoder` is passed, + // its global tree is used, otherwise no global tree is used. + static Status DecodeQuantTable(size_t required_size_x, size_t required_size_y, + BitReader* br, QuantEncoding* encoding, + size_t idx, + ModularFrameDecoder* modular_frame_decoder); + // if inplace is true, this can only be called once + // if it is false, it can be called multiple times (e.g. for progressive + // steps) + Status FinalizeDecoding(PassesDecoderState* dec_state, jxl::ThreadPool* pool, + bool inplace); + bool have_dc() const { return have_something; } + void MaybeDropFullImage(); + bool UsesFullImage() const { return use_full_image; } + + private: + Status ModularImageToDecodedRect(Image& gi, PassesDecoderState* dec_state, + jxl::ThreadPool* pool, + RenderPipelineInput& render_pipeline_input, + Rect modular_rect); + + Image full_image; + std::vector global_transform; + FrameDimensions frame_dim; + bool do_color; + bool have_something; + bool use_full_image = true; + bool all_same_shift; + Tree tree; + ANSCode code; + std::vector context_map; + GroupHeader global_header; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_MODULAR_H_ diff --git a/media/libjxl/src/lib/jxl/dec_noise.cc b/media/libjxl/src/lib/jxl/dec_noise.cc new file mode 100644 index 000000000..f48398b5c --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_noise.cc @@ -0,0 +1,131 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_noise.h" + +#include +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_noise.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/xorshift128plus-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Vec; + +using D = HWY_CAPPED(float, kBlockDim); +using DI = hwy::HWY_NAMESPACE::Rebind; +using DI8 = hwy::HWY_NAMESPACE::Repartition; + +// Converts one vector's worth of random bits to floats in [1, 2). +// NOTE: as the convolution kernel sums to 0, it doesn't matter if inputs are in +// [0, 1) or in [1, 2). +void BitsToFloat(const uint32_t* JXL_RESTRICT random_bits, + float* JXL_RESTRICT floats) { + const HWY_FULL(float) df; + const HWY_FULL(uint32_t) du; + + const auto bits = Load(du, random_bits); + // 1.0 + 23 random mantissa bits = [1, 2) + const auto rand12 = BitCast(df, Or(ShiftRight<9>(bits), Set(du, 0x3F800000))); + Store(rand12, df, floats); +} + +void RandomImage(Xorshift128Plus* rng, const Rect& rect, + ImageF* JXL_RESTRICT noise) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + + // May exceed the vector size, hence we have two loops over x below. + constexpr size_t kFloatsPerBatch = + Xorshift128Plus::N * sizeof(uint64_t) / sizeof(float); + HWY_ALIGN uint64_t batch[Xorshift128Plus::N] = {}; + + const HWY_FULL(float) df; + const size_t N = Lanes(df); + + for (size_t y = 0; y < ysize; ++y) { + float* JXL_RESTRICT row = rect.Row(noise, y); + + size_t x = 0; + // Only entire batches (avoids exceeding the image padding). + for (; x + kFloatsPerBatch <= xsize; x += kFloatsPerBatch) { + rng->Fill(batch); + for (size_t i = 0; i < kFloatsPerBatch; i += Lanes(df)) { + BitsToFloat(reinterpret_cast(batch) + i, row + x + i); + } + } + + // Any remaining pixels, rounded up to vectors (safe due to padding). + rng->Fill(batch); + size_t batch_pos = 0; // < kFloatsPerBatch + for (; x < xsize; x += N) { + BitsToFloat(reinterpret_cast(batch) + batch_pos, + row + x); + batch_pos += N; + } + } +} +void Random3Planes(size_t visible_frame_index, size_t nonvisible_frame_index, + size_t x0, size_t y0, const std::pair& plane0, + const std::pair& plane1, + const std::pair& plane2) { + HWY_ALIGN Xorshift128Plus rng(visible_frame_index, nonvisible_frame_index, x0, + y0); + RandomImage(&rng, plane0.second, plane0.first); + RandomImage(&rng, plane1.second, plane1.first); + RandomImage(&rng, plane2.second, plane2.first); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Random3Planes); +void Random3Planes(size_t visible_frame_index, size_t nonvisible_frame_index, + size_t x0, size_t y0, const std::pair& plane0, + const std::pair& plane1, + const std::pair& plane2) { + return HWY_DYNAMIC_DISPATCH(Random3Planes)(visible_frame_index, + nonvisible_frame_index, x0, y0, + plane0, plane1, plane2); +} + +void DecodeFloatParam(float precision, float* val, BitReader* br) { + const int absval_quant = br->ReadFixedBits<10>(); + *val = absval_quant / precision; +} + +Status DecodeNoise(BitReader* br, NoiseParams* noise_params) { + for (float& i : noise_params->lut) { + DecodeFloatParam(kNoisePrecision, &i, br); + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/dec_noise.h b/media/libjxl/src/lib/jxl/dec_noise.h new file mode 100644 index 000000000..f8c028636 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_noise.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_NOISE_H_ +#define LIB_JXL_DEC_NOISE_H_ + +// Noise synthesis. Currently disabled. + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +void Random3Planes(size_t visible_frame_index, size_t nonvisible_frame_index, + size_t x0, size_t y0, const std::pair& plane0, + const std::pair& plane1, + const std::pair& plane2); + +// Must only call if FrameHeader.flags.kNoise. +Status DecodeNoise(BitReader* br, NoiseParams* noise_params); + +} // namespace jxl + +#endif // LIB_JXL_DEC_NOISE_H_ diff --git a/media/libjxl/src/lib/jxl/dec_patch_dictionary.cc b/media/libjxl/src/lib/jxl/dec_patch_dictionary.cc new file mode 100644 index 000000000..4f8720922 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_patch_dictionary.cc @@ -0,0 +1,347 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_patch_dictionary.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/blending.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/patch_dictionary_internal.h" + +namespace jxl { + +Status PatchDictionary::Decode(BitReader* br, size_t xsize, size_t ysize, + bool* uses_extra_channels) { + positions_.clear(); + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumPatchDictionaryContexts, &code, &context_map)); + ANSSymbolReader decoder(&code, br); + + auto read_num = [&](size_t context) { + size_t r = decoder.ReadHybridUint(context, br, context_map); + return r; + }; + + size_t num_ref_patch = read_num(kNumRefPatchContext); + // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8 + // bytes per size_t) + const size_t num_pixels = xsize * ysize; + const size_t max_ref_patches = 1024 + num_pixels / 4; + const size_t max_patches = max_ref_patches * 4; + const size_t max_blending_infos = max_patches * 4; + if (num_ref_patch > max_ref_patches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + size_t num_ec = shared_->metadata->m.num_extra_channels; + + size_t total_patches = 0; + size_t next_size = 1; + + for (size_t id = 0; id < num_ref_patch; id++) { + PatchReferencePosition ref_pos; + ref_pos.ref = read_num(kReferenceFrameContext); + if (ref_pos.ref >= kMaxNumReferenceFrames || + shared_->reference_frames[ref_pos.ref].frame->xsize() == 0) { + return JXL_FAILURE("Invalid reference frame ID"); + } + if (!shared_->reference_frames[ref_pos.ref].ib_is_in_xyb) { + return JXL_FAILURE( + "Patches cannot use frames saved post color transforms"); + } + const ImageBundle& ib = *shared_->reference_frames[ref_pos.ref].frame; + ref_pos.x0 = read_num(kPatchReferencePositionContext); + ref_pos.y0 = read_num(kPatchReferencePositionContext); + ref_pos.xsize = read_num(kPatchSizeContext) + 1; + ref_pos.ysize = read_num(kPatchSizeContext) + 1; + if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) { + return JXL_FAILURE("Invalid position specified in reference frame"); + } + if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) { + return JXL_FAILURE("Invalid position specified in reference frame"); + } + size_t id_count = read_num(kPatchCountContext) + 1; + total_patches += id_count; + if (total_patches > max_patches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + if (next_size < total_patches) { + next_size *= 2; + next_size = std::min(next_size, max_patches); + } + if (next_size * (num_ec + 1) > max_blending_infos) { + return JXL_FAILURE("Too many patches in dictionary"); + } + positions_.reserve(next_size); + blendings_.reserve(next_size * (num_ec + 1)); + for (size_t i = 0; i < id_count; i++) { + PatchPosition pos; + pos.ref_pos_idx = ref_positions_.size(); + if (i == 0) { + pos.x = read_num(kPatchPositionContext); + pos.y = read_num(kPatchPositionContext); + } else { + pos.x = + positions_.back().x + UnpackSigned(read_num(kPatchOffsetContext)); + pos.y = + positions_.back().y + UnpackSigned(read_num(kPatchOffsetContext)); + } + if (pos.x + ref_pos.xsize > xsize) { + return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS + " > %" PRIuS, + pos.x, ref_pos.xsize, xsize); + } + if (pos.y + ref_pos.ysize > ysize) { + return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS + " > %" PRIuS, + pos.y, ref_pos.ysize, ysize); + } + for (size_t j = 0; j < num_ec + 1; j++) { + uint32_t blend_mode = read_num(kPatchBlendModeContext); + if (blend_mode >= uint32_t(PatchBlendMode::kNumBlendModes)) { + return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode); + } + PatchBlending info; + info.mode = static_cast(blend_mode); + if (UsesAlpha(info.mode)) { + *uses_extra_channels = true; + } + if (info.mode != PatchBlendMode::kNone && j > 0) { + *uses_extra_channels = true; + } + if (UsesAlpha(info.mode) && + shared_->metadata->m.extra_channel_info.size() > 1) { + info.alpha_channel = read_num(kPatchAlphaChannelContext); + if (info.alpha_channel >= + shared_->metadata->m.extra_channel_info.size()) { + return JXL_FAILURE( + "Invalid alpha channel for blending: %u out of %u\n", + info.alpha_channel, + (uint32_t)shared_->metadata->m.extra_channel_info.size()); + } + } else { + info.alpha_channel = 0; + } + if (UsesClamp(info.mode)) { + info.clamp = read_num(kPatchClampContext); + } else { + info.clamp = false; + } + blendings_.push_back(info); + } + positions_.push_back(std::move(pos)); + } + ref_positions_.emplace_back(std::move(ref_pos)); + } + positions_.shrink_to_fit(); + + if (!decoder.CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + + ComputePatchTree(); + return true; +} + +int PatchDictionary::GetReferences() const { + int result = 0; + for (size_t i = 0; i < ref_positions_.size(); ++i) { + result |= (1 << static_cast(ref_positions_[i].ref)); + } + return result; +} + +namespace { +struct PatchInterval { + size_t idx; + size_t y0, y1; +}; +} // namespace + +void PatchDictionary::ComputePatchTree() { + patch_tree_.clear(); + num_patches_.clear(); + sorted_patches_y0_.clear(); + sorted_patches_y1_.clear(); + if (positions_.empty()) { + return; + } + // Create a y-interval for each patch. + std::vector intervals(positions_.size()); + for (size_t i = 0; i < positions_.size(); ++i) { + const auto& pos = positions_[i]; + intervals[i].idx = i; + intervals[i].y0 = pos.y; + intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize; + } + auto sort_by_y0 = [&intervals](size_t start, size_t end) { + std::sort(intervals.data() + start, intervals.data() + end, + [](const PatchInterval& i0, const PatchInterval& i1) { + return i0.y0 < i1.y0; + }); + }; + auto sort_by_y1 = [&intervals](size_t start, size_t end) { + std::sort(intervals.data() + start, intervals.data() + end, + [](const PatchInterval& i0, const PatchInterval& i1) { + return i0.y1 < i1.y1; + }); + }; + // Count the number of patches for each row. + sort_by_y1(0, intervals.size()); + num_patches_.resize(intervals.back().y1); + for (auto iv : intervals) { + for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++; + } + PatchTreeNode root; + root.start = 0; + root.num = intervals.size(); + patch_tree_.push_back(root); + size_t next = 0; + while (next < patch_tree_.size()) { + auto& node = patch_tree_[next]; + size_t start = node.start; + size_t end = node.start + node.num; + // Choose the y_center for this node to be the median of interval starts. + sort_by_y0(start, end); + size_t middle_idx = start + node.num / 2; + node.y_center = intervals[middle_idx].y0; + // Divide the intervals in [start, end) into three groups: + // * those completely to the right of y_center: [right_start, end) + // * those overlapping y_center: [left_end, right_start) + // * those completely to the left of y_center: [start, left_end) + size_t right_start = middle_idx; + while (right_start < end && intervals[right_start].y0 == node.y_center) { + ++right_start; + } + sort_by_y1(start, right_start); + size_t left_end = right_start; + while (left_end > start && intervals[left_end - 1].y1 > node.y_center) { + --left_end; + } + // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node. + node.num = right_start - left_end; + node.start = sorted_patches_y0_.size(); + for (ssize_t i = static_cast(right_start) - 1; + i >= static_cast(left_end); --i) { + sorted_patches_y1_.push_back({intervals[i].y1, intervals[i].idx}); + } + sort_by_y0(left_end, right_start); + for (size_t i = left_end; i < right_start; ++i) { + sorted_patches_y0_.push_back({intervals[i].y0, intervals[i].idx}); + } + // Create the left and right nodes (if not empty). + node.left_child = node.right_child = -1; + if (left_end > start) { + PatchTreeNode left; + left.start = start; + left.num = left_end - left.start; + patch_tree_[next].left_child = patch_tree_.size(); + patch_tree_.push_back(left); + } + if (right_start < end) { + PatchTreeNode right; + right.start = right_start; + right.num = end - right.start; + patch_tree_[next].right_child = patch_tree_.size(); + patch_tree_.push_back(right); + } + ++next; + } +} + +std::vector PatchDictionary::GetPatchesForRow(size_t y) const { + std::vector result; + if (y < num_patches_.size() && num_patches_[y] > 0) { + result.reserve(num_patches_[y]); + for (ssize_t tree_idx = 0; tree_idx != -1;) { + JXL_DASSERT(tree_idx < (ssize_t)patch_tree_.size()); + const auto& node = patch_tree_[tree_idx]; + if (y <= node.y_center) { + for (size_t i = 0; i < node.num; ++i) { + const auto& p = sorted_patches_y0_[node.start + i]; + if (y < p.first) break; + result.push_back(p.second); + } + tree_idx = y < node.y_center ? node.left_child : -1; + } else { + for (size_t i = 0; i < node.num; ++i) { + const auto& p = sorted_patches_y1_[node.start + i]; + if (y >= p.first) break; + result.push_back(p.second); + } + tree_idx = node.right_child; + } + } + // Ensure that he relative order of patches that affect the same pixels is + // preserved. This is important for patches that have a blend mode + // different from kAdd. + std::sort(result.begin(), result.end()); + } + return result; +} + +// Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed +// to be located at position (x0, y) in the frame. +void PatchDictionary::AddOneRow(float* const* inout, size_t y, size_t x0, + size_t xsize) const { + size_t num_ec = shared_->metadata->m.num_extra_channels; + std::vector fg_ptrs(3 + num_ec); + for (size_t pos_idx : GetPatchesForRow(y)) { + const size_t blending_idx = pos_idx * (num_ec + 1); + const PatchPosition& pos = positions_[pos_idx]; + const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx]; + size_t by = pos.y; + size_t bx = pos.x; + size_t patch_xsize = ref_pos.xsize; + JXL_DASSERT(y >= by); + JXL_DASSERT(y < by + ref_pos.ysize); + size_t iy = y - by; + size_t ref = ref_pos.ref; + if (bx >= x0 + xsize) continue; + if (bx + patch_xsize < x0) continue; + size_t patch_x0 = std::max(bx, x0); + size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize); + for (size_t c = 0; c < 3; c++) { + fg_ptrs[c] = shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + c, ref_pos.y0 + iy) + + ref_pos.x0 + x0 - bx; + } + for (size_t i = 0; i < num_ec; i++) { + fg_ptrs[3 + i] = + shared_->reference_frames[ref].frame->extra_channels()[i].ConstRow( + ref_pos.y0 + iy) + + ref_pos.x0 + x0 - bx; + } + PerformBlending(inout, fg_ptrs.data(), inout, patch_x0 - x0, + patch_x1 - patch_x0, blendings_[blending_idx], + blendings_.data() + blending_idx + 1, + shared_->metadata->m.extra_channel_info); + } +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_patch_dictionary.h b/media/libjxl/src/lib/jxl/dec_patch_dictionary.h new file mode 100644 index 000000000..a950e83e8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_patch_dictionary.h @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_PATCH_DICTIONARY_H_ +#define LIB_JXL_DEC_PATCH_DICTIONARY_H_ + +// Chooses reference patches, and avoids encoding them once per occurrence. + +#include +#include +#include + +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +enum class PatchBlendMode : uint8_t { + // The new values are the old ones. Useful to skip some channels. + kNone = 0, + // The new values (in the crop) replace the old ones: sample = new + kReplace = 1, + // The new values (in the crop) get added to the old ones: sample = old + new + kAdd = 2, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // This blend mode is only supported if BlendColorSpace is kEncoded. The + // range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + kMul = 3, + // The new values (in the crop) replace the old ones if alpha>0: + // For first alpha channel: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + // If using kBlendAbove, new is the patch and old is the original image; if + // using kBlendBelow, the meaning is inverted. + kBlendAbove = 4, + kBlendBelow = 5, + // The new values (in the crop) are added to the old ones if alpha>0: + // For first alpha channel: sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + kAlphaWeightedAddAbove = 6, + kAlphaWeightedAddBelow = 7, + kNumBlendModes, +}; + +inline bool UsesAlpha(PatchBlendMode mode) { + return mode == PatchBlendMode::kBlendAbove || + mode == PatchBlendMode::kBlendBelow || + mode == PatchBlendMode::kAlphaWeightedAddAbove || + mode == PatchBlendMode::kAlphaWeightedAddBelow; +} +inline bool UsesClamp(PatchBlendMode mode) { + return UsesAlpha(mode) || mode == PatchBlendMode::kMul; +} + +struct PatchBlending { + PatchBlendMode mode; + uint32_t alpha_channel; + bool clamp; +}; + +// Position and size of the patch in the reference frame. +struct PatchReferencePosition { + size_t ref, x0, y0, xsize, ysize; +}; + +struct PatchPosition { + // Position of top-left corner of the patch in the image. + size_t x, y; + size_t ref_pos_idx; +}; + +struct PassesSharedState; + +// Encoder-side helper class to encode the PatchesDictionary. +class PatchDictionaryEncoder; + +class PatchDictionary { + public: + PatchDictionary() = default; + + void SetPassesSharedState(const PassesSharedState* shared) { + shared_ = shared; + } + + bool HasAny() const { return !positions_.empty(); } + + Status Decode(BitReader* br, size_t xsize, size_t ysize, + bool* uses_extra_channels); + + void Clear() { + positions_.clear(); + ComputePatchTree(); + } + + // Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed + // to be located at position (x0, y) in the frame. + void AddOneRow(float* const* inout, size_t y, size_t x0, size_t xsize) const; + + // Returns dependencies of this patch dictionary on reference frame ids as a + // bit mask: bits 0-3 indicate reference frame 0-3. + int GetReferences() const; + + std::vector GetPatchesForRow(size_t y) const; + + private: + friend class PatchDictionaryEncoder; + + const PassesSharedState* shared_; + std::vector positions_; + std::vector ref_positions_; + std::vector blendings_; + + // Interval tree on the y coordinates of the patches. + struct PatchTreeNode { + ssize_t left_child; + ssize_t right_child; + size_t y_center; + // Range of patches in sorted_patches_y0_ and sorted_patches_y1_ that + // contain the row y_center. + size_t start; + size_t num; + }; + std::vector patch_tree_; + // Number of patches for each row. + std::vector num_patches_; + std::vector> sorted_patches_y0_; + std::vector> sorted_patches_y1_; + + void ComputePatchTree(); +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_PATCH_DICTIONARY_H_ diff --git a/media/libjxl/src/lib/jxl/dec_tone_mapping-inl.h b/media/libjxl/src/lib/jxl/dec_tone_mapping-inl.h new file mode 100644 index 000000000..a3260372c --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_tone_mapping-inl.h @@ -0,0 +1,232 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_DEC_TONE_MAPPING_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_TONE_MAPPING_INL_H_ +#undef LIB_JXL_DEC_TONE_MAPPING_INL_H_ +#else +#define LIB_JXL_DEC_TONE_MAPPING_INL_H_ +#endif + +#include + +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +template +class Rec2408ToneMapper { + private: + using V = hwy::HWY_NAMESPACE::Vec; + + public: + explicit Rec2408ToneMapper(std::pair source_range, + std::pair target_range, + const float primaries_luminances[3]) + : source_range_(source_range), + target_range_(target_range), + red_Y_(primaries_luminances[0]), + green_Y_(primaries_luminances[1]), + blue_Y_(primaries_luminances[2]) {} + + void ToneMap(V* red, V* green, V* blue) const { + const V luminance = Mul(Set(df_, source_range_.second), + (MulAdd(Set(df_, red_Y_), *red, + MulAdd(Set(df_, green_Y_), *green, + Mul(Set(df_, blue_Y_), *blue))))); + const V pq_mastering_min = Set(df_, pq_mastering_min_); + const V inv_pq_mastering_range = Set(df_, inv_pq_mastering_range_); + const V normalized_pq = Min( + Set(df_, 1.f), + Mul(Sub(InvEOTF(luminance), pq_mastering_min), inv_pq_mastering_range)); + const V ks = Set(df_, ks_); + const V e2 = + IfThenElse(Lt(normalized_pq, ks), normalized_pq, P(normalized_pq)); + const V one_minus_e2 = Sub(Set(df_, 1), e2); + const V one_minus_e2_2 = Mul(one_minus_e2, one_minus_e2); + const V one_minus_e2_4 = Mul(one_minus_e2_2, one_minus_e2_2); + const V b = Set(df_, min_lum_); + const V e3 = MulAdd(b, one_minus_e2_4, e2); + const V pq_mastering_range = Set(df_, pq_mastering_range_); + const V e4 = MulAdd(e3, pq_mastering_range, pq_mastering_min); + const V new_luminance = + Min(Set(df_, target_range_.second), + ZeroIfNegative( + Mul(Set(df_, 10000), TF_PQ().DisplayFromEncoded(df_, e4)))); + + const V ratio = Div(new_luminance, luminance); + + const V normalizer = Set(df_, normalizer_); + for (V* const val : {red, green, blue}) { + *val = Mul(IfThenElse(Le(luminance, Set(df_, 1e-6f)), new_luminance, + Mul(*val, ratio)), + normalizer); + } + } + + private: + V InvEOTF(const V luminance) const { + return TF_PQ().EncodedFromDisplay(df_, + Mul(luminance, Set(df_, 1. / 10000))); + } + float InvEOTF(const float luminance) const { + return TF_PQ().EncodedFromDisplay(luminance / 10000.0f); + } + V T(const V a) const { + const V ks = Set(df_, ks_); + const V inv_one_minus_ks = Set(df_, inv_one_minus_ks_); + return Mul(Sub(a, ks), inv_one_minus_ks); + } + V P(const V b) const { + const V t_b = T(b); + const V t_b_2 = Mul(t_b, t_b); + const V t_b_3 = Mul(t_b_2, t_b); + const V ks = Set(df_, ks_); + const V max_lum = Set(df_, max_lum_); + return MulAdd( + MulAdd(Set(df_, 2), t_b_3, MulAdd(Set(df_, -3), t_b_2, Set(df_, 1))), + ks, + MulAdd(Add(t_b_3, MulAdd(Set(df_, -2), t_b_2, t_b)), + Sub(Set(df_, 1), ks), + MulAdd(Set(df_, -2), t_b_3, + Mul(Mul(Set(df_, 3), t_b_2), max_lum)))); + } + + D df_; + const std::pair source_range_; + const std::pair target_range_; + const float red_Y_; + const float green_Y_; + const float blue_Y_; + + const float pq_mastering_min_ = InvEOTF(source_range_.first); + const float pq_mastering_max_ = InvEOTF(source_range_.second); + const float pq_mastering_range_ = pq_mastering_max_ - pq_mastering_min_; + const float inv_pq_mastering_range_ = 1.0f / pq_mastering_range_; + // TODO(eustas): divide instead of inverse-multiply? + const float min_lum_ = (InvEOTF(target_range_.first) - pq_mastering_min_) * + inv_pq_mastering_range_; + // TODO(eustas): divide instead of inverse-multiply? + const float max_lum_ = (InvEOTF(target_range_.second) - pq_mastering_min_) * + inv_pq_mastering_range_; + const float ks_ = 1.5f * max_lum_ - 0.5f; + const float b_ = min_lum_; + + const float inv_one_minus_ks_ = 1.0f / std::max(1e-6f, 1.0f - ks_); + + const float normalizer_ = source_range_.second / target_range_.second; +}; + +class HlgOOTF { + public: + explicit HlgOOTF(float source_luminance, float target_luminance, + const float primaries_luminances[3]) + : HlgOOTF(/*gamma=*/std::pow( + 1.111f, std::log2(target_luminance / source_luminance)), + primaries_luminances) {} + + static HlgOOTF FromSceneLight(float display_luminance, + const float primaries_luminances[3]) { + return HlgOOTF(/*gamma=*/1.2f * + std::pow(1.111f, std::log2(display_luminance / 1000.f)), + primaries_luminances); + } + + static HlgOOTF ToSceneLight(float display_luminance, + const float primaries_luminances[3]) { + return HlgOOTF( + /*gamma=*/(1 / 1.2f) * + std::pow(1.111f, -std::log2(display_luminance / 1000.f)), + primaries_luminances); + } + + template + void Apply(V* red, V* green, V* blue) const { + hwy::HWY_NAMESPACE::DFromV df; + if (!apply_ootf_) return; + const V luminance = + MulAdd(Set(df, red_Y_), *red, + MulAdd(Set(df, green_Y_), *green, Mul(Set(df, blue_Y_), *blue))); + const V ratio = + Min(FastPowf(df, luminance, Set(df, exponent_)), Set(df, 1e9)); + *red = Mul(*red, ratio); + *green = Mul(*green, ratio); + *blue = Mul(*blue, ratio); + } + + bool WarrantsGamutMapping() const { return apply_ootf_ && exponent_ < 0; } + + private: + explicit HlgOOTF(float gamma, const float luminances[3]) + : exponent_(gamma - 1), + red_Y_(luminances[0]), + green_Y_(luminances[1]), + blue_Y_(luminances[2]) {} + const float exponent_; + const bool apply_ootf_ = exponent_ < -0.01f || 0.01f < exponent_; + const float red_Y_; + const float green_Y_; + const float blue_Y_; +}; + +template +void GamutMap(V* red, V* green, V* blue, const float primaries_luminances[3], + float preserve_saturation = 0.1f) { + hwy::HWY_NAMESPACE::DFromV df; + const V luminance = + MulAdd(Set(df, primaries_luminances[0]), *red, + MulAdd(Set(df, primaries_luminances[1]), *green, + Mul(Set(df, primaries_luminances[2]), *blue))); + + // Desaturate out-of-gamut pixels. This is done by mixing each pixel + // with just enough gray of the target luminance to make all + // components non-negative. + // - For saturation preservation, if a component is still larger than + // 1 then the pixel is normalized to have a maximum component of 1. + // That will reduce its luminance. + // - For luminance preservation, getting all components below 1 is + // done by mixing in yet more gray. That will desaturate it further. + V gray_mix_saturation = Zero(df); + V gray_mix_luminance = Zero(df); + for (const V* ch : {red, green, blue}) { + const V& val = *ch; + const V inv_val_minus_gray = Div(Set(df, 1), (Sub(val, luminance))); + gray_mix_saturation = + IfThenElse(Ge(val, luminance), gray_mix_saturation, + Max(gray_mix_saturation, Mul(val, inv_val_minus_gray))); + gray_mix_luminance = + Max(gray_mix_luminance, + IfThenElse(Le(val, luminance), gray_mix_saturation, + Mul(Sub(val, Set(df, 1)), inv_val_minus_gray))); + } + const V gray_mix = Clamp( + MulAdd(Set(df, preserve_saturation), + Sub(gray_mix_saturation, gray_mix_luminance), gray_mix_luminance), + Zero(df), Set(df, 1)); + for (V* const val : {red, green, blue}) { + *val = MulAdd(gray_mix, Sub(luminance, *val), *val); + } + const V normalizer = + Div(Set(df, 1), Max(Set(df, 1), Max(*red, Max(*green, *blue)))); + for (V* const val : {red, green, blue}) { + *val = Mul(*val, normalizer); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_TONE_MAPPING_INL_H_ diff --git a/media/libjxl/src/lib/jxl/dec_transforms-inl.h b/media/libjxl/src/lib/jxl/dec_transforms-inl.h new file mode 100644 index 000000000..075619b3b --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_transforms-inl.h @@ -0,0 +1,853 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_DEC_TRANSFORMS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_TRANSFORMS_INL_H_ +#undef LIB_JXL_DEC_TRANSFORMS_INL_H_ +#else +#define LIB_JXL_DEC_TRANSFORMS_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_scales.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::MulAdd; + +// Computes the lowest-frequency LF_ROWSxLF_COLS-sized square in output, which +// is a DCT_ROWS*DCT_COLS-sized DCT block, by doing a ROWS*COLS DCT on the +// input block. +template +JXL_INLINE void ReinterpretingDCT(const float* input, const size_t input_stride, + float* output, const size_t output_stride) { + static_assert(LF_ROWS == ROWS, + "ReinterpretingDCT should only be called with LF == N"); + static_assert(LF_COLS == COLS, + "ReinterpretingDCT should only be called with LF == N"); + HWY_ALIGN float block[ROWS * COLS]; + + // ROWS, COLS <= 8, so we can put scratch space on the stack. + HWY_ALIGN float scratch_space[ROWS * COLS]; + ComputeScaledDCT()(DCTFrom(input, input_stride), block, + scratch_space); + if (ROWS < COLS) { + for (size_t y = 0; y < LF_ROWS; y++) { + for (size_t x = 0; x < LF_COLS; x++) { + output[y * output_stride + x] = + block[y * COLS + x] * DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } else { + for (size_t y = 0; y < LF_COLS; y++) { + for (size_t x = 0; x < LF_ROWS; x++) { + output[y * output_stride + x] = + block[y * ROWS + x] * DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } +} + +template +void IDCT2TopBlock(const float* block, size_t stride_out, float* out) { + static_assert(kBlockDim % S == 0, "S should be a divisor of kBlockDim"); + static_assert(S % 2 == 0, "S should be even"); + float temp[kDCTBlockSize]; + constexpr size_t num_2x2 = S / 2; + for (size_t y = 0; y < num_2x2; y++) { + for (size_t x = 0; x < num_2x2; x++) { + float c00 = block[y * kBlockDim + x]; + float c01 = block[y * kBlockDim + num_2x2 + x]; + float c10 = block[(y + num_2x2) * kBlockDim + x]; + float c11 = block[(y + num_2x2) * kBlockDim + num_2x2 + x]; + float r00 = c00 + c01 + c10 + c11; + float r01 = c00 + c01 - c10 - c11; + float r10 = c00 - c01 + c10 - c11; + float r11 = c00 - c01 - c10 + c11; + temp[y * 2 * kBlockDim + x * 2] = r00; + temp[y * 2 * kBlockDim + x * 2 + 1] = r01; + temp[(y * 2 + 1) * kBlockDim + x * 2] = r10; + temp[(y * 2 + 1) * kBlockDim + x * 2 + 1] = r11; + } + } + for (size_t y = 0; y < S; y++) { + for (size_t x = 0; x < S; x++) { + out[y * stride_out + x] = temp[y * kBlockDim + x]; + } + } +} + +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels) { + HWY_ALIGN static constexpr float k4x4AFVBasis[16][16] = { + { + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + }, + { + 0.876902929799142f, + 0.2206518106944235f, + -0.10140050393753763f, + -0.1014005039375375f, + 0.2206518106944236f, + -0.10140050393753777f, + -0.10140050393753772f, + -0.10140050393753763f, + -0.10140050393753758f, + -0.10140050393753769f, + -0.1014005039375375f, + -0.10140050393753768f, + -0.10140050393753768f, + -0.10140050393753759f, + -0.10140050393753763f, + -0.10140050393753741f, + }, + { + 0.0, + 0.0, + 0.40670075830260755f, + 0.44444816619734445f, + 0.0, + 0.0, + 0.19574399372042936f, + 0.2929100136981264f, + -0.40670075830260716f, + -0.19574399372042872f, + 0.0, + 0.11379074460448091f, + -0.44444816619734384f, + -0.29291001369812636f, + -0.1137907446044814f, + 0.0, + }, + { + 0.0, + 0.0, + -0.21255748058288748f, + 0.3085497062849767f, + 0.0, + 0.4706702258572536f, + -0.1621205195722993f, + 0.0, + -0.21255748058287047f, + -0.16212051957228327f, + -0.47067022585725277f, + -0.1464291867126764f, + 0.3085497062849487f, + 0.0, + -0.14642918671266536f, + 0.4251149611657548f, + }, + { + 0.0, + -0.7071067811865474f, + 0.0, + 0.0, + 0.7071067811865476f, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + }, + { + -0.4105377591765233f, + 0.6235485373547691f, + -0.06435071657946274f, + -0.06435071657946266f, + 0.6235485373547694f, + -0.06435071657946284f, + -0.0643507165794628f, + -0.06435071657946274f, + -0.06435071657946272f, + -0.06435071657946279f, + -0.06435071657946266f, + -0.06435071657946277f, + -0.06435071657946277f, + -0.06435071657946273f, + -0.06435071657946274f, + -0.0643507165794626f, + }, + { + 0.0, + 0.0, + -0.4517556589999482f, + 0.15854503551840063f, + 0.0, + -0.04038515160822202f, + 0.0074182263792423875f, + 0.39351034269210167f, + -0.45175565899994635f, + 0.007418226379244351f, + 0.1107416575309343f, + 0.08298163094882051f, + 0.15854503551839705f, + 0.3935103426921022f, + 0.0829816309488214f, + -0.45175565899994796f, + }, + { + 0.0, + 0.0, + -0.304684750724869f, + 0.5112616136591823f, + 0.0, + 0.0, + -0.290480129728998f, + -0.06578701549142804f, + 0.304684750724884f, + 0.2904801297290076f, + 0.0, + -0.23889773523344604f, + -0.5112616136592012f, + 0.06578701549142545f, + 0.23889773523345467f, + 0.0, + }, + { + 0.0, + 0.0, + 0.3017929516615495f, + 0.25792362796341184f, + 0.0, + 0.16272340142866204f, + 0.09520022653475037f, + 0.0, + 0.3017929516615503f, + 0.09520022653475055f, + -0.16272340142866173f, + -0.35312385449816297f, + 0.25792362796341295f, + 0.0, + -0.3531238544981624f, + -0.6035859033230976f, + }, + { + 0.0, + 0.0, + 0.40824829046386274f, + 0.0, + 0.0, + 0.0, + 0.0, + -0.4082482904638628f, + -0.4082482904638635f, + 0.0, + 0.0, + -0.40824829046386296f, + 0.0, + 0.4082482904638634f, + 0.408248290463863f, + 0.0, + }, + { + 0.0, + 0.0, + 0.1747866975480809f, + 0.0812611176717539f, + 0.0, + 0.0, + -0.3675398009862027f, + -0.307882213957909f, + -0.17478669754808135f, + 0.3675398009862011f, + 0.0, + 0.4826689115059883f, + -0.08126111767175039f, + 0.30788221395790305f, + -0.48266891150598584f, + 0.0, + }, + { + 0.0, + 0.0, + -0.21105601049335784f, + 0.18567180916109802f, + 0.0, + 0.0, + 0.49215859013738733f, + -0.38525013709251915f, + 0.21105601049335806f, + -0.49215859013738905f, + 0.0, + 0.17419412659916217f, + -0.18567180916109904f, + 0.3852501370925211f, + -0.1741941265991621f, + 0.0, + }, + { + 0.0, + 0.0, + -0.14266084808807264f, + -0.3416446842253372f, + 0.0, + 0.7367497537172237f, + 0.24627107722075148f, + -0.08574019035519306f, + -0.14266084808807344f, + 0.24627107722075137f, + 0.14883399227113567f, + -0.04768680350229251f, + -0.3416446842253373f, + -0.08574019035519267f, + -0.047686803502292804f, + -0.14266084808807242f, + }, + { + 0.0, + 0.0, + -0.13813540350758585f, + 0.3302282550303788f, + 0.0, + 0.08755115000587084f, + -0.07946706605909573f, + -0.4613374887461511f, + -0.13813540350758294f, + -0.07946706605910261f, + 0.49724647109535086f, + 0.12538059448563663f, + 0.3302282550303805f, + -0.4613374887461554f, + 0.12538059448564315f, + -0.13813540350758452f, + }, + { + 0.0, + 0.0, + -0.17437602599651067f, + 0.0702790691196284f, + 0.0, + -0.2921026642334881f, + 0.3623817333531167f, + 0.0, + -0.1743760259965108f, + 0.36238173335311646f, + 0.29210266423348785f, + -0.4326608024727445f, + 0.07027906911962818f, + 0.0, + -0.4326608024727457f, + 0.34875205199302267f, + }, + { + 0.0, + 0.0, + 0.11354987314994337f, + -0.07417504595810355f, + 0.0, + 0.19402893032594343f, + -0.435190496523228f, + 0.21918684838857466f, + 0.11354987314994257f, + -0.4351904965232251f, + 0.5550443808910661f, + -0.25468277124066463f, + -0.07417504595810233f, + 0.2191868483885728f, + -0.25468277124066413f, + 0.1135498731499429f, + }, + }; + + const HWY_CAPPED(float, 16) d; + for (size_t i = 0; i < 16; i += Lanes(d)) { + auto pixel = Zero(d); + for (size_t j = 0; j < 16; j++) { + auto cf = Set(d, coeffs[j]); + auto basis = Load(d, k4x4AFVBasis[j] + i); + pixel = MulAdd(cf, basis, pixel); + } + Store(pixel, d, pixels + i); + } +} + +template +void AFVTransformToPixels(const float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride) { + HWY_ALIGN float scratch_space[4 * 8]; + size_t afv_x = afv_kind & 1; + size_t afv_y = afv_kind / 2; + float dcs[3] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + dcs[0] = (block00 + block10 + block01) * 4.0f; + dcs[1] = (block00 + block10 - block01); + dcs[2] = block00 - block10; + // IAFV: (even, even) positions. + HWY_ALIGN float coeff[4 * 4]; + coeff[0] = dcs[0]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + coeff[iy * 4 + ix] = coefficients[iy * 2 * 8 + ix * 2]; + } + } + HWY_ALIGN float block[4 * 8]; + AFVIDCT4x4(coeff, block); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + pixels[(iy + afv_y * 4) * pixels_stride + afv_x * 4 + ix] = + block[(afv_y == 1 ? 3 - iy : iy) * 4 + (afv_x == 1 ? 3 - ix : ix)]; + } + } + // IDCT4x4 in (odd, even) positions. + block[0] = dcs[1]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 4 + ix] = coefficients[iy * 2 * 8 + ix * 2 + 1]; + } + } + ComputeScaledIDCT<4, 4>()( + block, + DCTTo(pixels + afv_y * 4 * pixels_stride + (afv_x == 1 ? 0 : 4), + pixels_stride), + scratch_space); + // IDCT4x8. + block[0] = dcs[2]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(1 + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<4, 8>()( + block, + DCTTo(pixels + (afv_y == 1 ? 0 : 4) * pixels_stride, pixels_stride), + scratch_space); +} + +HWY_MAYBE_UNUSED void TransformToPixels(const AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* scratch_space) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::IDENTITY: { + PROFILER_ZONE("IDCT Identity"); + float dcs[4] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + dcs[0] = block00 + block01 + block10 + block11; + dcs[1] = block00 + block01 - block10 - block11; + dcs[2] = block00 - block01 + block10 - block11; + dcs[3] = block00 - block01 - block10 + block11; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + float block_dc = dcs[y * 2 + x]; + float residual_sum = 0; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + residual_sum += coefficients[(y + iy * 2) * 8 + x + ix * 2]; + } + } + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1] = + block_dc - residual_sum * (1.0f / 16); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 1 && iy == 1) continue; + pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix] = + coefficients[(y + iy * 2) * 8 + x + ix * 2] + + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1]; + } + } + pixels[y * 4 * pixels_stride + x * 4] = + coefficients[(y + 2) * 8 + x + 2] + + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1]; + } + } + break; + } + case Type::DCT8X4: { + PROFILER_ZONE("IDCT 8x4"); + float dcs[2] = {}; + float block0 = coefficients[0]; + float block1 = coefficients[8]; + dcs[0] = block0 + block1; + dcs[1] = block0 - block1; + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 8]; + block[0] = dcs[x]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(x + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<8, 4>()(block, DCTTo(pixels + x * 4, pixels_stride), + scratch_space); + } + break; + } + case Type::DCT4X8: { + PROFILER_ZONE("IDCT 4x8"); + float dcs[2] = {}; + float block0 = coefficients[0]; + float block1 = coefficients[8]; + dcs[0] = block0 + block1; + dcs[1] = block0 - block1; + for (size_t y = 0; y < 2; y++) { + HWY_ALIGN float block[4 * 8]; + block[0] = dcs[y]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(y + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<4, 8>()( + block, DCTTo(pixels + y * 4 * pixels_stride, pixels_stride), + scratch_space); + } + break; + } + case Type::DCT4X4: { + PROFILER_ZONE("IDCT 4"); + float dcs[4] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + dcs[0] = block00 + block01 + block10 + block11; + dcs[1] = block00 + block01 - block10 - block11; + dcs[2] = block00 - block01 + block10 - block11; + dcs[3] = block00 - block01 - block10 + block11; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 4]; + block[0] = dcs[y * 2 + x]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 4 + ix] = coefficients[(y + iy * 2) * 8 + x + ix * 2]; + } + } + ComputeScaledIDCT<4, 4>()( + block, + DCTTo(pixels + y * 4 * pixels_stride + x * 4, pixels_stride), + scratch_space); + } + } + break; + } + case Type::DCT2X2: { + PROFILER_ZONE("IDCT 2"); + HWY_ALIGN float coeffs[kDCTBlockSize]; + memcpy(coeffs, coefficients, sizeof(float) * kDCTBlockSize); + IDCT2TopBlock<2>(coeffs, kBlockDim, coeffs); + IDCT2TopBlock<4>(coeffs, kBlockDim, coeffs); + IDCT2TopBlock<8>(coeffs, kBlockDim, coeffs); + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + pixels[y * pixels_stride + x] = coeffs[y * kBlockDim + x]; + } + } + break; + } + case Type::DCT16X16: { + PROFILER_ZONE("IDCT 16"); + ComputeScaledIDCT<16, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT16X8: { + PROFILER_ZONE("IDCT 16x8"); + ComputeScaledIDCT<16, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT8X16: { + PROFILER_ZONE("IDCT 8x16"); + ComputeScaledIDCT<8, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X8: { + PROFILER_ZONE("IDCT 32x8"); + ComputeScaledIDCT<32, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT8X32: { + PROFILER_ZONE("IDCT 8x32"); + ComputeScaledIDCT<8, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X16: { + PROFILER_ZONE("IDCT 32x16"); + ComputeScaledIDCT<32, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT16X32: { + PROFILER_ZONE("IDCT 16x32"); + ComputeScaledIDCT<16, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X32: { + PROFILER_ZONE("IDCT 32"); + ComputeScaledIDCT<32, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT: { + PROFILER_ZONE("IDCT 8"); + ComputeScaledIDCT<8, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::AFV0: { + PROFILER_ZONE("IAFV0"); + AFVTransformToPixels<0>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV1: { + PROFILER_ZONE("IAFV1"); + AFVTransformToPixels<1>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV2: { + PROFILER_ZONE("IAFV2"); + AFVTransformToPixels<2>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV3: { + PROFILER_ZONE("IAFV3"); + AFVTransformToPixels<3>(coefficients, pixels, pixels_stride); + break; + } + case Type::DCT64X32: { + PROFILER_ZONE("IDCT 64x32"); + ComputeScaledIDCT<64, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X64: { + PROFILER_ZONE("IDCT 32x64"); + ComputeScaledIDCT<32, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT64X64: { + PROFILER_ZONE("IDCT 64"); + ComputeScaledIDCT<64, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X64: { + PROFILER_ZONE("IDCT 128x64"); + ComputeScaledIDCT<128, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT64X128: { + PROFILER_ZONE("IDCT 64x128"); + ComputeScaledIDCT<64, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X128: { + PROFILER_ZONE("IDCT 128"); + ComputeScaledIDCT<128, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT256X128: { + PROFILER_ZONE("IDCT 256x128"); + ComputeScaledIDCT<256, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X256: { + PROFILER_ZONE("IDCT 128x256"); + ComputeScaledIDCT<128, 256>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT256X256: { + PROFILER_ZONE("IDCT 256"); + ComputeScaledIDCT<256, 256>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + } +} + +HWY_MAYBE_UNUSED void LowestFrequenciesFromDC(const AcStrategy::Type strategy, + const float* dc, size_t dc_stride, + float* llf) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::DCT16X8: { + ReinterpretingDCT( + dc, dc_stride, llf, 2 * kBlockDim); + break; + } + case Type::DCT8X16: { + ReinterpretingDCT( + dc, dc_stride, llf, 2 * kBlockDim); + break; + } + case Type::DCT16X16: { + ReinterpretingDCT( + dc, dc_stride, llf, 2 * kBlockDim); + break; + } + case Type::DCT32X8: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT8X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT32X16: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT16X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT32X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT64X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 8 * kBlockDim); + break; + } + case Type::DCT32X64: { + ReinterpretingDCT( + dc, dc_stride, llf, 8 * kBlockDim); + break; + } + case Type::DCT64X64: { + ReinterpretingDCT( + dc, dc_stride, llf, 8 * kBlockDim); + break; + } + case Type::DCT128X64: { + ReinterpretingDCT( + dc, dc_stride, llf, 16 * kBlockDim); + break; + } + case Type::DCT64X128: { + ReinterpretingDCT( + dc, dc_stride, llf, 16 * kBlockDim); + break; + } + case Type::DCT128X128: { + ReinterpretingDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/16, /*ROWS=*/16, /*COLS=*/16>( + dc, dc_stride, llf, 16 * kBlockDim); + break; + } + case Type::DCT256X128: { + ReinterpretingDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/16, /*ROWS=*/32, /*COLS=*/16>( + dc, dc_stride, llf, 32 * kBlockDim); + break; + } + case Type::DCT128X256: { + ReinterpretingDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/32, /*ROWS=*/16, /*COLS=*/32>( + dc, dc_stride, llf, 32 * kBlockDim); + break; + } + case Type::DCT256X256: { + ReinterpretingDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/32, /*ROWS=*/32, /*COLS=*/32>( + dc, dc_stride, llf, 32 * kBlockDim); + break; + } + case Type::DCT: + case Type::DCT2X2: + case Type::DCT4X4: + case Type::DCT4X8: + case Type::DCT8X4: + case Type::AFV0: + case Type::AFV1: + case Type::AFV2: + case Type::AFV3: + case Type::IDENTITY: + llf[0] = dc[0]; + break; + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + }; +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_TRANSFORMS_INL_H_ diff --git a/media/libjxl/src/lib/jxl/dec_transforms_testonly.cc b/media/libjxl/src/lib/jxl/dec_transforms_testonly.cc new file mode 100644 index 000000000..9ee80c59d --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_transforms_testonly.cc @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_transforms_testonly.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_transforms_testonly.cc" +#include +#include + +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms-inl.h" + +namespace jxl { + +#if HWY_ONCE +HWY_EXPORT(TransformToPixels); +void TransformToPixels(AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride, + float* scratch_space) { + return HWY_DYNAMIC_DISPATCH(TransformToPixels)(strategy, coefficients, pixels, + pixels_stride, scratch_space); +} + +HWY_EXPORT(LowestFrequenciesFromDC); +void LowestFrequenciesFromDC(const jxl::AcStrategy::Type strategy, + const float* dc, size_t dc_stride, float* llf) { + return HWY_DYNAMIC_DISPATCH(LowestFrequenciesFromDC)(strategy, dc, dc_stride, + llf); +} + +HWY_EXPORT(AFVIDCT4x4); +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels) { + return HWY_DYNAMIC_DISPATCH(AFVIDCT4x4)(coeffs, pixels); +} +#endif // HWY_ONCE + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/dec_transforms_testonly.h b/media/libjxl/src/lib/jxl/dec_transforms_testonly.h new file mode 100644 index 000000000..97c4ca543 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_transforms_testonly.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ +#define LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ + +// Facade for (non-inlined) inverse integral transforms. + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +void TransformToPixels(AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT scratch_space); + +// Equivalent of the above for DC image. +void LowestFrequenciesFromDC(const jxl::AcStrategy::Type strategy, + const float* dc, size_t dc_stride, float* llf); + +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels); + +} // namespace jxl + +#endif // LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ diff --git a/media/libjxl/src/lib/jxl/dec_xyb-inl.h b/media/libjxl/src/lib/jxl/dec_xyb-inl.h new file mode 100644 index 000000000..a4f24cd12 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_xyb-inl.h @@ -0,0 +1,346 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// XYB -> linear sRGB helper function. + +#if defined(LIB_JXL_DEC_XYB_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_XYB_INL_H_ +#undef LIB_JXL_DEC_XYB_INL_H_ +#else +#define LIB_JXL_DEC_XYB_INL_H_ +#endif + +#include + +#include "lib/jxl/dec_xyb.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Broadcast; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sub; + +// Inverts the pixel-wise RGB->XYB conversion in OpsinDynamicsImage() (including +// the gamma mixing and simple gamma). Avoids clamping to [0, 1] - out of (sRGB) +// gamut values may be in-gamut after transforming to a wider space. +// "inverse_matrix" points to 9 broadcasted vectors, which are the 3x3 entries +// of the (row-major) opsin absorbance matrix inverse. Pre-multiplying its +// entries by c is equivalent to multiplying linear_* by c afterwards. +template +HWY_INLINE HWY_MAYBE_UNUSED void XybToRgb(D d, const V opsin_x, const V opsin_y, + const V opsin_b, + const OpsinParams& opsin_params, + V* const HWY_RESTRICT linear_r, + V* const HWY_RESTRICT linear_g, + V* const HWY_RESTRICT linear_b) { +#if HWY_TARGET == HWY_SCALAR + const auto neg_bias_r = Set(d, opsin_params.opsin_biases[0]); + const auto neg_bias_g = Set(d, opsin_params.opsin_biases[1]); + const auto neg_bias_b = Set(d, opsin_params.opsin_biases[2]); +#else + const auto neg_bias_rgb = LoadDup128(d, opsin_params.opsin_biases); + const auto neg_bias_r = Broadcast<0>(neg_bias_rgb); + const auto neg_bias_g = Broadcast<1>(neg_bias_rgb); + const auto neg_bias_b = Broadcast<2>(neg_bias_rgb); +#endif + + // Color space: XYB -> RGB + auto gamma_r = Add(opsin_y, opsin_x); + auto gamma_g = Sub(opsin_y, opsin_x); + auto gamma_b = opsin_b; + + gamma_r = Sub(gamma_r, Set(d, opsin_params.opsin_biases_cbrt[0])); + gamma_g = Sub(gamma_g, Set(d, opsin_params.opsin_biases_cbrt[1])); + gamma_b = Sub(gamma_b, Set(d, opsin_params.opsin_biases_cbrt[2])); + + // Undo gamma compression: linear = gamma^3 for efficiency. + const auto gamma_r2 = Mul(gamma_r, gamma_r); + const auto gamma_g2 = Mul(gamma_g, gamma_g); + const auto gamma_b2 = Mul(gamma_b, gamma_b); + const auto mixed_r = MulAdd(gamma_r2, gamma_r, neg_bias_r); + const auto mixed_g = MulAdd(gamma_g2, gamma_g, neg_bias_g); + const auto mixed_b = MulAdd(gamma_b2, gamma_b, neg_bias_b); + + const float* HWY_RESTRICT inverse_matrix = opsin_params.inverse_opsin_matrix; + + // Unmix (multiply by 3x3 inverse_matrix) + // TODO(eustas): ref would be more readable than pointer + *linear_r = Mul(LoadDup128(d, &inverse_matrix[0 * 4]), mixed_r); + *linear_g = Mul(LoadDup128(d, &inverse_matrix[3 * 4]), mixed_r); + *linear_b = Mul(LoadDup128(d, &inverse_matrix[6 * 4]), mixed_r); + *linear_r = MulAdd(LoadDup128(d, &inverse_matrix[1 * 4]), mixed_g, *linear_r); + *linear_g = MulAdd(LoadDup128(d, &inverse_matrix[4 * 4]), mixed_g, *linear_g); + *linear_b = MulAdd(LoadDup128(d, &inverse_matrix[7 * 4]), mixed_g, *linear_b); + *linear_r = MulAdd(LoadDup128(d, &inverse_matrix[2 * 4]), mixed_b, *linear_r); + *linear_g = MulAdd(LoadDup128(d, &inverse_matrix[5 * 4]), mixed_b, *linear_g); + *linear_b = MulAdd(LoadDup128(d, &inverse_matrix[8 * 4]), mixed_b, *linear_b); +} + +static inline HWY_MAYBE_UNUSED bool HasFastXYBTosRGB8() { +#if HWY_TARGET == HWY_NEON + return true; +#else + return false; +#endif +} + +static inline HWY_MAYBE_UNUSED void FastXYBTosRGB8(const float* input[4], + uint8_t* output, + bool is_rgba, size_t xsize) { + // This function is very NEON-specific. As such, it uses intrinsics directly. +#if HWY_TARGET == HWY_NEON + // WARNING: doing fixed point arithmetic correctly is very complicated. + // Changes to this function should be thoroughly tested. + + // Note that the input is assumed to have 13 bits of mantissa, and the output + // will have 14 bits. + auto srgb_tf = [&](int16x8_t v16) { + int16x8_t clz = vclzq_s16(v16); + // Convert to [0.25, 0.5) range. + int16x8_t v025_05_16 = vqshlq_s16(v16, vqsubq_s16(clz, vdupq_n_s16(2))); + + // third degree polynomial approximation between 0.25 and 0.5 + // of 1.055/2^(7/2.4) * x^(1/2.4) / 32. + // poly ~ ((0.95x-1.75)*x+1.72)*x+0.29 + // We actually compute ~ ((0.47x-0.87)*x+0.86)*(2x)+0.29 as 1.75 and 1.72 + // overflow our fixed point representation. + + int16x8_t twov = vqaddq_s16(v025_05_16, v025_05_16); + + // 0.47 * x + int16x8_t step1 = vqrdmulhq_n_s16(v025_05_16, 15706); + // - 0.87 + int16x8_t step2 = vsubq_s16(step1, vdupq_n_s16(28546)); + // * x + int16x8_t step3 = vqrdmulhq_s16(step2, v025_05_16); + // + 0.86 + int16x8_t step4 = vaddq_s16(step3, vdupq_n_s16(28302)); + // * 2x + int16x8_t step5 = vqrdmulhq_s16(step4, twov); + // + 0.29 + int16x8_t mul16 = vaddq_s16(step5, vdupq_n_s16(9485)); + + int16x8_t exp16 = vsubq_s16(vdupq_n_s16(11), clz); + // Compute 2**(1/2.4*exp16)/32. Values of exp16 that would overflow are + // capped to 1. + // Generated with the following Python script: + // a = [] + // b = [] + // + // for i in range(0, 16): + // v = 2**(5/12.*i) + // v /= 16 + // v *= 256 * 128 + // v = int(v) + // a.append(v // 256) + // b.append(v % 256) + // + // print(", ".join("0x%02x" % x for x in a)) + // + // print(", ".join("0x%02x" % x for x in b)) + + HWY_ALIGN constexpr uint8_t k2to512powersm1div32_high[16] = { + 0x08, 0x0a, 0x0e, 0x13, 0x19, 0x21, 0x2d, 0x3c, + 0x50, 0x6b, 0x8f, 0x8f, 0x8f, 0x8f, 0x8f, 0x8f, + }; + HWY_ALIGN constexpr uint8_t k2to512powersm1div32_low[16] = { + 0x00, 0xad, 0x41, 0x06, 0x65, 0xe7, 0x41, 0x68, + 0xa2, 0xa2, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + }; + // Using the highway implementation here since vqtbl1q is aarch64-only. + using hwy::HWY_NAMESPACE::Vec128; + uint8x16_t pow_low = + TableLookupBytes( + Vec128(vld1q_u8(k2to512powersm1div32_low)), + Vec128(vreinterpretq_u8_s16(exp16))) + .raw; + uint8x16_t pow_high = + TableLookupBytes( + Vec128(vld1q_u8(k2to512powersm1div32_high)), + Vec128(vreinterpretq_u8_s16(exp16))) + .raw; + int16x8_t pow16 = vreinterpretq_s16_u16(vsliq_n_u16( + vreinterpretq_u16_u8(pow_low), vreinterpretq_u16_u8(pow_high), 8)); + + // approximation of v * 12.92, divided by 2 + // Note that our input is using 13 mantissa bits instead of 15. + int16x8_t v16_linear = vrshrq_n_s16(vmulq_n_s16(v16, 826), 5); + // 1.055*pow(v, 1/2.4) - 0.055, divided by 2 + auto v16_pow = vsubq_s16(vqrdmulhq_s16(mul16, pow16), vdupq_n_s16(901)); + // > 0.0031308f (note that v16 has 13 mantissa bits) + return vbslq_s16(vcgeq_s16(v16, vdupq_n_s16(26)), v16_pow, v16_linear); + }; + + const float* JXL_RESTRICT row_in_x = input[0]; + const float* JXL_RESTRICT row_in_y = input[1]; + const float* JXL_RESTRICT row_in_b = input[2]; + const float* JXL_RESTRICT row_in_a = input[3]; + for (size_t x = 0; x < xsize; x += 8) { + // Normal ranges for xyb for in-gamut sRGB colors: + // x: -0.015386 0.028100 + // y: 0.000000 0.845308 + // b: 0.000000 0.845308 + + // We actually want x * 8 to have some extra precision. + // TODO(veluca): consider different approaches here, like vld1q_f32_x2. + float32x4_t opsin_x_left = vld1q_f32(row_in_x + x); + int16x4_t opsin_x16_times8_left = + vqmovn_s32(vcvtq_n_s32_f32(opsin_x_left, 18)); + float32x4_t opsin_x_right = + vld1q_f32(row_in_x + x + (x + 4 < xsize ? 4 : 0)); + int16x4_t opsin_x16_times8_right = + vqmovn_s32(vcvtq_n_s32_f32(opsin_x_right, 18)); + int16x8_t opsin_x16_times8 = + vcombine_s16(opsin_x16_times8_left, opsin_x16_times8_right); + + float32x4_t opsin_y_left = vld1q_f32(row_in_y + x); + int16x4_t opsin_y16_left = vqmovn_s32(vcvtq_n_s32_f32(opsin_y_left, 15)); + float32x4_t opsin_y_right = + vld1q_f32(row_in_y + x + (x + 4 < xsize ? 4 : 0)); + int16x4_t opsin_y16_right = vqmovn_s32(vcvtq_n_s32_f32(opsin_y_right, 15)); + int16x8_t opsin_y16 = vcombine_s16(opsin_y16_left, opsin_y16_right); + + float32x4_t opsin_b_left = vld1q_f32(row_in_b + x); + int16x4_t opsin_b16_left = vqmovn_s32(vcvtq_n_s32_f32(opsin_b_left, 15)); + float32x4_t opsin_b_right = + vld1q_f32(row_in_b + x + (x + 4 < xsize ? 4 : 0)); + int16x4_t opsin_b16_right = vqmovn_s32(vcvtq_n_s32_f32(opsin_b_right, 15)); + int16x8_t opsin_b16 = vcombine_s16(opsin_b16_left, opsin_b16_right); + + int16x8_t neg_bias16 = vdupq_n_s16(-124); // -0.0037930732552754493 + int16x8_t neg_bias_cbrt16 = vdupq_n_s16(-5110); // -0.155954201 + int16x8_t neg_bias_half16 = vdupq_n_s16(-62); + + // Color space: XYB -> RGB + // Compute ((y+x-bias_cbrt)^3-(y-x-bias_cbrt)^3)/2, + // ((y+x-bias_cbrt)^3+(y-x-bias_cbrt)^3)/2+bias, (b-bias_cbrt)^3+bias. + // Note that ignoring x2 in the formulas below (as x << y) results in + // errors of at least 3 in the final sRGB values. + int16x8_t opsin_yp16 = vqsubq_s16(opsin_y16, neg_bias_cbrt16); + int16x8_t ysq16 = vqrdmulhq_s16(opsin_yp16, opsin_yp16); + int16x8_t twentyfourx16 = vmulq_n_s16(opsin_x16_times8, 3); + int16x8_t twentyfourxy16 = vqrdmulhq_s16(opsin_yp16, twentyfourx16); + int16x8_t threexsq16 = + vrshrq_n_s16(vqrdmulhq_s16(opsin_x16_times8, twentyfourx16), 6); + + // We can ignore x^3 here. Note that this is multiplied by 8. + int16x8_t mixed_rmg16 = vqrdmulhq_s16(twentyfourxy16, opsin_yp16); + + int16x8_t mixed_rpg_sos_half = vhaddq_s16(ysq16, threexsq16); + int16x8_t mixed_rpg16 = vhaddq_s16( + vqrdmulhq_s16(opsin_yp16, mixed_rpg_sos_half), neg_bias_half16); + + int16x8_t gamma_b16 = vqsubq_s16(opsin_b16, neg_bias_cbrt16); + int16x8_t gamma_bsq16 = vqrdmulhq_s16(gamma_b16, gamma_b16); + int16x8_t gamma_bcb16 = vqrdmulhq_s16(gamma_bsq16, gamma_b16); + int16x8_t mixed_b16 = vqaddq_s16(gamma_bcb16, neg_bias16); + // mixed_rpg and mixed_b are in 0-1 range. + // mixed_rmg has a smaller range (-0.035 to 0.035 for valid sRGB). Note + // that at this point it is already multiplied by 8. + + // We multiply all the mixed values by 1/4 (i.e. shift them to 13-bit + // fixed point) to ensure intermediate quantities are in range. Note that + // r-g is not shifted, and was x8 before here; this corresponds to a x32 + // overall multiplicative factor and ensures that all the matrix constants + // are in 0-1 range. + // Similarly, mixed_rpg16 is already multiplied by 1/4 because of the two + // vhadd + using neg_bias_half. + mixed_b16 = vshrq_n_s16(mixed_b16, 2); + + // Unmix (multiply by 3x3 inverse_matrix) + // For increased precision, we use a matrix for converting from + // ((mixed_r - mixed_g)/2, (mixed_r + mixed_g)/2, mixed_b) to rgb. This + // avoids cancellation effects when computing (y+x)^3-(y-x)^3. + // We compute mixed_rpg - mixed_b because the (1+c)*mixed_rpg - c * + // mixed_b pattern is repeated frequently in the code below. This allows + // us to save a multiply per channel, and removes the presence of + // some constants above 1. Moreover, mixed_rmg - mixed_b is in (-1, 1) + // range, so the subtraction is safe. + // All the magic-looking constants here are derived by computing the + // inverse opsin matrix for the transformation modified as described + // above. + + // Precomputation common to multiple color values. + int16x8_t mixed_rpgmb16 = vqsubq_s16(mixed_rpg16, mixed_b16); + int16x8_t mixed_rpgmb_times_016 = vqrdmulhq_n_s16(mixed_rpgmb16, 5394); + int16x8_t mixed_rg16 = vqaddq_s16(mixed_rpgmb_times_016, mixed_rpg16); + + // R + int16x8_t linear_r16 = + vqaddq_s16(mixed_rg16, vqrdmulhq_n_s16(mixed_rmg16, 21400)); + + // G + int16x8_t linear_g16 = + vqaddq_s16(mixed_rg16, vqrdmulhq_n_s16(mixed_rmg16, -7857)); + + // B + int16x8_t linear_b16 = vqrdmulhq_n_s16(mixed_rpgmb16, -30996); + linear_b16 = vqaddq_s16(linear_b16, mixed_b16); + linear_b16 = vqaddq_s16(linear_b16, vqrdmulhq_n_s16(mixed_rmg16, -6525)); + + // Apply SRGB transfer function. + int16x8_t r = srgb_tf(linear_r16); + int16x8_t g = srgb_tf(linear_g16); + int16x8_t b = srgb_tf(linear_b16); + + uint8x8_t r8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(r, vshrq_n_s16(r, 8)), 6)); + uint8x8_t g8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(g, vshrq_n_s16(g, 8)), 6)); + uint8x8_t b8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(b, vshrq_n_s16(b, 8)), 6)); + + size_t n = xsize - x; + if (is_rgba) { + float32x4_t a_f32_left = + row_in_a ? vld1q_f32(row_in_a + x) : vdupq_n_f32(1.0f); + float32x4_t a_f32_right = + row_in_a ? vld1q_f32(row_in_a + x + (x + 4 < xsize ? 4 : 0)) + : vdupq_n_f32(1.0f); + int16x4_t a16_left = vqmovn_s32(vcvtq_n_s32_f32(a_f32_left, 8)); + int16x4_t a16_right = vqmovn_s32(vcvtq_n_s32_f32(a_f32_right, 8)); + uint8x8_t a8 = vqmovun_s16(vcombine_s16(a16_left, a16_right)); + uint8_t* buf = output + 4 * x; + uint8x8x4_t data = {r8, g8, b8, a8}; + if (n >= 8) { + vst4_u8(buf, data); + } else { + uint8_t tmp[8 * 4]; + vst4_u8(tmp, data); + memcpy(buf, tmp, n * 4); + } + } else { + uint8_t* buf = output + 3 * x; + uint8x8x3_t data = {r8, g8, b8}; + if (n >= 8) { + vst3_u8(buf, data); + } else { + uint8_t tmp[8 * 3]; + vst3_u8(tmp, data); + memcpy(buf, tmp, n * 3); + } + } + } +#else + (void)input; + (void)output; + (void)is_rgba; + (void)xsize; + JXL_ABORT("Unreachable"); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_XYB_INL_H_ diff --git a/media/libjxl/src/lib/jxl/dec_xyb.cc b/media/libjxl/src/lib/jxl/dec_xyb.cc new file mode 100644 index 000000000..ef4088f10 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_xyb.cc @@ -0,0 +1,323 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_xyb.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_xyb.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/sanitizers.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; +using hwy::HWY_NAMESPACE::MulAdd; + +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params) { + PROFILER_FUNC; + JXL_CHECK_IMAGE_INITIALIZED(*inout, Rect(*inout)); + + const size_t xsize = inout->xsize(); // not padded + JXL_CHECK(RunOnPool( + pool, 0, inout->ysize(), ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + + // Faster than adding via ByteOffset at end of loop. + float* JXL_RESTRICT row0 = inout->PlaneRow(0, y); + float* JXL_RESTRICT row1 = inout->PlaneRow(1, y); + float* JXL_RESTRICT row2 = inout->PlaneRow(2, y); + + const HWY_FULL(float) d; + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params, + &linear_r, &linear_g, &linear_b); + + Store(linear_r, d, row0 + x); + Store(linear_g, d, row1 + x); + Store(linear_b, d, row2 + x); + } + }, + "OpsinToLinear")); +} + +// Same, but not in-place. +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params) { + PROFILER_FUNC; + + JXL_ASSERT(SameSize(rect, *linear)); + JXL_CHECK_IMAGE_INITIALIZED(opsin, rect); + + JXL_CHECK(RunOnPool( + pool, 0, static_cast(rect.ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast(task); + + // Faster than adding via ByteOffset at end of loop. + const float* JXL_RESTRICT row_opsin_0 = rect.ConstPlaneRow(opsin, 0, y); + const float* JXL_RESTRICT row_opsin_1 = rect.ConstPlaneRow(opsin, 1, y); + const float* JXL_RESTRICT row_opsin_2 = rect.ConstPlaneRow(opsin, 2, y); + float* JXL_RESTRICT row_linear_0 = linear->PlaneRow(0, y); + float* JXL_RESTRICT row_linear_1 = linear->PlaneRow(1, y); + float* JXL_RESTRICT row_linear_2 = linear->PlaneRow(2, y); + + const HWY_FULL(float) d; + + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row_opsin_0 + x); + const auto in_opsin_y = Load(d, row_opsin_1 + x); + const auto in_opsin_b = Load(d, row_opsin_2 + x); + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params, + &linear_r, &linear_g, &linear_b); + + Store(linear_r, d, row_linear_0 + x); + Store(linear_g, d, row_linear_1 + x); + Store(linear_b, d, row_linear_2 + x); + } + }, + "OpsinToLinear(Rect)")); + JXL_CHECK_IMAGE_INITIALIZED(*linear, rect); +} + +// Transform YCbCr to RGB. +// Could be performed in-place (i.e. Y, Cb and Cr could alias R, B and B). +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect) { + JXL_CHECK_IMAGE_INITIALIZED(ycbcr, rect); + const HWY_CAPPED(float, kBlockDim) df; + const size_t S = Lanes(df); // Step. + + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + if ((xsize == 0) || (ysize == 0)) return; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto c128 = Set(df, 128.0f / 255); + const auto crcr = Set(df, 1.402f); + const auto cgcb = Set(df, -0.114f * 1.772f / 0.587f); + const auto cgcr = Set(df, -0.299f * 1.402f / 0.587f); + const auto cbcb = Set(df, 1.772f); + + for (size_t y = 0; y < ysize; y++) { + const float* y_row = rect.ConstPlaneRow(ycbcr, 1, y); + const float* cb_row = rect.ConstPlaneRow(ycbcr, 0, y); + const float* cr_row = rect.ConstPlaneRow(ycbcr, 2, y); + float* r_row = rect.PlaneRow(rgb, 0, y); + float* g_row = rect.PlaneRow(rgb, 1, y); + float* b_row = rect.PlaneRow(rgb, 2, y); + for (size_t x = 0; x < xsize; x += S) { + const auto y_vec = Add(Load(df, y_row + x), c128); + const auto cb_vec = Load(df, cb_row + x); + const auto cr_vec = Load(df, cr_row + x); + const auto r_vec = MulAdd(crcr, cr_vec, y_vec); + const auto g_vec = MulAdd(cgcr, cr_vec, MulAdd(cgcb, cb_vec, y_vec)); + const auto b_vec = MulAdd(cbcb, cb_vec, y_vec); + Store(r_vec, df, r_row + x); + Store(g_vec, df, g_row + x); + Store(b_vec, df, b_row + x); + } + } + JXL_CHECK_IMAGE_INITIALIZED(*rgb, rect); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(OpsinToLinearInplace); +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(OpsinToLinearInplace)(inout, pool, opsin_params); +} + +HWY_EXPORT(OpsinToLinear); +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(OpsinToLinear)(opsin, rect, pool, linear, + opsin_params); +} + +HWY_EXPORT(YcbcrToRgb); +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect) { + return HWY_DYNAMIC_DISPATCH(YcbcrToRgb)(ycbcr, rgb, rect); +} + +HWY_EXPORT(HasFastXYBTosRGB8); +bool HasFastXYBTosRGB8() { return HWY_DYNAMIC_DISPATCH(HasFastXYBTosRGB8)(); } + +HWY_EXPORT(FastXYBTosRGB8); +void FastXYBTosRGB8(const float* input[4], uint8_t* output, bool is_rgba, + size_t xsize) { + return HWY_DYNAMIC_DISPATCH(FastXYBTosRGB8)(input, output, is_rgba, xsize); +} + +void OpsinParams::Init(float intensity_target) { + InitSIMDInverseMatrix(GetOpsinAbsorbanceInverseMatrix(), inverse_opsin_matrix, + intensity_target); + memcpy(opsin_biases, kNegOpsinAbsorbanceBiasRGB, + sizeof(kNegOpsinAbsorbanceBiasRGB)); + memcpy(quant_biases, kDefaultQuantBias, sizeof(kDefaultQuantBias)); + for (size_t c = 0; c < 4; c++) { + opsin_biases_cbrt[c] = cbrtf(opsin_biases[c]); + } +} + +bool CanOutputToColorEncoding(const ColorEncoding& c_desired) { + if (!c_desired.HaveFields()) { + return false; + } + // TODO(veluca): keep in sync with dec_reconstruct.cc + if (!c_desired.tf.IsPQ() && !c_desired.tf.IsSRGB() && + !c_desired.tf.IsGamma() && !c_desired.tf.IsLinear() && + !c_desired.tf.IsHLG() && !c_desired.tf.IsDCI() && !c_desired.tf.Is709()) { + return false; + } + if (c_desired.IsGray() && c_desired.white_point != WhitePoint::kD65) { + // TODO(veluca): figure out what should happen here. + return false; + } + return true; +} + +Status OutputEncodingInfo::SetFromMetadata(const CodecMetadata& metadata) { + orig_color_encoding = metadata.m.color_encoding; + orig_intensity_target = metadata.m.IntensityTarget(); + desired_intensity_target = orig_intensity_target; + const auto& im = metadata.transform_data.opsin_inverse_matrix; + memcpy(orig_inverse_matrix, im.inverse_matrix, sizeof(orig_inverse_matrix)); + default_transform = im.all_default; + xyb_encoded = metadata.m.xyb_encoded; + std::copy(std::begin(im.opsin_biases), std::end(im.opsin_biases), + opsin_params.opsin_biases); + for (int i = 0; i < 3; ++i) { + opsin_params.opsin_biases_cbrt[i] = cbrtf(opsin_params.opsin_biases[i]); + } + opsin_params.opsin_biases_cbrt[3] = opsin_params.opsin_biases[3] = 1; + std::copy(std::begin(im.quant_biases), std::end(im.quant_biases), + opsin_params.quant_biases); + bool orig_ok = CanOutputToColorEncoding(orig_color_encoding); + bool orig_grey = orig_color_encoding.IsGray(); + return SetColorEncoding(!xyb_encoded || orig_ok + ? orig_color_encoding + : ColorEncoding::LinearSRGB(orig_grey)); +} + +Status OutputEncodingInfo::MaybeSetColorEncoding( + const ColorEncoding& c_desired) { + if (!xyb_encoded || !CanOutputToColorEncoding(c_desired)) { + return false; + } + return SetColorEncoding(c_desired); +} + +Status OutputEncodingInfo::SetColorEncoding(const ColorEncoding& c_desired) { + color_encoding = c_desired; + color_encoding_is_original = orig_color_encoding.SameColorEncoding(c_desired); + + // Compute the opsin inverse matrix and luminances based on primaries and + // white point. + float inverse_matrix[9]; + bool inverse_matrix_is_default = default_transform; + memcpy(inverse_matrix, orig_inverse_matrix, sizeof(inverse_matrix)); + constexpr float kSRGBLuminances[3] = {0.2126, 0.7152, 0.0722}; + memcpy(luminances, kSRGBLuminances, sizeof(luminances)); + if ((c_desired.primaries != Primaries::kSRGB || + c_desired.white_point != WhitePoint::kD65) && + !c_desired.IsGray()) { + float srgb_to_xyzd50[9]; + const auto& srgb = ColorEncoding::SRGB(/*is_gray=*/false); + JXL_CHECK(PrimariesToXYZD50( + srgb.GetPrimaries().r.x, srgb.GetPrimaries().r.y, + srgb.GetPrimaries().g.x, srgb.GetPrimaries().g.y, + srgb.GetPrimaries().b.x, srgb.GetPrimaries().b.y, + srgb.GetWhitePoint().x, srgb.GetWhitePoint().y, srgb_to_xyzd50)); + float original_to_xyz[3][3]; + JXL_RETURN_IF_ERROR(PrimariesToXYZ( + c_desired.GetPrimaries().r.x, c_desired.GetPrimaries().r.y, + c_desired.GetPrimaries().g.x, c_desired.GetPrimaries().g.y, + c_desired.GetPrimaries().b.x, c_desired.GetPrimaries().b.y, + c_desired.GetWhitePoint().x, c_desired.GetWhitePoint().y, + &original_to_xyz[0][0])); + memcpy(luminances, original_to_xyz[1], sizeof luminances); + if (xyb_encoded) { + float adapt_to_d50[9]; + JXL_RETURN_IF_ERROR(AdaptToXYZD50(c_desired.GetWhitePoint().x, + c_desired.GetWhitePoint().y, + adapt_to_d50)); + float xyzd50_to_original[9]; + MatMul(adapt_to_d50, &original_to_xyz[0][0], 3, 3, 3, xyzd50_to_original); + JXL_RETURN_IF_ERROR(Inv3x3Matrix(xyzd50_to_original)); + float srgb_to_original[9]; + MatMul(xyzd50_to_original, srgb_to_xyzd50, 3, 3, 3, srgb_to_original); + MatMul(srgb_to_original, orig_inverse_matrix, 3, 3, 3, inverse_matrix); + inverse_matrix_is_default = false; + } + } + + if (c_desired.IsGray()) { + float tmp_inv_matrix[9]; + memcpy(tmp_inv_matrix, inverse_matrix, sizeof(inverse_matrix)); + float srgb_to_luma[9]; + memcpy(&srgb_to_luma[0], luminances, sizeof(luminances)); + memcpy(&srgb_to_luma[3], luminances, sizeof(luminances)); + memcpy(&srgb_to_luma[6], luminances, sizeof(luminances)); + MatMul(srgb_to_luma, tmp_inv_matrix, 3, 3, 3, inverse_matrix); + } + + // The internal XYB color space uses absolute luminance, so we scale back the + // opsin inverse matrix to relative luminance where 1.0 corresponds to the + // original intensity target, or to absolute luminance for PQ, where 1.0 + // corresponds to 10000 nits. + if (xyb_encoded) { + float intensity_target = + (c_desired.tf.IsPQ() ? 10000 : orig_intensity_target); + InitSIMDInverseMatrix(inverse_matrix, opsin_params.inverse_opsin_matrix, + intensity_target); + all_default_opsin = (std::abs(intensity_target - 255.0) <= 0.1f && + inverse_matrix_is_default); + } + + // Set the inverse gamma based on color space transfer function. + inverse_gamma = (c_desired.tf.IsGamma() ? c_desired.tf.GetGamma() + : c_desired.tf.IsDCI() ? 1.0f / 2.6f + : 1.0); + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/dec_xyb.h b/media/libjxl/src/lib/jxl/dec_xyb.h new file mode 100644 index 000000000..ebaae9a17 --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_xyb.h @@ -0,0 +1,89 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_XYB_H_ +#define LIB_JXL_DEC_XYB_H_ + +// XYB -> linear sRGB. + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +// Parameters for XYB->sRGB conversion. +struct OpsinParams { + float inverse_opsin_matrix[9 * 4]; + float opsin_biases[4]; + float opsin_biases_cbrt[4]; + float quant_biases[4]; + void Init(float intensity_target); +}; + +struct OutputEncodingInfo { + // + // Fields depending only on image metadata + // + ColorEncoding orig_color_encoding; + // Used for the HLG OOTF and PQ tone mapping. + float orig_intensity_target; + // Opsin inverse matrix taken from the metadata. + float orig_inverse_matrix[9]; + bool default_transform; + bool xyb_encoded; + // + // Fields depending on output color encoding + // + ColorEncoding color_encoding; + bool color_encoding_is_original; + // Contains an opsin matrix that converts to the primaries of the output + // encoding. + OpsinParams opsin_params; + bool all_default_opsin; + // Used for Gamma and DCI transfer functions. + float inverse_gamma; + // Luminances of color_encoding's primaries, used for the HLG inverse OOTF and + // for PQ tone mapping. + // Default to sRGB's. + float luminances[3]; + // Used for the HLG inverse OOTF and PQ tone mapping. + float desired_intensity_target; + + Status SetFromMetadata(const CodecMetadata& metadata); + Status MaybeSetColorEncoding(const ColorEncoding& c_desired); + + private: + Status SetColorEncoding(const ColorEncoding& c_desired); +}; + +// Converts `inout` (not padded) from opsin to linear sRGB in-place. Called from +// per-pass postprocessing, hence parallelized. +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params); + +// Converts `opsin:rect` (opsin may be padded, rect.x0 must be vector-aligned) +// to linear sRGB. Called from whole-frame encoder, hence parallelized. +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params); + +// Bt.601 to match JPEG/JFIF. Inputs are _signed_ YCbCr values suitable for DCT, +// see F.1.1.3 of T.81 (because our data type is float, there is no need to add +// a bias to make the values unsigned). +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect); + +bool HasFastXYBTosRGB8(); +void FastXYBTosRGB8(const float* input[4], uint8_t* output, bool is_rgba, + size_t xsize); + +} // namespace jxl + +#endif // LIB_JXL_DEC_XYB_H_ diff --git a/media/libjxl/src/lib/jxl/decode.cc b/media/libjxl/src/lib/jxl/decode.cc new file mode 100644 index 000000000..1a0facce1 --- /dev/null +++ b/media/libjxl/src/lib/jxl/decode.cc @@ -0,0 +1,2907 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/decode.h" + +#include "jxl/types.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/box_content_decoder.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/decode_to_jpeg.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/memory_manager_internal.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/toc.h" + +namespace { + +// Checks if a + b > size, taking possible integer overflow into account. +bool OutOfBounds(size_t a, size_t b, size_t size) { + size_t pos = a + b; + if (pos > size) return true; + if (pos < a) return true; // overflow happened + return false; +} + +bool SumOverflows(size_t a, size_t b, size_t c) { + size_t sum = a + b; + if (sum < b) return true; + sum += c; + if (sum < c) return true; + return false; +} + +JXL_INLINE size_t InitialBasicInfoSizeHint() { + // Amount of bytes before the start of the codestream in the container format, + // assuming that the codestream is the first box after the signature and + // filetype boxes. 12 bytes signature box + 20 bytes filetype box + 16 bytes + // codestream box length + name + optional XLBox length. + const size_t container_header_size = 48; + + // Worst-case amount of bytes for basic info of the JPEG XL codestream header, + // that is all information up to and including extra_channel_bits. Up to + // around 2 bytes signature + 8 bytes SizeHeader + 31 bytes ColorEncoding + 4 + // bytes rest of ImageMetadata + 5 bytes part of ImageMetadata2. + // TODO(lode): recompute and update this value when alpha_bits is moved to + // extra channels info. + const size_t max_codestream_basic_info_size = 50; + + return container_header_size + max_codestream_basic_info_size; +} + +// Debug-printing failure macro similar to JXL_FAILURE, but for the status code +// JXL_DEC_ERROR +#ifdef JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JXL_DEC_ERROR) +#else // JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (((JXL_DEBUG_ON_ERROR) && \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__)), \ + JXL_DEC_ERROR) +#endif // JXL_CRASH_ON_ERROR + +JxlDecoderStatus ConvertStatus(JxlDecoderStatus status) { return status; } + +JxlDecoderStatus ConvertStatus(jxl::Status status) { + return status ? JXL_DEC_SUCCESS : JXL_DEC_ERROR; +} + +JxlSignature ReadSignature(const uint8_t* buf, size_t len, size_t* pos) { + if (*pos >= len) return JXL_SIG_NOT_ENOUGH_BYTES; + + buf += *pos; + len -= *pos; + + // JPEG XL codestream: 0xff 0x0a + if (len >= 1 && buf[0] == 0xff) { + if (len < 2) { + return JXL_SIG_NOT_ENOUGH_BYTES; + } else if (buf[1] == jxl::kCodestreamMarker) { + *pos += 2; + return JXL_SIG_CODESTREAM; + } else { + return JXL_SIG_INVALID; + } + } + + // JPEG XL container + if (len >= 1 && buf[0] == 0) { + if (len < 12) { + return JXL_SIG_NOT_ENOUGH_BYTES; + } else if (buf[1] == 0 && buf[2] == 0 && buf[3] == 0xC && buf[4] == 'J' && + buf[5] == 'X' && buf[6] == 'L' && buf[7] == ' ' && + buf[8] == 0xD && buf[9] == 0xA && buf[10] == 0x87 && + buf[11] == 0xA) { + *pos += 12; + return JXL_SIG_CONTAINER; + } else { + return JXL_SIG_INVALID; + } + } + + return JXL_SIG_INVALID; +} + +} // namespace + +uint32_t JxlDecoderVersion(void) { + return JPEGXL_MAJOR_VERSION * 1000000 + JPEGXL_MINOR_VERSION * 1000 + + JPEGXL_PATCH_VERSION; +} + +JxlSignature JxlSignatureCheck(const uint8_t* buf, size_t len) { + size_t pos = 0; + return ReadSignature(buf, len, &pos); +} + +namespace { + +size_t BitsPerChannel(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + default: + return 0; // signals unhandled JxlDataType + } +} + +enum class DecoderStage : uint32_t { + kInited, // Decoder created, no JxlDecoderProcessInput called yet + kStarted, // Running JxlDecoderProcessInput calls + kCodestreamFinished, // Codestream done, but other boxes could still occur. + // This stage can also occur before having seen the + // entire codestream if the user didn't subscribe to any + // codestream events at all, e.g. only to box events, + // or, the user only subscribed to basic info, and only + // the header of the codestream was parsed. + kError, // Error occurred, decoder object no longer usable +}; + +enum class FrameStage : uint32_t { + kHeader, // Must parse frame header. + kTOC, // Must parse TOC + kFull, // Must parse full pixels + kFullOutput, // Must output full pixels +}; + +enum class BoxStage : uint32_t { + kHeader, // Parsing box header of the next box, or start of non-container + // stream + kFtyp, // The ftyp box + kSkip, // Box whose contents are skipped + kCodestream, // Handling codestream box contents, or non-container stream + kPartialCodestream, // Handling the extra header of partial codestream box + kJpegRecon, // Handling jpeg reconstruction box +}; + +enum class JpegReconStage : uint32_t { + kNone, // Not outputting + kSettingMetadata, // Ready to output, must set metadata to the jpeg_data + kOutputting, // Currently outputting the JPEG bytes + kFinished, // JPEG reconstruction fully handled +}; + +/* +Given list of frame references to storage slots, and storage slots in which this +frame is saved, computes which frames are required to decode the frame at the +given index and any frames after it. The frames on which this depends are +returned as a vector of their indices, in no particular order. The given index +must be smaller than saved_as.size(), and references.size() must equal +saved_as.size(). Any frames beyond saved_as and references are considered +unknown future frames and must be treated as if something depends on them. +*/ +std::vector GetFrameDependencies(size_t index, + const std::vector& saved_as, + const std::vector& references) { + JXL_ASSERT(references.size() == saved_as.size()); + JXL_ASSERT(index < references.size()); + + std::vector result; + + constexpr size_t kNumStorage = 8; + + // value which indicates nothing is stored in this storage slot + const size_t invalid = references.size(); + // for each of the 8 storage slots, a vector that translates frame index to + // frame stored in this storage slot at this point, that is, the last + // frame that was stored in this slot before or at this index. + std::array, kNumStorage> storage; + for (size_t s = 0; s < kNumStorage; ++s) { + storage[s].resize(saved_as.size()); + int mask = 1 << s; + size_t id = invalid; + for (size_t i = 0; i < saved_as.size(); ++i) { + if (saved_as[i] & mask) { + id = i; + } + storage[s][i] = id; + } + } + + std::vector seen(index + 1, 0); + std::vector stack; + stack.push_back(index); + seen[index] = 1; + + // For frames after index, assume they can depend on any of the 8 storage + // slots, so push the frame for each stored reference to the stack and result. + // All frames after index are treated as having unknown references and with + // the possibility that there are more frames after the last known. + // TODO(lode): take values of saved_as and references after index, and a + // input flag indicating if they are all frames of the image, to further + // optimize this. + for (size_t s = 0; s < kNumStorage; ++s) { + size_t frame_ref = storage[s][index]; + if (frame_ref == invalid) continue; + if (seen[frame_ref]) continue; + stack.push_back(frame_ref); + seen[frame_ref] = 1; + result.push_back(frame_ref); + } + + while (!stack.empty()) { + size_t frame_index = stack.back(); + stack.pop_back(); + if (frame_index == 0) continue; // first frame cannot have references + for (size_t s = 0; s < kNumStorage; ++s) { + int mask = 1 << s; + if (!(references[frame_index] & mask)) continue; + size_t frame_ref = storage[s][frame_index - 1]; + if (frame_ref == invalid) continue; + if (seen[frame_ref]) continue; + stack.push_back(frame_ref); + seen[frame_ref] = 1; + result.push_back(frame_ref); + } + } + + return result; +} + +// Parameters for user-requested extra channel output. +struct ExtraChannelOutput { + JxlPixelFormat format; + void* buffer; + size_t buffer_size; +}; + +} // namespace + +namespace jxl { + +typedef struct JxlDecoderFrameIndexBoxEntryStruct { + // OFFi: offset of start byte of this frame compared to start + // byte of previous frame from this index in the JPEG XL codestream. For the + // first frame, this is the offset from the first byte of the JPEG XL + // codestream. + uint64_t OFFi; + // Ti: duration in ticks between the start of this frame and + // the start of the next frame in the index. If this is the last frame in the + // index, this is the duration in ticks between the start of this frame and + // the end of the stream. A tick lasts TNUM / TDEN seconds. + uint32_t Ti; + // Fi: amount of frames the next frame in the index occurs + // after this frame. If this is the last frame in the index, this is the + // amount of frames after this frame in the remainder of the stream. Only + // frames that are presented by the decoder are counted for this purpose, this + // excludes frames that are not intended for display but for compositing with + // other frames, such as frames that aren't the last frame with a duration of + // 0 ticks. + uint32_t Fi; +} JxlDecoderFrameIndexBoxEntry; + +typedef struct JxlDecoderFrameIndexBoxStruct { + int64_t NF() const { return entries.size(); } + int32_t TNUM = 1; + int32_t TDEN = 1000; + + std::vector entries; + + // That way we can ensure that every index box will have the first frame. + // If the API user decides to mark it as an indexed frame, we call + // the AddFrame again, this time with requested. + void AddFrame(uint64_t OFFi, uint32_t Ti, uint32_t Fi) { + JxlDecoderFrameIndexBoxEntry e; + e.OFFi = OFFi; + e.Ti = Ti; + e.Fi = Fi; + entries.push_back(e); + } +} JxlDecoderFrameIndexBox; + +} // namespace jxl + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct JxlDecoderStruct { + JxlDecoderStruct() = default; + + JxlMemoryManager memory_manager; + std::unique_ptr thread_pool; + + DecoderStage stage; + + // Status of progression, internal. + bool got_signature; + // Indicates we know that we've seen the last codestream box: either this + // was a jxlc box, or a jxlp box that has its index indicated as last by + // having its most significant bit set, or no boxes are used at all. This + // does not indicate the full codestream has already been seen, only the + // last box of it has been initiated. + bool last_codestream_seen; + bool got_codestream_signature; + bool got_basic_info; + bool got_transform_data; // To skip everything before ICC. + bool got_all_headers; // Codestream metadata headers. + bool post_headers; // Already decoding pixels. + jxl::ICCReader icc_reader; + jxl::JxlDecoderFrameIndexBox frame_index_box; + // This means either we actually got the preview image, or determined we + // cannot get it or there is none. + bool got_preview_image; + bool preview_frame; + + // Position of next_in in the original file including box format if present + // (as opposed to position in the codestream) + size_t file_pos; + + size_t box_contents_begin; + size_t box_contents_end; + size_t box_contents_size; + size_t box_size; + size_t header_size; + // Either a final box that runs until EOF, or the case of no container format + // at all. + bool box_contents_unbounded; + + JxlBoxType box_type; + JxlBoxType box_decoded_type; // Underlying type for brob boxes + // Set to true right after a JXL_DEC_BOX event only. + bool box_event; + bool decompress_boxes; + + bool box_out_buffer_set; + // Whether the out buffer is set for the current box, if the user did not yet + // release the buffer while the next box is encountered, this will be set to + // false. If this is false, no JXL_DEC_NEED_MORE_INPUT is emitted + // (irrespective of the value of box_out_buffer_set), because not setting + // output indicates the user does not wish the data of this box. + bool box_out_buffer_set_current_box; + uint8_t* box_out_buffer; + size_t box_out_buffer_size; + // which byte of the full box content the start of the out buffer points to + size_t box_out_buffer_begin; + // which byte of box_out_buffer to write to next + size_t box_out_buffer_pos; + + // Settings + bool keep_orientation; + bool unpremul_alpha; + bool render_spotcolors; + bool coalescing; + float desired_intensity_target; + + // Bitfield, for which informative events (JXL_DEC_BASIC_INFO, etc...) the + // decoder returns a status. By default, do not return for any of the events, + // only return when the decoder cannot continue because it needs more input or + // output data. + int events_wanted; + int orig_events_wanted; + + // Fields for reading the basic info from the header. + size_t basic_info_size_hint; + bool have_container; + size_t box_count; + + // The level of progressive detail in frame decoding. + JxlProgressiveDetail prog_detail = kDC; + // The progressive detail of the current frame. + JxlProgressiveDetail frame_prog_detail; + // The intended downsampling ratio for the current progression step. + size_t downsampling_target; + + // Whether the preview out buffer was set. It is possible for the buffer to + // be nullptr and buffer_set to be true, indicating it was deliberately + // set to nullptr. + bool preview_out_buffer_set; + // Idem for the image buffer. + // Set to true if either an image out buffer or an image out callback was set. + bool image_out_buffer_set; + + // Owned by the caller, buffers for DC image and full resolution images + void* preview_out_buffer; + void* image_out_buffer; + JxlImageOutInitCallback image_out_init_callback; + JxlImageOutRunCallback image_out_run_callback; + JxlImageOutDestroyCallback image_out_destroy_callback; + void* image_out_init_opaque; + struct SimpleImageOutCallback { + JxlImageOutCallback callback; + void* opaque; + }; + SimpleImageOutCallback simple_image_out_callback; + + size_t preview_out_size; + size_t image_out_size; + + JxlPixelFormat preview_out_format; + JxlPixelFormat image_out_format; + + // For extra channels. Empty if no extra channels are requested, and they are + // reset each frame + std::vector extra_channel_output; + + jxl::CodecMetadata metadata; + // Same as metadata.m, except for the color_encoding, which is set to the + // output encoding. + jxl::ImageMetadata image_metadata; + std::unique_ptr ib; + + std::unique_ptr passes_state; + std::unique_ptr frame_dec; + size_t next_section; + std::vector section_processed; + // The FrameDecoder is initialized, and not yet finalized + bool frame_dec_in_progress; + + // headers and TOC for the current frame. When got_toc is true, this is + // always the frame header of the last frame of the current still series, + // that is, the displayed frame. + std::unique_ptr frame_header; + + size_t remaining_frame_size; + FrameStage frame_stage; + bool dc_frame_progression_done; + // The currently processed frame is the last of the current composite still, + // and so must be returned as pixels + bool is_last_of_still; + // The currently processed frame is the last of the codestream + bool is_last_total; + // How many frames to skip. + size_t skip_frames; + // Skipping the current frame. May be false if skip_frames was just set to + // a positive value while already processing a current frame, then + // skipping_frame will be enabled only for the next frame. + bool skipping_frame; + + // Amount of internal frames and external frames started. External frames are + // user-visible frames, internal frames includes all external frames and + // also invisible frames such as patches, blending-only and dc_level frames. + size_t internal_frames; + size_t external_frames; + + // For each internal frame, which storage locations it references, and which + // storage locations it is stored in, using the bit mask as defined in + // FrameDecoder::References and FrameDecoder::SaveAs. + std::vector frame_references; + std::vector frame_saved_as; + + // Translates external frame index to internal frame index. The external + // index is the index of user-visible frames. The internal index can be larger + // since non-visible frames (such as frames with patches, ...) are included. + std::vector frame_external_to_internal; + + // Whether the frame with internal index is required to decode the frame + // being skipped to or any frames after that. If no skipping is active, + // this vector is ignored. If the current internal frame index is beyond this + // vector, it must be treated as a required frame. + std::vector frame_required; + + // Codestream input data is copied here temporarily when the decoder needs + // more input bytes to process the next part of the stream. We copy the input + // data in order to be able to release it all through the API it when + // returning JXL_DEC_NEED_MORE_INPUT. + std::vector codestream_copy; + // Number of bytes at the end of codestream_copy that were not yet consumed + // by calling AdvanceInput(). + size_t codestream_unconsumed; + // Position in the codestream_copy vector that the decoder already finished + // processing. It can be greater than the current size of codestream_copy in + // case where the decoder skips some parts of the frame that were not yet + // provided. + size_t codestream_pos; + // Number of bits after codestream_pos that were already processed. + size_t codestream_bits_ahead; + + BoxStage box_stage; + + jxl::JxlToJpegDecoder jpeg_decoder; + jxl::JxlBoxContentDecoder box_content_decoder; + // Decodes Exif or XMP metadata for JPEG reconstruction + jxl::JxlBoxContentDecoder metadata_decoder; + std::vector exif_metadata; + std::vector xmp_metadata; + // must store JPEG reconstruction metadata from the current box + // 0 = not stored, 1 = currently storing, 2 = finished + int store_exif; + int store_xmp; + size_t recon_out_buffer_pos; + size_t recon_exif_size; // Expected exif size as read from the jbrd box + size_t recon_xmp_size; // Expected exif size as read from the jbrd box + JpegReconStage recon_output_jpeg; + + bool JbrdNeedMoreBoxes() const { + // jbrd box wants exif but exif box not yet seen + if (store_exif < 2 && recon_exif_size > 0) return true; + // jbrd box wants xmp but xmp box not yet seen + if (store_xmp < 2 && recon_xmp_size > 0) return true; + return false; + } + + // Statistics which CodecInOut can keep + uint64_t dec_pixels; + + const uint8_t* next_in; + size_t avail_in; + bool input_closed; + + void AdvanceInput(size_t size) { + JXL_DASSERT(avail_in >= size); + next_in += size; + avail_in -= size; + file_pos += size; + } + + size_t AvailableCodestream() const { + size_t avail_codestream = avail_in; + if (!box_contents_unbounded) { + avail_codestream = + std::min(avail_codestream, box_contents_end - file_pos); + } + return avail_codestream; + } + + void AdvanceCodestream(size_t size) { + size_t avail_codestream = AvailableCodestream(); + if (codestream_copy.empty()) { + if (size <= avail_codestream) { + AdvanceInput(size); + } else { + codestream_pos = size - avail_codestream; + AdvanceInput(avail_codestream); + } + } else { + codestream_pos += size; + if (codestream_pos + codestream_unconsumed >= codestream_copy.size()) { + size_t advance = std::min( + codestream_unconsumed, + codestream_unconsumed + codestream_pos - codestream_copy.size()); + AdvanceInput(advance); + codestream_pos -= std::min(codestream_pos, codestream_copy.size()); + codestream_unconsumed = 0; + codestream_copy.clear(); + } + } + } + + JxlDecoderStatus RequestMoreInput() { + if (codestream_copy.empty()) { + size_t avail_codestream = AvailableCodestream(); + codestream_copy.insert(codestream_copy.end(), next_in, + next_in + avail_codestream); + AdvanceInput(avail_codestream); + } else { + AdvanceInput(codestream_unconsumed); + codestream_unconsumed = 0; + } + return JXL_DEC_NEED_MORE_INPUT; + } + + JxlDecoderStatus GetCodestreamInput(jxl::Span* span) { + if (codestream_copy.empty() && codestream_pos > 0) { + size_t avail_codestream = AvailableCodestream(); + size_t skip = std::min(codestream_pos, avail_codestream); + AdvanceInput(skip); + codestream_pos -= skip; + if (codestream_pos > 0) { + return RequestMoreInput(); + } + } + JXL_ASSERT(codestream_pos <= codestream_copy.size()); + JXL_ASSERT(codestream_unconsumed <= codestream_copy.size()); + size_t avail_codestream = AvailableCodestream(); + if (codestream_copy.empty()) { + if (avail_codestream == 0) { + return RequestMoreInput(); + } + *span = jxl::Span(next_in, avail_codestream); + return JXL_DEC_SUCCESS; + } else { + codestream_copy.insert(codestream_copy.end(), + next_in + codestream_unconsumed, + next_in + avail_codestream); + codestream_unconsumed = avail_codestream; + *span = jxl::Span(codestream_copy.data() + codestream_pos, + codestream_copy.size() - codestream_pos); + return JXL_DEC_SUCCESS; + } + } + + // Whether the decoder can use more codestream input for a purpose it needs. + // This returns false if the user didn't subscribe to any events that + // require the codestream (e.g. only subscribed to metadata boxes), or all + // parts of the codestream that are subscribed to (e.g. only basic info) have + // already occured. + bool CanUseMoreCodestreamInput() const { + // The decoder can set this to finished early if all relevant events were + // processed, so this check works. + return stage != DecoderStage::kCodestreamFinished; + } + + // If set then some operations will fail, if those would require + // allocating large objects. Actual memory usage might be two orders of + // magnitude bigger. + // TODO(eustas): remove once there is working API for memory / CPU limit. + size_t memory_limit_base = 0; + size_t cpu_limit_base = 0; + size_t used_cpu_base = 0; +}; + +namespace { + +bool CheckSizeLimit(JxlDecoder* dec, size_t xsize, size_t ysize) { + if (!dec->memory_limit_base) return true; + if (xsize == 0 || ysize == 0) return true; + if (xsize >= dec->memory_limit_base || ysize >= dec->memory_limit_base) { + return false; + } + // Rough estimate of real row length. + xsize = jxl::DivCeil(xsize, 32) * 32; + size_t num_pixels = xsize * ysize; + if (num_pixels / xsize != ysize) return false; // overflow + if (num_pixels > dec->memory_limit_base) return false; + return true; +} + +} // namespace + +// TODO(zond): Make this depend on the data loaded into the decoder. +JxlDecoderStatus JxlDecoderDefaultPixelFormat(const JxlDecoder* dec, + JxlPixelFormat* format) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + *format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + return JXL_DEC_SUCCESS; +} + +// Resets the state that must be reset for both Rewind and Reset +void JxlDecoderRewindDecodingState(JxlDecoder* dec) { + dec->stage = DecoderStage::kInited; + dec->got_signature = false; + dec->last_codestream_seen = false; + dec->got_codestream_signature = false; + dec->got_basic_info = false; + dec->got_transform_data = false; + dec->got_all_headers = false; + dec->post_headers = false; + dec->icc_reader.Reset(); + dec->got_preview_image = false; + dec->preview_frame = false; + dec->file_pos = 0; + dec->box_contents_begin = 0; + dec->box_contents_end = 0; + dec->box_contents_size = 0; + dec->box_size = 0; + dec->header_size = 0; + dec->box_contents_unbounded = false; + memset(dec->box_type, 0, sizeof(dec->box_type)); + memset(dec->box_decoded_type, 0, sizeof(dec->box_decoded_type)); + dec->box_event = false; + dec->box_stage = BoxStage::kHeader; + dec->box_out_buffer_set = false; + dec->box_out_buffer_set_current_box = false; + dec->box_out_buffer = nullptr; + dec->box_out_buffer_size = 0; + dec->box_out_buffer_begin = 0; + dec->box_out_buffer_pos = 0; + dec->exif_metadata.clear(); + dec->xmp_metadata.clear(); + dec->store_exif = 0; + dec->store_xmp = 0; + dec->recon_out_buffer_pos = 0; + dec->recon_exif_size = 0; + dec->recon_xmp_size = 0; + dec->recon_output_jpeg = JpegReconStage::kNone; + + dec->events_wanted = 0; + dec->basic_info_size_hint = InitialBasicInfoSizeHint(); + dec->have_container = 0; + dec->box_count = 0; + dec->downsampling_target = 8; + dec->preview_out_buffer_set = false; + dec->image_out_buffer_set = false; + dec->preview_out_buffer = nullptr; + dec->image_out_buffer = nullptr; + dec->image_out_init_callback = nullptr; + dec->image_out_run_callback = nullptr; + dec->image_out_destroy_callback = nullptr; + dec->image_out_init_opaque = nullptr; + dec->preview_out_size = 0; + dec->image_out_size = 0; + dec->extra_channel_output.clear(); + dec->dec_pixels = 0; + dec->next_in = 0; + dec->avail_in = 0; + dec->input_closed = false; + + dec->passes_state.reset(nullptr); + dec->frame_dec.reset(nullptr); + dec->next_section = 0; + dec->section_processed.clear(); + dec->frame_dec_in_progress = false; + + dec->ib.reset(); + dec->metadata = jxl::CodecMetadata(); + dec->image_metadata = dec->metadata.m; + dec->frame_header.reset(new jxl::FrameHeader(&dec->metadata)); + + dec->codestream_copy.clear(); + dec->codestream_unconsumed = 0; + dec->codestream_pos = 0; + dec->codestream_bits_ahead = 0; + + dec->frame_stage = FrameStage::kHeader; + dec->remaining_frame_size = 0; + dec->is_last_of_still = false; + dec->is_last_total = false; + dec->skip_frames = 0; + dec->skipping_frame = false; + dec->internal_frames = 0; + dec->external_frames = 0; +} + +void JxlDecoderReset(JxlDecoder* dec) { + JxlDecoderRewindDecodingState(dec); + + dec->thread_pool.reset(); + dec->keep_orientation = false; + dec->unpremul_alpha = false; + dec->render_spotcolors = true; + dec->coalescing = true; + dec->desired_intensity_target = 0; + dec->orig_events_wanted = 0; + dec->frame_references.clear(); + dec->frame_saved_as.clear(); + dec->frame_external_to_internal.clear(); + dec->frame_required.clear(); + dec->decompress_boxes = false; +} + +JxlDecoder* JxlDecoderCreate(const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlDecoder)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlDecoder* dec = new (alloc) JxlDecoder(); + dec->memory_manager = local_memory_manager; + +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + if (!memory_manager) { + dec->memory_limit_base = 53 << 16; + // Allow 5 x max_image_size processing units; every frame is accounted + // as W x H CPU processing units, so there could be numerous small frames + // or few larger ones. + dec->cpu_limit_base = 5 * dec->memory_limit_base; + } +#endif + + JxlDecoderReset(dec); + + return dec; +} + +void JxlDecoderDestroy(JxlDecoder* dec) { + if (dec) { + JxlMemoryManager local_memory_manager = dec->memory_manager; + // Call destructor directly since custom free function is used. + dec->~JxlDecoder(); + jxl::MemoryManagerFree(&local_memory_manager, dec); + } +} + +void JxlDecoderRewind(JxlDecoder* dec) { JxlDecoderRewindDecodingState(dec); } + +void JxlDecoderSkipFrames(JxlDecoder* dec, size_t amount) { + // Increment amount, rather than set it: making the amount smaller is + // impossible because the decoder may already have skipped frames required to + // decode earlier frames, and making the amount larger compared to an existing + // amount is impossible because if JxlDecoderSkipFrames is called in the + // middle of already skipping frames, the user cannot know how many frames + // have already been skipped internally so far so an absolute value cannot + // be defined. + dec->skip_frames += amount; + + dec->frame_required.clear(); + size_t next_frame = dec->external_frames + dec->skip_frames; + + // A frame that has been seen before a rewind + if (next_frame < dec->frame_external_to_internal.size()) { + size_t internal_index = dec->frame_external_to_internal[next_frame]; + if (internal_index < dec->frame_saved_as.size()) { + std::vector deps = GetFrameDependencies( + internal_index, dec->frame_saved_as, dec->frame_references); + + dec->frame_required.resize(internal_index + 1, 0); + for (size_t i = 0; i < deps.size(); i++) { + JXL_ASSERT(deps[i] < dec->frame_required.size()); + dec->frame_required[deps[i]] = 1; + } + } + } +} + +JxlDecoderStatus JxlDecoderSkipCurrentFrame(JxlDecoder* dec) { + if (!dec->frame_dec || !dec->frame_dec_in_progress) { + return JXL_DEC_ERROR; + } + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + dec->frame_dec_in_progress = false; + if (dec->is_last_of_still) { + dec->image_out_buffer_set = false; + } + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetParallelRunner(JxlDecoder* dec, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("parallel_runner must be set before starting"); + } + dec->thread_pool.reset( + new jxl::ThreadPool(parallel_runner, parallel_runner_opaque)); + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderSizeHintBasicInfo(const JxlDecoder* dec) { + if (dec->got_basic_info) return 0; + return dec->basic_info_size_hint; +} + +JxlDecoderStatus JxlDecoderSubscribeEvents(JxlDecoder* dec, int events_wanted) { + if (dec->stage != DecoderStage::kInited) { + return JXL_DEC_ERROR; // Cannot subscribe to events after having started. + } + if (events_wanted & 63) { + return JXL_DEC_ERROR; // Can only subscribe to informative events. + } + dec->events_wanted = events_wanted; + dec->orig_events_wanted = events_wanted; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetKeepOrientation(JxlDecoder* dec, + JXL_BOOL keep_orientation) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set keep_orientation option before starting"); + } + dec->keep_orientation = !!keep_orientation; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetUnpremultiplyAlpha(JxlDecoder* dec, + JXL_BOOL unpremul_alpha) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set unpremul_alpha option before starting"); + } + dec->unpremul_alpha = !!unpremul_alpha; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetRenderSpotcolors(JxlDecoder* dec, + JXL_BOOL render_spotcolors) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set render_spotcolors option before starting"); + } + dec->render_spotcolors = !!render_spotcolors; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetCoalescing(JxlDecoder* dec, JXL_BOOL coalescing) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set coalescing option before starting"); + } + dec->coalescing = !!coalescing; + return JXL_DEC_SUCCESS; +} + +namespace { +// helper function to get the dimensions of the current image buffer +void GetCurrentDimensions(const JxlDecoder* dec, size_t& xsize, size_t& ysize, + bool oriented) { + if (dec->frame_header->nonserialized_is_preview) { + xsize = dec->metadata.oriented_preview_xsize(dec->keep_orientation); + ysize = dec->metadata.oriented_preview_ysize(dec->keep_orientation); + return; + } + xsize = dec->metadata.oriented_xsize(dec->keep_orientation || !oriented); + ysize = dec->metadata.oriented_ysize(dec->keep_orientation || !oriented); + if (!dec->coalescing) { + const auto frame_dim = dec->frame_header->ToFrameDimensions(); + xsize = frame_dim.xsize_upsampled; + ysize = frame_dim.ysize_upsampled; + if (!dec->keep_orientation && oriented && + static_cast(dec->metadata.m.GetOrientation()) > 4) { + std::swap(xsize, ysize); + } + } +} +} // namespace + +namespace jxl { +namespace { + +template +bool CanRead(Span data, BitReader* reader, T* JXL_RESTRICT t) { + // Use a copy of the bit reader because CanRead advances bits. + BitReader reader2(data); + reader2.SkipBits(reader->TotalBitsConsumed()); + bool result = Bundle::CanRead(&reader2, t); + JXL_ASSERT(reader2.Close()); + return result; +} + +// Returns JXL_DEC_SUCCESS if the full bundle was successfully read, status +// indicating either error or need more input otherwise. +template +JxlDecoderStatus ReadBundle(JxlDecoder* dec, Span data, + BitReader* reader, T* JXL_RESTRICT t) { + if (!CanRead(data, reader, t)) { + return dec->RequestMoreInput(); + } + if (!Bundle::Read(reader, t)) { + return JXL_DEC_ERROR; + } + return JXL_DEC_SUCCESS; +} + +#define JXL_API_RETURN_IF_ERROR(expr) \ + { \ + JxlDecoderStatus status_ = ConvertStatus(expr); \ + if (status_ != JXL_DEC_SUCCESS) return status_; \ + } + +std::unique_ptr> GetBitReader( + Span span) { + BitReader* reader = new BitReader(span); + return std::unique_ptr>( + reader, [](BitReader* reader) { + // We can't allow Close to abort the program if the reader is out of + // bounds, or all return paths in the code, even those that already + // return failure, would have to manually call AllReadsWithinBounds(). + // Invalid JXL codestream should not cause program to quit. + (void)reader->AllReadsWithinBounds(); + (void)reader->Close(); + delete reader; + }); +} + +JxlDecoderStatus JxlDecoderReadBasicInfo(JxlDecoder* dec) { + if (!dec->got_codestream_signature) { + // Check and skip the codestream signature + Span span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + if (span.size() < 2) { + return dec->RequestMoreInput(); + } + if (span.data()[0] != 0xff || span.data()[1] != jxl::kCodestreamMarker) { + return JXL_API_ERROR("invalid signature"); + } + dec->got_codestream_signature = true; + dec->AdvanceCodestream(2); + } + + Span span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + JXL_API_RETURN_IF_ERROR( + ReadBundle(dec, span, reader.get(), &dec->metadata.size)); + JXL_API_RETURN_IF_ERROR( + ReadBundle(dec, span, reader.get(), &dec->metadata.m)); + size_t total_bits = reader->TotalBitsConsumed(); + dec->AdvanceCodestream(total_bits / jxl::kBitsPerByte); + dec->codestream_bits_ahead = total_bits % jxl::kBitsPerByte; + dec->got_basic_info = true; + dec->basic_info_size_hint = 0; + dec->image_metadata = dec->metadata.m; + JXL_DEBUG_V(2, "Decoded BasicInfo: %s", dec->metadata.DebugString().c_str()); + + if (!CheckSizeLimit(dec, dec->metadata.size.xsize(), + dec->metadata.size.ysize())) { + return JXL_API_ERROR("image is too large"); + } + + return JXL_DEC_SUCCESS; +} + +// Reads all codestream headers (but not frame headers) +JxlDecoderStatus JxlDecoderReadAllHeaders(JxlDecoder* dec) { + if (!dec->got_transform_data) { + Span span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + reader->SkipBits(dec->codestream_bits_ahead); + dec->metadata.transform_data.nonserialized_xyb_encoded = + dec->metadata.m.xyb_encoded; + JXL_API_RETURN_IF_ERROR( + ReadBundle(dec, span, reader.get(), &dec->metadata.transform_data)); + size_t total_bits = reader->TotalBitsConsumed(); + dec->AdvanceCodestream(total_bits / jxl::kBitsPerByte); + dec->codestream_bits_ahead = total_bits % jxl::kBitsPerByte; + dec->got_transform_data = true; + } + + Span span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + reader->SkipBits(dec->codestream_bits_ahead); + + if (dec->metadata.m.color_encoding.WantICC()) { + jxl::Status status = + dec->icc_reader.Init(reader.get(), dec->memory_limit_base); + // Always check AllReadsWithinBounds, not all the C++ decoder implementation + // handles reader out of bounds correctly yet (e.g. context map). Not + // checking AllReadsWithinBounds can cause reader->Close() to trigger an + // assert, but we don't want library to quit program for invalid codestream. + if (!reader->AllReadsWithinBounds() || + status.code() == StatusCode::kNotEnoughBytes) { + return dec->RequestMoreInput(); + } + if (!status) { + // Other non-successful status is an error + return JXL_DEC_ERROR; + } + PaddedBytes icc; + status = dec->icc_reader.Process(reader.get(), &icc); + if (status.code() == StatusCode::kNotEnoughBytes) { + return dec->RequestMoreInput(); + } + if (!status) { + // Other non-successful status is an error + return JXL_DEC_ERROR; + } + if (!dec->metadata.m.color_encoding.SetICCRaw(std::move(icc))) { + return JXL_DEC_ERROR; + } + } + + dec->got_all_headers = true; + JXL_API_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + + dec->AdvanceCodestream(reader->TotalBitsConsumed() / jxl::kBitsPerByte); + dec->codestream_bits_ahead = 0; + + if (!dec->passes_state) { + dec->passes_state.reset(new jxl::PassesDecoderState()); + } + + JXL_API_RETURN_IF_ERROR( + dec->passes_state->output_encoding_info.SetFromMetadata(dec->metadata)); + if (dec->desired_intensity_target > 0) { + dec->passes_state->output_encoding_info.desired_intensity_target = + dec->desired_intensity_target; + } + dec->image_metadata = dec->metadata.m; + + return JXL_DEC_SUCCESS; +} + +static size_t GetStride(const JxlDecoder* dec, const JxlPixelFormat& format) { + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize, true); + size_t stride = xsize * (BitsPerChannel(format.data_type) * + format.num_channels / jxl::kBitsPerByte); + if (format.align > 1) { + stride = jxl::DivCeil(stride, format.align) * format.align; + } + return stride; +} + +// Internal wrapper around jxl::ConvertToExternal which converts the stride, +// format and orientation and allows to choose whether to get all RGB(A) +// channels or alternatively get a single extra channel. +// If want_extra_channel, a valid index to a single extra channel must be +// given, the output must be single-channel, and format.num_channels is ignored +// and treated as if it is 1. +static JxlDecoderStatus ConvertImageInternal( + const JxlDecoder* dec, const jxl::ImageBundle& frame, + const JxlPixelFormat& format, bool want_extra_channel, + size_t extra_channel_index, void* out_image, size_t out_size, + const PixelCallback& out_callback) { + // TODO(lode): handle mismatch of RGB/grayscale color profiles and pixel data + // color/grayscale format + const size_t stride = GetStride(dec, format); + + bool float_format = format.data_type == JXL_TYPE_FLOAT || + format.data_type == JXL_TYPE_FLOAT16; + + jxl::Orientation undo_orientation = dec->keep_orientation + ? jxl::Orientation::kIdentity + : dec->metadata.m.GetOrientation(); + + jxl::Status status(true); + if (want_extra_channel) { + JXL_ASSERT(extra_channel_index < frame.extra_channels().size()); + status = jxl::ConvertToExternal(frame.extra_channels()[extra_channel_index], + BitsPerChannel(format.data_type), + float_format, format.endianness, stride, + dec->thread_pool.get(), out_image, out_size, + out_callback, undo_orientation); + } else { + status = jxl::ConvertToExternal( + frame, BitsPerChannel(format.data_type), float_format, + format.num_channels, format.endianness, stride, dec->thread_pool.get(), + out_image, out_size, out_callback, undo_orientation, + dec->unpremul_alpha); + } + + return status ? JXL_DEC_SUCCESS : JXL_DEC_ERROR; +} + +JxlDecoderStatus JxlDecoderProcessSections(JxlDecoder* dec) { + Span span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + const auto& toc = dec->frame_dec->Toc(); + size_t pos = 0; + std::vector section_info; + std::vector section_status; + for (size_t i = dec->next_section; i < toc.size(); ++i) { + if (dec->section_processed[i]) continue; + size_t id = toc[i].id; + size_t size = toc[i].size; + if (OutOfBounds(pos, size, span.size())) { + break; + } + auto br = + new jxl::BitReader(jxl::Span(span.data() + pos, size)); + section_info.emplace_back(jxl::FrameDecoder::SectionInfo{br, id}); + section_status.emplace_back(); + pos += size; + } + jxl::Status status = dec->frame_dec->ProcessSections( + section_info.data(), section_info.size(), section_status.data()); + bool out_of_bounds = false; + for (const auto& info : section_info) { + if (!info.br->AllReadsWithinBounds()) { + // Mark out of bounds section, but keep closing and deleting the next + // ones as well. + out_of_bounds = true; + } + JXL_ASSERT(info.br->Close()); + delete info.br; + } + if (out_of_bounds) { + // If any bit reader indicates out of bounds, it's an error, not just + // needing more input, since we ensure only bit readers containing + // a complete section are provided to the FrameDecoder. + return JXL_API_ERROR("frame out of bounds"); + } + if (!status) { + return JXL_API_ERROR("frame processing failed"); + } + bool found_skipped_section = false; + size_t num_done = 0; + size_t processed_bytes = 0; + for (size_t i = 0; i < section_status.size(); ++i) { + auto status = section_status[i]; + if (status == jxl::FrameDecoder::kDone) { + if (!found_skipped_section) { + processed_bytes += toc[dec->next_section + i].size; + ++num_done; + } + dec->section_processed[dec->next_section + i] = 1; + } else if (status == jxl::FrameDecoder::kSkipped) { + found_skipped_section = true; + } else { + return JXL_API_ERROR("unexpected section status"); + } + } + dec->next_section += num_done; + dec->remaining_frame_size -= processed_bytes; + dec->AdvanceCodestream(processed_bytes); + return JXL_DEC_SUCCESS; +} + +// TODO(eustas): no CodecInOut -> no image size reinforcement -> possible OOM. +JxlDecoderStatus JxlDecoderProcessCodestream(JxlDecoder* dec) { + // If no parallel runner is set, use the default + // TODO(lode): move this initialization to an appropriate location once the + // runner is used to decode pixels. + if (!dec->thread_pool) { + dec->thread_pool.reset(new jxl::ThreadPool(nullptr, nullptr)); + } + + // No matter what events are wanted, the basic info is always required. + if (!dec->got_basic_info) { + JxlDecoderStatus status = JxlDecoderReadBasicInfo(dec); + if (status != JXL_DEC_SUCCESS) return status; + } + + if (dec->events_wanted & JXL_DEC_BASIC_INFO) { + dec->events_wanted &= ~JXL_DEC_BASIC_INFO; + return JXL_DEC_BASIC_INFO; + } + + if (!dec->events_wanted) { + dec->stage = DecoderStage::kCodestreamFinished; + return JXL_DEC_SUCCESS; + } + + if (!dec->got_all_headers) { + JxlDecoderStatus status = JxlDecoderReadAllHeaders(dec); + if (status != JXL_DEC_SUCCESS) return status; + } + + if (dec->events_wanted & JXL_DEC_COLOR_ENCODING) { + dec->events_wanted &= ~JXL_DEC_COLOR_ENCODING; + return JXL_DEC_COLOR_ENCODING; + } + + if (!dec->events_wanted) { + dec->stage = DecoderStage::kCodestreamFinished; + return JXL_DEC_SUCCESS; + } + + dec->post_headers = true; + + if (!dec->got_preview_image && dec->metadata.m.have_preview) { + dec->preview_frame = true; + } + + // Handle frames + for (;;) { + bool parse_frames = + (dec->events_wanted & + (JXL_DEC_PREVIEW_IMAGE | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + if (!parse_frames) { + break; + } + if (dec->frame_stage == FrameStage::kHeader && dec->is_last_total) { + break; + } + if (dec->frame_stage == FrameStage::kHeader) { + if (dec->recon_output_jpeg == JpegReconStage::kSettingMetadata || + dec->recon_output_jpeg == JpegReconStage::kOutputting) { + // The image bundle contains the JPEG reconstruction frame, but the + // decoder is still waiting to decode an EXIF or XMP box. It's not + // implemented to decode additional frames during this, and a JPEG + // reconstruction image should have only one frame. + return JXL_API_ERROR( + "cannot decode a next frame after JPEG reconstruction frame"); + } + if (!dec->ib) { + dec->ib.reset(new jxl::ImageBundle(&dec->image_metadata)); + } + // If JPEG reconstruction is wanted and possible, set the jpeg_data of + // the ImageBundle. + if (!dec->jpeg_decoder.SetImageBundleJpegData(dec->ib.get())) + return JXL_DEC_ERROR; + + dec->frame_dec.reset(new FrameDecoder( + dec->passes_state.get(), dec->metadata, dec->thread_pool.get(), + /*use_slow_rendering_pipeline=*/false)); + dec->frame_header.reset(new FrameHeader(&dec->metadata)); + Span span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + bool output_needed = + (dec->preview_frame ? (dec->events_wanted & JXL_DEC_PREVIEW_IMAGE) + : (dec->events_wanted & JXL_DEC_FULL_IMAGE)); + jxl::Status status = dec->frame_dec->InitFrame( + reader.get(), dec->ib.get(), dec->preview_frame, output_needed); + if (!reader->AllReadsWithinBounds() || + status.code() == StatusCode::kNotEnoughBytes) { + return dec->RequestMoreInput(); + } else if (!status) { + return JXL_API_ERROR("invalid frame header"); + } + dec->AdvanceCodestream(reader->TotalBitsConsumed() / kBitsPerByte); + *dec->frame_header = dec->frame_dec->GetFrameHeader(); + jxl::FrameDimensions frame_dim = dec->frame_header->ToFrameDimensions(); + if (!CheckSizeLimit(dec, frame_dim.xsize_upsampled_padded, + frame_dim.ysize_upsampled_padded)) { + return JXL_API_ERROR("frame is too large"); + } + if (dec->cpu_limit_base != 0) { + // No overflow, checked in CheckSizeLimit. + size_t num_pixels = frame_dim.xsize * frame_dim.ysize; + if (dec->used_cpu_base + num_pixels < dec->used_cpu_base) { + return JXL_API_ERROR("used too much CPU"); + } + dec->used_cpu_base += num_pixels; + if (dec->used_cpu_base > dec->cpu_limit_base) { + return JXL_API_ERROR("used too much CPU"); + } + } + dec->remaining_frame_size = dec->frame_dec->SumSectionSizes(); + + dec->frame_stage = FrameStage::kTOC; + if (dec->preview_frame) { + if (!(dec->events_wanted & JXL_DEC_PREVIEW_IMAGE)) { + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + dec->got_preview_image = true; + dec->preview_frame = false; + } + continue; + } + + int saved_as = FrameDecoder::SavedAs(*dec->frame_header); + // is last in entire codestream + dec->is_last_total = dec->frame_header->is_last; + // is last of current still + dec->is_last_of_still = + dec->is_last_total || dec->frame_header->animation_frame.duration > 0; + // is kRegularFrame and coalescing is disabled + dec->is_last_of_still |= + (!dec->coalescing && + dec->frame_header->frame_type == FrameType::kRegularFrame); + const size_t internal_frame_index = dec->internal_frames; + const size_t external_frame_index = dec->external_frames; + if (dec->is_last_of_still) dec->external_frames++; + dec->internal_frames++; + + if (dec->skip_frames > 0) { + dec->skipping_frame = true; + if (dec->is_last_of_still) { + dec->skip_frames--; + } + } else { + dec->skipping_frame = false; + } + + if (external_frame_index >= dec->frame_external_to_internal.size()) { + dec->frame_external_to_internal.push_back(internal_frame_index); + JXL_ASSERT(dec->frame_external_to_internal.size() == + external_frame_index + 1); + } + + if (internal_frame_index >= dec->frame_saved_as.size()) { + dec->frame_saved_as.push_back(saved_as); + JXL_ASSERT(dec->frame_saved_as.size() == internal_frame_index + 1); + + // add the value 0xff (which means all references) to new slots: we only + // know the references of the frame at FinalizeFrame, and fill in the + // correct values there. As long as this information is not known, the + // worst case where the frame depends on all storage slots is assumed. + dec->frame_references.push_back(0xff); + JXL_ASSERT(dec->frame_references.size() == internal_frame_index + 1); + } + + if (dec->skipping_frame) { + // Whether this frame could be referenced by any future frame: either + // because it's a frame saved for blending or patches, or because it's + // a DC frame. + bool referenceable = + dec->frame_header->CanBeReferenced() || + dec->frame_header->frame_type == FrameType::kDCFrame; + if (internal_frame_index < dec->frame_required.size() && + !dec->frame_required[internal_frame_index]) { + referenceable = false; + } + if (!referenceable) { + // Skip all decoding for this frame, since the user is skipping this + // frame and no future frames can reference it. + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + continue; + } + } + + if ((dec->events_wanted & JXL_DEC_FRAME) && dec->is_last_of_still) { + // Only return this for the last of a series of stills: patches frames + // etc... before this one do not contain the correct information such + // as animation timing, ... + if (!dec->skipping_frame) { + return JXL_DEC_FRAME; + } + } + } + + if (dec->frame_stage == FrameStage::kTOC) { + dec->frame_dec->SetRenderSpotcolors(dec->render_spotcolors); + dec->frame_dec->SetCoalescing(dec->coalescing); + + if (!dec->preview_frame && + (dec->events_wanted & JXL_DEC_FRAME_PROGRESSION)) { + dec->frame_prog_detail = + dec->frame_dec->SetPauseAtProgressive(dec->prog_detail); + } else { + dec->frame_prog_detail = JxlProgressiveDetail::kFrames; + } + dec->dc_frame_progression_done = 0; + + dec->next_section = 0; + dec->section_processed.clear(); + dec->section_processed.resize(dec->frame_dec->Toc().size(), 0); + + // If we don't need pixels, we can skip actually decoding the frames + // (kFull / kFullOut). + if (dec->preview_frame || (dec->events_wanted & JXL_DEC_FULL_IMAGE)) { + dec->frame_dec_in_progress = true; + dec->frame_stage = FrameStage::kFull; + } else if (!dec->is_last_total) { + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + continue; + } else { + break; + } + } + + bool return_full_image = false; + + if (dec->frame_stage == FrameStage::kFull) { + if (dec->preview_frame) { + if (!dec->preview_out_buffer_set) { + return JXL_DEC_NEED_PREVIEW_OUT_BUFFER; + } + } else if (dec->events_wanted & JXL_DEC_FULL_IMAGE) { + if (!dec->image_out_buffer_set && + (!dec->jpeg_decoder.IsOutputSet() || + dec->ib->jpeg_data == nullptr) && + dec->is_last_of_still) { + // TODO(lode): remove the dec->is_last_of_still condition if the + // frame decoder needs the image buffer as working space for decoding + // non-visible or blending frames too + if (!dec->skipping_frame) { + return JXL_DEC_NEED_IMAGE_OUT_BUFFER; + } + } + } + + dec->frame_dec->MaybeSetUnpremultiplyAlpha(dec->unpremul_alpha); + + if (!dec->preview_frame && dec->image_out_buffer_set && + !!dec->image_out_buffer && + dec->image_out_format.data_type == JXL_TYPE_UINT8 && + dec->image_out_format.num_channels >= 3 && + dec->extra_channel_output.empty()) { + bool is_rgba = dec->image_out_format.num_channels == 4; + dec->frame_dec->MaybeSetRGB8OutputBuffer( + reinterpret_cast(dec->image_out_buffer), + GetStride(dec, dec->image_out_format), is_rgba, + !dec->keep_orientation); + } + + const bool little_endian = + dec->image_out_format.endianness == JXL_LITTLE_ENDIAN || + (dec->image_out_format.endianness == JXL_NATIVE_ENDIAN && + IsLittleEndian()); + bool swap_endianness = little_endian != IsLittleEndian(); + + // TODO(lode): Support more formats than just native endian float32 for + // the low-memory callback path + if (!dec->preview_frame && dec->image_out_buffer_set && + !!dec->image_out_init_callback && !!dec->image_out_run_callback && + dec->image_out_format.data_type == JXL_TYPE_FLOAT && + dec->image_out_format.num_channels >= 3 && + dec->extra_channel_output.empty() && !swap_endianness && + dec->frame_dec_in_progress) { + bool is_rgba = dec->image_out_format.num_channels == 4; + dec->frame_dec->MaybeSetFloatCallback( + PixelCallback{ + dec->image_out_init_callback, dec->image_out_run_callback, + dec->image_out_destroy_callback, dec->image_out_init_opaque}, + is_rgba, dec->unpremul_alpha, !dec->keep_orientation); + } + + size_t next_num_passes_to_pause = dec->frame_dec->NextNumPassesToPause(); + + JXL_API_RETURN_IF_ERROR(JxlDecoderProcessSections(dec)); + + bool all_sections_done = dec->frame_dec->HasDecodedAll(); + bool got_dc_only = !all_sections_done && dec->frame_dec->HasDecodedDC(); + + if (dec->frame_prog_detail >= JxlProgressiveDetail::kDC && + !dec->dc_frame_progression_done && got_dc_only) { + dec->dc_frame_progression_done = true; + dec->downsampling_target = 8; + return JXL_DEC_FRAME_PROGRESSION; + } + + bool new_progression_step_done = + dec->frame_dec->NumCompletePasses() >= next_num_passes_to_pause; + + if (!all_sections_done && + dec->frame_prog_detail >= JxlProgressiveDetail::kLastPasses && + new_progression_step_done) { + dec->downsampling_target = + dec->frame_header->passes.GetDownsamplingTargetForCompletedPasses( + dec->frame_dec->NumCompletePasses()); + return JXL_DEC_FRAME_PROGRESSION; + } + + if (!all_sections_done) { + // Not all sections have been processed yet + return dec->RequestMoreInput(); + } + + if (!dec->preview_frame) { + size_t internal_index = dec->internal_frames - 1; + JXL_ASSERT(dec->frame_references.size() > internal_index); + // Always fill this in, even if it was already written, it could be that + // this frame was skipped before and set to 255, while only now we know + // the true value. + dec->frame_references[internal_index] = dec->frame_dec->References(); + // Copy exif/xmp metadata from their boxes into the jpeg_data, if + // JPEG reconstruction is requested. + if (dec->jpeg_decoder.IsOutputSet() && dec->ib->jpeg_data != nullptr) { + } + } + + if (!dec->frame_dec->FinalizeFrame()) { + return JXL_API_ERROR("decoding frame failed"); + } + + dec->frame_dec_in_progress = false; + dec->frame_stage = FrameStage::kFullOutput; + } + + bool output_jpeg_reconstruction = false; + + if (dec->frame_stage == FrameStage::kFullOutput) { + if (dec->preview_frame) { + JxlDecoderStatus status = + ConvertImageInternal(dec, *dec->ib, dec->preview_out_format, + /*want_extra_channel=*/false, + /*extra_channel_index=*/0, + dec->preview_out_buffer, dec->preview_out_size, + /*out_callback=*/{}); + if (status != JXL_DEC_SUCCESS) return status; + } else if (dec->is_last_of_still) { + if (dec->events_wanted & JXL_DEC_FULL_IMAGE) { + dec->events_wanted &= ~JXL_DEC_FULL_IMAGE; + return_full_image = true; + } + + // Frame finished, restore the events_wanted with the per-frame events + // from orig_events_wanted, in case there is a next frame. + dec->events_wanted |= + (dec->orig_events_wanted & + (JXL_DEC_FULL_IMAGE | JXL_DEC_FRAME | JXL_DEC_FRAME_PROGRESSION)); + + // If no output buffer was set, we merely return the JXL_DEC_FULL_IMAGE + // status without outputting pixels. + if (dec->jpeg_decoder.IsOutputSet() && dec->ib->jpeg_data != nullptr) { + output_jpeg_reconstruction = true; + } else if (return_full_image && dec->image_out_buffer_set) { + if (!dec->frame_dec->HasRGBBuffer()) { + // Copy pixels if desired. + JxlDecoderStatus status = ConvertImageInternal( + dec, *dec->ib, dec->image_out_format, + /*want_extra_channel=*/false, + /*extra_channel_index=*/0, dec->image_out_buffer, + dec->image_out_size, + PixelCallback{dec->image_out_init_callback, + dec->image_out_run_callback, + dec->image_out_destroy_callback, + dec->image_out_init_opaque}); + if (status != JXL_DEC_SUCCESS) return status; + } + dec->image_out_buffer_set = false; + + bool has_ec = !dec->ib->extra_channels().empty(); + for (size_t i = 0; i < dec->extra_channel_output.size(); ++i) { + void* buffer = dec->extra_channel_output[i].buffer; + // buffer nullptr indicates this extra channel is not requested + if (!buffer) continue; + if (!has_ec) { + JXL_WARNING( + "Extra channels are not supported when callback is used"); + return JXL_DEC_ERROR; + } + const JxlPixelFormat* format = &dec->extra_channel_output[i].format; + JxlDecoderStatus status = ConvertImageInternal( + dec, *dec->ib, *format, + /*want_extra_channel=*/true, /*extra_channel_index=*/i, buffer, + dec->extra_channel_output[i].buffer_size, /*out_callback=*/{}); + if (status != JXL_DEC_SUCCESS) return status; + } + + dec->extra_channel_output.clear(); + } + } + } + + dec->frame_stage = FrameStage::kHeader; + + if (output_jpeg_reconstruction) { + dec->recon_output_jpeg = JpegReconStage::kSettingMetadata; + return JXL_DEC_FULL_IMAGE; + } else { + // The pixels have been output or are not needed, do not keep them in + // memory here. + dec->ib.reset(); + if (dec->preview_frame) { + dec->got_preview_image = true; + dec->preview_frame = false; + dec->events_wanted &= ~JXL_DEC_PREVIEW_IMAGE; + return JXL_DEC_PREVIEW_IMAGE; + } else if (return_full_image && !dec->skipping_frame) { + return JXL_DEC_FULL_IMAGE; + } + } + } + + dec->stage = DecoderStage::kCodestreamFinished; + // Return success, this means there is nothing more to do. + return JXL_DEC_SUCCESS; +} + +} // namespace +} // namespace jxl + +JxlDecoderStatus JxlDecoderSetInput(JxlDecoder* dec, const uint8_t* data, + size_t size) { + if (dec->next_in) { + return JXL_API_ERROR("already set input, use JxlDecoderReleaseInput first"); + } + if (dec->input_closed) { + return JXL_API_ERROR("input already closed"); + } + + dec->next_in = data; + dec->avail_in = size; + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderReleaseInput(JxlDecoder* dec) { + size_t result = dec->avail_in; + dec->next_in = nullptr; + dec->avail_in = 0; + return result; +} + +void JxlDecoderCloseInput(JxlDecoder* dec) { dec->input_closed = true; } + +JxlDecoderStatus JxlDecoderSetJPEGBuffer(JxlDecoder* dec, uint8_t* data, + size_t size) { + // JPEG reconstruction buffer can only set and updated before or during the + // first frame, the reconstruction box refers to the first frame and in + // theory multi-frame images should not be used with a jbrd box. + if (dec->internal_frames > 1) { + return JXL_API_ERROR("JPEG reconstruction only works for the first frame"); + } + if (dec->jpeg_decoder.IsOutputSet()) { + return JXL_API_ERROR("Already set JPEG buffer"); + } + return dec->jpeg_decoder.SetOutputBuffer(data, size); +} + +size_t JxlDecoderReleaseJPEGBuffer(JxlDecoder* dec) { + return dec->jpeg_decoder.ReleaseOutputBuffer(); +} + +// Parses the header of the box, outputting the 4-character type and the box +// size, including header size, as stored in the box header. +// @param in current input bytes. +// @param size available input size. +// @param pos position in the input, must begin at the header of the box. +// @param file_pos position of pos since the start of the JXL file, rather than +// the current input, used for integer overflow checking. +// @param type the output box type. +// @param box_size output the total box size, including header, in bytes, or 0 +// if it's a final unbounded box. +// @param header_size output size of the box header. +// @return JXL_DEC_SUCCESS if the box header was fully parsed. In that case the +// parsing position must be incremented by header_size bytes. +// JXL_DEC_NEED_MORE_INPUT if not enough input bytes available, in that case +// header_size indicates a lower bound for the known size the header has to be +// at least. JXL_DEC_ERROR if the box header is invalid. +static JxlDecoderStatus ParseBoxHeader(const uint8_t* in, size_t size, + size_t pos, size_t file_pos, + JxlBoxType type, uint64_t* box_size, + uint64_t* header_size) { + if (OutOfBounds(pos, 8, size)) { + *header_size = 8; + return JXL_DEC_NEED_MORE_INPUT; + } + size_t box_start = pos; + // Box size, including this header itself. + *box_size = LoadBE32(in + pos); + pos += 4; + if (*box_size == 1) { + *header_size = 16; + if (OutOfBounds(pos, 12, size)) return JXL_DEC_NEED_MORE_INPUT; + *box_size = LoadBE64(in + pos); + pos += 8; + } + memcpy(type, in + pos, 4); + pos += 4; + *header_size = pos - box_start; + if (*box_size > 0 && *box_size < *header_size) { + return JXL_API_ERROR("invalid box size"); + } + if (SumOverflows(file_pos, pos, *box_size)) { + return JXL_API_ERROR("Box size overflow"); + } + return JXL_DEC_SUCCESS; +} + +// This includes handling the codestream if it is not a box-based jxl file. +static JxlDecoderStatus HandleBoxes(JxlDecoder* dec) { + // Box handling loop + for (;;) { + if (dec->box_stage != BoxStage::kHeader) { + dec->AdvanceInput(dec->header_size); + dec->header_size = 0; + if ((dec->events_wanted & JXL_DEC_BOX) && + dec->box_out_buffer_set_current_box) { + uint8_t* next_out = dec->box_out_buffer + dec->box_out_buffer_pos; + size_t avail_out = dec->box_out_buffer_size - dec->box_out_buffer_pos; + + JxlDecoderStatus box_result = dec->box_content_decoder.Process( + dec->next_in, dec->avail_in, + dec->file_pos - dec->box_contents_begin, &next_out, &avail_out); + size_t produced = + next_out - (dec->box_out_buffer + dec->box_out_buffer_pos); + dec->box_out_buffer_pos += produced; + + // Don't return JXL_DEC_NEED_MORE_INPUT: the box stages below, instead, + // handle the input progression, and the above only outputs the part of + // the box seen so far. + if (box_result != JXL_DEC_SUCCESS && + box_result != JXL_DEC_NEED_MORE_INPUT) { + return box_result; + } + } + + if (dec->store_exif == 1 || dec->store_xmp == 1) { + std::vector& metadata = + (dec->store_exif == 1) ? dec->exif_metadata : dec->xmp_metadata; + for (;;) { + if (metadata.empty()) metadata.resize(64); + uint8_t* orig_next_out = metadata.data() + dec->recon_out_buffer_pos; + uint8_t* next_out = orig_next_out; + size_t avail_out = metadata.size() - dec->recon_out_buffer_pos; + JxlDecoderStatus box_result = dec->metadata_decoder.Process( + dec->next_in, dec->avail_in, + dec->file_pos - dec->box_contents_begin, &next_out, &avail_out); + size_t produced = next_out - orig_next_out; + dec->recon_out_buffer_pos += produced; + if (box_result == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + metadata.resize(metadata.size() * 2); + } else if (box_result == JXL_DEC_NEED_MORE_INPUT) { + break; // box stage handling below will handle this instead + } else if (box_result == JXL_DEC_SUCCESS) { + size_t needed_size = (dec->store_exif == 1) ? dec->recon_exif_size + : dec->recon_xmp_size; + if (dec->box_contents_unbounded && + dec->recon_out_buffer_pos < needed_size) { + // Unbounded box, but we know the expected size due to the jbrd + // box's data. Treat this as the JXL_DEC_NEED_MORE_INPUT case. + break; + } else { + metadata.resize(dec->recon_out_buffer_pos); + if (dec->store_exif == 1) dec->store_exif = 2; + if (dec->store_xmp == 1) dec->store_xmp = 2; + break; + } + } else { + // error + return box_result; + } + } + } + } + + if (dec->recon_output_jpeg == JpegReconStage::kSettingMetadata && + !dec->JbrdNeedMoreBoxes()) { + jxl::jpeg::JPEGData* jpeg_data = dec->ib->jpeg_data.get(); + if (dec->recon_exif_size) { + JxlDecoderStatus status = jxl::JxlToJpegDecoder::SetExif( + dec->exif_metadata.data(), dec->exif_metadata.size(), jpeg_data); + if (status != JXL_DEC_SUCCESS) return status; + } + if (dec->recon_xmp_size) { + JxlDecoderStatus status = jxl::JxlToJpegDecoder::SetXmp( + dec->xmp_metadata.data(), dec->xmp_metadata.size(), jpeg_data); + if (status != JXL_DEC_SUCCESS) return status; + } + dec->recon_output_jpeg = JpegReconStage::kOutputting; + } + + if (dec->recon_output_jpeg == JpegReconStage::kOutputting && + !dec->JbrdNeedMoreBoxes()) { + JxlDecoderStatus status = + dec->jpeg_decoder.WriteOutput(*dec->ib->jpeg_data); + if (status != JXL_DEC_SUCCESS) return status; + dec->recon_output_jpeg = JpegReconStage::kFinished; + dec->ib.reset(); + if (dec->events_wanted & JXL_DEC_FULL_IMAGE) { + // Return the full image event here now, this may be delayed if this + // could only be done after decoding an exif or xmp box after the + // codestream. + return JXL_DEC_FULL_IMAGE; + } + } + + if (dec->box_stage == BoxStage::kHeader) { + if (!dec->have_container) { + if (dec->stage == DecoderStage::kCodestreamFinished) + return JXL_DEC_SUCCESS; + dec->box_stage = BoxStage::kCodestream; + dec->box_contents_unbounded = true; + continue; + } + if (dec->avail_in == 0) { + if (dec->stage != DecoderStage::kCodestreamFinished) { + // Not yet seen (all) codestream boxes. + return JXL_DEC_NEED_MORE_INPUT; + } + if (dec->JbrdNeedMoreBoxes()) { + return JXL_DEC_NEED_MORE_INPUT; + } + if (dec->input_closed) { + return JXL_DEC_SUCCESS; + } + if (!(dec->events_wanted & JXL_DEC_BOX)) { + // All codestream and jbrd metadata boxes finished, and no individual + // boxes requested by user, so no need to request any more input. + // This returns success for backwards compatibility, when + // JxlDecoderCloseInput and JXL_DEC_BOX did not exist, as well + // as for efficiency. + return JXL_DEC_SUCCESS; + } + // Even though we are exactly at a box end, there still may be more + // boxes. The user may call JxlDecoderCloseInput to indicate the input + // is finished and get success instead. + return JXL_DEC_NEED_MORE_INPUT; + } + + bool boxed_codestream_done = + ((dec->events_wanted & JXL_DEC_BOX) && + dec->stage == DecoderStage::kCodestreamFinished && + dec->last_codestream_seen && !dec->JbrdNeedMoreBoxes()); + if (boxed_codestream_done && dec->avail_in >= 2 && + dec->next_in[0] == 0xff && + dec->next_in[1] == jxl::kCodestreamMarker) { + // We detected the start of the next naked codestream, so we can return + // success here. + return JXL_DEC_SUCCESS; + } + + uint64_t box_size, header_size; + JxlDecoderStatus status = + ParseBoxHeader(dec->next_in, dec->avail_in, 0, dec->file_pos, + dec->box_type, &box_size, &header_size); + if (status != JXL_DEC_SUCCESS) { + if (status == JXL_DEC_NEED_MORE_INPUT) { + dec->basic_info_size_hint = + InitialBasicInfoSizeHint() + header_size - dec->file_pos; + } + return status; + } + if (memcmp(dec->box_type, "brob", 4) == 0) { + if (dec->avail_in < header_size + 4) { + return JXL_DEC_NEED_MORE_INPUT; + } + memcpy(dec->box_decoded_type, dec->next_in + header_size, + sizeof(dec->box_decoded_type)); + } else { + memcpy(dec->box_decoded_type, dec->box_type, + sizeof(dec->box_decoded_type)); + } + + // Box order validity checks + // The signature box at box_count == 1 is not checked here since that's + // already done at the beginning. + dec->box_count++; + if (boxed_codestream_done && memcmp(dec->box_type, "JXL ", 4) == 0) { + // We detected the start of the next boxed stream, so we can return + // success here. + return JXL_DEC_SUCCESS; + } + if (dec->box_count == 2 && memcmp(dec->box_type, "ftyp", 4) != 0) { + return JXL_API_ERROR("the second box must be the ftyp box"); + } + if (memcmp(dec->box_type, "ftyp", 4) == 0 && dec->box_count != 2) { + return JXL_API_ERROR("the ftyp box must come second"); + } + + dec->box_contents_unbounded = (box_size == 0); + dec->box_contents_begin = dec->file_pos + header_size; + dec->box_contents_end = + dec->box_contents_unbounded ? 0 : (dec->file_pos + box_size); + dec->box_contents_size = + dec->box_contents_unbounded ? 0 : (box_size - header_size); + dec->box_size = box_size; + dec->header_size = header_size; + + if (dec->orig_events_wanted & JXL_DEC_JPEG_RECONSTRUCTION) { + // Initiate storing of Exif or XMP data for JPEG reconstruction + if (dec->store_exif == 0 && + memcmp(dec->box_decoded_type, "Exif", 4) == 0) { + dec->store_exif = 1; + dec->recon_out_buffer_pos = 0; + } + if (dec->store_xmp == 0 && + memcmp(dec->box_decoded_type, "xml ", 4) == 0) { + dec->store_xmp = 1; + dec->recon_out_buffer_pos = 0; + } + } + + if (dec->events_wanted & JXL_DEC_BOX) { + bool decompress = + dec->decompress_boxes && memcmp(dec->box_type, "brob", 4) == 0; + dec->box_content_decoder.StartBox( + decompress, dec->box_contents_unbounded, dec->box_contents_size); + } + if (dec->store_exif == 1 || dec->store_xmp == 1) { + bool brob = memcmp(dec->box_type, "brob", 4) == 0; + dec->metadata_decoder.StartBox(brob, dec->box_contents_unbounded, + dec->box_contents_size); + } + + if (memcmp(dec->box_type, "ftyp", 4) == 0) { + dec->box_stage = BoxStage::kFtyp; + } else if (memcmp(dec->box_type, "jxlc", 4) == 0) { + if (dec->last_codestream_seen) { + return JXL_API_ERROR("there can only be one jxlc box"); + } + dec->last_codestream_seen = true; + dec->box_stage = BoxStage::kCodestream; + } else if (memcmp(dec->box_type, "jxlp", 4) == 0) { + dec->box_stage = BoxStage::kPartialCodestream; + } else if ((dec->orig_events_wanted & JXL_DEC_JPEG_RECONSTRUCTION) && + memcmp(dec->box_type, "jbrd", 4) == 0) { + if (!(dec->events_wanted & JXL_DEC_JPEG_RECONSTRUCTION)) { + return JXL_API_ERROR( + "multiple JPEG reconstruction boxes not supported"); + } + dec->box_stage = BoxStage::kJpegRecon; + } else { + dec->box_stage = BoxStage::kSkip; + } + + if (dec->events_wanted & JXL_DEC_BOX) { + dec->box_event = true; + dec->box_out_buffer_set_current_box = false; + return JXL_DEC_BOX; + } + } else if (dec->box_stage == BoxStage::kFtyp) { + if (dec->box_contents_size < 12) { + return JXL_API_ERROR("file type box too small"); + } + if (dec->avail_in < 4) return JXL_DEC_NEED_MORE_INPUT; + if (memcmp(dec->next_in, "jxl ", 4) != 0) { + return JXL_API_ERROR("file type box major brand must be \"jxl \""); + } + dec->AdvanceInput(4); + dec->box_stage = BoxStage::kSkip; + } else if (dec->box_stage == BoxStage::kPartialCodestream) { + if (dec->last_codestream_seen) { + return JXL_API_ERROR("cannot have jxlp box after last jxlp box"); + } + // TODO(lode): error if box is unbounded but last bit not set + if (dec->avail_in < 4) return JXL_DEC_NEED_MORE_INPUT; + if (!dec->box_contents_unbounded && dec->box_contents_size < 4) { + return JXL_API_ERROR("jxlp box too small to contain index"); + } + size_t jxlp_index = LoadBE32(dec->next_in); + // The high bit of jxlp_index indicates whether this is the last + // jxlp box. + if (jxlp_index & 0x80000000) { + dec->last_codestream_seen = true; + } + dec->AdvanceInput(4); + dec->box_stage = BoxStage::kCodestream; + } else if (dec->box_stage == BoxStage::kCodestream) { + JxlDecoderStatus status = jxl::JxlDecoderProcessCodestream(dec); + if (status == JXL_DEC_FULL_IMAGE) { + if (dec->recon_output_jpeg != JpegReconStage::kNone) { + continue; + } + } + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (dec->file_pos == dec->box_contents_end && + !dec->box_contents_unbounded) { + dec->box_stage = BoxStage::kHeader; + continue; + } + } + + if (status == JXL_DEC_SUCCESS) { + if (dec->JbrdNeedMoreBoxes()) { + dec->box_stage = BoxStage::kSkip; + continue; + } + if (dec->box_contents_unbounded) { + // Last box reached and codestream done, nothing more to do. + break; + } + if (dec->events_wanted & JXL_DEC_BOX) { + // Codestream done, but there may be more other boxes. + dec->box_stage = BoxStage::kSkip; + continue; + } + } + return status; + } else if (dec->box_stage == BoxStage::kJpegRecon) { + if (!dec->jpeg_decoder.IsParsingBox()) { + // This is a new JPEG reconstruction metadata box. + dec->jpeg_decoder.StartBox(dec->box_contents_unbounded, + dec->box_contents_size); + } + const uint8_t* next_in = dec->next_in; + size_t avail_in = dec->avail_in; + JxlDecoderStatus recon_result = + dec->jpeg_decoder.Process(&next_in, &avail_in); + size_t consumed = next_in - dec->next_in; + dec->AdvanceInput(consumed); + if (recon_result == JXL_DEC_JPEG_RECONSTRUCTION) { + jxl::jpeg::JPEGData* jpeg_data = dec->jpeg_decoder.GetJpegData(); + size_t num_exif = jxl::JxlToJpegDecoder::NumExifMarkers(*jpeg_data); + size_t num_xmp = jxl::JxlToJpegDecoder::NumXmpMarkers(*jpeg_data); + if (num_exif) { + if (num_exif > 1) { + return JXL_API_ERROR( + "multiple exif markers for JPEG reconstruction not supported"); + } + if (JXL_DEC_SUCCESS != jxl::JxlToJpegDecoder::ExifBoxContentSize( + *jpeg_data, &dec->recon_exif_size)) { + return JXL_API_ERROR("invalid jbrd exif size"); + } + } + if (num_xmp) { + if (num_xmp > 1) { + return JXL_API_ERROR( + "multiple XMP markers for JPEG reconstruction not supported"); + } + if (JXL_DEC_SUCCESS != jxl::JxlToJpegDecoder::XmlBoxContentSize( + *jpeg_data, &dec->recon_xmp_size)) { + return JXL_API_ERROR("invalid jbrd XMP size"); + } + } + + dec->box_stage = BoxStage::kHeader; + // If successful JPEG reconstruction, return the success if the user + // cares about it, otherwise continue. + if (dec->events_wanted & recon_result) { + dec->events_wanted &= ~recon_result; + return recon_result; + } + } else { + // If anything else, return the result. + return recon_result; + } + } else if (dec->box_stage == BoxStage::kSkip) { + if (dec->box_contents_unbounded) { + if (dec->input_closed) { + return JXL_DEC_SUCCESS; + } + if (!(dec->box_out_buffer_set)) { + // An unbounded box is always the last box. Not requesting box data, + // so return success even if JxlDecoderCloseInput was not called for + // backwards compatibility as well as efficiency since this box is + // being skipped. + return JXL_DEC_SUCCESS; + } + // Arbitrarily more bytes may follow, only JxlDecoderCloseInput can + // mark the end. + dec->AdvanceInput(dec->avail_in); + return JXL_DEC_NEED_MORE_INPUT; + } + // Amount of remaining bytes in the box that is being skipped. + size_t remaining = dec->box_contents_end - dec->file_pos; + if (dec->avail_in < remaining) { + // Indicate how many more bytes needed starting from next_in. + dec->basic_info_size_hint = + InitialBasicInfoSizeHint() + dec->box_contents_end - dec->file_pos; + // Don't have the full box yet, skip all we have so far + dec->AdvanceInput(dec->avail_in); + return JXL_DEC_NEED_MORE_INPUT; + } else { + // Full box available, skip all its remaining bytes + dec->AdvanceInput(remaining); + dec->box_stage = BoxStage::kHeader; + } + } else { + JXL_DASSERT(false); // unknown box stage + } + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderProcessInput(JxlDecoder* dec) { + if (dec->stage == DecoderStage::kInited) { + dec->stage = DecoderStage::kStarted; + } + if (dec->stage == DecoderStage::kError) { + return JXL_API_ERROR( + "Cannot keep using decoder after it encountered an error, use " + "JxlDecoderReset to reset it"); + } + + if (!dec->got_signature) { + JxlSignature sig = JxlSignatureCheck(dec->next_in, dec->avail_in); + if (sig == JXL_SIG_INVALID) return JXL_API_ERROR("invalid signature"); + if (sig == JXL_SIG_NOT_ENOUGH_BYTES) { + if (dec->input_closed) { + return JXL_API_ERROR("file too small for signature"); + } + return JXL_DEC_NEED_MORE_INPUT; + } + + dec->got_signature = true; + + if (sig == JXL_SIG_CONTAINER) { + dec->have_container = 1; + } else { + dec->last_codestream_seen = true; + } + } + + JxlDecoderStatus status = HandleBoxes(dec); + + if (status == JXL_DEC_NEED_MORE_INPUT && dec->input_closed) { + return JXL_API_ERROR("missing input"); + } + + // Even if the box handling returns success, certain types of + // data may be missing. + if (status == JXL_DEC_SUCCESS) { + if (dec->CanUseMoreCodestreamInput()) { + return JXL_API_ERROR("codestream never finished"); + } + if (dec->JbrdNeedMoreBoxes()) { + return JXL_API_ERROR("missing metadata boxes for jpeg reconstruction"); + } + } + + return status; +} + +// To ensure ABI forward-compatibility, this struct has a constant size. +static_assert(sizeof(JxlBasicInfo) == 204, + "JxlBasicInfo struct size should remain constant"); + +JxlDecoderStatus JxlDecoderGetBasicInfo(const JxlDecoder* dec, + JxlBasicInfo* info) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + if (info) { + memset(info, 0, sizeof(*info)); + + const jxl::ImageMetadata& meta = dec->metadata.m; + + info->have_container = dec->have_container; + info->xsize = dec->metadata.size.xsize(); + info->ysize = dec->metadata.size.ysize(); + info->uses_original_profile = !meta.xyb_encoded; + + info->bits_per_sample = meta.bit_depth.bits_per_sample; + info->exponent_bits_per_sample = meta.bit_depth.exponent_bits_per_sample; + + info->have_preview = meta.have_preview; + info->have_animation = meta.have_animation; + info->orientation = static_cast(meta.orientation); + + if (!dec->keep_orientation) { + if (info->orientation >= JXL_ORIENT_TRANSPOSE) { + std::swap(info->xsize, info->ysize); + } + info->orientation = JXL_ORIENT_IDENTITY; + } + + info->intensity_target = meta.IntensityTarget(); + if (dec->desired_intensity_target > 0) { + info->intensity_target = dec->desired_intensity_target; + } + info->min_nits = meta.tone_mapping.min_nits; + info->relative_to_max_display = meta.tone_mapping.relative_to_max_display; + info->linear_below = meta.tone_mapping.linear_below; + + const jxl::ExtraChannelInfo* alpha = meta.Find(jxl::ExtraChannel::kAlpha); + if (alpha != nullptr) { + info->alpha_bits = alpha->bit_depth.bits_per_sample; + info->alpha_exponent_bits = alpha->bit_depth.exponent_bits_per_sample; + info->alpha_premultiplied = alpha->alpha_associated; + } else { + info->alpha_bits = 0; + info->alpha_exponent_bits = 0; + info->alpha_premultiplied = 0; + } + + info->num_color_channels = + meta.color_encoding.GetColorSpace() == jxl::ColorSpace::kGray ? 1 : 3; + + info->num_extra_channels = meta.num_extra_channels; + + if (info->have_preview) { + info->preview.xsize = dec->metadata.m.preview_size.xsize(); + info->preview.ysize = dec->metadata.m.preview_size.ysize(); + } + + if (info->have_animation) { + info->animation.tps_numerator = dec->metadata.m.animation.tps_numerator; + info->animation.tps_denominator = + dec->metadata.m.animation.tps_denominator; + info->animation.num_loops = dec->metadata.m.animation.num_loops; + info->animation.have_timecodes = dec->metadata.m.animation.have_timecodes; + } + + if (meta.have_intrinsic_size) { + info->intrinsic_xsize = dec->metadata.m.intrinsic_size.xsize(); + info->intrinsic_ysize = dec->metadata.m.intrinsic_size.ysize(); + } else { + info->intrinsic_xsize = info->xsize; + info->intrinsic_ysize = info->ysize; + } + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelInfo(const JxlDecoder* dec, + size_t index, + JxlExtraChannelInfo* info) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + const std::vector& channels = + dec->metadata.m.extra_channel_info; + + if (index >= channels.size()) return JXL_DEC_ERROR; // out of bounds + const jxl::ExtraChannelInfo& channel = channels[index]; + + info->type = static_cast(channel.type); + info->bits_per_sample = channel.bit_depth.bits_per_sample; + info->exponent_bits_per_sample = + channel.bit_depth.floating_point_sample + ? channel.bit_depth.exponent_bits_per_sample + : 0; + info->dim_shift = channel.dim_shift; + info->name_length = channel.name.size(); + info->alpha_premultiplied = channel.alpha_associated; + info->spot_color[0] = channel.spot_color[0]; + info->spot_color[1] = channel.spot_color[1]; + info->spot_color[2] = channel.spot_color[2]; + info->spot_color[3] = channel.spot_color[3]; + info->cfa_channel = channel.cfa_channel; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelName(const JxlDecoder* dec, + size_t index, char* name, + size_t size) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + const std::vector& channels = + dec->metadata.m.extra_channel_info; + + if (index >= channels.size()) return JXL_DEC_ERROR; // out of bounds + const jxl::ExtraChannelInfo& channel = channels[index]; + + // Also need null-termination character + if (channel.name.size() + 1 > size) return JXL_DEC_ERROR; + + memcpy(name, channel.name.c_str(), channel.name.size() + 1); + + return JXL_DEC_SUCCESS; +} + +namespace { + +// Gets the jxl::ColorEncoding for the desired target, and checks errors. +// Returns the object regardless of whether the actual color space is in ICC, +// but ensures that if the color encoding is not the encoding from the +// codestream header metadata, it cannot require ICC profile. +JxlDecoderStatus GetColorEncodingForTarget( + const JxlDecoder* dec, JxlColorProfileTarget target, + const jxl::ColorEncoding** encoding) { + if (!dec->got_all_headers) return JXL_DEC_NEED_MORE_INPUT; + *encoding = nullptr; + if (target == JXL_COLOR_PROFILE_TARGET_DATA && dec->metadata.m.xyb_encoded) { + *encoding = &dec->passes_state->output_encoding_info.color_encoding; + } else { + *encoding = &dec->metadata.m.color_encoding; + } + return JXL_DEC_SUCCESS; +} +} // namespace + +JxlDecoderStatus JxlDecoderGetColorAsEncodedProfile( + const JxlDecoder* dec, const JxlPixelFormat* unused_format, + JxlColorProfileTarget target, JxlColorEncoding* color_encoding) { + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + JxlDecoderStatus status = + GetColorEncodingForTarget(dec, target, &jxl_color_encoding); + if (status) return status; + + if (jxl_color_encoding->WantICC()) + return JXL_DEC_ERROR; // Indicate no encoded profile available. + + if (color_encoding) { + ConvertInternalToExternalColorEncoding(*jxl_color_encoding, color_encoding); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetICCProfileSize( + const JxlDecoder* dec, const JxlPixelFormat* unused_format, + JxlColorProfileTarget target, size_t* size) { + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + JxlDecoderStatus status = + GetColorEncodingForTarget(dec, target, &jxl_color_encoding); + if (status != JXL_DEC_SUCCESS) return status; + + if (jxl_color_encoding->WantICC()) { + jxl::ColorSpace color_space = + dec->metadata.m.color_encoding.GetColorSpace(); + if (color_space == jxl::ColorSpace::kUnknown || + color_space == jxl::ColorSpace::kXYB) { + // This indicates there's no ICC profile available + // TODO(lode): for the XYB case, do we want to craft an ICC profile that + // represents XYB as an RGB profile? It may be possible, but not with + // only 1D transfer functions. + return JXL_DEC_ERROR; + } + } + + if (size) { + *size = jxl_color_encoding->ICC().size(); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetColorAsICCProfile( + const JxlDecoder* dec, const JxlPixelFormat* unused_format, + JxlColorProfileTarget target, uint8_t* icc_profile, size_t size) { + size_t wanted_size; + // This also checks the NEED_MORE_INPUT and the unknown/xyb cases + JxlDecoderStatus status = + JxlDecoderGetICCProfileSize(dec, nullptr, target, &wanted_size); + if (status != JXL_DEC_SUCCESS) return status; + if (size < wanted_size) return JXL_API_ERROR("ICC profile output too small"); + + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + status = GetColorEncodingForTarget(dec, target, &jxl_color_encoding); + if (status != JXL_DEC_SUCCESS) return status; + + memcpy(icc_profile, jxl_color_encoding->ICC().data(), + jxl_color_encoding->ICC().size()); + + return JXL_DEC_SUCCESS; +} + +namespace { + +// Returns the amount of bits needed for getting memory buffer size, and does +// all error checking required for size checking and format validity. +JxlDecoderStatus PrepareSizeCheck(const JxlDecoder* dec, + const JxlPixelFormat* format, size_t* bits) { + if (!dec->got_basic_info) { + // Don't know image dimensions yet, cannot check for valid size. + return JXL_DEC_NEED_MORE_INPUT; + } + if (!dec->coalescing && + (!dec->frame_header || dec->frame_stage == FrameStage::kHeader)) { + return JXL_API_ERROR("Don't know frame dimensions yet"); + } + if (format->num_channels > 4) { + return JXL_API_ERROR("More than 4 channels not supported"); + } + + *bits = BitsPerChannel(format->data_type); + + if (*bits == 0) { + return JXL_API_ERROR("Invalid/unsupported data type"); + } + + return JXL_DEC_SUCCESS; +} + +} // namespace + +size_t JxlDecoderGetIntendedDownsamplingRatio(JxlDecoder* dec) { + return dec->downsampling_target; +} + +JxlDecoderStatus JxlDecoderFlushImage(JxlDecoder* dec) { + if (!dec->image_out_buffer_set) return JXL_DEC_ERROR; + if (!dec->frame_dec || !dec->frame_dec_in_progress) { + return JXL_DEC_ERROR; + } + if (!dec->frame_dec->HasDecodedDC()) { + // FrameDecoder::Flush currently requires DC to have been decoded already + // to work correctly. + return JXL_DEC_ERROR; + } + + if (!dec->frame_dec->Flush()) { + return JXL_DEC_ERROR; + } + + if (dec->jpeg_decoder.IsOutputSet() && dec->ib->jpeg_data != nullptr) { + return JXL_DEC_SUCCESS; + } + + if (dec->frame_dec->HasRGBBuffer()) { + return JXL_DEC_SUCCESS; + } + + // Temporarily shrink `dec->ib` to the actual size of the full image to call + // ConvertImageInternal. + size_t xsize = dec->ib->xsize(); + size_t ysize = dec->ib->ysize(); + size_t xsize_nopadding, ysize_nopadding; + GetCurrentDimensions(dec, xsize_nopadding, ysize_nopadding, false); + dec->ib->ShrinkTo(xsize_nopadding, ysize_nopadding); + JxlDecoderStatus status = jxl::ConvertImageInternal( + dec, *dec->ib, dec->image_out_format, + /*want_extra_channel=*/false, + /*extra_channel_index=*/0, dec->image_out_buffer, dec->image_out_size, + jxl::PixelCallback{ + dec->image_out_init_callback, dec->image_out_run_callback, + dec->image_out_destroy_callback, dec->image_out_init_opaque}); + dec->ib->ShrinkTo(xsize, ysize); + if (status != JXL_DEC_SUCCESS) return status; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderPreviewOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + + size_t xsize = dec->metadata.oriented_preview_xsize(dec->keep_orientation); + size_t ysize = dec->metadata.oriented_preview_ysize(dec->keep_orientation); + + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + size_t last_row_size = row_size; + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * (ysize - 1) + last_row_size; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderSetPreviewOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size) { + if (!dec->got_basic_info || !dec->metadata.m.have_preview || + !(dec->orig_events_wanted & JXL_DEC_PREVIEW_IMAGE)) { + return JXL_API_ERROR("No preview out buffer needed at this time"); + } + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = + JxlDecoderPreviewOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->preview_out_buffer_set = true; + dec->preview_out_buffer = buffer; + dec->preview_out_size = size; + dec->preview_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderDCOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + + size_t xsize = jxl::DivCeil( + dec->metadata.oriented_xsize(dec->keep_orientation), jxl::kBlockDim); + size_t ysize = jxl::DivCeil( + dec->metadata.oriented_ysize(dec->keep_orientation), jxl::kBlockDim); + + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + size_t last_row_size = row_size; + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * (ysize - 1) + last_row_size; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderSetDCOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size) { + // No buffer set: this feature is deprecated + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderImageOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize, true); + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * ysize; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetImageOutBuffer(JxlDecoder* dec, + const JxlPixelFormat* format, + void* buffer, size_t size) { + if (!dec->got_basic_info || !(dec->orig_events_wanted & JXL_DEC_FULL_IMAGE)) { + return JXL_API_ERROR("No image out buffer needed at this time"); + } + if (dec->image_out_buffer_set && !!dec->image_out_run_callback) { + return JXL_API_ERROR( + "Cannot change from image out callback to image out buffer"); + } + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = + JxlDecoderImageOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->image_out_buffer_set = true; + dec->image_out_buffer = buffer; + dec->image_out_size = size; + dec->image_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderExtraChannelBufferSize(const JxlDecoder* dec, + const JxlPixelFormat* format, + size_t* size, + uint32_t index) { + if (!dec->got_basic_info || !(dec->orig_events_wanted & JXL_DEC_FULL_IMAGE)) { + return JXL_API_ERROR("No extra channel buffer needed at this time"); + } + + if (index >= dec->metadata.m.num_extra_channels) { + return JXL_API_ERROR("Invalid extra channel index"); + } + + size_t num_channels = 1; // Do not use format's num_channels + + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize, true); + size_t row_size = + jxl::DivCeil(xsize * num_channels * bits, jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * ysize; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetExtraChannelBuffer(JxlDecoder* dec, + const JxlPixelFormat* format, + void* buffer, size_t size, + uint32_t index) { + size_t min_size; + // This also checks whether the format and index are valid and supported and + // basic info is available. + JxlDecoderStatus status = + JxlDecoderExtraChannelBufferSize(dec, format, &min_size, index); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + if (dec->extra_channel_output.size() <= index) { + dec->extra_channel_output.resize(dec->metadata.m.num_extra_channels, + {{}, nullptr, 0}); + } + // Guaranteed correct thanks to check in JxlDecoderExtraChannelBufferSize. + JXL_ASSERT(index < dec->extra_channel_output.size()); + + dec->extra_channel_output[index].format = *format; + dec->extra_channel_output[index].format.num_channels = 1; + dec->extra_channel_output[index].buffer = buffer; + dec->extra_channel_output[index].buffer_size = size; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetImageOutCallback(JxlDecoder* dec, + const JxlPixelFormat* format, + JxlImageOutCallback callback, + void* opaque) { + dec->simple_image_out_callback.callback = callback; + dec->simple_image_out_callback.opaque = opaque; + const auto init_callback = + +[](void* init_opaque, size_t num_threads, size_t num_pixels_per_thread) { + // No initialization to do, just reuse init_opaque as run_opaque. + return init_opaque; + }; + const auto run_callback = + +[](void* run_opaque, size_t thread_id, size_t x, size_t y, + size_t num_pixels, const void* pixels) { + const auto* const simple_callback = + static_cast(run_opaque); + simple_callback->callback(simple_callback->opaque, x, y, num_pixels, + pixels); + }; + const auto destroy_callback = +[](void* run_opaque) {}; + return JxlDecoderSetMultithreadedImageOutCallback( + dec, format, init_callback, run_callback, + /*destroy_callback=*/destroy_callback, &dec->simple_image_out_callback); +} + +JxlDecoderStatus JxlDecoderSetMultithreadedImageOutCallback( + JxlDecoder* dec, const JxlPixelFormat* format, + JxlImageOutInitCallback init_callback, JxlImageOutRunCallback run_callback, + JxlImageOutDestroyCallback destroy_callback, void* init_opaque) { + if (dec->image_out_buffer_set && !!dec->image_out_buffer) { + return JXL_API_ERROR( + "Cannot change from image out buffer to image out callback"); + } + + if (init_callback == nullptr || run_callback == nullptr || + destroy_callback == nullptr) { + return JXL_API_ERROR("All callbacks are required"); + } + + // Perform error checking for invalid format. + size_t bits_dummy; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits_dummy); + if (status != JXL_DEC_SUCCESS) return status; + + dec->image_out_buffer_set = true; + dec->image_out_init_callback = init_callback; + dec->image_out_run_callback = run_callback; + dec->image_out_destroy_callback = destroy_callback; + dec->image_out_init_opaque = init_opaque; + dec->image_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetFrameHeader(const JxlDecoder* dec, + JxlFrameHeader* header) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + const auto& metadata = dec->metadata.m; + memset(header, 0, sizeof(*header)); + if (metadata.have_animation) { + header->duration = dec->frame_header->animation_frame.duration; + if (metadata.animation.have_timecodes) { + header->timecode = dec->frame_header->animation_frame.timecode; + } + } + header->name_length = dec->frame_header->name.size(); + header->is_last = dec->frame_header->is_last; + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize, true); + header->layer_info.xsize = xsize; + header->layer_info.ysize = ysize; + if (!dec->coalescing && dec->frame_header->custom_size_or_origin) { + header->layer_info.crop_x0 = dec->frame_header->frame_origin.x0; + header->layer_info.crop_y0 = dec->frame_header->frame_origin.y0; + header->layer_info.have_crop = JXL_TRUE; + } else { + header->layer_info.crop_x0 = 0; + header->layer_info.crop_y0 = 0; + header->layer_info.have_crop = JXL_FALSE; + } + if (!dec->keep_orientation && !dec->coalescing) { + // orient the crop offset + size_t W = dec->metadata.oriented_xsize(false); + size_t H = dec->metadata.oriented_ysize(false); + if (metadata.orientation > 4) { + std::swap(header->layer_info.crop_x0, header->layer_info.crop_y0); + } + size_t o = (metadata.orientation - 1) & 3; + if (o > 0 && o < 3) { + header->layer_info.crop_x0 = W - xsize - header->layer_info.crop_x0; + } + if (o > 1) { + header->layer_info.crop_y0 = H - ysize - header->layer_info.crop_y0; + } + } + if (dec->coalescing) { + header->layer_info.blend_info.blendmode = JXL_BLEND_REPLACE; + header->layer_info.blend_info.source = 0; + header->layer_info.blend_info.alpha = 0; + header->layer_info.blend_info.clamp = JXL_FALSE; + header->layer_info.save_as_reference = 0; + } else { + header->layer_info.blend_info.blendmode = + static_cast(dec->frame_header->blending_info.mode); + header->layer_info.blend_info.source = + dec->frame_header->blending_info.source; + header->layer_info.blend_info.alpha = + dec->frame_header->blending_info.alpha_channel; + header->layer_info.blend_info.clamp = + dec->frame_header->blending_info.clamp; + header->layer_info.save_as_reference = dec->frame_header->save_as_reference; + } + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelBlendInfo(const JxlDecoder* dec, + size_t index, + JxlBlendInfo* blend_info) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + const auto& metadata = dec->metadata.m; + if (index >= metadata.num_extra_channels) { + return JXL_API_ERROR("Invalid extra channel index"); + } + blend_info->blendmode = static_cast( + dec->frame_header->extra_channel_blending_info[index].mode); + blend_info->source = + dec->frame_header->extra_channel_blending_info[index].source; + blend_info->alpha = + dec->frame_header->extra_channel_blending_info[index].alpha_channel; + blend_info->clamp = + dec->frame_header->extra_channel_blending_info[index].clamp; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetFrameName(const JxlDecoder* dec, char* name, + size_t size) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + if (size < dec->frame_header->name.size() + 1) { + return JXL_API_ERROR("too small frame name output buffer"); + } + memcpy(name, dec->frame_header->name.c_str(), + dec->frame_header->name.size() + 1); + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetPreferredColorProfile( + JxlDecoder* dec, const JxlColorEncoding* color_encoding) { + if (!dec->got_all_headers) { + return JXL_API_ERROR("color info not yet available"); + } + if (dec->post_headers) { + return JXL_API_ERROR("too late to set the color encoding"); + } + if (dec->image_metadata.color_encoding.IsGray() && + color_encoding->color_space != JXL_COLOR_SPACE_GRAY && + ((dec->preview_out_buffer_set && + dec->preview_out_format.num_channels < 3) || + (dec->image_out_buffer_set && dec->image_out_format.num_channels < 3))) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + if (color_encoding->color_space == JXL_COLOR_SPACE_UNKNOWN || + color_encoding->color_space == JXL_COLOR_SPACE_XYB) { + return JXL_API_ERROR("only RGB or grayscale output supported"); + } + + jxl::ColorEncoding c_out; + JXL_API_RETURN_IF_ERROR( + ConvertExternalToInternalColorEncoding(*color_encoding, &c_out)); + auto& output_encoding = dec->passes_state->output_encoding_info; + if (!c_out.SameColorEncoding(output_encoding.color_encoding)) { + JXL_API_RETURN_IF_ERROR(output_encoding.MaybeSetColorEncoding(c_out)); + dec->image_metadata.color_encoding = output_encoding.color_encoding; + } + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetDesiredIntensityTarget( + JxlDecoder* dec, float desired_intensity_target) { + if (desired_intensity_target < 0) { + return JXL_API_ERROR("negative intensity target requested"); + } + dec->desired_intensity_target = desired_intensity_target; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetBoxBuffer(JxlDecoder* dec, uint8_t* data, + size_t size) { + if (dec->box_out_buffer_set) { + return JXL_API_ERROR("must release box buffer before setting it again"); + } + if (!dec->box_event) { + return JXL_API_ERROR("can only set box buffer after box event"); + } + + dec->box_out_buffer_set = true; + dec->box_out_buffer_set_current_box = true; + dec->box_out_buffer = data; + dec->box_out_buffer_size = size; + dec->box_out_buffer_pos = 0; + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderReleaseBoxBuffer(JxlDecoder* dec) { + if (!dec->box_out_buffer_set) { + return 0; + } + size_t result = dec->box_out_buffer_size - dec->box_out_buffer_pos; + dec->box_out_buffer_set = false; + dec->box_out_buffer = nullptr; + dec->box_out_buffer_size = 0; + if (!dec->box_out_buffer_set_current_box) { + dec->box_out_buffer_begin = 0; + } else { + dec->box_out_buffer_begin += dec->box_out_buffer_pos; + } + dec->box_out_buffer_set_current_box = false; + return result; +} + +JxlDecoderStatus JxlDecoderSetDecompressBoxes(JxlDecoder* dec, + JXL_BOOL decompress) { + // TODO(lode): return error if libbrotli is not compiled in the jxl decoding + // library + dec->decompress_boxes = decompress; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetBoxType(JxlDecoder* dec, JxlBoxType type, + JXL_BOOL decompressed) { + if (!dec->box_event) { + return JXL_API_ERROR("can only get box info after JXL_DEC_BOX event"); + } + if (decompressed) { + memcpy(type, dec->box_decoded_type, sizeof(dec->box_decoded_type)); + } else { + memcpy(type, dec->box_type, sizeof(dec->box_type)); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetBoxSizeRaw(const JxlDecoder* dec, + uint64_t* size) { + if (!dec->box_event) { + return JXL_API_ERROR("can only get box info after JXL_DEC_BOX event"); + } + if (size) { + *size = dec->box_size; + } + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetProgressiveDetail(JxlDecoder* dec, + JxlProgressiveDetail detail) { + if (detail != kDC && detail != kLastPasses && detail != kPasses) { + return JXL_API_ERROR( + "Values other than kDC (%d), kLastPasses (%d) and kPasses (%d), " + "like %d are not implemented.", + kDC, kLastPasses, kPasses, detail); + } + dec->prog_detail = detail; + return JXL_DEC_SUCCESS; +} diff --git a/media/libjxl/src/lib/jxl/decode_test.cc b/media/libjxl/src/lib/jxl/decode_test.cc new file mode 100644 index 000000000..5b9b735e1 --- /dev/null +++ b/media/libjxl/src/lib/jxl/decode_test.cc @@ -0,0 +1,5493 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/decode.h" + +#include +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "jxl/decode_cxx.h" +#include "jxl/resizable_parallel_runner_cxx.h" +#include "jxl/thread_parallel_runner_cxx.h" +#include "jxl/types.h" +#include "lib/extras/codec.h" +#include "lib/extras/dec/color_description.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/progressive_split.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" +#include "lib/jxl/toc.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +void AppendU32BE(uint32_t u32, jxl::PaddedBytes* bytes) { + bytes->push_back(u32 >> 24); + bytes->push_back(u32 >> 16); + bytes->push_back(u32 >> 8); + bytes->push_back(u32 >> 0); +} + +// What type of codestream format in the boxes to use for testing +enum CodeStreamBoxFormat { + // Do not use box format at all, only pure codestream + kCSBF_None, + // Have a single codestream box, with its actual size given in the box + kCSBF_Single, + // Have a single codestream box, with box size 0 (final box running to end) + kCSBF_Single_Zero_Terminated, + // Single codestream box, with another unknown box behind it + kCSBF_Single_Other, + // Have multiple partial codestream boxes + kCSBF_Multi, + // Have multiple partial codestream boxes, with final box size 0 (running + // to end) + kCSBF_Multi_Zero_Terminated, + // Have multiple partial codestream boxes, terminated by non-codestream box + kCSBF_Multi_Other_Terminated, + // Have multiple partial codestream boxes, terminated by non-codestream box + // that has its size set to 0 (running to end) + kCSBF_Multi_Other_Zero_Terminated, + // Have multiple partial codestream boxes, and the first one has a content + // of zero length + kCSBF_Multi_First_Empty, + // Have multiple partial codestream boxes, and the last one has a content + // of zero length and there is an unknown empty box at the end + kCSBF_Multi_Last_Empty_Other, + // Have a compressed exif box before a regular codestream box + kCSBF_Brob_Exif, + // Not a value but used for counting amount of enum entries + kCSBF_NUM_ENTRIES, +}; + +// Unknown boxes for testing +static const char* unk1_box_type = "unk1"; +static const char* unk1_box_contents = "abcdefghijklmnopqrstuvwxyz"; +static const size_t unk1_box_size = strlen(unk1_box_contents); +static const char* unk2_box_type = "unk2"; +static const char* unk2_box_contents = "0123456789"; +static const size_t unk2_box_size = strlen(unk2_box_contents); +static const char* unk3_box_type = "unk3"; +static const char* unk3_box_contents = "ABCDEF123456"; +static const size_t unk3_box_size = strlen(unk3_box_contents); +// Box with brob-compressed exif, including header +static const uint8_t* box_brob_exif = reinterpret_cast( + "\0\0\0@brobExif\241\350\2\300\177\244v\2525\304\360\27=?\267{" + "\33\37\314\332\214QX17PT\"\256\0\0\202s\214\313t\333\310\320k\20\276\30" + "\204\277l$\326c#\1\b"); +size_t box_brob_exif_size = 64; +// The uncompressed Exif data from the brob box +static const uint8_t* exif_uncompressed = reinterpret_cast( + "\0\0\0\0MM\0*" + "\0\0\0\b\0\5\1\22\0\3\0\0\0\1\0\5\0\0\1\32\0\5\0\0\0\1\0\0\0J\1\33\0\5\0\0" + "\0\1\0\0\0R\1(" + "\0\3\0\0\0\1\0\1\0\0\2\23\0\3\0\0\0\1\0\1\0\0\0\0\0\0\0\0\0\1\0\0\0\1\0\0" + "\0\1\0\0\0\1"); +size_t exif_uncompressed_size = 94; + +// Returns an ICC profile output by the JPEG XL decoder for RGB_D65_SRG_Rel_Lin, +// but with, on purpose, rXYZ, bXYZ and gXYZ (the RGB primaries) switched to a +// different order to ensure the profile does not match any known profile, so +// the encoder cannot encode it in a compact struct instead. +jxl::PaddedBytes GetIccTestProfile() { + const uint8_t* profile = reinterpret_cast( + "\0\0\3\200lcms\0040\0\0mntrRGB XYZ " + "\a\344\0\a\0\27\0\21\0$" + "\0\37acspAPPL\0\0\0\1\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1\0\0\366" + "\326\0\1\0\0\0\0\323-lcms\372c\207\36\227\200{" + "\2\232s\255\327\340\0\n\26\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0\rdesc\0\0\1 " + "\0\0\0Bcprt\0\0\1d\0\0\1\0wtpt\0\0\2d\0\0\0\24chad\0\0\2x\0\0\0," + "bXYZ\0\0\2\244\0\0\0\24gXYZ\0\0\2\270\0\0\0\24rXYZ\0\0\2\314\0\0\0\24rTR" + "C\0\0\2\340\0\0\0 gTRC\0\0\2\340\0\0\0 bTRC\0\0\2\340\0\0\0 " + "chrm\0\0\3\0\0\0\0$dmnd\0\0\3$\0\0\0(" + "dmdd\0\0\3L\0\0\0002mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0&" + "\0\0\0\34\0R\0G\0B\0_\0D\0006\0005\0_\0S\0R\0G\0_\0R\0e\0l\0_" + "\0L\0i\0n\0\0mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\344\0\0\0\34\0C\0o\0" + "p\0y\0r\0i\0g\0h\0t\0 \0002\0000\0001\08\0 \0G\0o\0o\0g\0l\0e\0 " + "\0L\0L\0C\0,\0 \0C\0C\0-\0B\0Y\0-\0S\0A\0 \0003\0.\0000\0 " + "\0U\0n\0p\0o\0r\0t\0e\0d\0 " + "\0l\0i\0c\0e\0n\0s\0e\0(\0h\0t\0t\0p\0s\0:\0/\0/" + "\0c\0r\0e\0a\0t\0i\0v\0e\0c\0o\0m\0m\0o\0n\0s\0.\0o\0r\0g\0/" + "\0l\0i\0c\0e\0n\0s\0e\0s\0/\0b\0y\0-\0s\0a\0/\0003\0.\0000\0/" + "\0l\0e\0g\0a\0l\0c\0o\0d\0e\0)XYZ " + "\0\0\0\0\0\0\366\326\0\1\0\0\0\0\323-" + "sf32\0\0\0\0\0\1\fB\0\0\5\336\377\377\363%" + "\0\0\a\223\0\0\375\220\377\377\373\241\377\377\375\242\0\0\3\334\0\0\300" + "nXYZ \0\0\0\0\0\0o\240\0\08\365\0\0\3\220XYZ " + "\0\0\0\0\0\0$\237\0\0\17\204\0\0\266\304XYZ " + "\0\0\0\0\0\0b\227\0\0\267\207\0\0\30\331para\0\0\0\0\0\3\0\0\0\1\0\0\0\1" + "\0\0\0\0\0\0\0\1\0\0\0\0\0\0chrm\0\0\0\0\0\3\0\0\0\0\243\327\0\0T|" + "\0\0L\315\0\0\231\232\0\0&" + "g\0\0\17\\mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\f\0\0\0\34\0G\0o\0o\0g" + "\0l\0emluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\26\0\0\0\34\0I\0m\0a\0g\0e" + "\0 \0c\0o\0d\0e\0c\0\0"); + size_t profile_size = 896; + jxl::PaddedBytes icc_profile; + icc_profile.assign(profile, profile + profile_size); + return icc_profile; +} + +} // namespace + +namespace jxl { +namespace { + +void AppendTestBox(const char* type, const char* contents, size_t contents_size, + bool unbounded, PaddedBytes* bytes) { + AppendU32BE(contents_size + 8, bytes); + bytes->push_back(type[0]); + bytes->push_back(type[1]); + bytes->push_back(type[2]); + bytes->push_back(type[3]); + const uint8_t* contents_u = reinterpret_cast(contents); + bytes->append(contents_u, contents_u + contents_size); +} + +struct TestCodestreamParams { + CompressParams cparams; + CodeStreamBoxFormat box_format = kCSBF_None; + JxlOrientation orientation = JXL_ORIENT_IDENTITY; + bool add_preview = false; + bool add_intrinsic_size = false; + bool add_icc_profile = false; + float intensity_target = 0.0; + std::string color_space; + PaddedBytes* jpeg_codestream = nullptr; + const ProgressiveMode* progressive_mode = nullptr; +}; + +// Input pixels always given as 16-bit RGBA, 8 bytes per pixel. +// include_alpha determines if the encoded image should contain the alpha +// channel. +// add_icc_profile: if false, encodes the image as sRGB using the JXL fields, +// for grayscale or RGB images. If true, encodes the image using the ICC profile +// returned by GetIccTestProfile, without the JXL fields, this requires the +// image is RGB, not grayscale. +// Providing jpeg_codestream will populate the jpeg_codestream with compressed +// JPEG bytes, and make it possible to reconstruct those exact JPEG bytes using +// the return value _if_ add_container indicates a box format. +PaddedBytes CreateTestJXLCodestream(Span pixels, size_t xsize, + size_t ysize, size_t num_channels, + const TestCodestreamParams& params) { + // Compress the pixels with JPEG XL. + bool grayscale = (num_channels <= 2); + bool include_alpha = !(num_channels & 1) && params.jpeg_codestream == nullptr; + size_t bitdepth = params.jpeg_codestream == nullptr ? 16 : 8; + CodecInOut io; + io.SetSize(xsize, ysize); + ColorEncoding color_encoding; + if (params.add_icc_profile) { + // the hardcoded ICC profile we attach requires RGB. + EXPECT_EQ(false, grayscale); + EXPECT_TRUE(params.color_space.empty()); + EXPECT_TRUE(color_encoding.SetICC(GetIccTestProfile())); + } else if (!params.color_space.empty()) { + JxlColorEncoding c; + EXPECT_TRUE(jxl::ParseDescription(params.color_space, &c)); + EXPECT_TRUE(ConvertExternalToInternalColorEncoding(c, &color_encoding)); + EXPECT_EQ(color_encoding.IsGray(), grayscale); + } else { + color_encoding = jxl::ColorEncoding::SRGB(/*is_gray=*/grayscale); + } + ThreadPool pool(nullptr, nullptr); + io.metadata.m.SetUintSamples(bitdepth); + if (include_alpha) { + io.metadata.m.SetAlphaBits(bitdepth); + } + if (params.intensity_target != 0) { + io.metadata.m.SetIntensityTarget(params.intensity_target); + } + // Make the grayscale-ness of the io metadata color_encoding and the packed + // image match. + io.metadata.m.color_encoding = color_encoding; + EXPECT_TRUE(ConvertFromExternal( + pixels, xsize, ysize, color_encoding, num_channels, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + &pool, &io.Main(), /*float_in=*/false, /*align=*/0)); + jxl::PaddedBytes jpeg_data; + if (params.jpeg_codestream != nullptr) { +#if JPEGXL_ENABLE_JPEG + std::vector jpeg_bytes; + io.jpeg_quality = 70; + EXPECT_TRUE(Encode(io, extras::Codec::kJPG, io.metadata.m.color_encoding, + /*bits_per_sample=*/8, &jpeg_bytes, &pool)); + params.jpeg_codestream->append(jpeg_bytes.data(), + jpeg_bytes.data() + jpeg_bytes.size()); + EXPECT_TRUE(jxl::jpeg::DecodeImageJPG( + jxl::Span(jpeg_bytes.data(), jpeg_bytes.size()), &io)); + EXPECT_TRUE( + EncodeJPEGData(*io.Main().jpeg_data, &jpeg_data, params.cparams)); + io.metadata.m.xyb_encoded = false; +#else // JPEGXL_ENABLE_JPEG + JXL_ABORT( + "unable to create reconstructible JPEG without JPEG support enabled"); +#endif // JPEGXL_ENABLE_JPEG + } + if (params.add_preview) { + io.preview_frame = io.Main().Copy(); + io.preview_frame.ShrinkTo(xsize / 7, ysize / 7); + io.metadata.m.have_preview = true; + EXPECT_TRUE(io.metadata.m.preview_size.Set(io.preview_frame.xsize(), + io.preview_frame.ysize())); + } + if (params.add_intrinsic_size) { + EXPECT_TRUE(io.metadata.m.intrinsic_size.Set(xsize / 3, ysize / 3)); + } + io.metadata.m.orientation = params.orientation; + AuxOut aux_out; + PaddedBytes compressed; + PassesEncoderState enc_state; + if (params.progressive_mode) { + enc_state.progressive_splitter.SetProgressiveMode(*params.progressive_mode); + } + EXPECT_TRUE(EncodeFile(params.cparams, &io, &enc_state, &compressed, + GetJxlCms(), &aux_out, &pool)); + CodeStreamBoxFormat add_container = params.box_format; + if (add_container != kCSBF_None) { + // Header with signature box and ftyp box. + const uint8_t header[] = {0, 0, 0, 0xc, 0x4a, 0x58, 0x4c, 0x20, + 0xd, 0xa, 0x87, 0xa, 0, 0, 0, 0x14, + 0x66, 0x74, 0x79, 0x70, 0x6a, 0x78, 0x6c, 0x20, + 0, 0, 0, 0, 0x6a, 0x78, 0x6c, 0x20}; + + bool is_multi = add_container == kCSBF_Multi || + add_container == kCSBF_Multi_Zero_Terminated || + add_container == kCSBF_Multi_Other_Terminated || + add_container == kCSBF_Multi_Other_Zero_Terminated || + add_container == kCSBF_Multi_First_Empty || + add_container == kCSBF_Multi_Last_Empty_Other; + + if (is_multi) { + size_t third = compressed.size() / 3; + std::vector compressed0(compressed.data(), + compressed.data() + third); + std::vector compressed1(compressed.data() + third, + compressed.data() + 2 * third); + std::vector compressed2(compressed.data() + 2 * third, + compressed.data() + compressed.size()); + + PaddedBytes c; + c.append(header, header + sizeof(header)); + if (params.jpeg_codestream != nullptr) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &c); + c.append(jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + } + uint32_t jxlp_index = 0; + if (add_container == kCSBF_Multi_First_Empty) { + // Dummy (empty) codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + } + // First codestream part + AppendU32BE(compressed0.size() + 12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + c.append(compressed0.data(), compressed0.data() + compressed0.size()); + // A few non-codestream boxes in between + AppendTestBox(unk1_box_type, unk1_box_contents, unk1_box_size, false, &c); + AppendTestBox(unk2_box_type, unk2_box_contents, unk2_box_size, false, &c); + // Dummy (empty) codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + // Second codestream part + AppendU32BE(compressed1.size() + 12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + c.append(compressed1.data(), compressed1.data() + compressed1.size()); + // Third (last) codestream part + AppendU32BE(add_container == kCSBF_Multi_Zero_Terminated + ? 0 + : (compressed2.size() + 12), + &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + if (add_container != kCSBF_Multi_Last_Empty_Other) { + AppendU32BE(jxlp_index++ | 0x80000000, &c); + } else { + AppendU32BE(jxlp_index++, &c); + } + c.append(compressed2.data(), compressed2.data() + compressed2.size()); + if (add_container == kCSBF_Multi_Last_Empty_Other) { + // Dummy (empty) codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++ | 0x80000000, &c); + AppendTestBox(unk3_box_type, unk3_box_contents, unk3_box_size, false, + &c); + } + if (add_container == kCSBF_Multi_Other_Terminated) { + AppendTestBox(unk3_box_type, unk3_box_contents, unk3_box_size, false, + &c); + } + if (add_container == kCSBF_Multi_Other_Zero_Terminated) { + AppendTestBox(unk3_box_type, unk3_box_contents, unk3_box_size, true, + &c); + } + compressed.swap(c); + } else { + PaddedBytes c; + c.append(header, header + sizeof(header)); + if (params.jpeg_codestream != nullptr) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &c); + c.append(jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + } + if (add_container == kCSBF_Brob_Exif) { + c.append(box_brob_exif, box_brob_exif + box_brob_exif_size); + } + AppendU32BE(add_container == kCSBF_Single_Zero_Terminated + ? 0 + : (compressed.size() + 8), + &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('c'); + c.append(compressed.data(), compressed.data() + compressed.size()); + if (add_container == kCSBF_Single_Other) { + AppendTestBox(unk1_box_type, unk1_box_contents, unk1_box_size, false, + &c); + } + compressed.swap(c); + } + } + + return compressed; +} + +JxlDecoderStatus ProcessInputIgnoreBoxes(JxlDecoder* dec) { + JxlDecoderStatus status; + while ((status = JxlDecoderProcessInput(dec)) == JXL_DEC_BOX) { + continue; + } + return status; +} + +// Decodes one-shot with the API for non-streaming decoding tests. +std::vector DecodeWithAPI(JxlDecoder* dec, + Span compressed, + const JxlPixelFormat& format, + bool use_callback, bool set_buffer_early, + bool use_resizable_runner, + bool require_boxes, bool expect_success, + PaddedBytes* icc = nullptr) { + JxlThreadParallelRunnerPtr runner_fixed; + JxlResizableParallelRunnerPtr runner_resizable; + JxlParallelRunner runner_fn; + void* runner; + + if (use_resizable_runner) { + runner_resizable = JxlResizableParallelRunnerMake(nullptr); + runner = runner_resizable.get(); + runner_fn = JxlResizableParallelRunner; + } else { + size_t hw_threads = JxlThreadParallelRunnerDefaultNumWorkerThreads(); + runner_fixed = + JxlThreadParallelRunnerMake(nullptr, std::min(hw_threads, 16)); + runner = runner_fixed.get(); + runner_fn = JxlThreadParallelRunner; + } + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, runner_fn, runner)); + + auto process_input = + require_boxes ? ProcessInputIgnoreBoxes : JxlDecoderProcessInput; + + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | (set_buffer_early ? JXL_DEC_FRAME : 0) | + JXL_DEC_PREVIEW_IMAGE | JXL_DEC_FULL_IMAGE | + (require_boxes ? JXL_DEC_BOX : 0) | + (icc != nullptr ? JXL_DEC_COLOR_ENCODING : 0))); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, process_input(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + if (use_resizable_runner) { + JxlResizableParallelRunnerSetThreads( + runner, + JxlResizableParallelRunnerSuggestThreads(info.xsize, info.ysize)); + } + + std::vector pixels(buffer_size); + size_t bytes_per_pixel = format.num_channels * + test::GetDataBits(format.data_type) / + jxl::kBitsPerByte; + size_t stride = bytes_per_pixel * info.xsize; + if (format.align > 1) { + stride = jxl::DivCeil(stride, format.align) * format.align; + } + auto callback = [&](size_t x, size_t y, size_t num_pixels, + const void* pixels_row) { + memcpy(pixels.data() + stride * y + bytes_per_pixel * x, pixels_row, + num_pixels * bytes_per_pixel); + }; + + JxlDecoderStatus status = process_input(dec); + + if (status == JXL_DEC_COLOR_ENCODING) { + size_t icc_size = 0; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, &icc_size)); + icc->resize(icc_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + icc->data(), icc_size)); + + status = process_input(dec); + } + + std::vector preview; + if (status == JXL_DEC_NEED_PREVIEW_OUT_BUFFER) { + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + preview.resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, preview.data(), + preview.size())); + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, process_input(dec)); + + status = process_input(dec); + } + + if (set_buffer_early) { + EXPECT_EQ(JXL_DEC_FRAME, status); + } else { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, status); + } + + if (use_callback) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutCallback( + dec, &format, + [](void* opaque, size_t x, size_t y, size_t xsize, + const void* pixels_row) { + auto cb = static_cast(opaque); + (*cb)(x, y, xsize, pixels_row); + }, + /*opaque=*/&callback)); + } else { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + } + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, process_input(dec)); + + // After the full image was output, JxlDecoderProcessInput should return + // success to indicate all is done, unless we requested boxes and the last + // box was not a terminal unbounded box, in which case it should ask for + // more input. + JxlDecoderStatus expected_status = + expect_success ? JXL_DEC_SUCCESS : JXL_DEC_NEED_MORE_INPUT; + EXPECT_EQ(expected_status, process_input(dec)); + + return pixels; +} + +// Decodes one-shot with the API for non-streaming decoding tests. +std::vector DecodeWithAPI(Span compressed, + const JxlPixelFormat& format, + bool use_callback, bool set_buffer_early, + bool use_resizable_runner, + bool require_boxes, bool expect_success) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + std::vector pixels = + DecodeWithAPI(dec, compressed, format, use_callback, set_buffer_early, + use_resizable_runner, require_boxes, expect_success); + JxlDecoderDestroy(dec); + return pixels; +} + +} // namespace +} // namespace jxl + +//////////////////////////////////////////////////////////////////////////////// + +TEST(DecodeTest, JxlSignatureCheckTest) { + std::vector>> tests = { + // No JPEGXL header starts with 'a'. + {JXL_SIG_INVALID, {'a'}}, + {JXL_SIG_INVALID, {'a', 'b', 'c', 'd', 'e', 'f'}}, + + // Empty file is not enough bytes. + {JXL_SIG_NOT_ENOUGH_BYTES, {}}, + + // JPEGXL headers. + {JXL_SIG_NOT_ENOUGH_BYTES, {0xff}}, // Part of a signature. + {JXL_SIG_INVALID, {0xff, 0xD8}}, // JPEG-1 + {JXL_SIG_CODESTREAM, {0xff, 0x0a}}, + + // JPEGXL container file. + {JXL_SIG_CONTAINER, + {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87, 0xA}}, + // Ending with invalid byte. + {JXL_SIG_INVALID, {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87, 0}}, + // Part of signature. + {JXL_SIG_NOT_ENOUGH_BYTES, + {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87}}, + {JXL_SIG_NOT_ENOUGH_BYTES, {0}}, + }; + for (const auto& test : tests) { + EXPECT_EQ(test.first, + JxlSignatureCheck(test.second.data(), test.second.size())) + << "Where test data is " << ::testing::PrintToString(test.second); + } +} + +TEST(DecodeTest, DefaultAllocTest) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, CustomAllocTest) { + struct CalledCounters { + int allocs = 0; + int frees = 0; + } counters; + + JxlMemoryManager mm; + mm.opaque = &counters; + mm.alloc = [](void* opaque, size_t size) { + reinterpret_cast(opaque)->allocs++; + return malloc(size); + }; + mm.free = [](void* opaque, void* address) { + reinterpret_cast(opaque)->frees++; + free(address); + }; + + JxlDecoder* dec = JxlDecoderCreate(&mm); + EXPECT_NE(nullptr, dec); + EXPECT_LE(1, counters.allocs); + EXPECT_EQ(0, counters.frees); + JxlDecoderDestroy(dec); + EXPECT_LE(1, counters.frees); +} + +// TODO(lode): add multi-threaded test when multithreaded pixel decoding from +// API is implemented. +TEST(DecodeTest, DefaultParallelRunnerTest) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, nullptr, nullptr)); + JxlDecoderDestroy(dec); +} + +// Creates the header of a JPEG XL file with various custom parameters for +// testing. +// xsize, ysize: image dimensions to store in the SizeHeader, max 512. +// bits_per_sample, orientation: a selection of header parameters to test with. +// orientation: image orientation to set in the metadata +// alpha_bits: if non-0, alpha extra channel bits to set in the metadata. Also +// gives the alpha channel the name "alpha_test" +// have_container: add box container format around the codestream. +// metadata_default: if true, ImageMetadata is set to default and +// bits_per_sample, orientation and alpha_bits are ignored. +// insert_box: insert an extra box before the codestream box, making the header +// farther away from the front than is ideal. Only used if have_container. +std::vector GetTestHeader(size_t xsize, size_t ysize, + size_t bits_per_sample, size_t orientation, + size_t alpha_bits, bool xyb_encoded, + bool have_container, bool metadata_default, + bool insert_extra_box, + const jxl::PaddedBytes& icc_profile) { + jxl::BitWriter writer; + jxl::BitWriter::Allotment allotment(&writer, 65536); // Large enough + + if (have_container) { + const std::vector signature_box = {0, 0, 0, 0xc, 'J', 'X', + 'L', ' ', 0xd, 0xa, 0x87, 0xa}; + const std::vector filetype_box = { + 0, 0, 0, 0x14, 'f', 't', 'y', 'p', 'j', 'x', + 'l', ' ', 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + const std::vector extra_box_header = {0, 0, 0, 0xff, + 't', 'e', 's', 't'}; + // Beginning of codestream box, with an arbitrary size certainly large + // enough to contain the header + const std::vector codestream_box_header = {0, 0, 0, 0xff, + 'j', 'x', 'l', 'c'}; + + for (size_t i = 0; i < signature_box.size(); i++) { + writer.Write(8, signature_box[i]); + } + for (size_t i = 0; i < filetype_box.size(); i++) { + writer.Write(8, filetype_box[i]); + } + if (insert_extra_box) { + for (size_t i = 0; i < extra_box_header.size(); i++) { + writer.Write(8, extra_box_header[i]); + } + for (size_t i = 0; i < 255 - 8; i++) { + writer.Write(8, 0); + } + } + for (size_t i = 0; i < codestream_box_header.size(); i++) { + writer.Write(8, codestream_box_header[i]); + } + } + + // JXL signature + writer.Write(8, 0xff); + writer.Write(8, 0x0a); + + // SizeHeader + jxl::CodecMetadata metadata; + EXPECT_TRUE(metadata.size.Set(xsize, ysize)); + EXPECT_TRUE(WriteSizeHeader(metadata.size, &writer, 0, nullptr)); + + if (!metadata_default) { + metadata.m.SetUintSamples(bits_per_sample); + metadata.m.orientation = orientation; + metadata.m.SetAlphaBits(alpha_bits); + metadata.m.xyb_encoded = xyb_encoded; + if (alpha_bits != 0) { + metadata.m.extra_channel_info[0].name = "alpha_test"; + } + } + + if (!icc_profile.empty()) { + jxl::PaddedBytes copy = icc_profile; + EXPECT_TRUE(metadata.m.color_encoding.SetICC(std::move(copy))); + } + + EXPECT_TRUE(jxl::Bundle::Write(metadata.m, &writer, 0, nullptr)); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + EXPECT_TRUE(jxl::Bundle::Write(metadata.transform_data, &writer, 0, nullptr)); + + if (!icc_profile.empty()) { + EXPECT_TRUE(metadata.m.color_encoding.WantICC()); + EXPECT_TRUE(jxl::WriteICC(icc_profile, &writer, 0, nullptr)); + } + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + return std::vector( + writer.GetSpan().data(), + writer.GetSpan().data() + writer.GetSpan().size()); +} + +TEST(DecodeTest, BasicInfoTest) { + size_t xsize[2] = {50, 33}; + size_t ysize[2] = {50, 77}; + size_t bits_per_sample[2] = {8, 23}; + size_t orientation[2] = {3, 5}; + size_t alpha_bits[2] = {0, 8}; + JXL_BOOL have_container[2] = {0, 1}; + bool xyb_encoded = false; + + std::vector> test_samples; + // Test with direct codestream + test_samples.push_back(GetTestHeader( + xsize[0], ysize[0], bits_per_sample[0], orientation[0], alpha_bits[0], + xyb_encoded, have_container[0], /*metadata_default=*/false, + /*insert_extra_box=*/false, {})); + // Test with container and different parameters + test_samples.push_back(GetTestHeader( + xsize[1], ysize[1], bits_per_sample[1], orientation[1], alpha_bits[1], + xyb_encoded, have_container[1], /*metadata_default=*/false, + /*insert_extra_box=*/false, {})); + + for (size_t i = 0; i < test_samples.size(); ++i) { + const std::vector& data = test_samples[i]; + // Test decoding too small header first, until we reach the final byte. + for (size_t size = 0; size <= data.size(); ++size) { + // Test with a new decoder for each tested byte size. + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + const uint8_t* next_in = data.data(); + size_t avail_in = size; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + + JxlBasicInfo info; + bool have_basic_info = !JxlDecoderGetBasicInfo(dec, &info); + + if (size == data.size()) { + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + + // All header bytes given so the decoder must have the basic info. + EXPECT_EQ(true, have_basic_info); + EXPECT_EQ(have_container[i], info.have_container); + EXPECT_EQ(alpha_bits[i], info.alpha_bits); + // Orientations 5..8 swap the dimensions + if (orientation[i] >= 5) { + EXPECT_EQ(xsize[i], info.ysize); + EXPECT_EQ(ysize[i], info.xsize); + } else { + EXPECT_EQ(xsize[i], info.xsize); + EXPECT_EQ(ysize[i], info.ysize); + } + // The API should set the orientation to identity by default since it + // already applies the transformation internally by default. + EXPECT_EQ(1u, info.orientation); + + EXPECT_EQ(3u, info.num_color_channels); + + if (alpha_bits[i] != 0) { + // Expect an extra channel + EXPECT_EQ(1u, info.num_extra_channels); + JxlExtraChannelInfo extra; + EXPECT_EQ(0, JxlDecoderGetExtraChannelInfo(dec, 0, &extra)); + EXPECT_EQ(alpha_bits[i], extra.bits_per_sample); + EXPECT_EQ(JXL_CHANNEL_ALPHA, extra.type); + EXPECT_EQ(0, extra.alpha_premultiplied); + // Verify the name "alpha_test" given to the alpha channel + EXPECT_EQ(10u, extra.name_length); + char name[11]; + EXPECT_EQ(0, + JxlDecoderGetExtraChannelName(dec, 0, name, sizeof(name))); + EXPECT_EQ(std::string("alpha_test"), std::string(name)); + } else { + EXPECT_EQ(0u, info.num_extra_channels); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + } else { + // If we did not give the full header, the basic info should not be + // available. Allow a few bytes of slack due to some bits for default + // opsinmatrix/extension bits. + if (size + 2 < data.size()) { + EXPECT_EQ(false, have_basic_info); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, status); + } + } + + // Test that decoder doesn't allow setting a setting required at beginning + // unless it's reset + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + JxlDecoderReset(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, BufferSizeTest) { + size_t xsize = 33; + size_t ysize = 77; + size_t bits_per_sample = 8; + size_t orientation = 1; + size_t alpha_bits = 8; + bool have_container = false; + bool xyb_encoded = false; + + std::vector header = + GetTestHeader(xsize, ysize, bits_per_sample, orientation, alpha_bits, + xyb_encoded, have_container, /*metadata_default=*/false, + /*insert_extra_box=*/false, {}); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + const uint8_t* next_in = header.data(); + size_t avail_in = header.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + size_t image_out_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &image_out_size)); + EXPECT_EQ(xsize * ysize * 4, image_out_size); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, BasicInfoSizeHintTest) { + // Test on a file where the size hint is too small initially due to inserting + // a box before the codestream (something that is normally not recommended) + size_t xsize = 50; + size_t ysize = 50; + size_t bits_per_sample = 16; + size_t orientation = 1; + size_t alpha_bits = 0; + bool xyb_encoded = false; + std::vector data = GetTestHeader( + xsize, ysize, bits_per_sample, orientation, alpha_bits, xyb_encoded, + /*have_container=*/true, /*metadata_default=*/false, + /*insert_extra_box=*/true, {}); + + JxlDecoderStatus status; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + + size_t hint0 = JxlDecoderSizeHintBasicInfo(dec); + // Test that the test works as intended: we construct a file on purpose to + // be larger than the first hint by having that extra box. + EXPECT_LT(hint0, data.size()); + const uint8_t* next_in = data.data(); + // Do as if we have only as many bytes as indicated by the hint available + size_t avail_in = std::min(hint0, data.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, status); + // Basic info cannot be available yet due to the extra inserted box. + EXPECT_EQ(false, !JxlDecoderGetBasicInfo(dec, nullptr)); + + size_t num_read = avail_in - JxlDecoderReleaseInput(dec); + EXPECT_LT(num_read, data.size()); + + size_t hint1 = JxlDecoderSizeHintBasicInfo(dec); + // The hint must be larger than the previous hint (taking already processed + // bytes into account, the hint is a hint for the next avail_in) since the + // decoder now knows there is a box in between. + EXPECT_GT(hint1 + num_read, hint0); + avail_in = std::min(hint1, data.size() - num_read); + next_in += num_read; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + JxlBasicInfo info; + // We should have the basic info now, since we only added one box in-between, + // and the decoder should have known its size, its implementation can return + // a correct hint. + EXPECT_EQ(true, !JxlDecoderGetBasicInfo(dec, &info)); + + // Also test if the basic info is correct. + EXPECT_EQ(1, info.have_container); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_EQ(orientation, info.orientation); + EXPECT_EQ(bits_per_sample, info.bits_per_sample); + + JxlDecoderDestroy(dec); +} + +std::vector GetIccTestHeader(const jxl::PaddedBytes& icc_profile, + bool xyb_encoded) { + size_t xsize = 50; + size_t ysize = 50; + size_t bits_per_sample = 16; + size_t orientation = 1; + size_t alpha_bits = 0; + return GetTestHeader(xsize, ysize, bits_per_sample, orientation, alpha_bits, + xyb_encoded, + /*have_container=*/false, /*metadata_default=*/false, + /*insert_extra_box=*/false, icc_profile); +} + +// Tests the case where pixels and metadata ICC profile are the same +TEST(DecodeTest, IccProfileTestOriginal) { + jxl::PaddedBytes icc_profile = GetIccTestProfile(); + bool xyb_encoded = false; + std::vector data = GetIccTestHeader(icc_profile, xyb_encoded); + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Expect the opposite of xyb_encoded for uses_original_profile + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_TRUE, info.uses_original_profile); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + // the encoded color profile expected to be not available, since the image + // has an ICC profile instead + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + size_t dec_profile_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_profile_size)); + + // Check that can get return status with NULL size + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + // The profiles must be equal. This requires they have equal size, and if + // they do, we can get the profile and compare the contents. + EXPECT_EQ(icc_profile.size(), dec_profile_size); + if (icc_profile.size() == dec_profile_size) { + jxl::PaddedBytes icc_profile2(icc_profile.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile2.data(), icc_profile2.size())); + EXPECT_EQ(icc_profile, icc_profile2); + } + + // the data is not xyb_encoded, so same result expected for the pixel data + // color profile + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, nullptr)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + EXPECT_EQ(icc_profile.size(), dec_profile_size); + + JxlDecoderDestroy(dec); +} + +// Tests the case where pixels and metadata ICC profile are different +TEST(DecodeTest, IccProfileTestXybEncoded) { + jxl::PaddedBytes icc_profile = GetIccTestProfile(); + bool xyb_encoded = true; + std::vector data = GetIccTestHeader(icc_profile, xyb_encoded); + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + JxlPixelFormat format_int = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Expect the opposite of xyb_encoded for uses_original_profile + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_FALSE, info.uses_original_profile); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + // the encoded color profile expected to be not available, since the image + // has an ICC profile instead + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + // Check that can get return status with NULL size + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + size_t dec_profile_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_profile_size)); + + // The profiles must be equal. This requires they have equal size, and if + // they do, we can get the profile and compare the contents. + EXPECT_EQ(icc_profile.size(), dec_profile_size); + if (icc_profile.size() == dec_profile_size) { + jxl::PaddedBytes icc_profile2(icc_profile.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile2.data(), icc_profile2.size())); + EXPECT_EQ(icc_profile, icc_profile2); + } + + // Data is xyb_encoded, so the data profile is a different profile, encoded + // as structured profile. + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, nullptr)); + JxlColorEncoding pixel_encoding; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_PRIMARIES_SRGB, pixel_encoding.primaries); + // The API returns LINEAR by default when the colorspace cannot be represented + // by enum values. + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + + // Test the same but with integer format. + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format_int, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_PRIMARIES_SRGB, pixel_encoding.primaries); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + + // Test after setting the preferred color profile to non-linear sRGB: + // for XYB images with ICC profile, this setting is expected to take effect. + jxl::ColorEncoding temp_jxl_srgb = jxl::ColorEncoding::SRGB(false); + JxlColorEncoding pixel_encoding_srgb; + ConvertInternalToExternalColorEncoding(temp_jxl_srgb, &pixel_encoding_srgb); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreferredColorProfile(dec, &pixel_encoding_srgb)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_SRGB, pixel_encoding.transfer_function); + + // The decoder can also output this as a generated ICC profile anyway, and + // we're certain that it will differ from the above defined profile since + // the sRGB data should not have swapped R/G/B primaries. + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + // We don't need to dictate exactly what size the generated ICC profile + // must be (since there are many ways to represent the same color space), + // but it should not be zero. + EXPECT_NE(0u, dec_profile_size); + jxl::PaddedBytes icc_profile2(dec_profile_size); + if (0 != dec_profile_size) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile2.data(), icc_profile2.size())); + // expected not equal + EXPECT_NE(icc_profile, icc_profile2); + } + + // Test setting another different preferred profile, to verify that the + // returned JXL_COLOR_PROFILE_TARGET_DATA ICC profile is correctly + // updated. + + jxl::ColorEncoding temp_jxl_linear = jxl::ColorEncoding::LinearSRGB(false); + JxlColorEncoding pixel_encoding_linear; + ConvertInternalToExternalColorEncoding(temp_jxl_linear, + &pixel_encoding_linear); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreferredColorProfile(dec, &pixel_encoding_linear)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + EXPECT_NE(0u, dec_profile_size); + jxl::PaddedBytes icc_profile3(dec_profile_size); + if (0 != dec_profile_size) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile3.data(), icc_profile3.size())); + // expected not equal to the previously set preferred profile. + EXPECT_NE(icc_profile2, icc_profile3); + } + + JxlDecoderDestroy(dec); +} + +// Test decoding ICC from partial files byte for byte. +// This test must pass also if JXL_CRASH_ON_ERROR is enabled, that is, the +// decoding of the ANS histogram and stream of the encoded ICC profile must also +// handle the case of not enough input bytes with StatusCode::kNotEnoughBytes +// rather than fatal error status codes. +TEST(DecodeTest, ICCPartialTest) { + jxl::PaddedBytes icc_profile = GetIccTestProfile(); + std::vector data = GetIccTestHeader(icc_profile, false); + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + const uint8_t* next_in = data.data(); + size_t avail_in = 0; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + + bool seen_basic_info = false; + bool seen_color_encoding = false; + size_t total_size = 0; + + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_size >= data.size()) { + // End of partial codestream with codestrema headers and ICC profile + // reached, it should not require more input since full image is not + // requested + FAIL(); + break; + } + size_t increment = 1; + if (total_size + increment > data.size()) { + increment = data.size() - total_size; + } + total_size += increment; + avail_in += increment; + } else if (status == JXL_DEC_BASIC_INFO) { + EXPECT_FALSE(seen_basic_info); + seen_basic_info = true; + } else if (status == JXL_DEC_COLOR_ENCODING) { + EXPECT_TRUE(seen_basic_info); + EXPECT_FALSE(seen_color_encoding); + seen_color_encoding = true; + + // Sanity check that the ICC profile was decoded correctly + size_t dec_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, &format, + JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &dec_profile_size)); + EXPECT_EQ(icc_profile.size(), dec_profile_size); + + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_TRUE(seen_color_encoding); + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_TRUE(seen_basic_info); + EXPECT_TRUE(seen_color_encoding); + + JxlDecoderDestroy(dec); +} + +struct PixelTestConfig { + // Input image definition. + bool grayscale; + bool include_alpha; + size_t xsize; + size_t ysize; + bool add_preview; + bool add_intrinsic_size; + // Output format. + JxlEndianness endianness; + JxlDataType data_type; + uint32_t output_channels; + // Container options. + CodeStreamBoxFormat add_container; + // Decoding mode. + bool use_callback; + bool set_buffer_early; + bool use_resizable_runner; + // Exif orientation, 1-8 + JxlOrientation orientation; + bool keep_orientation; + size_t upsampling; +}; + +class DecodeTestParam : public ::testing::TestWithParam {}; + +TEST_P(DecodeTestParam, PixelTest) { + PixelTestConfig config = GetParam(); + JxlDecoder* dec = JxlDecoderCreate(NULL); + + if (config.keep_orientation) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetKeepOrientation(dec, JXL_TRUE)); + } + + size_t num_pixels = config.xsize * config.ysize; + uint32_t orig_channels = + (config.grayscale ? 1 : 3) + (config.include_alpha ? 1 : 0); + std::vector pixels = + jxl::test::GetSomeTestImage(config.xsize, config.ysize, orig_channels, 0); + JxlPixelFormat format_orig = {orig_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, + 0}; + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.cparams.resampling = config.upsampling; + params.cparams.ec_resampling = config.upsampling; + params.box_format = config.add_container; + params.orientation = config.orientation; + params.add_preview = config.add_preview; + params.add_intrinsic_size = config.add_intrinsic_size; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), config.xsize, + config.ysize, orig_channels, params); + + JxlPixelFormat format = {config.output_channels, config.data_type, + config.endianness, 0}; + + bool swap_xy = !config.keep_orientation && (config.orientation > 4); + size_t xsize = swap_xy ? config.ysize : config.xsize; + size_t ysize = swap_xy ? config.xsize : config.ysize; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, config.use_callback, config.set_buffer_early, + config.use_resizable_runner, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * config.output_channels * + jxl::test::GetDataBits(config.data_type) / jxl::kBitsPerByte, + pixels2.size()); + + // If an orientation transformation is expected, to compare the pixels, also + // apply this transformation to the original pixels. ConvertToExternal is + // used to achieve this, with a temporary conversion to CodecInOut and back. + if (config.orientation > 1 && !config.keep_orientation) { + jxl::Span bytes(pixels.data(), pixels.size()); + jxl::ColorEncoding color_encoding = + jxl::ColorEncoding::SRGB(config.grayscale); + + jxl::CodecInOut io; + if (config.include_alpha) io.metadata.m.SetAlphaBits(16); + io.SetSize(config.xsize, config.ysize); + + EXPECT_TRUE(ConvertFromExternal(bytes, config.xsize, config.ysize, + color_encoding, config.output_channels, + /*alpha_is_premultiplied=*/false, 16, + JXL_BIG_ENDIAN, nullptr, &io.Main(), + /*float_in=*/false, + /*align=*/0)); + + for (size_t i = 0; i < pixels.size(); i++) pixels[i] = 0; + EXPECT_TRUE(ConvertToExternal( + io.Main(), 16, + /*float_out=*/false, orig_channels, JXL_BIG_ENDIAN, + xsize * 2 * orig_channels, nullptr, pixels.data(), pixels.size(), + /*out_callback=*/{}, + static_cast(config.orientation))); + } + if (config.upsampling == 1) { + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } else { + // resampling is of course not lossless, so as a rough check: + // count pixels that are more than off-by-25 in the 8-bit value of one of + // the channels + EXPECT_LE( + jxl::test::ComparePixels( + pixels.data(), pixels2.data(), xsize, ysize, format_orig, format, + 50.0 * (config.data_type == JXL_TYPE_UINT8 ? 1.0 : 256.0)), + 300u); + } + + JxlDecoderDestroy(dec); +} + +std::vector GeneratePixelTests() { + std::vector all_tests; + struct ChannelInfo { + bool grayscale; + bool include_alpha; + size_t output_channels; + }; + ChannelInfo ch_info[] = { + {false, true, 4}, // RGBA -> RGBA + {true, false, 1}, // G -> G + {true, true, 1}, // GA -> G + {true, true, 2}, // GA -> GA + {false, false, 3}, // RGB -> RGB + {false, true, 3}, // RGBA -> RGB + {false, false, 4}, // RGB -> RGBA + }; + + struct OutputFormat { + JxlEndianness endianness; + JxlDataType data_type; + }; + OutputFormat out_formats[] = { + {JXL_NATIVE_ENDIAN, JXL_TYPE_UINT8}, + {JXL_LITTLE_ENDIAN, JXL_TYPE_UINT16}, + {JXL_BIG_ENDIAN, JXL_TYPE_UINT16}, + {JXL_NATIVE_ENDIAN, JXL_TYPE_FLOAT16}, + {JXL_LITTLE_ENDIAN, JXL_TYPE_FLOAT}, + {JXL_BIG_ENDIAN, JXL_TYPE_FLOAT}, + }; + + auto make_test = [&](ChannelInfo ch, size_t xsize, size_t ysize, bool preview, + bool intrinsic_size, CodeStreamBoxFormat box, + JxlOrientation orientation, bool keep_orientation, + OutputFormat format, bool use_callback, + bool set_buffer_early, bool resizable_runner, + size_t upsampling) { + PixelTestConfig c; + c.grayscale = ch.grayscale; + c.include_alpha = ch.include_alpha; + c.add_preview = preview; + c.add_intrinsic_size = intrinsic_size; + c.xsize = xsize; + c.ysize = ysize; + c.add_container = (CodeStreamBoxFormat)box; + c.output_channels = ch.output_channels; + c.data_type = format.data_type; + c.endianness = format.endianness; + c.use_callback = use_callback; + c.set_buffer_early = set_buffer_early; + c.use_resizable_runner = resizable_runner; + c.orientation = orientation; + c.keep_orientation = keep_orientation; + c.upsampling = upsampling; + all_tests.push_back(c); + }; + + // Test output formats and methods. + for (ChannelInfo ch : ch_info) { + for (int use_callback = 0; use_callback <= 1; use_callback++) { + for (size_t upsampling : {1, 2, 4, 8}) { + for (OutputFormat fmt : out_formats) { + make_test(ch, 301, 33, /*add_preview=*/false, + /*add_intrinsic_size=*/false, + CodeStreamBoxFormat::kCSBF_None, JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, fmt, use_callback, + /*set_buffer_early=*/false, /*resizable_runner=*/false, + upsampling); + } + } + } + } + // Test codestream formats. + for (size_t box = 1; box < kCSBF_NUM_ENTRIES; ++box) { + make_test(ch_info[0], 77, 33, /*add_preview=*/false, + /*add_intrinsic_size=*/false, (CodeStreamBoxFormat)box, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, + /*set_buffer_early=*/false, /*resizable_runner=*/false, 1); + } + // Test previews. + for (int add_preview = 0; add_preview <= 1; add_preview++) { + make_test(ch_info[0], 77, 33, add_preview, /*add_intrinsic_size=*/false, + CodeStreamBoxFormat::kCSBF_None, JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/false, + /*resizable_runner=*/false, 1); + } + // Test intrinsic sizes. + for (int add_intrinsic_size = 0; add_intrinsic_size <= 1; + add_intrinsic_size++) { + make_test(ch_info[0], 55, 34, /*add_preview=*/false, add_intrinsic_size, + CodeStreamBoxFormat::kCSBF_None, JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/false, + /*resizable_runner=*/false, 1); + } + // Test setting buffers early. + make_test(ch_info[0], 300, 33, /*add_preview=*/false, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/true, + /*resizable_runner=*/false, 1); + + // Test using the resizable runner + for (size_t i = 0; i < 4; i++) { + make_test(ch_info[0], 300 << i, 33 << i, /*add_preview=*/false, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/false, + /*resizable_runner=*/true, 1); + } + + // Test orientations. + for (int orientation = 1; orientation <= 8; ++orientation) { + make_test(ch_info[0], 280, 12, /*add_preview=*/false, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + static_cast(orientation), + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/true, + /*resizable_runner=*/false, 1); + make_test(ch_info[0], 280, 12, /*add_preview=*/false, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + static_cast(orientation), + /*keep_orientation=*/true, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/true, + /*resizable_runner=*/false, 1); + } + + return all_tests; +} + +std::ostream& operator<<(std::ostream& os, const PixelTestConfig& c) { + os << c.xsize << "x" << c.ysize; + const char* colors[] = {"", "G", "GA", "RGB", "RGBA"}; + os << colors[(c.grayscale ? 1 : 3) + (c.include_alpha ? 1 : 0)]; + os << "to"; + os << colors[c.output_channels]; + switch (c.data_type) { + case JXL_TYPE_UINT8: + os << "u8"; + break; + case JXL_TYPE_UINT16: + os << "u16"; + break; + case JXL_TYPE_FLOAT: + os << "f32"; + break; + case JXL_TYPE_FLOAT16: + os << "f16"; + break; + default: + JXL_ASSERT(false); + }; + if (jxl::test::GetDataBits(c.data_type) > jxl::kBitsPerByte) { + if (c.endianness == JXL_NATIVE_ENDIAN) { + // add nothing + } else if (c.endianness == JXL_BIG_ENDIAN) { + os << "BE"; + } else if (c.endianness == JXL_LITTLE_ENDIAN) { + os << "LE"; + } + } + if (c.add_container != CodeStreamBoxFormat::kCSBF_None) { + os << "Box"; + os << (size_t)c.add_container; + } + if (c.add_preview) os << "Preview"; + if (c.add_intrinsic_size) os << "IntrinicSize"; + if (c.use_callback) os << "Callback"; + if (c.set_buffer_early) os << "EarlyBuffer"; + if (c.use_resizable_runner) os << "ResizableRunner"; + if (c.orientation != 1) os << "O" << c.orientation; + if (c.keep_orientation) os << "Keep"; + if (c.upsampling > 1) os << "x" << c.upsampling; + return os; +} + +std::string PixelTestDescription( + const testing::TestParamInfo& info) { + std::stringstream name; + name << info.param; + return name.str(); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(DecodeTest, DecodeTestParam, + testing::ValuesIn(GeneratePixelTests()), + PixelTestDescription); + +TEST(DecodeTest, PixelTestWithICCProfileLossless) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.add_icc_profile = true; + // For variation: some have container and no preview, others have preview + // and no container. + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + + for (uint32_t channels = 3; channels <= 4; ++channels) { + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/false, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + EXPECT_EQ(0u, + jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}; + + // Test with the container for one of the pixel formats. + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/true, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0u, + jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/false, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*reuqire_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0u, + jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + } + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PixelTestWithICCProfileLossy) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + params.add_icc_profile = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + uint32_t channels = 3; + + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + jxl::PaddedBytes icc; + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true, /*icc=*/&icc); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + + // The input pixels use the profile matching GetIccTestProfile, since we set + // add_icc_profile for CreateTestJXLCodestream to true. + jxl::ColorEncoding color_encoding0; + EXPECT_TRUE(color_encoding0.SetICC(GetIccTestProfile())); + jxl::Span span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE( + ConvertFromExternal(span0, xsize, ysize, color_encoding0, /*channels=*/3, + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/16, format_orig.endianness, + /*pool=*/nullptr, &io0.Main(), /*float_in=*/false, + /*align=*/0)); + + jxl::ColorEncoding color_encoding1; + EXPECT_TRUE(color_encoding1.SetICC(std::move(icc))); + jxl::Span span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + io1.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal(span1, xsize, ysize, color_encoding1, + channels, /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/32, format.endianness, + /*pool=*/nullptr, &io1.Main(), + /*float_in=*/true, /*align=*/0)); + + jxl::ButteraugliParams ba; + EXPECT_THAT(ButteraugliDistance(io0, io1, ba, jxl::GetJxlCms(), + /*distmap=*/nullptr, nullptr), + IsSlightlyBelow(0.785f)); + + JxlDecoderDestroy(dec); +} + +std::string ColorDescription(JxlColorEncoding c) { + jxl::ColorEncoding color_encoding; + EXPECT_TRUE(ConvertExternalToInternalColorEncoding(c, &color_encoding)); + return Description(color_encoding); +} + +std::string GetOrigProfile(JxlDecoder* dec) { + JxlColorEncoding c; + JxlColorProfileTarget target = JXL_COLOR_PROFILE_TARGET_ORIGINAL; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile(dec, nullptr, target, &c)); + return ColorDescription(c); +} + +std::string GetDataProfile(JxlDecoder* dec) { + JxlColorEncoding c; + JxlColorProfileTarget target = JXL_COLOR_PROFILE_TARGET_DATA; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile(dec, nullptr, target, &c)); + return ColorDescription(c); +} + +double ButteraugliDistance(size_t xsize, size_t ysize, + const std::vector& pixels_in, + const jxl::ColorEncoding& color_in, + float intensity_in, + const std::vector& pixels_out, + const jxl::ColorEncoding& color_out, + float intensity_out) { + jxl::CodecInOut in; + in.metadata.m.color_encoding = color_in; + in.metadata.m.SetIntensityTarget(intensity_in); + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Span(pixels_in.data(), pixels_in.size()), xsize, + ysize, color_in, color_in.Channels(), + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*pool=*/nullptr, &in.Main(), /*float_in=*/false, /*align=*/0)); + jxl::CodecInOut out; + out.metadata.m.color_encoding = color_out; + out.metadata.m.SetIntensityTarget(intensity_out); + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Span(pixels_out.data(), pixels_out.size()), xsize, + ysize, color_out, color_out.Channels(), + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*pool=*/nullptr, &out.Main(), /*float_in=*/false, /*align=*/0)); + return ButteraugliDistance(in, out, jxl::ButteraugliParams(), + jxl::GetJxlCms(), nullptr, nullptr); +} + +class DecodeAllEncodingsTest + : public ::testing::TestWithParam {}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + DecodeAllEncodingsTestInstantiation, DecodeAllEncodingsTest, + ::testing::ValuesIn(jxl::test::AllEncodings())); +TEST_P(DecodeAllEncodingsTest, PreserveOriginalProfileTest) { + size_t xsize = 123, ysize = 77; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + int events = JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_FULL_IMAGE; + const auto& cdesc = GetParam(); + jxl::ColorEncoding c_in = jxl::test::ColorEncodingFromDescriptor(cdesc); + if (c_in.rendering_intent != jxl::RenderingIntent::kRelative) return; + std::string color_space_in = Description(c_in); + float intensity_in = c_in.tf.IsPQ() ? 10000 : 255; + printf("Testing input color space %s\n", color_space_in.c_str()); + jxl::TestCodestreamParams params; + params.color_space = color_space_in; + params.intensity_target = intensity_in; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_FALSE(info.uses_original_profile); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + EXPECT_EQ(GetOrigProfile(dec), color_space_in); + EXPECT_EQ(GetDataProfile(dec), color_space_in); + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + std::vector out(pixels.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, out.data(), out.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + double dist = ButteraugliDistance(xsize, ysize, pixels, c_in, intensity_in, + out, c_in, intensity_in); + EXPECT_LT(dist, 1.2); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); +} + +namespace { +void SetPreferredColorProfileTest( + const jxl::test::ColorEncodingDescriptor& from) { + size_t xsize = 123, ysize = 77; + int events = JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_FULL_IMAGE; + jxl::ColorEncoding c_in = jxl::test::ColorEncodingFromDescriptor(from); + if (c_in.rendering_intent != jxl::RenderingIntent::kRelative) return; + if (c_in.white_point != jxl::WhitePoint::kD65) return; + uint32_t num_channels = c_in.Channels(); + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::string color_space_in = Description(c_in); + float intensity_in = c_in.tf.IsPQ() ? 10000 : 255; + jxl::TestCodestreamParams params; + params.color_space = color_space_in; + params.intensity_target = intensity_in; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + for (const auto& c1 : jxl::test::AllEncodings()) { + jxl::ColorEncoding c_out = jxl::test::ColorEncodingFromDescriptor(c1); + float intensity_out = intensity_in; + if (c_out.rendering_intent != jxl::RenderingIntent::kRelative) continue; + if ((c_in.primaries == jxl::Primaries::k2100 && + c_out.primaries != jxl::Primaries::k2100) || + (c_in.primaries == jxl::Primaries::kP3 && + c_out.primaries == jxl::Primaries::kSRGB)) { + // Converting to a narrower gamut does not work without gammut mapping. + continue; + } + if (c_out.tf.IsHLG() && intensity_out > 300) { + // The Linear->HLG OOTF function at this intensity level can push + // saturated colors out of gamut, so we would need gamut mapping in + // this case too. + continue; + } + std::string color_space_out = Description(c_out); + if (color_space_in == color_space_out) continue; + printf("Testing input color space %s with output color space %s\n", + color_space_in.c_str(), color_space_out.c_str()); + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_FALSE(info.uses_original_profile); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + EXPECT_EQ(GetOrigProfile(dec), color_space_in); + EXPECT_EQ(GetDataProfile(dec), color_space_in); + JxlColorEncoding encoding_out; + EXPECT_TRUE(jxl::ParseDescription(color_space_out, &encoding_out)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreferredColorProfile(dec, &encoding_out)); + EXPECT_EQ(GetOrigProfile(dec), color_space_in); + EXPECT_EQ(GetDataProfile(dec), color_space_out); + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + JxlPixelFormat out_format = format; + out_format.num_channels = c_out.Channels(); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &out_format, &buffer_size)); + std::vector out(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &out_format, out.data(), out.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + double dist = ButteraugliDistance(xsize, ysize, pixels, c_in, intensity_in, + out, c_out, intensity_out); + if (c_in.white_point == c_out.white_point) { + EXPECT_LT(dist, 1.2); + } else { + EXPECT_LT(dist, 4.0); + } + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + } +} +} // namespace + +TEST(DecodeTest, SetPreferredColorProfileTestFromGray) { + jxl::test::ColorEncodingDescriptor gray = { + jxl::ColorSpace::kGray, jxl::WhitePoint::kD65, jxl::Primaries::kSRGB, + jxl::TransferFunction::kSRGB, jxl::RenderingIntent::kRelative}; + SetPreferredColorProfileTest(gray); +} + +TEST_P(DecodeAllEncodingsTest, SetPreferredColorProfileTest) { + const auto& from = GetParam(); + SetPreferredColorProfileTest(from); +} + +// Tests the case of lossy sRGB image without alpha channel, decoded to RGB8 +// and to RGBA8 +TEST(DecodeTest, PixelTestOpaqueSrgbLossy) { + for (unsigned channels = 3; channels <= 4; channels++) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + jxl::TestCodestreamParams()); + + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/true, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success*/ true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + + jxl::ColorEncoding color_encoding0 = jxl::ColorEncoding::SRGB(false); + jxl::Span span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span0, xsize, ysize, color_encoding0, /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + format_orig.endianness, + /*pool=*/nullptr, &io0.Main(), /*float_in=*/false, + /*align=*/0)); + + jxl::ColorEncoding color_encoding1 = jxl::ColorEncoding::SRGB(false); + jxl::Span span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + EXPECT_TRUE(ConvertFromExternal(span1, xsize, ysize, color_encoding1, + channels, /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/8, format.endianness, + /*pool=*/nullptr, &io1.Main(), + /*float_in=*/false, + /*align=*/0)); + + jxl::ButteraugliParams ba; + EXPECT_THAT(ButteraugliDistance(io0, io1, ba, jxl::GetJxlCms(), + /*distmap=*/nullptr, nullptr), + IsSlightlyBelow(0.8f)); + + JxlDecoderDestroy(dec); + } +} + +// Opaque image with noise enabled, decoded to RGB8 and RGBA8. +TEST(DecodeTest, PixelTestOpaqueSrgbLossyNoise) { + for (unsigned channels = 3; channels <= 4; channels++) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 512, ysize = 300; + size_t num_pixels = xsize * ysize; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + params.cparams.noise = jxl::Override::kOn; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + + jxl::ColorEncoding color_encoding0 = jxl::ColorEncoding::SRGB(false); + jxl::Span span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span0, xsize, ysize, color_encoding0, /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + format_orig.endianness, + /*pool=*/nullptr, &io0.Main(), /*float_in=*/false, + /*align=*/0)); + + jxl::ColorEncoding color_encoding1 = jxl::ColorEncoding::SRGB(false); + jxl::Span span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + EXPECT_TRUE(ConvertFromExternal(span1, xsize, ysize, color_encoding1, + channels, /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/8, format.endianness, + /*pool=*/nullptr, &io1.Main(), + /*float_in=*/false, + /*align=*/0)); + + jxl::ButteraugliParams ba; + EXPECT_THAT(ButteraugliDistance(io0, io1, ba, jxl::GetJxlCms(), + /*distmap=*/nullptr, nullptr), + IsSlightlyBelow(2.6f)); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ProcessEmptyInputWithBoxes) { + size_t xsize = 123, ysize = 77; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + jxl::TestCodestreamParams params; + params.box_format = (CodeStreamBoxFormat)i; + printf("Testing empty input with box format %d\n", (int)params.box_format); + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + const int events = + JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | JXL_DEC_COLOR_ENCODING; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + const size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, compressed.size()); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ExtraBytesAfterCompressedStream) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat box_format = (CodeStreamBoxFormat)i; + if (box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + printf("Testing with box format %d\n", (int)box_format); + size_t last_unknown_box_size = 0; + if (box_format == kCSBF_Single_Other) { + last_unknown_box_size = unk1_box_size + 8; + } else if (box_format == kCSBF_Multi_Other_Terminated) { + last_unknown_box_size = unk3_box_size + 8; + } else if (box_format == kCSBF_Multi_Last_Empty_Other) { + // If boxes are not required, the decoder wont consume the last empty + // jxlp box. + last_unknown_box_size = 12 + unk3_box_size + 8; + } + jxl::TestCodestreamParams params; + params.box_format = box_format; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + // Add some more bytes after compressed data. + compressed.push_back(0); + compressed.push_back(1); + compressed.push_back(2); + JxlDecoder* dec = JxlDecoderCreate(NULL); + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + size_t unconsumed_bytes = JxlDecoderReleaseInput(dec); + EXPECT_EQ(last_unknown_box_size + 3, unconsumed_bytes); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ExtraBytesAfterCompressedStreamRequireBoxes) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat box_format = (CodeStreamBoxFormat)i; + if (box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + printf("Testing with box format %d\n", (int)box_format); + bool expect_success = (box_format == kCSBF_None || + box_format == kCSBF_Single_Zero_Terminated || + box_format == kCSBF_Multi_Zero_Terminated); + jxl::TestCodestreamParams params; + params.box_format = box_format; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + // Add some more bytes after compressed data. + compressed.push_back(0); + compressed.push_back(1); + compressed.push_back(2); + JxlDecoder* dec = JxlDecoderCreate(NULL); + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format, /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/true, expect_success); + size_t unconsumed_bytes = JxlDecoderReleaseInput(dec); + EXPECT_EQ(3, unconsumed_bytes); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ConcatenatedCompressedStreams) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat first_box_format = (CodeStreamBoxFormat)i; + if (first_box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + jxl::TestCodestreamParams params1; + params1.box_format = first_box_format; + jxl::PaddedBytes compressed1 = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params1); + for (int j = 0; j < kCSBF_NUM_ENTRIES; ++j) { + CodeStreamBoxFormat second_box_format = (CodeStreamBoxFormat)j; + if (second_box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + printf("Testing with box format pair %d, %d\n", (int)first_box_format, + (int)second_box_format); + jxl::TestCodestreamParams params2; + params2.box_format = second_box_format; + jxl::PaddedBytes compressed2 = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + 3, params2); + jxl::PaddedBytes concat; + concat.append(compressed1); + concat.append(compressed2); + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + size_t remaining = concat.size(); + for (int part = 0; part < 2; ++part) { + printf(" Decoding part %d\n", part + 1); + JxlDecoder* dec = JxlDecoderCreate(NULL); + size_t pos = concat.size() - remaining; + bool expect_success = + (part == 0 || second_box_format == kCSBF_None || + second_box_format == kCSBF_Single_Zero_Terminated || + second_box_format == kCSBF_Multi_Zero_Terminated); + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(concat.data() + pos, remaining), + format, /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/true, + expect_success); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + remaining = JxlDecoderReleaseInput(dec); + JxlDecoderDestroy(dec); + } + EXPECT_EQ(0, remaining); + } + } +} + +void TestPartialStream(bool reconstructible_jpeg) { + size_t xsize = 123, ysize = 77; + uint32_t channels = 4; + if (reconstructible_jpeg) { + channels = 3; + } + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, 0); + JxlPixelFormat format_orig = {channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + if (reconstructible_jpeg) { + params.cparams.color_transform = jxl::ColorTransform::kNone; + } else { + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + } + + std::vector pixels2; + pixels2.resize(pixels.size()); + + jxl::PaddedBytes jpeg_output(64); + size_t used_jpeg_output = 0; + + std::vector codestreams(kCSBF_NUM_ENTRIES); + std::vector jpeg_codestreams(kCSBF_NUM_ENTRIES); + for (size_t i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + params.box_format = (CodeStreamBoxFormat)i; + if (reconstructible_jpeg) { + params.jpeg_codestream = &jpeg_codestreams[i]; + } + codestreams[i] = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + channels, params); + } + + // Test multiple step sizes, to test different combinations of the streaming + // box parsing. + std::vector increments = {1, 3, 17, 23, 120, 700, 1050}; + + for (size_t index = 0; index < increments.size(); index++) { + for (size_t i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + if (reconstructible_jpeg && + (CodeStreamBoxFormat)i == CodeStreamBoxFormat::kCSBF_None) { + continue; + } + const jxl::PaddedBytes& data = codestreams[i]; + const uint8_t* next_in = data.data(); + size_t avail_in = 0; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | + JXL_DEC_JPEG_RECONSTRUCTION)); + + bool seen_basic_info = false; + bool seen_full_image = false; + bool seen_jpeg_recon = false; + + size_t total_size = 0; + + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_size >= data.size()) { + // End of test data reached, it should have successfully decoded the + // image now. + FAIL(); + break; + } + + size_t increment = increments[index]; + // End of the file reached, should be the final test. + if (total_size + increment > data.size()) { + increment = data.size() - total_size; + } + total_size += increment; + avail_in += increment; + } else if (status == JXL_DEC_BASIC_INFO) { + // This event should happen exactly once + EXPECT_FALSE(seen_basic_info); + if (seen_basic_info) break; + seen_basic_info = true; + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_JPEG_RECONSTRUCTION) { + EXPECT_FALSE(seen_basic_info); + EXPECT_FALSE(seen_full_image); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec, jpeg_output.data(), + jpeg_output.size())); + seen_jpeg_recon = true; + } else if (status == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + EXPECT_TRUE(seen_jpeg_recon); + used_jpeg_output = + jpeg_output.size() - JxlDecoderReleaseJPEGBuffer(dec); + jpeg_output.resize(jpeg_output.size() * 2); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer( + dec, jpeg_output.data() + used_jpeg_output, + jpeg_output.size() - used_jpeg_output)); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer( + dec, &format_orig, pixels2.data(), pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + // This event should happen exactly once + EXPECT_FALSE(seen_full_image); + if (seen_full_image) break; + // This event should happen after basic info + EXPECT_TRUE(seen_basic_info); + seen_full_image = true; + if (reconstructible_jpeg) { + used_jpeg_output = + jpeg_output.size() - JxlDecoderReleaseJPEGBuffer(dec); + EXPECT_EQ(used_jpeg_output, jpeg_codestreams[i].size()); + EXPECT_EQ(0, memcmp(jpeg_output.data(), jpeg_codestreams[i].data(), + used_jpeg_output)); + } else { + EXPECT_EQ(pixels, pixels2); + } + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_TRUE(seen_full_image); + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + // Ensure the decoder emitted the basic info and full image events + EXPECT_TRUE(seen_basic_info); + EXPECT_TRUE(seen_full_image); + + JxlDecoderDestroy(dec); + } + } +} + +// Tests the return status when trying to decode pixels on incomplete file: it +// should return JXL_DEC_NEED_MORE_INPUT, not error. +TEST(DecodeTest, PixelPartialTest) { TestPartialStream(false); } + +#if JPEGXL_ENABLE_JPEG +// Tests the return status when trying to decode JPEG bytes on incomplete file. +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGPartialTest)) { + TestPartialStream(true); +} +#endif // JPEGXL_ENABLE_JPEG + +// The DC event still exists, but is no longer implemented, it is deprecated. +TEST(DecodeTest, DCNotGettableTest) { + // 1x1 pixel JXL image + std::string compressed( + "\377\n\0\20\260\23\0H\200(" + "\0\334\0U\17\0\0\250P\31e\334\340\345\\\317\227\37:," + "\246m\\gh\253m\vK\22E\306\261I\252C&pH\22\353 " + "\363\6\22\bp\0\200\237\34\231W2d\255$\1", + 68); + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_DC_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput( + dec, reinterpret_cast(compressed.data()), + compressed.size())); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Since the image is only 1x1 pixel, there is only 1 group, the decoder is + // unable to get DC size from this, and will not return the DC at all. Since + // no full image is requested either, it is expected to return success. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PreviewTest) { + size_t xsize = 77, ysize = 120; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + + jxl::TestCodestreamParams params; + params.add_preview = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + params); + + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_PREVIEW_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + + // GetSomeTestImage is hardcoded to use a top-left cropped preview with + // floor of 1/7th of the size + size_t xsize_preview = (xsize / 7); + size_t ysize_preview = (ysize / 7); + EXPECT_EQ(xsize_preview, info.preview.xsize); + EXPECT_EQ(ysize_preview, info.preview.ysize); + EXPECT_EQ(xsize_preview * ysize_preview * 3, buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + std::vector preview(xsize_preview * ysize_preview * 3); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetPreviewOutBuffer( + dec, &format, preview.data(), preview.size())); + + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + + jxl::Image3F preview0(xsize_preview, ysize_preview); + jxl::Image3F preview1(xsize_preview, ysize_preview); + + // For preview0, the original: top-left crop the preview image the way + // GetSomeTestImage does. + for (size_t y = 0; y < ysize_preview; y++) { + for (size_t x = 0; x < xsize_preview; x++) { + preview0.PlaneRow(0, y)[x] = + (1.f / 255) * (pixels[(y * xsize + x) * 6 + 0]); + preview0.PlaneRow(1, y)[x] = + (1.f / 255) * (pixels[(y * xsize + x) * 6 + 2]); + preview0.PlaneRow(2, y)[x] = + (1.f / 255) * (pixels[(y * xsize + x) * 6 + 4]); + preview1.PlaneRow(0, y)[x] = + (1.f / 255) * (preview[(y * xsize_preview + x) * 3 + 0]); + preview1.PlaneRow(1, y)[x] = + (1.f / 255) * (preview[(y * xsize_preview + x) * 3 + 1]); + preview1.PlaneRow(2, y)[x] = + (1.f / 255) * (preview[(y * xsize_preview + x) * 3 + 2]); + } + } + + jxl::CodecInOut io0; + io0.SetFromImage(std::move(preview0), jxl::ColorEncoding::SRGB(false)); + jxl::CodecInOut io1; + io1.SetFromImage(std::move(preview1), jxl::ColorEncoding::SRGB(false)); + + jxl::ButteraugliParams ba; + // TODO(lode): this ButteraugliDistance silently returns 0 (dangerous for + // tests) if xsize or ysize is < 8, no matter how different the images, a tiny + // size that could happen for a preview. ButteraugliDiffmap does support + // smaller than 8x8, but jxl's ButteraugliDistance does not. Perhaps move + // butteraugli's <8x8 handling from ButteraugliDiffmap to + // ButteraugliComparator::Diffmap in butteraugli.cc. + EXPECT_LE(ButteraugliDistance(io0, io1, ba, jxl::GetJxlCms(), + /*distmap=*/nullptr, nullptr), + 0.6f); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, AlignTest) { + size_t xsize = 123, ysize = 77; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + + size_t align = 17; + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, align}; + // On purpose not using jxl::RoundUpTo to test it independently. + size_t expected_line_bytes = (1 * 3 * xsize + align - 1) / align * align; + + for (int use_callback = 0; use_callback <= 1; ++use_callback) { + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), compressed.size()), format, + use_callback, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_succes=*/true); + EXPECT_EQ(expected_line_bytes * ysize, pixels2.size()); + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } +} + +TEST(DecodeTest, AnimationTest) { + size_t xsize = 123, ysize = 77; + static const size_t num_frames = 2; + std::vector frames[2]; + frames[0] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + frames[1] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 1); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frames[i].data(), frames[i].size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + + // Decode and test the animation frames + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + EXPECT_EQ(0u, frame_header.name_length); + // For now, test with empty name, there's currently no easy way to encode + // a jxl file with a frame name because ImageBundle doesn't have a + // jxl::FrameHeader to set the name in. We can test the null termination + // character though. + char name; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameName(dec, &name, 1)); + EXPECT_EQ(0, name); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, AnimationTestStreaming) { + size_t xsize = 123, ysize = 77; + static const size_t num_frames = 2; + std::vector frames[2]; + frames[0] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + frames[1] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 1); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frames[i].data(), frames[i].size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + + // Decode and test the animation frames + + const size_t step_size = 16; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t frame_headers_seen = 0; + size_t frames_seen = 0; + bool seen_basic_info = false; + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + std::vector frames2[2]; + for (size_t i = 0; i < num_frames; ++i) { + frames2[i].resize(frames[i].size()); + } + + size_t total_in = 0; + size_t loop_count = 0; + + for (;;) { + if (loop_count++ > compressed.size()) { + fprintf(stderr, "Too many loops\n"); + FAIL(); + break; + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + auto status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + + if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_in >= compressed.size()) { + fprintf(stderr, "Already gave all input data\n"); + FAIL(); + break; + } + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, frames2[frames_seen].data(), + frames2[frames_seen].size())); + } else if (status == JXL_DEC_BASIC_INFO) { + EXPECT_EQ(false, seen_basic_info); + seen_basic_info = true; + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + } else if (status == JXL_DEC_FRAME) { + EXPECT_EQ(true, seen_basic_info); + frame_headers_seen++; + } else if (status == JXL_DEC_FULL_IMAGE) { + frames_seen++; + EXPECT_EQ(frame_headers_seen, frames_seen); + } else { + fprintf(stderr, "Unexpected status: %d\n", (int)status); + FAIL(); + } + } + + EXPECT_EQ(true, seen_basic_info); + EXPECT_EQ(num_frames, frames_seen); + EXPECT_EQ(num_frames, frame_headers_seen); + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(frames[i], frames2[i]); + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, ExtraChannelTest) { + size_t xsize = 55, ysize = 257; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + + size_t align = 17; + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, align}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(1u, info.num_extra_channels); + EXPECT_EQ(JXL_FALSE, info.alpha_premultiplied); + + JxlExtraChannelInfo extra_info; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelInfo(dec, 0, &extra_info)); + EXPECT_EQ(0, extra_info.type); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + size_t extra_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderExtraChannelBufferSize(dec, &format, &extra_size, 0)); + + std::vector image(buffer_size); + std::vector extra(extra_size); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, image.data(), image.size())); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetExtraChannelBuffer( + dec, &format, extra.data(), extra.size(), 0)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + // After the full image was output, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), image.data(), xsize, + ysize, format_orig, format)); + + // Compare the extracted extra channel with the original alpha channel + + std::vector alpha(pixels.size() / 4); + for (size_t i = 0; i < pixels.size(); i += 8) { + size_t index_alpha = i / 4; + alpha[index_alpha + 0] = pixels[i + 6]; + alpha[index_alpha + 1] = pixels[i + 7]; + } + JxlPixelFormat format_alpha = format; + format_alpha.num_channels = 1; + JxlPixelFormat format_orig_alpha = format_orig; + format_orig_alpha.num_channels = 1; + + EXPECT_EQ(0u, + jxl::test::ComparePixels(alpha.data(), extra.data(), xsize, ysize, + format_orig_alpha, format_alpha)); +} + +TEST(DecodeTest, SkipCurrentFrameTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 7; + std::vector frames[num_frames]; + for (size_t i = 0; i < num_frames; i++) { + frames[i] = jxl::test::GetSomeTestImage(xsize, ysize, 3, i); + } + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + if (i & 1) { + // Mark some frames as referenceable, others not. + bundle.use_for_next_frame = true; + } + + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frames[i].data(), frames[i].size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + jxl::PassDefinition passes[] = { + {2, 0, false, 4}, {4, 0, false, 4}, {8, 2, false, 2}, {8, 0, false, 1}}; + jxl::ProgressiveMode progressive_mode{passes}; + enc_state.progressive_splitter.SetProgressiveMode(progressive_mode); + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | + JXL_DEC_FRAME_PROGRESSION | + JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetProgressiveDetail(dec, kLastPasses)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + printf("Decoding frame %d\n", (int)i); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSkipCurrentFrame(dec)); + std::vector pixels(buffer_size); + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSkipCurrentFrame(dec)); + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + if (i == 2) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + EXPECT_EQ(8, JxlDecoderGetIntendedDownsamplingRatio(dec)); + if (i == 3) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + EXPECT_EQ(4, JxlDecoderGetIntendedDownsamplingRatio(dec)); + if (i == 4) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + EXPECT_EQ(2, JxlDecoderGetIntendedDownsamplingRatio(dec)); + if (i == 5) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSkipCurrentFrame(dec)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, SkipFrameTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 16; + std::vector frames[num_frames]; + for (size_t i = 0; i < num_frames; i++) { + frames[i] = jxl::test::GetSomeTestImage(xsize, ysize, 3, i); + } + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + if (i & 1) { + // Mark some frames as referenceable, others not. + bundle.use_for_next_frame = true; + } + + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frames[i].data(), frames[i].size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + + // Decode and test the animation frames + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + if (i == 3) { + JxlDecoderSkipFrames(dec, 5); + i += 5; + } + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + // Test rewinding the decoder and skipping different frames + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames; ++i) { + int test_skipping = (i == 9) ? 3 : 0; + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Since this is after JXL_DEC_FRAME but before JXL_DEC_FULL_IMAGE, this + // should only skip the next frame, not the currently processed one. + if (test_skipping) JxlDecoderSkipFrames(dec, test_skipping); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + + if (test_skipping) i += test_skipping; + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, SkipFrameWithBlendingTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 16; + std::vector frames[num_frames]; + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + + for (size_t i = 0; i < num_frames; ++i) { + if (i < 5) { + std::vector frame_internal = + jxl::test::GetSomeTestImage(xsize, ysize, 3, i * 2 + 1); + // An internal frame with 0 duration, and use_for_next_frame, this is a + // frame that is not rendered and not output by the API, but on which the + // rendered frames depend + jxl::ImageBundle bundle_internal(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frame_internal.data(), + frame_internal.size()), + xsize, ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle_internal, + /*float_in=*/false, /*align=*/0)); + bundle_internal.duration = 0; + bundle_internal.use_for_next_frame = true; + io.frames.push_back(std::move(bundle_internal)); + } + + std::vector frame = + jxl::test::GetSomeTestImage(xsize, ysize, 3, i * 2); + // Actual rendered frame + frame_durations[i] = 5 + i; + jxl::ImageBundle bundle(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frame.data(), frame.size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*channels=*/3, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.duration = frame_durations[i]; + // Create some variation in which frames depend on which. + if (i != 3 && i != 9 && i != 10) { + bundle.use_for_next_frame = true; + } + if (i != 12) { + bundle.blend = true; + // Choose a blend mode that depends on the pixels of the saved frame and + // doesn't use alpha + bundle.blendmode = jxl::BlendMode::kMul; + } + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + + // Independently decode all frames without any skipping, to create the + // expected blended frames, for the actual tests below to compare with. + { + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + frames[i].resize(xsize * ysize * 6); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, frames[i].data(), + frames[i].size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + + // Test rewinding mid-way, not decoding all frames. + if (i == 8) { + break; + } + } + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames; ++i) { + if (i == 3) { + JxlDecoderSkipFrames(dec, 5); + i += 5; + } + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + // Test rewinding the decoder and skipping different frames + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames; ++i) { + int test_skipping = (i == 9) ? 3 : 0; + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Since this is after JXL_DEC_FRAME but before JXL_DEC_FULL_IMAGE, this + // should only skip the next frame, not the currently processed one. + if (test_skipping) JxlDecoderSkipFrames(dec, test_skipping); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + + if (test_skipping) i += test_skipping; + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, SkipFrameWithAlphaBlendingTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 16; + std::vector frames[num_frames + 5]; + JxlPixelFormat format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames + 5); + io.SetSize(xsize, ysize); + + std::vector frame_durations_c; + std::vector frame_durations_nc; + std::vector frame_xsize, frame_ysize, frame_x0, frame_y0; + + for (size_t i = 0; i < num_frames; ++i) { + size_t cropxsize = 1 + xsize * 2 / (i + 1); + size_t cropysize = 1 + ysize * 3 / (i + 2); + int cropx0 = i * 3 - 8; + int cropy0 = i * 4 - 7; + if (i < 5) { + std::vector frame_internal = + jxl::test::GetSomeTestImage(xsize / 2, ysize / 2, 4, i * 2 + 1); + // An internal frame with 0 duration, and use_for_next_frame, this is a + // frame that is not rendered and not output by default by the API, but on + // which the rendered frames depend + jxl::ImageBundle bundle_internal(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frame_internal.data(), + frame_internal.size()), + xsize / 2, ysize / 2, jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*channels=*/4, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle_internal, + /*float_in=*/false, /*align=*/0)); + bundle_internal.duration = 0; + bundle_internal.use_for_next_frame = true; + bundle_internal.origin = {13, 17}; + io.frames.push_back(std::move(bundle_internal)); + frame_durations_nc.push_back(0); + frame_xsize.push_back(xsize / 2); + frame_ysize.push_back(ysize / 2); + frame_x0.push_back(13); + frame_y0.push_back(17); + } + + std::vector frame = + jxl::test::GetSomeTestImage(cropxsize, cropysize, 4, i * 2); + // Actual rendered frame + jxl::ImageBundle bundle(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frame.data(), frame.size()), cropxsize, + cropysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*channels=*/4, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.duration = 5 + i; + frame_durations_nc.push_back(5 + i); + frame_durations_c.push_back(5 + i); + frame_xsize.push_back(cropxsize); + frame_ysize.push_back(cropysize); + frame_x0.push_back(cropx0); + frame_y0.push_back(cropy0); + bundle.origin = {cropx0, cropy0}; + // Create some variation in which frames depend on which. + if (i != 3 && i != 9 && i != 10) { + bundle.use_for_next_frame = true; + } + if (i != 12) { + bundle.blend = true; + bundle.blendmode = jxl::BlendMode::kBlend; + } + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + // try both with and without coalescing + for (auto coalescing : {JXL_TRUE, JXL_FALSE}) { + // Independently decode all frames without any skipping, to create the + // expected blended frames, for the actual tests below to compare with. + { + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec, coalescing)); + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + for (size_t i = 0; i < num_frames + (coalescing ? 0 : 5); ++i) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + if (coalescing) { + EXPECT_EQ(xsize * ysize * 8, buffer_size); + } else { + EXPECT_EQ(frame_xsize[i] * frame_ysize[i] * 8, buffer_size); + } + frames[i].resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, frames[i].data(), + frames[i].size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec, coalescing)); + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | + JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + std::vector pixels(buffer_size); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ((coalescing ? frame_durations_c[i] : frame_durations_nc[i]), + frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels.data(), + pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.xsize, xsize); + } else { + EXPECT_EQ(frame_header.layer_info.xsize, frame_xsize[i]); + } + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.ysize, ysize); + } else { + EXPECT_EQ(frame_header.layer_info.ysize, frame_ysize[i]); + } + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + frame_header.layer_info.xsize, + frame_header.layer_info.ysize, + format, format)); + + // Test rewinding mid-way, not decoding all frames. + if (i == 8) { + break; + } + } + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames + (coalescing ? 0 : 5); ++i) { + if (i == 3) { + JxlDecoderSkipFrames(dec, 5); + i += 5; + } + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + std::vector pixels(buffer_size); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ((coalescing ? frame_durations_c[i] : frame_durations_nc[i]), + frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames + (coalescing ? 0 : 5), + frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels.data(), + pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.xsize, xsize); + EXPECT_EQ(frame_header.layer_info.ysize, ysize); + EXPECT_EQ(frame_header.layer_info.crop_x0, 0); + EXPECT_EQ(frame_header.layer_info.crop_y0, 0); + } else { + EXPECT_EQ(frame_header.layer_info.xsize, frame_xsize[i]); + EXPECT_EQ(frame_header.layer_info.ysize, frame_ysize[i]); + EXPECT_EQ(frame_header.layer_info.crop_x0, frame_x0[i]); + EXPECT_EQ(frame_header.layer_info.crop_y0, frame_y0[i]); + EXPECT_EQ(frame_header.layer_info.blend_info.blendmode, + i != 12 + 5 && frame_header.duration != 0 + ? 2 + : 0); // kBlend or the default kReplace + } + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + frame_header.layer_info.xsize, + frame_header.layer_info.ysize, + format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + // Test rewinding the decoder and skipping different frames + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames + (coalescing ? 0 : 5); ++i) { + int test_skipping = (i == 9) ? 3 : 0; + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + std::vector pixels(buffer_size); + + // Since this is after JXL_DEC_FRAME but before JXL_DEC_FULL_IMAGE, this + // should only skip the next frame, not the currently processed one. + if (test_skipping) JxlDecoderSkipFrames(dec, test_skipping); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ((coalescing ? frame_durations_c[i] : frame_durations_nc[i]), + frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames + (coalescing ? 0 : 5), + frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels.data(), + pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + frame_header.layer_info.xsize, + frame_header.layer_info.ysize, + format, format)); + + if (test_skipping) i += test_skipping; + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, OrientedCroppedFrameTest) { + const auto test = [](bool keep_orientation, uint32_t orientation, + uint32_t resampling) { + size_t xsize = 90, ysize = 120; + JxlPixelFormat format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + size_t oxsize = (!keep_orientation && orientation > 4 ? ysize : xsize); + size_t oysize = (!keep_orientation && orientation > 4 ? xsize : ysize); + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.orientation = orientation; + io.frames.clear(); + io.SetSize(xsize, ysize); + + for (size_t i = 0; i < 3; ++i) { + size_t cropxsize = 1 + xsize * 2 / (i + 1); + size_t cropysize = 1 + ysize * 3 / (i + 2); + int cropx0 = i * 3 - 8; + int cropy0 = i * 4 - 7; + + std::vector frame = + jxl::test::GetSomeTestImage(cropxsize, cropysize, 4, i * 2); + jxl::ImageBundle bundle(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frame.data(), frame.size()), cropxsize, + cropysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*channels=*/4, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*pool=*/nullptr, &bundle, + /*float_in=*/false, /*align=*/0)); + bundle.origin = {cropx0, cropy0}; + bundle.use_for_next_frame = true; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams + .SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + cparams.resampling = resampling; + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, + jxl::GetJxlCms(), &aux_out, nullptr)); + + // 0 is merged frame as decoded with coalescing enabled (default) + // 1-3 are non-coalesced frames as decoded with coalescing disabled + // 4 is the manually merged frame + std::vector frames[5]; + frames[4].resize(xsize * ysize * 8, 0); + + // try both with and without coalescing + for (auto coalescing : {JXL_TRUE, JXL_FALSE}) { + // Independently decode all frames without any skipping, to create the + // expected blended frames, for the actual tests below to compare with. + { + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec, coalescing)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetKeepOrientation(dec, keep_orientation)); + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + for (size_t i = (coalescing ? 0 : 1); i < (coalescing ? 1 : 4); ++i) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetFrameHeader(dec, &frame_header)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + if (coalescing) { + EXPECT_EQ(xsize * ysize * 8, buffer_size); + } else { + EXPECT_EQ(frame_header.layer_info.xsize * + frame_header.layer_info.ysize * 8, + buffer_size); + } + frames[i].resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, frames[i].data(), + frames[i].size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(frame_header.layer_info.blend_info.blendmode, + JXL_BLEND_REPLACE); + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.xsize, oxsize); + EXPECT_EQ(frame_header.layer_info.ysize, oysize); + EXPECT_EQ(frame_header.layer_info.crop_x0, 0); + EXPECT_EQ(frame_header.layer_info.crop_y0, 0); + } else { + // manually merge this layer + int x0 = frame_header.layer_info.crop_x0; + int y0 = frame_header.layer_info.crop_y0; + int w = frame_header.layer_info.xsize; + int h = frame_header.layer_info.ysize; + for (int y = 0; y < static_cast(oysize); y++) { + if (y < y0 || y >= y0 + h) continue; + // pointers do whole 16-bit RGBA pixels at a time + uint64_t* row_merged = static_cast( + (void*)(frames[4].data() + y * oxsize * 8)); + uint64_t* row_layer = static_cast( + (void*)(frames[i].data() + (y - y0) * w * 8)); + for (int x = 0; x < static_cast(oxsize); x++) { + if (x < x0 || x >= x0 + w) continue; + row_merged[x] = row_layer[x - x0]; + } + } + } + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } + } + + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[0].data(), frames[4].data(), + oxsize, oysize, format, format)); + }; + + for (bool keep_orientation : {true, false}) { + for (uint32_t orientation = 1; orientation <= 8; orientation++) { + for (uint32_t resampling : {1, 2, 4, 8}) { + SCOPED_TRACE(testing::Message() + << "keep_orientation: " << keep_orientation << ", " + << "orientation: " << orientation << ", " + << "resampling: " << resampling); + test(keep_orientation, orientation, resampling); + } + } + } +} + +struct FramePositions { + size_t frame_start; + size_t header_end; + size_t toc_end; + std::vector section_end; +}; + +struct StreamPositions { + size_t codestream_start; + size_t codestream_end; + size_t basic_info; + size_t jbrd_end = 0; + std::vector box_start; + std::vector frames; +}; + +void AnalyzeCodestream(const jxl::PaddedBytes& data, + StreamPositions* streampos) { + // Unbox data to codestream and mark where it is broken up by boxes. + std::vector codestream; + std::vector> breakpoints; + bool codestream_end = false; + ASSERT_LE(2, data.size()); + if (data[0] == 0xff && data[1] == 0x0a) { + codestream = std::vector(data.begin(), data.end()); + streampos->codestream_start = 0; + } else { + const uint8_t* in = data.data(); + size_t pos = 0; + while (pos < data.size()) { + ASSERT_LE(pos + 8, data.size()); + streampos->box_start.push_back(pos); + size_t box_size = LoadBE32(in + pos); + if (box_size == 0) box_size = data.size() - pos; + ASSERT_LE(pos + box_size, data.size()); + if (memcmp(in + pos + 4, "jxlc", 4) == 0) { + EXPECT_TRUE(codestream.empty()); + streampos->codestream_start = pos + 8; + codestream.insert(codestream.end(), in + pos + 8, in + pos + box_size); + codestream_end = true; + } else if (memcmp(in + pos + 4, "jxlp", 4) == 0) { + codestream_end = (LoadBE32(in + pos + 8) & 0x80000000); + if (codestream.empty()) { + streampos->codestream_start = pos + 12; + } else if (box_size > 12 || !codestream_end) { + breakpoints.push_back({codestream.size(), 12}); + } + codestream.insert(codestream.end(), in + pos + 12, in + pos + box_size); + } else if (memcmp(in + pos + 4, "jbrd", 4) == 0) { + EXPECT_TRUE(codestream.empty()); + streampos->jbrd_end = pos + box_size; + } else if (!codestream.empty() && !codestream_end) { + breakpoints.push_back({codestream.size(), box_size}); + } + pos += box_size; + } + ASSERT_EQ(pos, data.size()); + } + // Translate codestream positions to boxed stream positions. + size_t offset = streampos->codestream_start; + size_t bp = 0; + auto add_offset = [&](size_t pos) { + while (bp < breakpoints.size() && pos >= breakpoints[bp].first) { + offset += breakpoints[bp++].second; + } + return pos + offset; + }; + // Analyze the unboxed codestream. + jxl::BitReader br( + jxl::Span(codestream.data(), codestream.size())); + ASSERT_EQ(br.ReadFixedBits<16>(), 0x0AFF); + jxl::CodecMetadata metadata; + EXPECT_TRUE(ReadSizeHeader(&br, &metadata.size)); + EXPECT_TRUE(ReadImageMetadata(&br, &metadata.m)); + streampos->basic_info = + add_offset(br.TotalBitsConsumed() / jxl::kBitsPerByte); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + EXPECT_TRUE(jxl::Bundle::Read(&br, &metadata.transform_data)); + EXPECT_TRUE(br.JumpToByteBoundary()); + bool has_preview = metadata.m.have_preview; + while (br.TotalBitsConsumed() < br.TotalBytes() * jxl::kBitsPerByte) { + FramePositions p; + p.frame_start = add_offset(br.TotalBitsConsumed() / jxl::kBitsPerByte); + jxl::FrameHeader frame_header(&metadata); + if (has_preview) { + frame_header.nonserialized_is_preview = true; + has_preview = false; + } + EXPECT_TRUE(ReadFrameHeader(&br, &frame_header)); + p.header_end = + add_offset(jxl::DivCeil(br.TotalBitsConsumed(), jxl::kBitsPerByte)); + jxl::FrameDimensions frame_dim = frame_header.ToFrameDimensions(); + uint64_t groups_total_size; + const size_t toc_entries = jxl::NumTocEntries( + frame_dim.num_groups, frame_dim.num_dc_groups, + frame_header.passes.num_passes, /*has_ac_global=*/true); + std::vector section_offsets; + std::vector section_sizes; + EXPECT_TRUE(ReadGroupOffsets(toc_entries, &br, §ion_offsets, + §ion_sizes, &groups_total_size)); + EXPECT_EQ(br.TotalBitsConsumed() % jxl::kBitsPerByte, 0); + size_t sections_start = br.TotalBitsConsumed() / jxl::kBitsPerByte; + p.toc_end = add_offset(sections_start); + for (size_t i = 0; i < toc_entries; ++i) { + size_t end = sections_start + section_offsets[i] + section_sizes[i]; + p.section_end.push_back(add_offset(end)); + } + br.SkipBits(groups_total_size * jxl::kBitsPerByte); + streampos->frames.push_back(p); + } + streampos->codestream_end = add_offset(codestream.size()); + EXPECT_EQ(br.TotalBitsConsumed(), br.TotalBytes() * jxl::kBitsPerByte); + EXPECT_TRUE(br.Close()); +} + +enum ExpectedFlushState { NO_FLUSH, SAME_FLUSH, NEW_FLUSH }; +struct Breakpoint { + size_t file_pos; + ExpectedFlushState expect_flush; +}; + +void VerifyProgression(size_t xsize, size_t ysize, uint32_t num_channels, + const std::vector& pixels, + const jxl::PaddedBytes& data, + std::vector breakpoints) { + // Size large enough for multiple groups, required to have progressive stages. + ASSERT_LT(256, xsize); + ASSERT_LT(256, ysize); + std::vector pixels2; + pixels2.resize(pixels.size()); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + int bp = 0; + const uint8_t* next_in = data.data(); + size_t avail_in = breakpoints[bp].file_pos; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + double prev_dist = 1.0; + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + printf("bp: %d status: 0x%x\n", bp, (int)status); + if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + // Output buffer/callback not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FRAME) { + // Nothing to do. + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_EQ(bp + 1, breakpoints.size()); + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT || + status == JXL_DEC_FULL_IMAGE) { + if (breakpoints[bp].expect_flush == NO_FLUSH) { + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + } else { + if (status != JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + } + double dist = jxl::test::DistanceRMS(pixels2.data(), pixels.data(), + xsize, ysize, format); + if (breakpoints[bp].expect_flush == NEW_FLUSH) { + EXPECT_LT(dist, prev_dist); + prev_dist = dist; + } else { + EXPECT_EQ(dist, prev_dist); + } + } + if (status == JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(bp + 1, breakpoints.size()); + continue; + } + ASSERT_LT(++bp, breakpoints.size()); + next_in += avail_in - JxlDecoderReleaseInput(dec); + avail_in = breakpoints[bp].file_pos - (next_in - data.data()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + } else { + printf("Unexpected status: 0x%x\n", (int)status); + FAIL(); // unexpected returned status + } + } + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, ProgressionTest) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.progressive_dc = 1; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(3, fp.size()); + EXPECT_EQ(7, fp[2].section_end.size()); + EXPECT_EQ(data.size(), fp[2].section_end[6]); + std::vector breakpoints{ + {fp[0].frame_start, NO_FLUSH}, // headers + {fp[1].frame_start, NO_FLUSH}, // preview + {fp[2].frame_start, NO_FLUSH}, // dc frame + {fp[2].section_end[0], NO_FLUSH}, // DC global + {fp[2].section_end[1] - 1, NO_FLUSH}, // partial DC group + {fp[2].section_end[1], NEW_FLUSH}, // DC group + {fp[2].section_end[2], SAME_FLUSH}, // AC global + {fp[2].section_end[3], NEW_FLUSH}, // AC group 0 + {fp[2].section_end[4] - 1, SAME_FLUSH}, // partial AC group 1 + {fp[2].section_end[4], NEW_FLUSH}, // AC group 1 + {fp[2].section_end[5], NEW_FLUSH}, // AC group 2 + {data.size() - 1, SAME_FLUSH}, // partial AC group 3 + {data.size(), NEW_FLUSH}}; // full image + VerifyProgression(xsize, ysize, num_channels, pixels, data, breakpoints); +} + +TEST(DecodeTest, ProgressionTestLosslessAlpha) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 4; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.cparams.responsive = 1; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(1, fp.size()); + EXPECT_EQ(7, fp[0].section_end.size()); + EXPECT_EQ(data.size(), fp[0].section_end[6]); + std::vector breakpoints{ + {fp[0].frame_start, NO_FLUSH}, // headers + {fp[0].section_end[0] - 1, NO_FLUSH}, // partial DC global + {fp[0].section_end[0], NEW_FLUSH}, // DC global + {fp[0].section_end[1], SAME_FLUSH}, // DC group + {fp[0].section_end[2], SAME_FLUSH}, // AC global + {fp[0].section_end[3], NEW_FLUSH}, // AC group 0 + {fp[0].section_end[4] - 1, SAME_FLUSH}, // partial AC group 1 + {fp[0].section_end[4], NEW_FLUSH}, // AC group 1 + {fp[0].section_end[5], NEW_FLUSH}, // AC group 2 + {data.size() - 1, SAME_FLUSH}, // partial AC group 3 + {data.size(), NEW_FLUSH}}; // full image + VerifyProgression(xsize, ysize, num_channels, pixels, data, breakpoints); +} + +void VerifyFilePosition(size_t expected_pos, const jxl::PaddedBytes& data, + JxlDecoder* dec) { + size_t remaining = JxlDecoderReleaseInput(dec); + size_t pos = data.size() - remaining; + EXPECT_EQ(expected_pos, pos); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data() + pos, remaining)); +} + +TEST(DecodeTest, InputHandlingTestOneShot) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + printf("Testing with box format %d\n", i); + jxl::TestCodestreamParams params; + params.cparams.progressive_dc = 1; + params.add_preview = true; + params.box_format = (CodeStreamBoxFormat)i; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(3, fp.size()); + + std::vector pixels2; + pixels2.resize(pixels.size()); + + int kNumEvents = 6; + int events[] = { + JXL_DEC_BASIC_INFO, JXL_DEC_COLOR_ENCODING, JXL_DEC_PREVIEW_IMAGE, + JXL_DEC_FRAME, JXL_DEC_FULL_IMAGE, JXL_DEC_FRAME_PROGRESSION, + }; + size_t end_positions[] = { + streampos.basic_info, fp[0].frame_start, + fp[1].frame_start, fp[2].toc_end, + streampos.codestream_end, streampos.codestream_end}; + int events_wanted = 0; + for (int j = 0; j < kNumEvents; ++j) { + events_wanted |= events[j]; + size_t end_pos = end_positions[j]; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events_wanted)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.basic_info, data, dec); + if (j >= 1) { + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].frame_start, data, dec); + } + if (j >= 2) { + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + EXPECT_GE(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, pixels2.data(), + buffer_size)); + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].frame_start, data, dec); + } + if (j >= 3) { + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[2].toc_end, data, dec); + if (j >= 5) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetProgressiveDetail(dec, kDC)); + } + } + if (j >= 4) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[2].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + if (j >= 5) { + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[2].section_end[1], data, dec); + } + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.codestream_end, data, dec); + } + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + VerifyFilePosition(end_pos, data, dec); + JxlDecoderDestroy(dec); + } + } +} + +#if JPEGXL_ENABLE_JPEG +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(InputHandlingTestJPEGOneshot)) { + size_t xsize = 123; + size_t ysize = 77; + size_t channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, /*seed=*/0); + for (int i = 1; i < kCSBF_NUM_ENTRIES; ++i) { + printf("Testing with box format %d\n", i); + jxl::PaddedBytes jpeg_codestream; + jxl::TestCodestreamParams params; + params.cparams.color_transform = jxl::ColorTransform::kNone; + params.jpeg_codestream = &jpeg_codestream; + params.add_preview = true; + params.box_format = (CodeStreamBoxFormat)i; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + channels, params); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector& fp = streampos.frames; + // We have preview and regular frame. + EXPECT_EQ(2, fp.size()); + EXPECT_LT(0, streampos.jbrd_end); + + std::vector pixels2; + pixels2.resize(pixels.size()); + + int kNumEvents = 6; + int events[] = {JXL_DEC_BASIC_INFO, JXL_DEC_JPEG_RECONSTRUCTION, + JXL_DEC_COLOR_ENCODING, JXL_DEC_PREVIEW_IMAGE, + JXL_DEC_FRAME, JXL_DEC_FULL_IMAGE}; + size_t end_positions[] = {streampos.basic_info, streampos.basic_info, + fp[0].frame_start, fp[1].frame_start, + fp[1].toc_end, streampos.codestream_end}; + int events_wanted = 0; + for (int j = 0; j < kNumEvents; ++j) { + printf("j = %d\n", j); + events_wanted |= events[j]; + size_t end_pos = end_positions[j]; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events_wanted)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data(), data.size())); + if (j >= 1) { + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.jbrd_end, data, dec); + } + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.basic_info, data, dec); + if (j >= 2) { + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].frame_start, data, dec); + } + if (j >= 3) { + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + EXPECT_GE(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, pixels2.data(), + buffer_size)); + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].frame_start, data, dec); + } + if (j >= 4) { + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].toc_end, data, dec); + } + if (j >= 5) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.codestream_end, data, dec); + } + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + VerifyFilePosition(end_pos, data, dec); + JxlDecoderDestroy(dec); + } + } +} +#endif // JPEGXL_ENABLE_JPEG + +TEST(DecodeTest, InputHandlingTestStreaming) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + printf("Testing with box format %d\n", i); + fflush(stdout); + jxl::TestCodestreamParams params; + params.cparams.progressive_dc = 1; + params.box_format = (CodeStreamBoxFormat)i; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(3, fp.size()); + std::vector pixels2; + pixels2.resize(pixels.size()); + int events_wanted = + (JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_PREVIEW_IMAGE | + JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE | JXL_DEC_FRAME_PROGRESSION | + JXL_DEC_BOX); + for (size_t increment : {1, 7, 27, 1024}) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events_wanted)); + size_t file_pos = 0; + size_t box_index = 0; + size_t avail_in = 0; + for (;;) { + const uint8_t* next_in = data.data() + file_pos; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + size_t consumed = avail_in - remaining; + file_pos += consumed; + avail_in += increment; + avail_in = std::min(avail_in, data.size() - file_pos); + if (status == JXL_DEC_BASIC_INFO) { + EXPECT_EQ(file_pos, streampos.basic_info); + } else if (status == JXL_DEC_COLOR_ENCODING) { + EXPECT_EQ(file_pos, streampos.frames[0].frame_start); + } else if (status == JXL_DEC_NEED_PREVIEW_OUT_BUFFER) { + EXPECT_EQ(file_pos, streampos.frames[0].toc_end); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + EXPECT_GE(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, pixels2.data(), + buffer_size)); + } else if (status == JXL_DEC_PREVIEW_IMAGE) { + EXPECT_EQ(file_pos, streampos.frames[1].frame_start); + } else if (status == JXL_DEC_FRAME) { + EXPECT_EQ(file_pos, streampos.frames[2].toc_end); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetProgressiveDetail(dec, kDC)); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(file_pos, streampos.frames[2].toc_end); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FRAME_PROGRESSION) { + EXPECT_EQ(file_pos, streampos.frames[2].section_end[1]); + } else if (status == JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(file_pos, streampos.codestream_end); + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_EQ(file_pos, streampos.codestream_end); + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + EXPECT_LT(remaining, 12); + if ((i == kCSBF_None && file_pos >= 2) || + (box_index > 0 && box_index < streampos.box_start.size() && + file_pos >= streampos.box_start[box_index - 1] + 12 && + file_pos < streampos.box_start[box_index])) { + EXPECT_EQ(remaining, 0); + } + if (file_pos == data.size()) break; + } else if (status == JXL_DEC_BOX) { + ASSERT_LT(box_index, streampos.box_start.size()); + EXPECT_EQ(file_pos, streampos.box_start[box_index++]); + } else { + printf("Unexpected status: 0x%x\n", (int)status); + FAIL(); + } + } + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, FlushTest) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() - 1; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + // Crude test of actual pixel data: pixel threshold of about 4% (2560/65535). + // 29000 pixels can be above the threshold + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 29000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + // Lower threshold for the final (still lossy) image + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 11000u); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, FlushTestImageOutCallback) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + size_t bytes_per_pixel = format.num_channels * 2; + size_t stride = bytes_per_pixel * xsize; + auto callback = [&](size_t x, size_t y, size_t num_pixels, + const void* pixels_row) { + memcpy(pixels2.data() + stride * y + bytes_per_pixel * x, pixels_row, + num_pixels * bytes_per_pixel); + }; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() - 1; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output callback not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutCallback( + dec, &format, + [](void* opaque, size_t x, size_t y, + size_t xsize, const void* pixels_row) { + auto cb = + static_cast(opaque); + (*cb)(x, y, xsize, pixels_row); + }, + /*opaque=*/&callback)); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + // Crude test of actual pixel data: pixel threshold of about 4% (2560/65535). + // 29000 pixels can be above the threshold + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 29000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + // Lower threshold for the final (still lossy) image + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 11000u); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, FlushTestLossyProgressiveAlpha) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 4; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() - 1; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 30000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 11000u); + + JxlDecoderDestroy(dec); +} +TEST(DecodeTest, FlushTestLossyProgressiveAlphaUpsampling) { + size_t xsize = 533, ysize = 401; + uint32_t num_channels = 4; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.resampling = 2; + params.cparams.ec_resampling = 4; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() * 2 / 3; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 125000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 70000u); + + JxlDecoderDestroy(dec); +} +TEST(DecodeTest, FlushTestLosslessProgressiveAlpha) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 4; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.cparams.responsive = 1; + params.add_preview = true; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() / 2; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 2700u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format), + 0u); + + JxlDecoderDestroy(dec); +} + +class DecodeProgressiveTest : public ::testing::TestWithParam {}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(DecodeProgressiveTestInstantiation, + DecodeProgressiveTest, + ::testing::Range(0, 8)); +TEST_P(DecodeProgressiveTest, ProgressiveEventTest) { + const int params = GetParam(); + int single_group = params & 1; + int lossless = (params >> 1) & 1; + uint32_t num_channels = 3 + ((params >> 2) & 1); + std::set progressive_details = {kDC, kLastPasses, + kPasses}; + for (auto prog_detail : progressive_details) { + // Only few combinations are expected to support outputting + // intermediate flushes for complete DC and complete passes. + // The test can be updated if more cases are expected to support it. + bool expect_flush = (num_channels & 1) && !lossless; + size_t xsize, ysize; + if (single_group) { + // An image smaller than 256x256 ensures it contains only 1 group. + xsize = 99; + ysize = 100; + } else { + xsize = 277; + ysize = 280; + } + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::ColorEncoding color_encoding = jxl::ColorEncoding::SRGB(false); + jxl::CodecInOut io; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + color_encoding, num_channels, + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*pool=*/nullptr, &io.Main(), /*float_in=*/false, /*align=*/0)); + jxl::TestCodestreamParams params; + if (lossless) { + params.cparams.SetLossless(); + } else { + params.cparams.butteraugli_distance = 0.5f; + } + jxl::PassDefinition passes[] = {{2, 0, false, 4}, + {4, 0, false, 4}, + {8, 2, false, 2}, + {8, 1, false, 2}, + {8, 0, false, 1}}; + const int kNumPasses = 5; + jxl::ProgressiveMode progressive_mode{passes}; + params.progressive_mode = &progressive_mode; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + for (size_t increment : {(size_t)1, data.size()}) { + printf( + "Testing with single_group=%d, lossless=%d, " + "num_channels=%d, prog_detail=%d, increment=%d\n", + single_group, lossless, (int)num_channels, (int)prog_detail, + (int)increment); + std::vector> passes(kNumPasses + 1); + for (int i = 0; i <= kNumPasses; ++i) { + passes[i].resize(pixels.size()); + } + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | + JXL_DEC_FULL_IMAGE | JXL_DEC_FRAME_PROGRESSION)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSetProgressiveDetail(dec, kFrames)); + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderSetProgressiveDetail(dec, kDCProgressive)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSetProgressiveDetail(dec, kDCGroups)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSetProgressiveDetail(dec, kGroups)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetProgressiveDetail(dec, prog_detail)); + + uint8_t* next_in = data.data(); + size_t avail_in = 0; + size_t pos = 0; + + auto process_input = [&]() { + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT && pos < data.size()) { + size_t chunk = std::min(increment, data.size() - pos); + pos += chunk; + avail_in += chunk; + continue; + } + return status; + } + }; + + EXPECT_EQ(JXL_DEC_BASIC_INFO, process_input()); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, process_input()); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, passes[kNumPasses].data(), + passes[kNumPasses].size())); + + auto next_pass = [&](int pass) { + if (prog_detail <= kDC) return kNumPasses; + if (prog_detail <= kLastPasses) { + return std::min(pass + 2, kNumPasses); + } + return pass + 1; + }; + + if (expect_flush) { + // Return a particular downsampling ratio only after the last + // pass for that downsampling was processed. + int expected_downsampling_ratios[] = {8, 8, 4, 4, 2}; + for (int p = 0; p < kNumPasses; p = next_pass(p)) { + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, process_input()); + EXPECT_EQ(expected_downsampling_ratios[p], + JxlDecoderGetIntendedDownsamplingRatio(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + passes[p] = passes[kNumPasses]; + } + } + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, process_input()); + EXPECT_EQ(JXL_DEC_SUCCESS, process_input()); + + JxlDecoderDestroy(dec); + + if (!expect_flush) { + continue; + } + jxl::ButteraugliParams ba; + std::vector distances(kNumPasses + 1); + for (int p = 0;; p = next_pass(p)) { + jxl::CodecInOut io1; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Span(passes[p].data(), passes[p].size()), xsize, + ysize, color_encoding, num_channels, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, + /*pool=*/nullptr, &io1.Main(), /*float_in=*/false, + /*align=*/0)); + distances[p] = ButteraugliDistance(io, io1, ba, jxl::GetJxlCms(), + nullptr, nullptr); + if (p == kNumPasses) break; + } + const float kMaxDistance[kNumPasses + 1] = {30.0f, 20.0f, 10.0f, + 5.0f, 3.0f, 2.0f}; + EXPECT_LT(distances[kNumPasses], kMaxDistance[kNumPasses]); + for (int p = 0; p < kNumPasses;) { + int next_p = next_pass(p); + EXPECT_LT(distances[p], kMaxDistance[p]); + // Verify that the returned pass image is actually not the + // same as the next pass image, by checking that it has a bit + // worse butteraugli score. + EXPECT_LT(distances[next_p] * 1.2f, distances[p]); + p = next_p; + } + } + } +} + +void VerifyJPEGReconstruction(const jxl::PaddedBytes& container, + const jxl::PaddedBytes& jpeg_bytes) { + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_JPEG_RECONSTRUCTION | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), container.data(), container.size()); + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec.get())); + std::vector reconstructed_buffer(128); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data(), + reconstructed_buffer.size())); + size_t used = 0; + JxlDecoderStatus process_result = JXL_DEC_JPEG_NEED_MORE_OUTPUT; + while (process_result == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + reconstructed_buffer.resize(reconstructed_buffer.size() * 2); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data() + used, + reconstructed_buffer.size() - used)); + process_result = JxlDecoderProcessInput(dec.get()); + } + ASSERT_EQ(JXL_DEC_FULL_IMAGE, process_result); + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + ASSERT_EQ(used, jpeg_bytes.size()); + EXPECT_EQ(0, memcmp(reconstructed_buffer.data(), jpeg_bytes.data(), used)); +} + +#if JPEGXL_ENABLE_JPEG +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructTestCodestream)) { + size_t xsize = 123; + size_t ysize = 77; + size_t channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, /*seed=*/0); + jxl::PaddedBytes jpeg_codestream; + jxl::TestCodestreamParams params; + params.cparams.color_transform = jxl::ColorTransform::kNone; + params.box_format = kCSBF_Single; + params.jpeg_codestream = &jpeg_codestream; + params.add_preview = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + channels, params); + VerifyJPEGReconstruction(compressed, jpeg_codestream); +} +#endif // JPEGXL_ENABLE_JPEG + +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructionTest)) { + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + jxl::jpeg::DecodeImageJPG(jxl::Span(orig), &orig_io)); + orig_io.metadata.m.xyb_encoded = false; + jxl::BitWriter writer; + ASSERT_TRUE(WriteHeaders(&orig_io.metadata, &writer, nullptr)); + writer.ZeroPadToByte(); + jxl::PassesEncoderState enc_state; + jxl::CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kNone; + ASSERT_TRUE(jxl::EncodeFrame(cparams, jxl::FrameInfo{}, &orig_io.metadata, + orig_io.Main(), &enc_state, jxl::GetJxlCms(), + /*pool=*/nullptr, &writer, + /*aux_out=*/nullptr)); + + jxl::PaddedBytes jpeg_data; + ASSERT_TRUE( + EncodeJPEGData(*orig_io.Main().jpeg_data.get(), &jpeg_data, cparams)); + jxl::PaddedBytes container; + container.append(jxl::kContainerHeader, + jxl::kContainerHeader + sizeof(jxl::kContainerHeader)); + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &container); + container.append(jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlc"), 0, true, &container); + jxl::PaddedBytes codestream = std::move(writer).TakeBytes(); + container.append(codestream.data(), codestream.data() + codestream.size()); + VerifyJPEGReconstruction(container, orig); +} + +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructionMetadataTest)) { + const std::string jpeg_path = "jxl/jpeg_reconstruction/1x1_exif_xmp.jpg"; + const std::string jxl_path = "jxl/jpeg_reconstruction/1x1_exif_xmp.jxl"; + const jxl::PaddedBytes jpeg = jxl::ReadTestData(jpeg_path); + const jxl::PaddedBytes jxl = jxl::ReadTestData(jxl_path); + VerifyJPEGReconstruction(jxl, jpeg); +} + +TEST(DecodeTest, ContinueFinalNonEssentialBoxTest) { + size_t xsize = 80, ysize = 90; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + jxl::TestCodestreamParams params; + params.box_format = kCSBF_Multi_Other_Terminated; + params.add_icc_profile = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + StreamPositions streampos; + AnalyzeCodestream(compressed, &streampos); + + // The non-essential final box size including 8-byte header + size_t final_box_size = unk3_box_size + 8; + size_t last_box_begin = compressed.size() - final_box_size; + // Verify that the test is indeed setup correctly to be at the beginning of + // the 'unkn' box header. + ASSERT_EQ(compressed[last_box_begin + 3], final_box_size); + ASSERT_EQ(compressed[last_box_begin + 4], 'u'); + ASSERT_EQ(compressed[last_box_begin + 5], 'n'); + ASSERT_EQ(compressed[last_box_begin + 6], 'k'); + ASSERT_EQ(compressed[last_box_begin + 7], '3'); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), last_box_begin)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + // The decoder returns success despite not having seen the final unknown box + // yet. This is because calling JxlDecoderCloseInput is not mandatory for + // backwards compatibility, so it doesn't know more bytes follow, the current + // bytes ended at a perfectly valid place. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + size_t remaining = JxlDecoderReleaseInput(dec); + // Since the test was set up to end exactly at the boundary of the final + // codestream box, and the decoder returned success, all bytes are expected to + // be consumed until the end of the frame header. + EXPECT_EQ(remaining, last_box_begin - streampos.frames[0].toc_end); + + // Now set the remaining non-codestream box as input. + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data() + last_box_begin, + compressed.size() - last_box_begin)); + // Even though JxlDecoderProcessInput already returned JXL_DEC_SUCCESS before, + // when calling it again now after setting more input, success is expected, no + // event occurs but the box has been successfully skipped. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +namespace { +bool BoxTypeEquals(const std::string& type_string, JxlBoxType type) { + return type_string.size() == 4 && type_string[0] == type[0] && + type_string[1] == type[1] && type_string[2] == type[2] && + type_string[3] == type[3]; +} +} // namespace + +TEST(DecodeTest, ExtentedBoxSizeTest) { + const std::string jxl_path = "jxl/boxes/square-extended-size-container.jxl"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jxl_path); + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + + JxlBoxType type; + uint64_t box_size; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, orig.data(), orig.size())); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_TRUE(BoxTypeEquals("JXL ", type)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_EQ(12, box_size); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_TRUE(BoxTypeEquals("ftyp", type)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_EQ(20, box_size); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_TRUE(BoxTypeEquals("jxlc", type)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_EQ(72, box_size); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, BoxTest) { + size_t xsize = 1, ysize = 1; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + jxl::TestCodestreamParams params; + params.box_format = kCSBF_Multi_Other_Terminated; + params.add_icc_profile = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + + std::vector expected_box_types = { + "JXL ", "ftyp", "jxlp", "unk1", "unk2", "jxlp", "jxlp", "jxlp", "unk3"}; + + // Value 0 means to not test the size: codestream is not required to be a + // particular exact size. + std::vector expected_box_sizes = {12, 20, 0, 34, 18, 0, 0, 0, 20}; + + JxlBoxType type; + uint64_t box_size; + std::vector contents(50); + size_t expected_release_size = 0; + + // Cannot get these when decoding didn't start yet + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + + uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + for (size_t i = 0; i < expected_box_types.size(); i++) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_TRUE(BoxTypeEquals(expected_box_types[i], type)); + if (expected_box_sizes[i]) { + EXPECT_EQ(expected_box_sizes[i], box_size); + } + + if (expected_release_size > 0) { + EXPECT_EQ(expected_release_size, JxlDecoderReleaseBoxBuffer(dec)); + expected_release_size = 0; + } + + if (type[0] == 'u' && type[1] == 'n' && type[2] == 'k') { + JxlDecoderSetBoxBuffer(dec, contents.data(), contents.size()); + size_t expected_box_contents_size = + type[3] == '1' ? unk1_box_size + : (type[3] == '2' ? unk2_box_size : unk3_box_size); + expected_release_size = contents.size() - expected_box_contents_size; + } + size_t consumed = avail_in - JxlDecoderReleaseInput(dec); + next_in += consumed; + avail_in -= consumed; + } + + // After the last DEC_BOX event, check that the input position is exactly at + // the stat of the box header. + EXPECT_EQ(avail_in, expected_box_sizes.back()); + + // Even though all input is given, the decoder cannot assume there aren't + // more boxes if the input was not closed. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + JxlDecoderCloseInput(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, ExifBrobBoxTest) { + size_t xsize = 1, ysize = 1; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.box_format = kCSBF_Brob_Exif; + params.add_icc_profile = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + + // Test raw brob box, not brotli-decompressing + for (int streaming = 0; streaming < 2; ++streaming) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + if (!streaming) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + JxlDecoderCloseInput(dec); + } + // for streaming input case + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t total_in = 0; + size_t step_size = 64; + + std::vector box_buffer; + size_t box_num_output; + bool seen_brob_begin = false; + bool seen_brob_end = false; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (streaming) { + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, next_in, avail_in)); + if (total_in == compressed.size()) JxlDecoderCloseInput(dec); + } else { + FAIL(); + break; + } + } else if (status == JXL_DEC_BOX || status == JXL_DEC_SUCCESS) { + if (!box_buffer.empty()) { + EXPECT_EQ(false, seen_brob_end); + seen_brob_end = true; + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + EXPECT_EQ(box_num_output, box_brob_exif_size - 8); + EXPECT_EQ( + 0, memcmp(box_buffer.data(), box_brob_exif + 8, box_num_output)); + box_buffer.clear(); + } + if (status == JXL_DEC_SUCCESS) break; + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + if (BoxTypeEquals("brob", type)) { + EXPECT_EQ(false, seen_brob_begin); + seen_brob_begin = true; + box_buffer.resize(8); + JxlDecoderSetBoxBuffer(dec, box_buffer.data(), box_buffer.size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_EQ(true, seen_brob_begin); + EXPECT_EQ(true, seen_brob_end); + + JxlDecoderDestroy(dec); + } + + // Test decompressed brob box + for (int streaming = 0; streaming < 2; ++streaming) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + if (!streaming) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + JxlDecoderCloseInput(dec); + } + // for streaming input case + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t total_in = 0; + size_t step_size = 64; + + std::vector box_buffer; + size_t box_num_output; + bool seen_exif_begin = false; + bool seen_exif_end = false; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetDecompressBoxes(dec, JXL_TRUE)); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (streaming) { + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, next_in, avail_in)); + if (total_in == compressed.size()) JxlDecoderCloseInput(dec); + } else { + FAIL(); + break; + } + } else if (status == JXL_DEC_BOX || status == JXL_DEC_SUCCESS) { + if (!box_buffer.empty()) { + EXPECT_EQ(false, seen_exif_end); + seen_exif_end = true; + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + // Expect that the output has the same size and contents as the + // uncompressed exif data. Only check contents if the sizes match to + // avoid comparing uninitialized memory in the test. + EXPECT_EQ(box_num_output, exif_uncompressed_size); + if (box_num_output == exif_uncompressed_size) { + EXPECT_EQ(0, memcmp(box_buffer.data(), exif_uncompressed, + exif_uncompressed_size)); + } + box_buffer.clear(); + } + if (status == JXL_DEC_SUCCESS) break; + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_TRUE)); + if (BoxTypeEquals("Exif", type)) { + EXPECT_EQ(false, seen_exif_begin); + seen_exif_begin = true; + box_buffer.resize(8); + JxlDecoderSetBoxBuffer(dec, box_buffer.data(), box_buffer.size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_EQ(true, seen_exif_begin); + EXPECT_EQ(true, seen_exif_end); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, PartialCodestreamBoxTest) { + size_t xsize = 23, ysize = 81; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + // Lossless to verify pixels exactly after roundtrip. + jxl::TestCodestreamParams params; + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.box_format = kCSBF_Multi; + params.add_icc_profile = true; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + params); + + std::vector extracted_codestream; + + { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | JXL_DEC_BOX)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + JxlDecoderCloseInput(dec); + + size_t num_jxlp = 0; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + std::vector box_buffer; + size_t box_num_output; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + FAIL(); + break; + } else if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format_orig, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + continue; + } else if (status == JXL_DEC_BOX || status == JXL_DEC_SUCCESS) { + if (!box_buffer.empty()) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + EXPECT_GE(box_num_output, 4); + // Do not insert the first 4 bytes, which are not part of the + // codestream, but the partial codestream box index + extracted_codestream.insert(extracted_codestream.end(), + box_buffer.begin() + 4, + box_buffer.begin() + box_num_output); + box_buffer.clear(); + } + if (status == JXL_DEC_SUCCESS) break; + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + if (BoxTypeEquals("jxlp", type)) { + num_jxlp++; + box_buffer.resize(8); + JxlDecoderSetBoxBuffer(dec, box_buffer.data(), box_buffer.size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + // The test file created with kCSBF_Multi is expected to have 4 jxlp boxes. + EXPECT_EQ(4, num_jxlp); + + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format_orig)); + + JxlDecoderDestroy(dec); + } + + // Now test whether the codestream extracted from the jxlp boxes can itself + // also be decoded and gives the same pixels + { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | JXL_DEC_BOX)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, extracted_codestream.data(), + extracted_codestream.size())); + JxlDecoderCloseInput(dec); + + size_t num_boxes = 0; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + std::vector box_buffer; + size_t box_num_output; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + FAIL(); + break; + } else if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format_orig, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + continue; + } else if (status == JXL_DEC_BOX) { + num_boxes++; + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_EQ(0, num_boxes); // The data does not use the container format. + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format_orig)); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, SpotColorTest) { + jxl::ThreadPool* pool = nullptr; + jxl::CodecInOut io; + size_t xsize = 55, ysize = 257; + io.metadata.m.color_encoding = jxl::ColorEncoding::LinearSRGB(); + jxl::Image3F main(xsize, ysize); + jxl::ImageF spot(xsize, ysize); + jxl::ZeroFillImage(&main); + jxl::ZeroFillImage(&spot); + + for (size_t y = 0; y < ysize; y++) { + float* JXL_RESTRICT rowm = main.PlaneRow(1, y); + float* JXL_RESTRICT rows = spot.Row(y); + for (size_t x = 0; x < xsize; x++) { + rowm[x] = (x + y) * (1.f / 255.f); + rows[x] = ((x ^ y) & 255) * (1.f / 255.f); + } + } + io.SetFromImage(std::move(main), jxl::ColorEncoding::LinearSRGB()); + jxl::ExtraChannelInfo info; + info.bit_depth.bits_per_sample = 8; + info.dim_shift = 0; + info.type = jxl::ExtraChannel::kSpotColor; + info.spot_color[0] = 0.5f; + info.spot_color[1] = 0.2f; + info.spot_color[2] = 1.f; + info.spot_color[3] = 0.5f; + + io.metadata.m.extra_channel_info.push_back(info); + std::vector ec; + ec.push_back(std::move(spot)); + io.frames[0].SetExtraChannels(std::move(ec)); + + jxl::CompressParams cparams; + cparams.speed_tier = jxl::SpeedTier::kLightning; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + + jxl::PaddedBytes compressed; + std::unique_ptr enc_state = + jxl::make_unique(); + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, enc_state.get(), &compressed, + jxl::GetJxlCms(), nullptr, pool)); + + for (size_t render_spot = 0; render_spot < 2; render_spot++) { + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + if (!render_spot) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetRenderSpotcolors(dec, JXL_FALSE)); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo binfo; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &binfo)); + EXPECT_EQ(1u, binfo.num_extra_channels); + EXPECT_EQ(xsize, binfo.xsize); + EXPECT_EQ(ysize, binfo.ysize); + + JxlExtraChannelInfo extra_info; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelInfo(dec, 0, &extra_info)); + EXPECT_EQ((unsigned int)jxl::ExtraChannel::kSpotColor, extra_info.type); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + size_t extra_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderExtraChannelBufferSize(dec, &format, &extra_size, 0)); + + std::vector image(buffer_size); + std::vector extra(extra_size); + size_t bytes_per_pixel = format.num_channels * + jxl::test::GetDataBits(format.data_type) / + jxl::kBitsPerByte; + size_t stride = bytes_per_pixel * binfo.xsize; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, image.data(), image.size())); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetExtraChannelBuffer(dec, &format, extra.data(), + extra.size(), 0)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + // After the full image was output, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + + for (size_t y = 0; y < ysize; y++) { + uint8_t* JXL_RESTRICT rowm = image.data() + stride * y; + uint8_t* JXL_RESTRICT rows = extra.data() + xsize * y; + for (size_t x = 0; x < xsize; x++) { + if (!render_spot) { + // if spot color isn't rendered, main image should be as we made it + // (red and blue are all zeroes) + + EXPECT_EQ(rowm[x * 3 + 0], 0); + EXPECT_EQ(rowm[x * 3 + 1], (x + y > 255 ? 255 : x + y)); + EXPECT_EQ(rowm[x * 3 + 2], 0); + } + if (render_spot) { + // if spot color is rendered, expect red and blue to look like the + // spot color channel + EXPECT_LT(abs(rowm[x * 3 + 0] - (rows[x] * 0.25f)), 1); + EXPECT_LT(abs(rowm[x * 3 + 2] - (rows[x] * 0.5f)), 1); + } + EXPECT_EQ(rows[x], ((x ^ y) & 255)); + } + } + } +} + +TEST(DecodeTest, CloseInput) { + std::vector partial_file = {0xff}; + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), + JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec.get(), partial_file.data(), + partial_file.size())); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec.get())); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec.get())); + JxlDecoderCloseInput(dec.get()); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderProcessInput(dec.get())); +} diff --git a/media/libjxl/src/lib/jxl/decode_to_jpeg.cc b/media/libjxl/src/lib/jxl/decode_to_jpeg.cc new file mode 100644 index 000000000..aa57b2723 --- /dev/null +++ b/media/libjxl/src/lib/jxl/decode_to_jpeg.cc @@ -0,0 +1,169 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/decode_to_jpeg.h" + +namespace jxl { + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +JxlDecoderStatus JxlToJpegDecoder::Process(const uint8_t** next_in, + size_t* avail_in) { + if (!inside_box_) { + JXL_ABORT( + "processing of JPEG reconstruction data outside JPEG reconstruction " + "box"); + } + Span to_decode; + if (box_until_eof_) { + // Until EOF means consume all data. + to_decode = Span(*next_in, *avail_in); + *next_in += *avail_in; + *avail_in = 0; + } else { + // Defined size means consume min(available, needed). + size_t avail_recon_in = + std::min(*avail_in, box_size_ - buffer_.size()); + to_decode = Span(*next_in, avail_recon_in); + *next_in += avail_recon_in; + *avail_in -= avail_recon_in; + } + bool old_data_exists = !buffer_.empty(); + if (old_data_exists) { + // Append incoming data to buffer if we already had data in the buffer. + buffer_.insert(buffer_.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + to_decode = Span(buffer_.data(), buffer_.size()); + } + if (!box_until_eof_ && to_decode.size() > box_size_) { + JXL_ABORT("JPEG reconstruction data to decode larger than expected"); + } + if (box_until_eof_ || to_decode.size() == box_size_) { + // If undefined size, or the right size, try to decode. + jpeg_data_ = make_unique(); + const auto status = jpeg::DecodeJPEGData(to_decode, jpeg_data_.get()); + if (status.IsFatalError()) return JXL_DEC_ERROR; + if (status) { + // Successful decoding, emit event after updating state to track that we + // are no longer parsing JPEG reconstruction data. + inside_box_ = false; + return JXL_DEC_JPEG_RECONSTRUCTION; + } + if (box_until_eof_) { + // Unsuccessful decoding and undefined size, assume incomplete data. Copy + // the data if we haven't already. + if (!old_data_exists) { + buffer_.insert(buffer_.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + } + } else { + // Unsuccessful decoding of correct amount of data, assume error. + return JXL_DEC_ERROR; + } + } else { + // Not enough data, copy the data if we haven't already. + if (!old_data_exists) { + buffer_.insert(buffer_.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + } + } + return JXL_DEC_NEED_MORE_INPUT; +} + +size_t JxlToJpegDecoder::NumExifMarkers(const jpeg::JPEGData& jpeg_data) { + size_t num = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kExif) { + num++; + } + } + return num; +} + +size_t JxlToJpegDecoder::NumXmpMarkers(const jpeg::JPEGData& jpeg_data) { + size_t num = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kXMP) { + num++; + } + } + return num; +} + +JxlDecoderStatus JxlToJpegDecoder::ExifBoxContentSize( + const jpeg::JPEGData& jpeg_data, size_t* size) { + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kExif) { + if (jpeg_data.app_data[i].size() < 3 + sizeof(jpeg::kExifTag)) { + // too small for app marker header + return JXL_DEC_ERROR; + } + // The first 4 bytes are the TIFF header from the box contents, and are + // not included in the JPEG + *size = jpeg_data.app_data[i].size() + 4 - 3 - sizeof(jpeg::kExifTag); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} + +JxlDecoderStatus JxlToJpegDecoder::XmlBoxContentSize( + const jpeg::JPEGData& jpeg_data, size_t* size) { + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kXMP) { + if (jpeg_data.app_data[i].size() < 3 + sizeof(jpeg::kXMPTag)) { + // too small for app marker header + return JXL_DEC_ERROR; + } + *size = jpeg_data.app_data[i].size() - 3 - sizeof(jpeg::kXMPTag); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} + +JxlDecoderStatus JxlToJpegDecoder::SetExif(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data) { + for (size_t i = 0; i < jpeg_data->app_data.size(); ++i) { + if (jpeg_data->app_marker_type[i] == jxl::jpeg::AppMarkerType::kExif) { + if (jpeg_data->app_data[i].size() != + size + 3 + sizeof(jpeg::kExifTag) - 4) + return JXL_DEC_ERROR; + // The first 9 bytes are used for JPEG marker header. + jpeg_data->app_data[i][0] = 0xE1; + // The second and third byte are already filled in correctly + memcpy(jpeg_data->app_data[i].data() + 3, jpeg::kExifTag, + sizeof(jpeg::kExifTag)); + // The first 4 bytes are the TIFF header from the box contents, and are + // not included in the JPEG + memcpy(jpeg_data->app_data[i].data() + 3 + sizeof(jpeg::kExifTag), + data + 4, size - 4); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} +JxlDecoderStatus JxlToJpegDecoder::SetXmp(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data) { + for (size_t i = 0; i < jpeg_data->app_data.size(); ++i) { + if (jpeg_data->app_marker_type[i] == jxl::jpeg::AppMarkerType::kXMP) { + if (jpeg_data->app_data[i].size() != size + 3 + sizeof(jpeg::kXMPTag)) + return JXL_DEC_ERROR; + // The first 9 bytes are used for JPEG marker header. + jpeg_data->app_data[i][0] = 0xE1; + // The second and third byte are already filled in correctly + memcpy(jpeg_data->app_data[i].data() + 3, jpeg::kXMPTag, + sizeof(jpeg::kXMPTag)); + memcpy(jpeg_data->app_data[i].data() + 3 + sizeof(jpeg::kXMPTag), data, + size); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/decode_to_jpeg.h b/media/libjxl/src/lib/jxl/decode_to_jpeg.h new file mode 100644 index 000000000..68fd06e66 --- /dev/null +++ b/media/libjxl/src/lib/jxl/decode_to_jpeg.h @@ -0,0 +1,217 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DECODE_TO_JPEG_H_ +#define LIB_JXL_DECODE_TO_JPEG_H_ + +// JPEG XL to JPEG bytes decoder logic. The JxlToJpegDecoder class keeps track +// of the decoder state needed to parse the JPEG reconstruction box and provide +// the reconstructed JPEG to the output buffer. + +#include +#include + +#include +#include + +#include "jxl/decode.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" // JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#if JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +namespace jxl { + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +class JxlToJpegDecoder { + public: + // Returns whether an output buffer is set. + bool IsOutputSet() const { return next_out_ != nullptr; } + + // Returns whether the decoder is parsing a boxa JPEG box was parsed. + bool IsParsingBox() const { return inside_box_; } + + // Sets the output buffer used when producing JPEG output. + JxlDecoderStatus SetOutputBuffer(uint8_t* data, size_t size) { + if (next_out_) return JXL_DEC_ERROR; + next_out_ = data; + avail_size_ = size; + return JXL_DEC_SUCCESS; + } + + // Releases the buffer set with SetOutputBuffer(). + size_t ReleaseOutputBuffer() { + size_t result = avail_size_; + next_out_ = nullptr; + avail_size_ = 0; + return result; + } + + void StartBox(bool box_until_eof, size_t contents_size) { + // A new box implies that we clear the buffer. + buffer_.clear(); + inside_box_ = true; + if (box_until_eof) { + box_until_eof_ = true; + } else { + box_size_ = contents_size; + } + } + + // Consumes data from next_in/avail_in to reconstruct JPEG data. + // Uses box_size_, inside_box_ and box_until_eof_ to calculate how much to + // consume. Potentially stores unparsed data in buffer_. + // Potentially populates jpeg_data_. Potentially updates inside_box_. + // Returns JXL_DEC_JPEG_RECONSTRUCTION when finished, JXL_DEC_NEED_MORE_INPUT + // if more input is needed, JXL_DEC_ERROR on parsing error. + JxlDecoderStatus Process(const uint8_t** next_in, size_t* avail_in); + + // Returns non-owned copy of the JPEGData, only after Process finished and + // the JPEGData was not yet moved to an image bundle with + // SetImageBundleJpegData. + jpeg::JPEGData* GetJpegData() { return jpeg_data_.get(); } + + // Returns how many exif or xmp app markers are present in the JPEG data. A + // return value higher than 1 would require multiple exif boxes or multiple + // xmp boxes in the container format, and this is not supported by the API and + // considered an error. May only be called after Process returned success. + static size_t NumExifMarkers(const jpeg::JPEGData& jpeg_data); + static size_t NumXmpMarkers(const jpeg::JPEGData& jpeg_data); + + // Returns box content size for metadata, using the known data from the app + // markers. + static JxlDecoderStatus ExifBoxContentSize(const jpeg::JPEGData& jpeg_data, + size_t* size); + static JxlDecoderStatus XmlBoxContentSize(const jpeg::JPEGData& jpeg_data, + size_t* size); + + // Returns JXL_DEC_ERROR if there is no exif/XMP marker or the data size + // does not match, or this function is called before Process returned + // success, JXL_DEC_SUCCESS otherwise. As input, provide the full box contents + // but not the box header. In case of exif, this includes the 4-byte TIFF + // header, even though it won't be copied into the JPEG. + static JxlDecoderStatus SetExif(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data); + static JxlDecoderStatus SetXmp(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data); + + // Sets the JpegData of the ImageBundle passed if there is anything to set. + // Releases the JpegData from this decoder if set. + Status SetImageBundleJpegData(ImageBundle* ib) { + if (IsOutputSet() && jpeg_data_ != nullptr) { + if (!jpeg::SetJPEGDataFromICC(ib->metadata()->color_encoding.ICC(), + jpeg_data_.get())) { + return false; + } + ib->jpeg_data.reset(jpeg_data_.release()); + } + return true; + } + + JxlDecoderStatus WriteOutput(const jpeg::JPEGData& jpeg_data) { + // Copy JPEG bytestream if desired. + uint8_t* tmp_next_out = next_out_; + size_t tmp_avail_size = avail_size_; + auto write = [&tmp_next_out, &tmp_avail_size](const uint8_t* buf, + size_t len) { + size_t to_write = std::min(tmp_avail_size, len); + if (to_write != 0) memcpy(tmp_next_out, buf, to_write); + tmp_next_out += to_write; + tmp_avail_size -= to_write; + return to_write; + }; + Status write_result = jpeg::WriteJpeg(jpeg_data, write); + if (!write_result) { + if (tmp_avail_size == 0) { + return JXL_DEC_JPEG_NEED_MORE_OUTPUT; + } + return JXL_DEC_ERROR; + } + next_out_ = tmp_next_out; + avail_size_ = tmp_avail_size; + return JXL_DEC_SUCCESS; + } + + private: + // Content of the most recently parsed JPEG reconstruction box if any. + std::vector buffer_; + + // Decoded content of the most recently parsed JPEG reconstruction box is + // stored here. + std::unique_ptr jpeg_data_; + + // True if the decoder is currently reading bytes inside a JPEG reconstruction + // box. + bool inside_box_ = false; + + // True if the JPEG reconstruction box had undefined size (all remaining + // bytes). + bool box_until_eof_ = false; + // Size of most recently parsed JPEG reconstruction box contents. + size_t box_size_ = 0; + + // Next bytes to write JPEG reconstruction to. + uint8_t* next_out_ = nullptr; + // Available bytes to write JPEG reconstruction to. + size_t avail_size_ = 0; +}; + +#else + +// Fake class that disables support for decoding JPEG XL to JPEG. +class JxlToJpegDecoder { + public: + bool IsOutputSet() const { return false; } + bool IsParsingBox() const { return false; } + + JxlDecoderStatus SetOutputBuffer(uint8_t* /* data */, size_t /* size */) { + return JXL_DEC_ERROR; + } + size_t ReleaseOutputBuffer() { return 0; } + + void StartBox(bool /* box_until_eof */, size_t /* contents_size */) {} + + JxlDecoderStatus Process(const uint8_t** next_in, size_t* avail_in) { + return JXL_DEC_ERROR; + } + jpeg::JPEGData* GetJpegData() { return nullptr; } + + Status SetImageBundleJpegData(ImageBundle* /* ib */) { return true; } + + static size_t NumExifMarkers(const jpeg::JPEGData& /*jpeg_data*/) { + return 0; + } + static size_t NumXmpMarkers(const jpeg::JPEGData& /*jpeg_data*/) { return 0; } + static size_t ExifBoxContentSize(const jpeg::JPEGData& /*jpeg_data*/, + size_t* /*size*/) { + return JXL_DEC_ERROR; + } + static size_t XmlBoxContentSize(const jpeg::JPEGData& /*jpeg_data*/, + size_t* /*size*/) { + return JXL_DEC_ERROR; + } + static JxlDecoderStatus SetExif(const uint8_t* /*data*/, size_t /*size*/, + jpeg::JPEGData* /*jpeg_data*/) { + return JXL_DEC_ERROR; + } + static JxlDecoderStatus SetXmp(const uint8_t* /*data*/, size_t /*size*/, + jpeg::JPEGData* /*jpeg_data*/) { + return JXL_DEC_ERROR; + } + + JxlDecoderStatus WriteOutput(const jpeg::JPEGData& /* jpeg_data */) { + return JXL_DEC_SUCCESS; + } +}; + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jxl + +#endif // LIB_JXL_DECODE_TO_JPEG_H_ diff --git a/media/libjxl/src/lib/jxl/enc_ac_strategy.cc b/media/libjxl/src/lib/jxl/enc_ac_strategy.cc new file mode 100644 index 000000000..a6de18fbd --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ac_strategy.cc @@ -0,0 +1,1123 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_ac_strategy.h" + +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_ac_strategy.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fast_math-inl.h" + +// Some of the floating point constants in this file and in other +// files in the libjxl project have been obtained using the +// tools/optimizer/simplex_fork.py tool. It is a variation of +// Nelder-Mead optimization, and we generally try to minimize +// BPP * pnorm aggregate as reported by the benchmark_xl tool, +// but occasionally the values are optimized by using additional +// constraints such as maintaining a certain density, or ratio of +// popularity of integral transforms. Jyrki visually reviews all +// such changes and often makes manual changes to maintain good +// visual quality to changes where butteraugli was not sufficiently +// sensitive to some kind of degradation. Unfortunately image quality +// is still more of an art than science. + +// This must come before the begin/end_target, but HWY_ONCE is only true +// after that, so use an "include guard". +#ifndef LIB_JXL_ENC_AC_STRATEGY_ +#define LIB_JXL_ENC_AC_STRATEGY_ +// Parameters of the heuristic are marked with a OPTIMIZE comment. +namespace jxl { + +// Debugging utilities. + +// Returns a linear sRGB color (as bytes) for each AC strategy. +const uint8_t* TypeColor(const uint8_t& raw_strategy) { + JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy)); + static_assert(AcStrategy::kNumValidStrategies == 27, "Change colors"); + static constexpr uint8_t kColors[][3] = { + {0xFF, 0xFF, 0x00}, // DCT8 + {0xFF, 0x80, 0x80}, // HORNUSS + {0xFF, 0x80, 0x80}, // DCT2x2 + {0xFF, 0x80, 0x80}, // DCT4x4 + {0x80, 0xFF, 0x00}, // DCT16x16 + {0x00, 0xC0, 0x00}, // DCT32x32 + {0xC0, 0xFF, 0x00}, // DCT16x8 + {0xC0, 0xFF, 0x00}, // DCT8x16 + {0x00, 0xFF, 0x00}, // DCT32x8 + {0x00, 0xFF, 0x00}, // DCT8x32 + {0x00, 0xFF, 0x00}, // DCT32x16 + {0x00, 0xFF, 0x00}, // DCT16x32 + {0xFF, 0x80, 0x00}, // DCT4x8 + {0xFF, 0x80, 0x00}, // DCT8x4 + {0xFF, 0xFF, 0x80}, // AFV0 + {0xFF, 0xFF, 0x80}, // AFV1 + {0xFF, 0xFF, 0x80}, // AFV2 + {0xFF, 0xFF, 0x80}, // AFV3 + {0x00, 0xC0, 0xFF}, // DCT64x64 + {0x00, 0xFF, 0xFF}, // DCT64x32 + {0x00, 0xFF, 0xFF}, // DCT32x64 + {0x00, 0x40, 0xFF}, // DCT128x128 + {0x00, 0x80, 0xFF}, // DCT128x64 + {0x00, 0x80, 0xFF}, // DCT64x128 + {0x00, 0x00, 0xC0}, // DCT256x256 + {0x00, 0x00, 0xFF}, // DCT256x128 + {0x00, 0x00, 0xFF}, // DCT128x256 + }; + return kColors[raw_strategy]; +} + +const uint8_t* TypeMask(const uint8_t& raw_strategy) { + JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy)); + static_assert(AcStrategy::kNumValidStrategies == 27, "Add masks"); + // implicitly, first row and column is made dark + static constexpr uint8_t kMask[][64] = { + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // DCT8 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 1, 1, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // HORNUSS + { + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + }, // 2x2 + { + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + }, // 4x4 + {}, // DCT16x16 (unused) + {}, // DCT32x32 (unused) + {}, // DCT16x8 (unused) + {}, // DCT8x16 (unused) + {}, // DCT32x8 (unused) + {}, // DCT8x32 (unused) + {}, // DCT32x16 (unused) + {}, // DCT16x32 (unused) + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // DCT4x8 + { + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + }, // DCT8x4 + { + 1, 1, 1, 1, 1, 0, 0, 0, // + 1, 1, 1, 1, 0, 0, 0, 0, // + 1, 1, 1, 0, 0, 0, 0, 0, // + 1, 1, 0, 0, 0, 0, 0, 0, // + 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // AFV0 + { + 0, 0, 0, 0, 1, 1, 1, 1, // + 0, 0, 0, 0, 0, 1, 1, 1, // + 0, 0, 0, 0, 0, 0, 1, 1, // + 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // AFV1 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 0, 0, 0, 0, // + }, // AFV2 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 1, 1, // + 0, 0, 0, 0, 0, 1, 1, 1, // + }, // AFV3 + }; + return kMask[raw_strategy]; +} + +void DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize, + size_t ysize, const char* tag, AuxOut* aux_out) { + Image3F color_acs(xsize, ysize); + for (size_t y = 0; y < ysize; y++) { + float* JXL_RESTRICT rows[3] = { + color_acs.PlaneRow(0, y), + color_acs.PlaneRow(1, y), + color_acs.PlaneRow(2, y), + }; + const AcStrategyRow acs_row = ac_strategy.ConstRow(y / kBlockDim); + for (size_t x = 0; x < xsize; x++) { + AcStrategy acs = acs_row[x / kBlockDim]; + const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy()); + for (size_t c = 0; c < 3; c++) { + rows[c][x] = color[c] / 255.f; + } + } + } + size_t stride = color_acs.PixelsPerRow(); + for (size_t c = 0; c < 3; c++) { + for (size_t by = 0; by < DivCeil(ysize, kBlockDim); by++) { + float* JXL_RESTRICT row = color_acs.PlaneRow(c, by * kBlockDim); + const AcStrategyRow acs_row = ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < DivCeil(xsize, kBlockDim); bx++) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy()); + const uint8_t* JXL_RESTRICT mask = TypeMask(acs.RawStrategy()); + if (acs.covered_blocks_x() == 1 && acs.covered_blocks_y() == 1) { + for (size_t iy = 0; iy < kBlockDim && by * kBlockDim + iy < ysize; + iy++) { + for (size_t ix = 0; ix < kBlockDim && bx * kBlockDim + ix < xsize; + ix++) { + if (mask[iy * kBlockDim + ix]) { + row[iy * stride + bx * kBlockDim + ix] = color[c] / 800.f; + } + } + } + } + // draw block edges + for (size_t ix = 0; ix < kBlockDim * acs.covered_blocks_x() && + bx * kBlockDim + ix < xsize; + ix++) { + row[0 * stride + bx * kBlockDim + ix] = color[c] / 350.f; + } + for (size_t iy = 0; iy < kBlockDim * acs.covered_blocks_y() && + by * kBlockDim + iy < ysize; + iy++) { + row[iy * stride + bx * kBlockDim + 0] = color[c] / 350.f; + } + } + } + } + aux_out->DumpImage(tag, color_acs); +} + +} // namespace jxl +#endif // LIB_JXL_ENC_AC_STRATEGY_ + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Round; +using hwy::HWY_NAMESPACE::Sqrt; + +bool MultiBlockTransformCrossesHorizontalBoundary( + const AcStrategyImage& ac_strategy, size_t start_x, size_t y, + size_t end_x) { + if (start_x >= ac_strategy.xsize() || y >= ac_strategy.ysize()) { + return false; + } + if (y % 8 == 0) { + // Nothing crosses 64x64 boundaries, and the memory on the other side + // of the 64x64 block may still uninitialized. + return false; + } + end_x = std::min(end_x, ac_strategy.xsize()); + // The first multiblock might be before the start_x, let's adjust it + // to point to the first IsFirstBlock() == true block we find by backward + // tracing. + AcStrategyRow row = ac_strategy.ConstRow(y); + const size_t start_x_limit = start_x & ~7; + while (start_x != start_x_limit && !row[start_x].IsFirstBlock()) { + --start_x; + } + for (size_t x = start_x; x < end_x;) { + if (row[x].IsFirstBlock()) { + x += row[x].covered_blocks_x(); + } else { + return true; + } + } + return false; +} + +bool MultiBlockTransformCrossesVerticalBoundary( + const AcStrategyImage& ac_strategy, size_t x, size_t start_y, + size_t end_y) { + if (x >= ac_strategy.xsize() || start_y >= ac_strategy.ysize()) { + return false; + } + if (x % 8 == 0) { + // Nothing crosses 64x64 boundaries, and the memory on the other side + // of the 64x64 block may still uninitialized. + return false; + } + end_y = std::min(end_y, ac_strategy.ysize()); + // The first multiblock might be before the start_y, let's adjust it + // to point to the first IsFirstBlock() == true block we find by backward + // tracing. + const size_t start_y_limit = start_y & ~7; + while (start_y != start_y_limit && + !ac_strategy.ConstRow(start_y)[x].IsFirstBlock()) { + --start_y; + } + + for (size_t y = start_y; y < end_y;) { + AcStrategyRow row = ac_strategy.ConstRow(y); + if (row[x].IsFirstBlock()) { + y += row[x].covered_blocks_y(); + } else { + return true; + } + } + return false; +} + +float EstimateEntropy(const AcStrategy& acs, size_t x, size_t y, + const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, float* block, + float* scratch_space, uint32_t* quantized) { + const size_t size = (1 << acs.log2_covered_blocks()) * kDCTBlockSize; + + // Apply transform. + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT block_c = block + size * c; + TransformFromPixels(acs.Strategy(), &config.Pixel(c, x, y), + config.src_stride, block_c, scratch_space); + } + + HWY_FULL(float) df; + + const size_t num_blocks = acs.covered_blocks_x() * acs.covered_blocks_y(); + float quant_norm8 = 0; + float masking = 0; + if (num_blocks == 1) { + // When it is only one 8x8, we don't need aggregation of values. + quant_norm8 = config.Quant(x / 8, y / 8); + masking = 2.0f * config.Masking(x / 8, y / 8); + } else if (num_blocks == 2) { + // Taking max instead of 8th norm seems to work + // better for smallest blocks up to 16x8. Jyrki couldn't get + // improvements in trying the same for 16x16 blocks. + if (acs.covered_blocks_y() == 2) { + quant_norm8 = + std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8, y / 8 + 1)); + masking = 2.0f * std::max(config.Masking(x / 8, y / 8), + config.Masking(x / 8, y / 8 + 1)); + } else { + quant_norm8 = + std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8 + 1, y / 8)); + masking = 2.0f * std::max(config.Masking(x / 8, y / 8), + config.Masking(x / 8 + 1, y / 8)); + } + } else { + float masking_norm2 = 0; + float masking_max = 0; + // Load QF value, calculate empirical heuristic on masking field + // for weighting the information loss. Information loss manifests + // itself as ringing, and masking could hide it. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + float qval = config.Quant(x / 8 + ix, y / 8 + iy); + qval *= qval; + qval *= qval; + quant_norm8 += qval * qval; + float maskval = config.Masking(x / 8 + ix, y / 8 + iy); + masking_max = std::max(masking_max, maskval); + masking_norm2 += maskval * maskval; + } + } + quant_norm8 /= num_blocks; + quant_norm8 = FastPowf(quant_norm8, 1.0f / 8.0f); + masking_norm2 = sqrt(masking_norm2 / num_blocks); + // This is a highly empirical formula. + masking = (masking_norm2 + masking_max); + } + const auto q = Set(df, quant_norm8); + + // Compute entropy. + float entropy = config.base_entropy; + auto info_loss = Zero(df); + auto info_loss2 = Zero(df); + + for (size_t c = 0; c < 3; c++) { + const float* inv_matrix = config.dequant->InvMatrix(acs.RawStrategy(), c); + const auto cmap_factor = Set(df, cmap_factors[c]); + + auto entropy_v = Zero(df); + auto nzeros_v = Zero(df); + auto cost1 = Set(df, config.cost1); + auto cost2 = Set(df, config.cost2); + auto cost_delta = Set(df, config.cost_delta); + for (size_t i = 0; i < num_blocks * kDCTBlockSize; i += Lanes(df)) { + const auto in = Load(df, block + c * size + i); + const auto in_y = Mul(Load(df, block + size + i), cmap_factor); + const auto im = Load(df, inv_matrix + i); + const auto val = Mul(Sub(in, in_y), Mul(im, q)); + const auto rval = Round(val); + const auto diff = AbsDiff(val, rval); + info_loss = Add(info_loss, diff); + info_loss2 = MulAdd(diff, diff, info_loss2); + const auto q = Abs(rval); + const auto q_is_zero = Eq(q, Zero(df)); + entropy_v = Add(entropy_v, IfThenElseZero(Ge(q, Set(df, 1.5f)), cost2)); + // We used to have q * C here, but that cost model seems to + // be punishing large values more than necessary. Sqrt tries + // to avoid large values less aggressively. Having high accuracy + // around zero is most important at low qualities, and there + // we have directly specified costs for 0, 1, and 2. + entropy_v = MulAdd(Sqrt(q), cost_delta, entropy_v); + nzeros_v = Add(nzeros_v, IfThenZeroElse(q_is_zero, Set(df, 1.0f))); + } + entropy_v = MulAdd(nzeros_v, cost1, entropy_v); + + entropy += GetLane(SumOfLanes(df, entropy_v)); + size_t num_nzeros = GetLane(SumOfLanes(df, nzeros_v)); + // Add #bit of num_nonzeros, as an estimate of the cost for encoding the + // number of non-zeros of the block. + size_t nbits = CeilLog2Nonzero(num_nzeros + 1) + 1; + // Also add #bit of #bit of num_nonzeros, to estimate the ANS cost, with a + // bias. + entropy += config.zeros_mul * (CeilLog2Nonzero(nbits + 17) + nbits); + } + float ret = + entropy + + masking * + ((config.info_loss_multiplier * GetLane(SumOfLanes(df, info_loss))) + + (config.info_loss_multiplier2 * + sqrt(num_blocks * GetLane(SumOfLanes(df, info_loss2))))); + return ret; +} + +uint8_t FindBest8x8Transform(size_t x, size_t y, int encoding_speed_tier, + const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, + float* block, float* scratch_space, + uint32_t* quantized, float* entropy_out) { + struct TransformTry8x8 { + AcStrategy::Type type; + int encoding_speed_tier_max_limit; + float entropy_add; + float entropy_mul; + }; + static const TransformTry8x8 kTransforms8x8[] = { + { + AcStrategy::Type::DCT, + 9, + 3.0f, + 0.745f, + }, + { + AcStrategy::Type::DCT4X4, + 5, + 4.0f, + 1.0179946967008329f, + }, + { + AcStrategy::Type::DCT2X2, + 4, + 4.0f, + 0.76721119707580943f, + }, + { + AcStrategy::Type::DCT4X8, + 5, + 0.0f, + 0.700754622182473063f, + }, + { + AcStrategy::Type::DCT8X4, + 5, + 0.0f, + 0.700754622182473063f, + }, + { + AcStrategy::Type::IDENTITY, + 5, + 8.0f, + 0.81217614513585534f, + }, + { + AcStrategy::Type::AFV0, + 4, + 3.0f, + 0.70086131125719425f, + }, + { + AcStrategy::Type::AFV1, + 4, + 3.0f, + 0.70086131125719425f, + }, + { + AcStrategy::Type::AFV2, + 4, + 3.0f, + 0.70086131125719425f, + }, + { + AcStrategy::Type::AFV3, + 4, + 3.0f, + 0.70086131125719425f, + }, + }; + double best = 1e30; + uint8_t best_tx = kTransforms8x8[0].type; + for (auto tx : kTransforms8x8) { + if (tx.encoding_speed_tier_max_limit < encoding_speed_tier) { + continue; + } + AcStrategy acs = AcStrategy::FromRawStrategy(tx.type); + float entropy = EstimateEntropy(acs, x, y, config, cmap_factors, block, + scratch_space, quantized); + entropy = tx.entropy_add + tx.entropy_mul * entropy; + if (entropy < best) { + best_tx = tx.type; + best = entropy; + } + } + *entropy_out = best; + return best_tx; +} + +// bx, by addresses the 64x64 block at 8x8 subresolution +// cx, cy addresses the left, upper 8x8 block position of the candidate +// transform. +void TryMergeAcs(AcStrategy::Type acs_raw, size_t bx, size_t by, size_t cx, + size_t cy, const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, + const float entropy_mul, const uint8_t candidate_priority, + uint8_t* priority, float* JXL_RESTRICT entropy_estimate, + float* block, float* scratch_space, uint32_t* quantized) { + AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw); + float entropy_current = 0; + for (size_t iy = 0; iy < acs.covered_blocks_y(); ++iy) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ++ix) { + if (priority[(cy + iy) * 8 + (cx + ix)] >= candidate_priority) { + // Transform would reuse already allocated blocks and + // lead to invalid overlaps, for example DCT64X32 vs. + // DCT32X64. + return; + } + entropy_current += entropy_estimate[(cy + iy) * 8 + (cx + ix)]; + } + } + float entropy_candidate = + entropy_mul * EstimateEntropy(acs, (bx + cx) * 8, (by + cy) * 8, config, + cmap_factors, block, scratch_space, + quantized); + if (entropy_candidate >= entropy_current) return; + // Accept the candidate. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + entropy_estimate[(cy + iy) * 8 + cx + ix] = 0; + priority[(cy + iy) * 8 + cx + ix] = candidate_priority; + } + } + ac_strategy->Set(bx + cx, by + cy, acs_raw); + entropy_estimate[cy * 8 + cx] = entropy_candidate; +} + +static void SetEntropyForTransform(size_t cx, size_t cy, + const AcStrategy::Type acs_raw, + float entropy, + float* JXL_RESTRICT entropy_estimate) { + const AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw); + for (size_t dy = 0; dy < acs.covered_blocks_y(); ++dy) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); ++dx) { + entropy_estimate[(cy + dy) * 8 + cx + dx] = 0.0; + } + } + entropy_estimate[cy * 8 + cx] = entropy; +} + +AcStrategy::Type AcsSquare(size_t blocks) { + if (blocks == 2) { + return AcStrategy::Type::DCT16X16; + } else if (blocks == 4) { + return AcStrategy::Type::DCT32X32; + } else { + return AcStrategy::Type::DCT64X64; + } +} + +AcStrategy::Type AcsVerticalSplit(size_t blocks) { + if (blocks == 2) { + return AcStrategy::Type::DCT16X8; + } else if (blocks == 4) { + return AcStrategy::Type::DCT32X16; + } else { + return AcStrategy::Type::DCT64X32; + } +} + +AcStrategy::Type AcsHorizontalSplit(size_t blocks) { + if (blocks == 2) { + return AcStrategy::Type::DCT8X16; + } else if (blocks == 4) { + return AcStrategy::Type::DCT16X32; + } else { + return AcStrategy::Type::DCT32X64; + } +} + +// The following function tries to merge smaller transforms into +// squares and the rectangles originating from a single middle division +// (horizontal or vertical) fairly. +// +// This is now generalized to concern about squares +// of blocks X blocks size, where a block is 8x8 pixels. +void FindBestFirstLevelDivisionForSquare( + size_t blocks, bool allow_square_transform, size_t bx, size_t by, size_t cx, + size_t cy, const ACSConfig& config, const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, const float entropy_mul_JXK, + const float entropy_mul_JXJ, float* JXL_RESTRICT entropy_estimate, + float* block, float* scratch_space, uint32_t* quantized) { + // We denote J for the larger dimension here, and K for the smaller. + // For example, for 32x32 block splitting, J would be 32, K 16. + const size_t blocks_half = blocks / 2; + const AcStrategy::Type acs_rawJXK = AcsVerticalSplit(blocks); + const AcStrategy::Type acs_rawKXJ = AcsHorizontalSplit(blocks); + const AcStrategy::Type acs_rawJXJ = AcsSquare(blocks); + const AcStrategy acsJXK = AcStrategy::FromRawStrategy(acs_rawJXK); + const AcStrategy acsKXJ = AcStrategy::FromRawStrategy(acs_rawKXJ); + const AcStrategy acsJXJ = AcStrategy::FromRawStrategy(acs_rawJXJ); + AcStrategyRow row0 = ac_strategy->ConstRow(by + cy + 0); + AcStrategyRow row1 = ac_strategy->ConstRow(by + cy + blocks_half); + // Let's check if we can consider a JXJ block here at all. + // This is not necessary in the basic use of hierarchically merging + // blocks in the simplest possible way, but is needed when we try other + // 'floating' options of merging, possibly after a simple hierarchical + // merge has been explored. + if (MultiBlockTransformCrossesHorizontalBoundary(*ac_strategy, bx + cx, + by + cy, bx + cx + blocks) || + MultiBlockTransformCrossesHorizontalBoundary( + *ac_strategy, bx + cx, by + cy + blocks, bx + cx + blocks) || + MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx, by + cy, + by + cy + blocks) || + MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx + blocks, + by + cy, by + cy + blocks)) { + return; // not suitable for JxJ analysis, some transforms leak out. + } + // For floating transforms there may be + // already blocks selected that make either or both JXK and + // KXJ not feasible for this location. + const bool allow_JXK = !MultiBlockTransformCrossesVerticalBoundary( + *ac_strategy, bx + cx + blocks_half, by + cy, by + cy + blocks); + const bool allow_KXJ = !MultiBlockTransformCrossesHorizontalBoundary( + *ac_strategy, bx + cx, by + cy + blocks_half, bx + cx + blocks); + // Current entropies aggregated on NxN resolution. + float entropy[2][2] = {}; + for (size_t dy = 0; dy < blocks; ++dy) { + for (size_t dx = 0; dx < blocks; ++dx) { + entropy[dy / blocks_half][dx / blocks_half] += + entropy_estimate[(cy + dy) * 8 + (cx + dx)]; + } + } + float entropy_JXK_left = std::numeric_limits::max(); + float entropy_JXK_right = std::numeric_limits::max(); + float entropy_KXJ_top = std::numeric_limits::max(); + float entropy_KXJ_bottom = std::numeric_limits::max(); + float entropy_JXJ = std::numeric_limits::max(); + if (allow_JXK) { + if (row0[bx + cx + 0].RawStrategy() != acs_rawJXK) { + entropy_JXK_left = + entropy_mul_JXK * + EstimateEntropy(acsJXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config, + cmap_factors, block, scratch_space, quantized); + } + if (row0[bx + cx + blocks_half].RawStrategy() != acs_rawJXK) { + entropy_JXK_right = + entropy_mul_JXK * EstimateEntropy(acsJXK, (bx + cx + blocks_half) * 8, + (by + cy + 0) * 8, config, + cmap_factors, block, scratch_space, + quantized); + } + } + if (allow_KXJ) { + if (row0[bx + cx].RawStrategy() != acs_rawKXJ) { + entropy_KXJ_top = + entropy_mul_JXK * + EstimateEntropy(acsKXJ, (bx + cx + 0) * 8, (by + cy + 0) * 8, config, + cmap_factors, block, scratch_space, quantized); + } + if (row1[bx + cx].RawStrategy() != acs_rawKXJ) { + entropy_KXJ_bottom = + entropy_mul_JXK * EstimateEntropy(acsKXJ, (bx + cx + 0) * 8, + (by + cy + blocks_half) * 8, config, + cmap_factors, block, scratch_space, + quantized); + } + } + if (allow_square_transform) { + // We control the exploration of the square transform separately so that + // we can turn it off at high decoding speeds for 32x32, but still allow + // exploring 16x32 and 32x16. + entropy_JXJ = entropy_mul_JXJ * EstimateEntropy(acsJXJ, (bx + cx + 0) * 8, + (by + cy + 0) * 8, config, + cmap_factors, block, + scratch_space, quantized); + } + + // Test if this block should have JXK or KXJ transforms, + // because it can have only one or the other. + float costJxN = std::min(entropy_JXK_left, entropy[0][0] + entropy[1][0]) + + std::min(entropy_JXK_right, entropy[0][1] + entropy[1][1]); + float costNxJ = std::min(entropy_KXJ_top, entropy[0][0] + entropy[0][1]) + + std::min(entropy_KXJ_bottom, entropy[1][0] + entropy[1][1]); + if (entropy_JXJ < costJxN && entropy_JXJ < costNxJ) { + ac_strategy->Set(bx + cx, by + cy, acs_rawJXJ); + SetEntropyForTransform(cx, cy, acs_rawJXJ, entropy_JXJ, entropy_estimate); + } else if (costJxN < costNxJ) { + if (entropy_JXK_left < entropy[0][0] + entropy[1][0]) { + ac_strategy->Set(bx + cx, by + cy, acs_rawJXK); + SetEntropyForTransform(cx, cy, acs_rawJXK, entropy_JXK_left, + entropy_estimate); + } + if (entropy_JXK_right < entropy[0][1] + entropy[1][1]) { + ac_strategy->Set(bx + cx + blocks_half, by + cy, acs_rawJXK); + SetEntropyForTransform(cx + blocks_half, cy, acs_rawJXK, + entropy_JXK_right, entropy_estimate); + } + } else { + if (entropy_KXJ_top < entropy[0][0] + entropy[0][1]) { + ac_strategy->Set(bx + cx, by + cy, acs_rawKXJ); + SetEntropyForTransform(cx, cy, acs_rawKXJ, entropy_KXJ_top, + entropy_estimate); + } + if (entropy_KXJ_bottom < entropy[1][0] + entropy[1][1]) { + ac_strategy->Set(bx + cx, by + cy + blocks_half, acs_rawKXJ); + SetEntropyForTransform(cx, cy + blocks_half, acs_rawKXJ, + entropy_KXJ_bottom, entropy_estimate); + } + } +} + +void ProcessRectACS(PassesEncoderState* JXL_RESTRICT enc_state, + const ACSConfig& config, const Rect& rect) { + // Main philosophy here: + // 1. First find best 8x8 transform for each area. + // 2. Merging them into larger transforms where possibly, but + // starting from the smallest transforms (16x8 and 8x16). + // Additional complication: 16x8 and 8x16 are considered + // simultanouesly and fairly against each other. + // We are looking at 64x64 squares since the YtoX and YtoB + // maps happen to be at that resolution, and having + // integral transforms cross these boundaries leads to + // additional complications. + const CompressParams& cparams = enc_state->cparams; + const float butteraugli_target = cparams.butteraugli_distance; + AcStrategyImage* ac_strategy = &enc_state->shared.ac_strategy; + // TODO(veluca): reuse allocations + auto mem = hwy::AllocateAligned(5 * AcStrategy::kMaxCoeffArea); + auto qmem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); + uint32_t* JXL_RESTRICT quantized = qmem.get(); + float* JXL_RESTRICT block = mem.get(); + float* JXL_RESTRICT scratch_space = mem.get() + 3 * AcStrategy::kMaxCoeffArea; + size_t bx = rect.x0(); + size_t by = rect.y0(); + JXL_ASSERT(rect.xsize() <= 8); + JXL_ASSERT(rect.ysize() <= 8); + size_t tx = bx / kColorTileDimInBlocks; + size_t ty = by / kColorTileDimInBlocks; + const float cmap_factors[3] = { + enc_state->shared.cmap.YtoXRatio( + enc_state->shared.cmap.ytox_map.ConstRow(ty)[tx]), + 0.0f, + enc_state->shared.cmap.YtoBRatio( + enc_state->shared.cmap.ytob_map.ConstRow(ty)[tx]), + }; + if (cparams.speed_tier > SpeedTier::kHare) return; + // First compute the best 8x8 transform for each square. Later, we do not + // experiment with different combinations, but only use the best of the 8x8s + // when DCT8X8 is specified in the tree search. + // 8x8 transforms have 10 variants, but every larger transform is just a DCT. + float entropy_estimate[64] = {}; + // Favor all 8x8 transforms (against 16x8 and larger transforms)) at + // low butteraugli_target distances. + static const float k8x8mul1 = -0.55; + static const float k8x8mul2 = 1.0735757687292623f; + static const float k8x8base = 1.4; + const float mul8x8 = k8x8mul2 + k8x8mul1 / (butteraugli_target + k8x8base); + for (size_t iy = 0; iy < rect.ysize(); iy++) { + for (size_t ix = 0; ix < rect.xsize(); ix++) { + float entropy = 0.0; + const uint8_t best_of_8x8s = FindBest8x8Transform( + 8 * (bx + ix), 8 * (by + iy), static_cast(cparams.speed_tier), + config, cmap_factors, ac_strategy, block, scratch_space, quantized, + &entropy); + ac_strategy->Set(bx + ix, by + iy, + static_cast(best_of_8x8s)); + entropy_estimate[iy * 8 + ix] = entropy * mul8x8; + } + } + // Merge when a larger transform is better than the previously + // searched best combination of 8x8 transforms. + struct MergeTry { + AcStrategy::Type type; + uint8_t priority; + uint8_t decoding_speed_tier_max_limit; + uint8_t encoding_speed_tier_max_limit; + float entropy_mul; + }; + static const float k8X16mul1 = -0.55; + static const float k8X16mul2 = 0.9019587899705066; + static const float k8X16base = 1.6; + const float entropy_mul16X8 = + k8X16mul2 + k8X16mul1 / (butteraugli_target + k8X16base); + // const float entropy_mul16X8 = mul8X16 * 0.91195782912371126f; + + static const float k16X16mul1 = -0.35; + static const float k16X16mul2 = 0.82; + static const float k16X16base = 2.0; + const float entropy_mul16X16 = + k16X16mul2 + k16X16mul1 / (butteraugli_target + k16X16base); + // const float entropy_mul16X16 = mul16X16 * 0.83183417727960129f; + + static const float k32X16mul1 = -0.1; + static const float k32X16mul2 = 0.84; + static const float k32X16base = 2.5; + const float entropy_mul16X32 = + k32X16mul2 + k32X16mul1 / (butteraugli_target + k32X16base); + + const float entropy_mul32X32 = 0.9; + const float entropy_mul64X64 = 1.43f; + // TODO(jyrki): Consider this feedback in further changes: + // Also effectively when the multipliers for smaller blocks are + // below 1, this raises the bar for the bigger blocks even higher + // in that sense these constants are not independent (e.g. changing + // the constant for DCT16x32 by -5% (making it more likely) also + // means that DCT32x32 becomes harder to do when starting from + // two DCT16x32s). It might be better to make them more independent, + // e.g. by not applying the multiplier when storing the new entropy + // estimates in TryMergeToACSCandidate(). + const MergeTry kTransformsForMerge[9] = { + {AcStrategy::Type::DCT16X8, 2, 4, 5, entropy_mul16X8}, + {AcStrategy::Type::DCT8X16, 2, 4, 5, entropy_mul16X8}, + // FindBestFirstLevelDivisionForSquare looks for DCT16X16 and its + // subdivisions. {AcStrategy::Type::DCT16X16, 3, entropy_mul16X16}, + {AcStrategy::Type::DCT16X32, 4, 4, 4, entropy_mul16X32}, + {AcStrategy::Type::DCT32X16, 4, 4, 4, entropy_mul16X32}, + // FindBestFirstLevelDivisionForSquare looks for DCT32X32 and its + // subdivisions. {AcStrategy::Type::DCT32X32, 5, 1, 5, + // 0.9822994906548809f}, + // TODO(jyrki): re-enable 64x32 and 64x64 if/when possible. + {AcStrategy::Type::DCT64X32, 6, 1, 3, 1.26f}, + {AcStrategy::Type::DCT32X64, 6, 1, 3, 1.26f}, + // {AcStrategy::Type::DCT64X64, 8, 1, 3, 2.0846542128012948f}, + }; + /* + These sizes not yet included in merge heuristic: + set(AcStrategy::Type::DCT32X8, 0.0f, 2.261390410971102f); + set(AcStrategy::Type::DCT8X32, 0.0f, 2.261390410971102f); + set(AcStrategy::Type::DCT128X128, 0.0f, 1.0f); + set(AcStrategy::Type::DCT128X64, 0.0f, 0.73f); + set(AcStrategy::Type::DCT64X128, 0.0f, 0.73f); + set(AcStrategy::Type::DCT256X256, 0.0f, 1.0f); + set(AcStrategy::Type::DCT256X128, 0.0f, 0.73f); + set(AcStrategy::Type::DCT128X256, 0.0f, 0.73f); + */ + + // Priority is a tricky kludge to avoid collisions so that transforms + // don't overlap. + uint8_t priority[64] = {}; + for (auto tx : kTransformsForMerge) { + if (tx.decoding_speed_tier_max_limit < cparams.decoding_speed_tier) { + continue; + } + AcStrategy acs = AcStrategy::FromRawStrategy(tx.type); + + for (size_t cy = 0; cy + acs.covered_blocks_y() - 1 < rect.ysize(); + cy += acs.covered_blocks_y()) { + for (size_t cx = 0; cx + acs.covered_blocks_x() - 1 < rect.xsize(); + cx += acs.covered_blocks_x()) { + if (cy + 7 < rect.ysize() && cx + 7 < rect.xsize()) { + if (cparams.decoding_speed_tier < 4 && + tx.type == AcStrategy::Type::DCT32X64) { + // We handle both DCT8X16 and DCT16X8 at the same time. + if ((cy | cx) % 8 == 0) { + FindBestFirstLevelDivisionForSquare( + 8, true, bx, by, cx, cy, config, cmap_factors, ac_strategy, + tx.entropy_mul, entropy_mul64X64, entropy_estimate, block, + scratch_space, quantized); + } + continue; + } else if (tx.type == AcStrategy::Type::DCT32X16) { + // We handled both DCT8X16 and DCT16X8 at the same time, + // and that is above. The last column and last row, + // when the last column or last row is odd numbered, + // are still handled by TryMergeAcs. + continue; + } + } + if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) || + (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) { + // already covered by FindBest32X32 + continue; + } + + if (cy + 3 < rect.ysize() && cx + 3 < rect.xsize()) { + if (tx.type == AcStrategy::Type::DCT16X32) { + // We handle both DCT8X16 and DCT16X8 at the same time. + bool enable_32x32 = cparams.decoding_speed_tier < 4; + if ((cy | cx) % 4 == 0) { + FindBestFirstLevelDivisionForSquare( + 4, enable_32x32, bx, by, cx, cy, config, cmap_factors, + ac_strategy, tx.entropy_mul, entropy_mul32X32, + entropy_estimate, block, scratch_space, quantized); + } + continue; + } else if (tx.type == AcStrategy::Type::DCT32X16) { + // We handled both DCT8X16 and DCT16X8 at the same time, + // and that is above. The last column and last row, + // when the last column or last row is odd numbered, + // are still handled by TryMergeAcs. + continue; + } + } + if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) || + (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) { + // already covered by FindBest32X32 + continue; + } + if (cy + 1 < rect.ysize() && cx + 1 < rect.xsize()) { + if (tx.type == AcStrategy::Type::DCT8X16) { + // We handle both DCT8X16 and DCT16X8 at the same time. + if ((cy | cx) % 2 == 0) { + FindBestFirstLevelDivisionForSquare( + 2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy, + tx.entropy_mul, entropy_mul16X16, entropy_estimate, block, + scratch_space, quantized); + } + continue; + } else if (tx.type == AcStrategy::Type::DCT16X8) { + // We handled both DCT8X16 and DCT16X8 at the same time, + // and that is above. The last column and last row, + // when the last column or last row is odd numbered, + // are still handled by TryMergeAcs. + continue; + } + } + if ((tx.type == AcStrategy::Type::DCT8X16 && cy % 2 == 1) || + (tx.type == AcStrategy::Type::DCT16X8 && cx % 2 == 1)) { + // already covered by FindBestFirstLevelDivisionForSquare + continue; + } + // All other merge sizes are handled here. + // Some of the DCT16X8s and DCT8X16s will still leak through here + // when there is an odd number of 8x8 blocks, then the last row + // and column will get their DCT16X8s and DCT8X16s through the + // normal integral transform merging process. + TryMergeAcs(tx.type, bx, by, cx, cy, config, cmap_factors, ac_strategy, + tx.entropy_mul, tx.priority, &priority[0], entropy_estimate, + block, scratch_space, quantized); + } + } + } + // Here we still try to do some non-aligned matching, find a few more + // 16X8, 8X16 and 16X16s between the non-2-aligned blocks. + if (cparams.speed_tier >= SpeedTier::kHare) { + return; + } + for (int ii = 0; ii < 3; ++ii) { + for (size_t cy = 1 - (ii == 1); cy + 1 < rect.ysize(); cy += 2) { + for (size_t cx = 1 - (ii == 2); cx + 1 < rect.xsize(); cx += 2) { + FindBestFirstLevelDivisionForSquare( + 2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy, + entropy_mul16X8, entropy_mul16X16, entropy_estimate, block, + scratch_space, quantized); + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ProcessRectACS); + +void AcStrategyHeuristics::Init(const Image3F& src, + PassesEncoderState* enc_state) { + this->enc_state = enc_state; + config.dequant = &enc_state->shared.matrices; + const CompressParams& cparams = enc_state->cparams; + const float butteraugli_target = cparams.butteraugli_distance; + + if (cparams.speed_tier >= SpeedTier::kCheetah) { + JXL_CHECK(enc_state->shared.matrices.EnsureComputed(1)); // DCT8 only + } else { + uint32_t acs_mask = 0; + // All transforms up to 64x64. + for (size_t i = 0; i < AcStrategy::DCT128X128; i++) { + acs_mask |= (1 << i); + } + JXL_CHECK(enc_state->shared.matrices.EnsureComputed(acs_mask)); + } + + // Image row pointers and strides. + config.quant_field_row = enc_state->initial_quant_field.Row(0); + config.quant_field_stride = enc_state->initial_quant_field.PixelsPerRow(); + auto& mask = enc_state->initial_quant_masking; + if (mask.xsize() > 0 && mask.ysize() > 0) { + config.masking_field_row = mask.Row(0); + config.masking_field_stride = mask.PixelsPerRow(); + } + + config.src_rows[0] = src.ConstPlaneRow(0, 0); + config.src_rows[1] = src.ConstPlaneRow(1, 0); + config.src_rows[2] = src.ConstPlaneRow(2, 0); + config.src_stride = src.PixelsPerRow(); + + // Entropy estimate is composed of two factors: + // - estimate of the number of bits that will be used by the block + // - information loss due to quantization + // The following constant controls the relative weights of these components. + config.info_loss_multiplier = 138.0f; + config.info_loss_multiplier2 = 50.46839691767866; + // TODO(jyrki): explore base_entropy setting more. + // A small value (0?) works better at high distance, while a larger value + // may be more effective at low distance/high bpp. + config.base_entropy = 0.0; + config.zeros_mul = 7.565053364251793f; + // Lots of +1 and -1 coefficients at high quality, it is + // beneficial to favor them. At low qualities zeros matter more + // and +1 / -1 coefficients are already quite harmful. + float slope = std::min(1.0f, butteraugli_target * (1.0f / 3)); + config.cost1 = 1 + slope * 8.8703248061477744f; + config.cost2 = 4.4628149885273363f; + config.cost_delta = 5.3359184934516337f; + JXL_ASSERT(enc_state->shared.ac_strategy.xsize() == + enc_state->shared.frame_dim.xsize_blocks); + JXL_ASSERT(enc_state->shared.ac_strategy.ysize() == + enc_state->shared.frame_dim.ysize_blocks); +} + +void AcStrategyHeuristics::ProcessRect(const Rect& rect) { + PROFILER_FUNC; + const CompressParams& cparams = enc_state->cparams; + // In Falcon mode, use DCT8 everywhere and uniform quantization. + if (cparams.speed_tier >= SpeedTier::kCheetah) { + enc_state->shared.ac_strategy.FillDCT8(rect); + return; + } + HWY_DYNAMIC_DISPATCH(ProcessRectACS) + (enc_state, config, rect); +} + +void AcStrategyHeuristics::Finalize(AuxOut* aux_out) { + const auto& ac_strategy = enc_state->shared.ac_strategy; + // Accounting and debug output. + if (aux_out != nullptr) { + aux_out->num_small_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::IDENTITY) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT2X2) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT4X4); + aux_out->num_dct4x8_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT4X8) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X4); + aux_out->num_afv_blocks = ac_strategy.CountBlocks(AcStrategy::Type::AFV0) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV1) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV2) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV3); + aux_out->num_dct8_blocks = ac_strategy.CountBlocks(AcStrategy::Type::DCT); + aux_out->num_dct8x16_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X16) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X8); + aux_out->num_dct8x32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X32) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X8); + aux_out->num_dct16_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X16); + aux_out->num_dct16x32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X32) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X16); + aux_out->num_dct32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X32); + aux_out->num_dct32x64_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X64) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT64X32); + aux_out->num_dct64_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT64X64); + } + + if (WantDebugOutput(aux_out)) { + DumpAcStrategy(ac_strategy, enc_state->shared.frame_dim.xsize, + enc_state->shared.frame_dim.ysize, "ac_strategy", aux_out); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_ac_strategy.h b/media/libjxl/src/lib/jxl/enc_ac_strategy.h new file mode 100644 index 000000000..409f18b89 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ac_strategy.h @@ -0,0 +1,75 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_AC_STRATEGY_H_ +#define LIB_JXL_ENC_AC_STRATEGY_H_ + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +// `FindBestAcStrategy` uses heuristics to choose which AC strategy should be +// used in each block, as well as the initial quantization field. + +namespace jxl { + +// AC strategy selection: utility struct. + +struct ACSConfig { + const DequantMatrices* JXL_RESTRICT dequant; + float info_loss_multiplier; + float info_loss_multiplier2; + float* JXL_RESTRICT quant_field_row; + size_t quant_field_stride; + float* JXL_RESTRICT masking_field_row; + size_t masking_field_stride; + const float* JXL_RESTRICT src_rows[3]; + size_t src_stride; + // Cost for 1 (-1), 2 (-2) explicitly, cost for others computed with cost1 + + // cost2 + sqrt(q) * cost_delta. + float cost1; + float cost2; + float cost_delta; + float base_entropy; + float zeros_mul; + const float& Pixel(size_t c, size_t x, size_t y) const { + return src_rows[c][y * src_stride + x]; + } + float Masking(size_t bx, size_t by) const { + JXL_DASSERT(masking_field_row[by * masking_field_stride + bx] > 0); + return masking_field_row[by * masking_field_stride + bx]; + } + float Quant(size_t bx, size_t by) const { + JXL_DASSERT(quant_field_row[by * quant_field_stride + bx] > 0); + return quant_field_row[by * quant_field_stride + bx]; + } +}; + +struct AcStrategyHeuristics { + void Init(const Image3F& src, PassesEncoderState* enc_state); + void ProcessRect(const Rect& rect); + void Finalize(AuxOut* aux_out); + ACSConfig config; + PassesEncoderState* enc_state; +}; + +// Debug. +void DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize, + size_t ysize, const char* tag, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_AC_STRATEGY_H_ diff --git a/media/libjxl/src/lib/jxl/enc_adaptive_quantization.cc b/media/libjxl/src/lib/jxl/enc_adaptive_quantization.cc new file mode 100644 index 000000000..4d245b41d --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_adaptive_quantization.cc @@ -0,0 +1,1151 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_adaptive_quantization.h" + +#include +#include +#include + +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/gauss_blur.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// The following functions modulate an exponent (out_val) and return the updated +// value. Their descriptor is limited to 8 lanes for 8x8 blocks. + +// Hack for mask estimation. Eventually replace this code with butteraugli's +// masking. +float ComputeMaskForAcStrategyUse(const float out_val) { + const float kMul = 1.0f; + const float kOffset = 0.001f; + return kMul / (out_val + kOffset); +} + +template +V ComputeMask(const D d, const V out_val) { + const auto kBase = Set(d, -0.74174993f); + const auto kMul4 = Set(d, 3.2353257320940401f); + const auto kMul2 = Set(d, 12.906028311180409f); + const auto kOffset2 = Set(d, 305.04035728311436f); + const auto kMul3 = Set(d, 5.0220313103171232f); + const auto kOffset3 = Set(d, 2.1925739705298404f); + const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3); + const auto kMul0 = Set(d, 0.74760422233706747f); + const auto k1 = Set(d, 1.0f); + + // Avoid division by zero. + const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f)); + const auto v2 = Div(k1, Add(v1, kOffset2)); + const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3)); + const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4)); + // TODO(jyrki): + // A log or two here could make sense. In butteraugli we have effectively + // log(log(x + C)) for this kind of use, as a single log is used in + // saturating visual masking and here the modulation values are exponential, + // another log would counter that. + return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3)))); +} + +// For converting full vectors to a subset. Assumes `vfull` lanes are identical. +template +Vec CapTo(const D d, VFull vfull) { + using T = typename D::T; + const HWY_FULL(T) dfull; + HWY_ALIGN T lanes[MaxLanes(dfull)]; + Store(vfull, dfull, lanes); + return Load(d, lanes); +} + +// mul and mul2 represent a scaling difference between jxl and butteraugli. +static const float kSGmul = 226.0480446705883f; +static const float kSGmul2 = 1.0f / 73.377132366608819f; +static const float kLog2 = 0.693147181f; +// Includes correction factor for std::log -> log2. +static const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2; +static const float kSGVOffset = 7.14672470003f; + +template +V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) { + // The opsin space in jxl is the cubic root of photons, i.e., v * v * v + // is related to the number of photons. + // + // SimpleGamma(v * v * v) is the psychovisual space in butteraugli. + // This ratio allows quantization to move from jxl's opsin space to + // butteraugli's log-gamma space. + float kEpsilon = 1e-2; + v = ZeroIfNegative(v); + const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul); + const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon); + const auto kDenMul = Set(d, kLog2 * kSGmul); + + const auto v2 = Mul(v, v); + + const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon)); + const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset); + return invert ? Div(num, den) : Div(den, num); +} + +template +static float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane( + RatioOfDerivativesOfCubicRootToSimpleGamma(DScalar(), vscalar)); +} + +// TODO(veluca): this function computes an approximation of the derivative of +// SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or +// exact derivatives. For reference, SimpleGamma was: +/* +template +V SimpleGamma(const D d, V v) { + // A simple HDR compatible gamma function. + const auto mul = Set(d, kSGmul); + const auto kRetMul = Set(d, kSGRetMul); + const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f); + const auto kVOffset = Set(d, kSGVOffset); + + v *= mul; + + // This should happen rarely, but may lead to a NaN, which is rather + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + // TODO(veluca): with FastLog2f, this no longer leads to NaNs. + v = ZeroIfNegative(v); + return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd; +} +*/ + +template +V GammaModulation(const D d, const size_t x, const size_t y, + const ImageF& xyb_x, const ImageF& xyb_y, const V out_val) { + const float kBias = 0.16f; + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[0]); + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[1]); + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[2]); + auto overall_ratio = Zero(d); + auto bias = Set(d, kBias); + auto half = Set(d, 0.5f); + for (size_t dy = 0; dy < 8; ++dy) { + const float* const JXL_RESTRICT row_in_x = xyb_x.Row(y + dy); + const float* const JXL_RESTRICT row_in_y = xyb_y.Row(y + dy); + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { + const auto iny = Add(Load(d, row_in_y + x + dx), bias); + const auto inx = Load(d, row_in_x + x + dx); + const auto r = Sub(iny, inx); + const auto g = Add(iny, inx); + const auto ratio_r = + RatioOfDerivativesOfCubicRootToSimpleGamma(d, r); + const auto ratio_g = + RatioOfDerivativesOfCubicRootToSimpleGamma(d, g); + const auto avg_ratio = Mul(half, Add(ratio_r, ratio_g)); + + overall_ratio = Add(overall_ratio, avg_ratio); + } + } + overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 1.0f / 64)); + // ideally -1.0, but likely optimal correction adds some entropy, so slightly + // less than that. + // ln(2) constant folded in because we want std::log but have FastLog2f. + const auto kGam = Set(d, -0.15526878023684174f * 0.693147180559945f); + return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val); +} + +template +V ColorModulation(const D d, const size_t x, const size_t y, + const ImageF& xyb_x, const ImageF& xyb_y, const ImageF& xyb_b, + const double butteraugli_target, V out_val) { + static const float kStrengthMul = 2.177823400325309; + static const float kRedRampStart = 0.0073200141118951231; + static const float kRedRampLength = 0.019421555948474039; + static const float kBlueRampLength = 0.086890611400405895; + static const float kBlueRampStart = 0.26973418507870539; + const float strength = kStrengthMul * (1.0f - 0.25f * butteraugli_target); + if (strength < 0) { + return out_val; + } + // x values are smaller than y and b values, need to take the difference into + // account. + const float red_strength = strength * 5.992297772961519f; + const float blue_strength = strength; + { + // Reduce some bits from areas not blue or red. + const float offset = strength * -0.009174542291185913f; + out_val = Add(out_val, Set(d, offset)); + } + // Calculate how much of the 8x8 block is covered with blue or red. + auto blue_coverage = Zero(d); + auto red_coverage = Zero(d); + for (size_t dy = 0; dy < 8; ++dy) { + const float* const JXL_RESTRICT row_in_x = xyb_x.Row(y + dy); + const float* const JXL_RESTRICT row_in_y = xyb_y.Row(y + dy); + const float* const JXL_RESTRICT row_in_b = xyb_b.Row(y + dy); + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { + const auto pixel_x = Max( + Set(d, 0.0f), Sub(Load(d, row_in_x + x + dx), Set(d, kRedRampStart))); + const auto pixel_y = Load(d, row_in_y + x + dx); + const auto pixel_b = + Max(Set(d, 0.0f), Sub(Load(d, row_in_b + x + dx), + Add(pixel_y, Set(d, kBlueRampStart)))); + const auto blue_slope = Min(pixel_b, Set(d, kBlueRampLength)); + const auto red_slope = Min(pixel_x, Set(d, kRedRampLength)); + red_coverage = Add(red_coverage, red_slope); + blue_coverage = Add(blue_coverage, blue_slope); + } + } + + // Saturate when the high red or high blue coverage is above a level. + // The idea here is that if a certain fraction of the block is red or + // blue we consider as if it was fully red or blue. + static const float ratio = 30.610615782142737f; // out of 64 pixels. + + auto overall_red_coverage = SumOfLanes(d, red_coverage); + overall_red_coverage = + Min(overall_red_coverage, Set(d, ratio * kRedRampLength)); + overall_red_coverage = + Mul(overall_red_coverage, Set(d, red_strength / ratio)); + + auto overall_blue_coverage = SumOfLanes(d, blue_coverage); + overall_blue_coverage = + Min(overall_blue_coverage, Set(d, ratio * kBlueRampLength)); + overall_blue_coverage = + Mul(overall_blue_coverage, Set(d, blue_strength / ratio)); + + return Add(overall_red_coverage, Add(overall_blue_coverage, out_val)); +} + +// Change precision in 8x8 blocks that have high frequency content. +template +V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb, + const V out_val) { + // Zero out the invalid differences for the rightmost value per row. + const Rebind du; + HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, 0}; + + auto sum = Zero(d); // sum of absolute differences with right and below + + for (size_t dy = 0; dy < 8; ++dy) { + const float* JXL_RESTRICT row_in = xyb.Row(y + dy) + x; + const float* JXL_RESTRICT row_in_next = + dy == 7 ? row_in : xyb.Row(y + dy + 1) + x; + + // In SCALAR, there is no guarantee of having extra row padding. + // Hence, we need to ensure we don't access pixels outside the row itself. + // In SIMD modes, however, rows are padded, so it's safe to access one + // garbage value after the row. The vector then gets masked with kMaskRight + // to remove the influence of that value. +#if HWY_TARGET != HWY_SCALAR + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { +#else + for (size_t dx = 0; dx < 7; dx += Lanes(d)) { +#endif + const auto p = Load(d, row_in + dx); + const auto pr = LoadU(d, row_in + dx + 1); + const auto mask = BitCast(d, Load(du, kMaskRight + dx)); + sum = Add(sum, And(mask, AbsDiff(p, pr))); + + const auto pd = Load(d, row_in_next + dx); + sum = Add(sum, AbsDiff(p, pd)); + } + } + + sum = SumOfLanes(d, sum); + return MulAdd(sum, Set(d, -2.0052193233688884f / 112), out_val); +} + +void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x, + const ImageF& xyb_y, const ImageF& xyb_b, + const float scale, const Rect& rect, ImageF* out) { + JXL_ASSERT(SameSize(xyb_x, xyb_y)); + JXL_ASSERT(DivCeil(xyb_x.xsize(), kBlockDim) == out->xsize()); + JXL_ASSERT(DivCeil(xyb_x.ysize(), kBlockDim) == out->ysize()); + + float base_level = 0.5f * scale; + float kDampenRampStart = 7.0f; + float kDampenRampEnd = 14.0f; + float dampen = 1.0f; + if (butteraugli_target >= kDampenRampStart) { + dampen = 1.0f - ((butteraugli_target - kDampenRampStart) / + (kDampenRampEnd - kDampenRampStart)); + if (dampen < 0) { + dampen = 0; + } + } + const float mul = scale * dampen; + const float add = (1.0f - dampen) * base_level; + for (size_t iy = rect.y0(); iy < rect.y0() + rect.ysize(); iy++) { + const size_t y = iy * 8; + float* const JXL_RESTRICT row_out = out->Row(iy); + const HWY_CAPPED(float, kBlockDim) df; + for (size_t ix = rect.x0(); ix < rect.x0() + rect.xsize(); ix++) { + size_t x = ix * 8; + auto out_val = Set(df, row_out[ix]); + out_val = ComputeMask(df, out_val); + out_val = HfModulation(df, x, y, xyb_y, out_val); + out_val = ColorModulation(df, x, y, xyb_x, xyb_y, xyb_b, + butteraugli_target, out_val); + out_val = GammaModulation(df, x, y, xyb_x, xyb_y, out_val); + // We want multiplicative quantization field, so everything + // until this point has been modulating the exponent. + row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add; + } + } +} + +template +V MaskingSqrt(const D d, V v) { + static const float kLogOffset = 26.481471032459346f; + static const float kMul = 211.50759899638012f; + const auto mul_v = Set(d, kMul * 1e8); + const auto offset_v = Set(d, kLogOffset); + return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v))); +} + +float MaskingSqrt(const float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane(MaskingSqrt(DScalar(), vscalar)); +} + +void StoreMin4(const float v, float& min0, float& min1, float& min2, + float& min3) { + if (v < min3) { + if (v < min0) { + min3 = min2; + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min3 = min2; + min2 = min1; + min1 = v; + } else if (v < min2) { + min3 = min2; + min2 = v; + } else { + min3 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas are generally smooth, don't do masking. +// Output is downsampled 2x. +void FuzzyErosion(const Rect& from_rect, const ImageF& from, + const Rect& to_rect, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + constexpr int kStep = 1; + static_assert(kStep == 1, "Step must be 1"); + JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize()); + JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize()); + for (size_t fy = 0; fy < from_rect.ysize(); ++fy) { + size_t y = fy + from_rect.y0(); + size_t ym1 = y >= kStep ? y - kStep : y; + size_t yp1 = y + kStep < ysize ? y + kStep : y; + const float* rowt = from.Row(ym1); + const float* row = from.Row(y); + const float* rowb = from.Row(yp1); + float* row_out = to_rect.Row(to, fy / 2); + for (size_t fx = 0; fx < from_rect.xsize(); ++fx) { + size_t x = fx + from_rect.x0(); + size_t xm1 = x >= kStep ? x - kStep : x; + size_t xp1 = x + kStep < xsize ? x + kStep : x; + float min0 = row[x]; + float min1 = row[xm1]; + float min2 = row[xp1]; + float min3 = rowt[xm1]; + // Sort the first four values. + if (min0 > min1) std::swap(min0, min1); + if (min0 > min2) std::swap(min0, min2); + if (min0 > min3) std::swap(min0, min3); + if (min1 > min2) std::swap(min1, min2); + if (min1 > min3) std::swap(min1, min3); + if (min2 > min3) std::swap(min2, min3); + // The remaining five values of a 3x3 neighbourhood. + StoreMin4(rowt[x], min0, min1, min2, min3); + StoreMin4(rowt[xp1], min0, min1, min2, min3); + StoreMin4(rowb[xm1], min0, min1, min2, min3); + StoreMin4(rowb[x], min0, min1, min2, min3); + StoreMin4(rowb[xp1], min0, min1, min2, min3); + static const float kMulC = 0.05f; + static const float kMul0 = 0.05f; + static const float kMul1 = 0.05f; + static const float kMul2 = 0.05f; + static const float kMul3 = 0.05f; + float v = kMulC * row[x] + kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + + kMul3 * min3; + if (fx % 2 == 0 && fy % 2 == 0) { + row_out[fx / 2] = v; + } else { + row_out[fx / 2] += v; + } + } + } +} + +struct AdaptiveQuantizationImpl { + void Init(const Image3F& xyb) { + JXL_DASSERT(xyb.xsize() % kBlockDim == 0); + JXL_DASSERT(xyb.ysize() % kBlockDim == 0); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + aq_map = ImageF(xsize / kBlockDim, ysize / kBlockDim); + } + void PrepareBuffers(size_t num_threads) { + diff_buffer = ImageF(kEncTileDim + 8, num_threads); + for (size_t i = pre_erosion.size(); i < num_threads; i++) { + pre_erosion.emplace_back(kEncTileDimInBlocks * 2 + 2, + kEncTileDimInBlocks * 2 + 2); + } + } + + void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb, + const Rect& rect, const int thread, ImageF* mask) { + PROFILER_ZONE("aq DiffPrecompute"); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + + // The XYB gamma is 3.0 to be able to decode faster with two muls. + // Butteraugli's gamma is matching the gamma of human eye, around 2.6. + // We approximate the gamma difference by adding one cubic root into + // the adaptive quantization. This gives us a total gamma of 2.6666 + // for quantization uses. + const float match_gamma_offset = 0.019; + + const HWY_FULL(float) df; + const float kXMul = 23.426802998210313f; + const auto kXMulv = Set(df, kXMul); + + size_t y_start = rect.y0() * 8; + size_t y_end = y_start + rect.ysize() * 8; + + size_t x0 = rect.x0() * 8; + size_t x1 = x0 + rect.xsize() * 8; + if (x0 != 0) x0 -= 4; + if (x1 != xyb.xsize()) x1 += 4; + if (y_start != 0) y_start -= 4; + if (y_end != xyb.ysize()) y_end += 4; + pre_erosion[thread].ShrinkTo((x1 - x0) / 4, (y_end - y_start) / 4); + + // Computes image (padded to multiple of 8x8) of local pixel differences. + // Subsample both directions by 4. + for (size_t y = y_start; y < y_end; ++y) { + size_t y2 = y + 1 < ysize ? y + 1 : y; + size_t y1 = y > 0 ? y - 1 : y; + + const float* row_in = xyb.PlaneRow(1, y); + const float* row_in1 = xyb.PlaneRow(1, y1); + const float* row_in2 = xyb.PlaneRow(1, y2); + const float* row_x_in = xyb.PlaneRow(0, y); + const float* row_x_in1 = xyb.PlaneRow(0, y1); + const float* row_x_in2 = xyb.PlaneRow(0, y2); + float* JXL_RESTRICT row_out = diff_buffer.Row(thread); + + auto scalar_pixel = [&](size_t x) { + const size_t x2 = x + 1 < xsize ? x + 1 : x; + const size_t x1 = x > 0 ? x - 1 : x; + const float base = + 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); + const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( + row_in[x] + match_gamma_offset); + float diff = gammac * (row_in[x] - base); + diff *= diff; + const float base_x = + 0.25f * (row_x_in2[x] + row_x_in1[x] + row_x_in[x1] + row_x_in[x2]); + float diff_x = gammac * (row_x_in[x] - base_x); + diff_x *= diff_x; + diff += kXMul * diff_x; + diff = MaskingSqrt(diff); + if ((y % 4) != 0) { + row_out[x - x0] += diff; + } else { + row_out[x - x0] = diff; + } + }; + + size_t x = x0; + // First pixel of the row. + if (x0 == 0) { + scalar_pixel(x0); + ++x; + } + // SIMD + const auto match_gamma_offset_v = Set(df, match_gamma_offset); + const auto quarter = Set(df, 0.25f); + for (; x + 1 + Lanes(df) < x1; x += Lanes(df)) { + const auto in = LoadU(df, row_in + x); + const auto in_r = LoadU(df, row_in + x + 1); + const auto in_l = LoadU(df, row_in + x - 1); + const auto in_t = LoadU(df, row_in2 + x); + const auto in_b = LoadU(df, row_in1 + x); + auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b))); + auto gammacv = + RatioOfDerivativesOfCubicRootToSimpleGamma( + df, Add(in, match_gamma_offset_v)); + auto diff = Mul(gammacv, Sub(in, base)); + diff = Mul(diff, diff); + + const auto in_x = LoadU(df, row_x_in + x); + const auto in_x_r = LoadU(df, row_x_in + x + 1); + const auto in_x_l = LoadU(df, row_x_in + x - 1); + const auto in_x_t = LoadU(df, row_x_in2 + x); + const auto in_x_b = LoadU(df, row_x_in1 + x); + auto base_x = + Mul(quarter, Add(Add(in_x_r, in_x_l), Add(in_x_t, in_x_b))); + auto diff_x = Mul(gammacv, Sub(in_x, base_x)); + diff_x = Mul(diff_x, diff_x); + diff = MulAdd(kXMulv, diff_x, diff); + diff = MaskingSqrt(df, diff); + if ((y & 3) != 0) { + diff = Add(diff, LoadU(df, row_out + x - x0)); + } + StoreU(diff, df, row_out + x - x0); + } + // Scalar + for (; x < x1; ++x) { + scalar_pixel(x); + } + if (y % 4 == 3) { + float* row_dout = pre_erosion[thread].Row((y - y_start) / 4); + for (size_t x = 0; x < (x1 - x0) / 4; x++) { + row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] + + row_out[x * 4 + 2] + row_out[x * 4 + 3]) * + 0.25f; + } + } + } + Rect from_rect(x0 % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1, + rect.xsize() * 2, rect.ysize() * 2); + FuzzyErosion(from_rect, pre_erosion[thread], rect, &aq_map); + for (size_t y = 0; y < rect.ysize(); ++y) { + const float* aq_map_row = rect.ConstRow(aq_map, y); + float* mask_row = rect.Row(mask, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]); + } + } + PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1), + xyb.Plane(2), scale, rect, &aq_map); + } + std::vector pre_erosion; + ImageF aq_map; + ImageF diff_buffer; +}; + +ImageF AdaptiveQuantizationMap(const float butteraugli_target, + const Image3F& xyb, + const FrameDimensions& frame_dim, float scale, + ThreadPool* pool, ImageF* mask) { + PROFILER_ZONE("aq AdaptiveQuantMap"); + + AdaptiveQuantizationImpl impl; + impl.Init(xyb); + *mask = ImageF(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + JXL_CHECK(RunOnPool( + pool, 0, + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) * + DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks), + [&](const size_t num_threads) { + impl.PrepareBuffers(num_threads); + return true; + }, + [&](const uint32_t tid, const size_t thread) { + size_t n_enc_tiles = + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = + std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = + std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks); + Rect r(bx0, by0, bx1 - bx0, by1 - by0); + impl.ComputeTile(butteraugli_target, scale, xyb, r, thread, mask); + }, + "AQ DiffPrecompute")); + + return std::move(impl).aq_map; +} + +} // namespace + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(AdaptiveQuantizationMap); + +namespace { +// If true, prints the quantization maps at each iteration. +bool FLAGS_dump_quant_state = false; + +void DumpHeatmap(const AuxOut* aux_out, const std::string& label, + const ImageF& image, float good_threshold, + float bad_threshold) { + Image3F heatmap = CreateHeatMapImage(image, good_threshold, bad_threshold); + char filename[200]; + snprintf(filename, sizeof(filename), "%s%05d", label.c_str(), + aux_out->num_butteraugli_iters); + aux_out->DumpImage(filename, heatmap); +} + +void DumpHeatmaps(const AuxOut* aux_out, float ba_target, + const ImageF& quant_field, const ImageF& tile_heatmap, + const ImageF& bt_diffmap) { + if (!WantDebugOutput(aux_out)) return; + ImageF inv_qmap(quant_field.xsize(), quant_field.ysize()); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* JXL_RESTRICT row_q = quant_field.ConstRow(y); + float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + row_inv_q[x] = 1.0f / row_q[x]; // never zero + } + } + DumpHeatmap(aux_out, "quant_heatmap", inv_qmap, 4.0f * ba_target, + 6.0f * ba_target); + DumpHeatmap(aux_out, "tile_heatmap", tile_heatmap, ba_target, + 1.5f * ba_target); + // matches heat maps produced by the command line tool. + DumpHeatmap(aux_out, "bt_diffmap", bt_diffmap, ButteraugliFuzzyInverse(1.5), + ButteraugliFuzzyInverse(0.5)); +} + +ImageF TileDistMap(const ImageF& distmap, int tile_size, int margin, + const AcStrategyImage& ac_strategy) { + PROFILER_FUNC; + const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size; + const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size; + ImageF tile_distmap(tile_xsize, tile_ysize); + size_t distmap_stride = tile_distmap.PixelsPerRow(); + for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y); + float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y); + for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) { + AcStrategy acs = ac_strategy_row[tile_x]; + if (!acs.IsFirstBlock()) continue; + int this_tile_xsize = acs.covered_blocks_x() * tile_size; + int this_tile_ysize = acs.covered_blocks_y() * tile_size; + int y_begin = std::max(0, tile_size * tile_y - margin); + int y_end = std::min(distmap.ysize(), + tile_size * tile_y + this_tile_ysize + margin); + int x_begin = std::max(0, tile_size * tile_x - margin); + int x_end = std::min(distmap.xsize(), + tile_size * tile_x + this_tile_xsize + margin); + float dist_norm = 0.0; + double pixels = 0; + for (int y = y_begin; y < y_end; ++y) { + float ymul = 1.0; + constexpr float kBorderMul = 0.98f; + constexpr float kCornerMul = 0.7f; + if (margin != 0 && (y == y_begin || y == y_end - 1)) { + ymul = kBorderMul; + } + const float* const JXL_RESTRICT row = distmap.Row(y); + for (int x = x_begin; x < x_end; ++x) { + float xmul = ymul; + if (margin != 0 && (x == x_begin || x == x_end - 1)) { + if (xmul == 1.0) { + xmul = kBorderMul; + } else { + xmul = kCornerMul; + } + } + float v = row[x]; + v *= v; + v *= v; + v *= v; + v *= v; + dist_norm += xmul * v; + pixels += xmul; + } + } + if (pixels == 0) pixels = 1; + // 16th norm is less than the max norm, we reduce the difference + // with this normalization factor. + constexpr float kTileNorm = 1.2f; + const float tile_dist = + kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f); + dist_row[tile_x] = tile_dist; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + dist_row[tile_x + distmap_stride * iy + ix] = tile_dist; + } + } + } + } + return tile_distmap; +} + +constexpr float kDcQuantPow = 0.57f; +static const float kDcQuant = 1.12f; +static const float kAcQuant = 0.8294f; + +void FindBestQuantization(const ImageBundle& linear, const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.resampling > 1 && + cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) { + // For downsampled opsin image, the butteraugli based adaptive quantization + // loop would only make the size bigger without improving the distance much, + // so in this case we enable it only for very high butteraugli targets. + return; + } + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + ImageF& quant_field = enc_state->initial_quant_field; + + // TODO(veluca): this should really be rather handled on the + // ButteraugliComparator side. + struct TemporaryShrink { + TemporaryShrink(ImageBundle& bundle, size_t xsize, size_t ysize) + : bundle(bundle), + orig_xsize(bundle.xsize()), + orig_ysize(bundle.ysize()) { + bundle.ShrinkTo(xsize, ysize); + } + TemporaryShrink(const TemporaryShrink&) = delete; + TemporaryShrink(TemporaryShrink&&) = delete; + + ~TemporaryShrink() { bundle.ShrinkTo(orig_xsize, orig_ysize); } + + ImageBundle& bundle; + size_t orig_xsize; + size_t orig_ysize; + } t(const_cast(linear), + enc_state->shared.frame_header.nonserialized_metadata->xsize(), + enc_state->shared.frame_header.nonserialized_metadata->ysize()); + + const float butteraugli_target = cparams.butteraugli_distance; + const float original_butteraugli = cparams.original_butteraugli_distance; + ButteraugliParams params = cparams.ba_params; + params.intensity_target = linear.metadata()->IntensityTarget(); + // Hack the default intensity target value to be 80.0, the intensity + // target of sRGB images and a more reasonable viewing default than + // JPEG XL file format's default. + if (fabs(params.intensity_target - 255.0f) < 1e-3) { + params.intensity_target = 80.0f; + } + JxlButteraugliComparator comparator(params, cms); + JXL_CHECK(comparator.SetReferenceImage(linear)); + bool lower_is_better = + (comparator.GoodQualityScore() < comparator.BadQualityScore()); + const float initial_quant_dc = InitialQuantDC(butteraugli_target); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + &quant_field); + ImageF tile_distmap; + ImageF initial_quant_field = CopyImage(quant_field); + + float initial_qf_min, initial_qf_max; + ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max); + float initial_qf_ratio = initial_qf_max / initial_qf_min; + float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio); + float asymmetry = 2; + if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low; + float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low); + float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry); + + JXL_ASSERT(qf_higher / qf_lower < 253); + + constexpr int kOriginalComparisonRound = 1; + int iters = cparams.max_butteraugli_iters; + if (iters > 7) { + iters = 7; + } + if (cparams.speed_tier != SpeedTier::kTortoise) { + iters = 2; + } + for (int i = 0; i < iters + 1; ++i) { + if (FLAGS_dump_quant_state) { + printf("\nQuantization field:\n"); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + for (size_t x = 0; x < quant_field.xsize(); ++x) { + printf(" %.5f", quant_field.Row(y)[x]); + } + printf("\n"); + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + ImageBundle dec_linear = RoundtripImage(opsin, enc_state, cms, pool); + PROFILER_ZONE("enc Butteraugli"); + float score; + ImageF diffmap; + JXL_CHECK(comparator.CompareWith(dec_linear, &diffmap, &score)); + if (!lower_is_better) { + score = -score; + diffmap = ScaleImage(-1.0f, diffmap); + } + tile_distmap = TileDistMap(diffmap, 8 * cparams.resampling, 0, + enc_state->shared.ac_strategy); + if (WantDebugOutput(aux_out)) { + aux_out->DumpImage(("dec" + ToString(i)).c_str(), *dec_linear.color()); + DumpHeatmaps(aux_out, butteraugli_target, quant_field, tile_distmap, + diffmap); + } + if (aux_out != nullptr) ++aux_out->num_butteraugli_iters; + if (cparams.log_search_state) { + float minval, maxval; + ImageMinMax(quant_field, &minval, &maxval); + printf("\nButteraugli iter: %d/%d\n", i, cparams.max_butteraugli_iters); + printf("Butteraugli distance: %f (target = %f)\n", score, + original_butteraugli); + printf("quant range: %f ... %f DC quant: %f\n", minval, maxval, + initial_quant_dc); + if (FLAGS_dump_quant_state) { + quantizer.DumpQuantizationMap(raw_quant_field); + } + } + + if (i == iters) break; + + double kPow[8] = { + 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + double kPowMod[8] = { + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + if (i == kOriginalComparisonRound) { + // Don't allow optimization to make the quant field a lot worse than + // what the initial guess was. This allows the AC field to have enough + // precision to reduce the oscillations due to the dc reconstruction. + double kInitMul = 0.6; + const double kOneMinusInitMul = 1.0 - kInitMul; + for (size_t y = 0; y < quant_field.ysize(); ++y) { + float* const JXL_RESTRICT row_q = quant_field.Row(y); + const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x]; + if (row_q[x] < clamp) { + row_q[x] = clamp; + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + + double cur_pow = 0.0; + if (i < 7) { + cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i]; + if (cur_pow < 0) { + cur_pow = 0; + } + } + if (cur_pow == 0.0) { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / original_butteraugli; + if (diff > 1.0f) { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } else { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / original_butteraugli; + if (diff <= 1.0f) { + row_q[x] *= std::pow(diff, cur_pow); + } else { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +void FindBestQuantizationMaxError(const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) { + // TODO(szabadka): Make this work for non-opsin color spaces. + const CompressParams& cparams = enc_state->cparams; + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + ImageF& quant_field = enc_state->initial_quant_field; + + // TODO(veluca): better choice of this value. + const float initial_quant_dc = + 16 * std::sqrt(0.1f / cparams.butteraugli_distance); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + &quant_field); + + const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0], + 1.0f / enc_state->cparams.max_error[1], + 1.0f / enc_state->cparams.max_error[2]}; + + for (int i = 0; i < cparams.max_butteraugli_iters + 1; ++i) { + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + if (aux_out) { + aux_out->DumpXybImage(("ops" + ToString(i)).c_str(), opsin); + } + ImageBundle decoded = RoundtripImage(opsin, enc_state, cms, pool); + if (aux_out) { + aux_out->DumpXybImage(("dec" + ToString(i)).c_str(), *decoded.color()); + } + + for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) { + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) { + AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + float max_error = 0; + for (size_t c = 0; c < 3; c++) { + for (size_t y = by * kBlockDim; + y < (by + acs.covered_blocks_y()) * kBlockDim; y++) { + if (y >= decoded.ysize()) continue; + const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y); + const float* JXL_RESTRICT dec_row = + decoded.color()->ConstPlaneRow(c, y); + for (size_t x = bx * kBlockDim; + x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) { + if (x >= decoded.xsize()) continue; + max_error = std::max( + std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error); + } + } + } + // Target an error between max_error/2 and max_error. + // If the error in the varblock is above the target, increase the qf to + // compensate. If the error is below the target, decrease the qf. + // However, to avoid an excessive increase of the qf, only do so if the + // error is less than half the maximum allowed error. + const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f + : (max_error > 1.0f) ? max_error + : 1.0f; + for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) { + float* JXL_RESTRICT quant_field_row = quant_field.Row(qy); + for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) { + quant_field_row[qx] *= qf_mul; + } + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +} // namespace + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + ImageF* quant_field) { + // Replace the whole quant_field in non-8x8 blocks with the maximum of each + // 8x8 block. + size_t stride = quant_field->PixelsPerRow(); + for (size_t y = 0; y < rect.ysize(); ++y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y); + float* JXL_RESTRICT quant_row = rect.Row(quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + AcStrategy acs = ac_strategy_row[x]; + if (!acs.IsFirstBlock()) continue; + JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize()); + JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize()); + float max = quant_row[x]; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + max = std::max(quant_row[x + ix + iy * stride], max); + } + } + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + quant_row[x + ix + iy * stride] = max; + } + } + } + } +} + +float InitialQuantDC(float butteraugli_target) { + const float kDcMul = 2.9; // Butteraugli target where non-linearity kicks in. + const float butteraugli_target_dc = std::max( + 0.5f * butteraugli_target, + std::min(butteraugli_target, + kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target, + kDcQuantPow))); + // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc. + // The maximum DC value might not be in the kXybRange because of inverse + // gaborish, so we add some slack to the maximum theoretical quant obtained + // this way (64). + return std::min(kDcQuant / butteraugli_target_dc, 50.f); +} + +ImageF InitialQuantField(const float butteraugli_target, const Image3F& opsin, + const FrameDimensions& frame_dim, ThreadPool* pool, + float rescale, ImageF* mask) { + PROFILER_FUNC; + const float quant_ac = kAcQuant / butteraugli_target; + return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)( + butteraugli_target, opsin, frame_dim, quant_ac * rescale, pool, mask); +} + +void FindBestQuantizer(const ImageBundle* linear, const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, double rescale) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.max_error_mode) { + PROFILER_ZONE("enc find best maxerr"); + FindBestQuantizationMaxError(opsin, enc_state, cms, pool, aux_out); + } else if (cparams.speed_tier <= SpeedTier::kKitten) { + // Normal encoding to a butteraugli score. + PROFILER_ZONE("enc find best2"); + FindBestQuantization(*linear, opsin, enc_state, cms, pool, aux_out); + } +} + +ImageBundle RoundtripImage(const Image3F& opsin, PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool) { + PROFILER_ZONE("enc roundtrip"); + std::unique_ptr dec_state = + jxl::make_unique(); + JXL_CHECK(dec_state->output_encoding_info.SetFromMetadata( + *enc_state->shared.metadata)); + dec_state->shared = &enc_state->shared; + JXL_ASSERT(opsin.ysize() % kBlockDim == 0); + + const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim); + const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim); + const size_t num_groups = xsize_groups * ysize_groups; + + size_t num_special_frames = enc_state->special_frames.size(); + + std::unique_ptr modular_frame_encoder = + jxl::make_unique(enc_state->shared.frame_header, + enc_state->cparams); + JXL_CHECK(InitializePassesEncoder(opsin, cms, pool, enc_state, + modular_frame_encoder.get(), nullptr)); + JXL_CHECK(dec_state->Init()); + JXL_CHECK(dec_state->InitForAC(pool)); + + ImageBundle decoded(&enc_state->shared.metadata->m); + decoded.origin = enc_state->shared.frame_header.frame_origin; + decoded.SetFromImage(Image3F(opsin.xsize(), opsin.ysize()), + dec_state->output_encoding_info.color_encoding); + + PassesDecoderState::PipelineOptions options; + options.use_slow_render_pipeline = false; + options.coalescing = true; + options.render_spotcolors = false; + + // Same as dec_state->shared->frame_header.nonserialized_metadata->m + const ImageMetadata& metadata = *decoded.metadata(); + + JXL_CHECK(dec_state->PreparePipeline(&decoded, options)); + + hwy::AlignedUniquePtr group_dec_caches; + const auto allocate_storage = [&](const size_t num_threads) -> Status { + JXL_RETURN_IF_ERROR( + dec_state->render_pipeline->PrepareForThreads(num_threads, + /*use_group_ids=*/false)); + group_dec_caches = hwy::MakeUniqueAlignedArray(num_threads); + return true; + }; + const auto process_group = [&](const uint32_t group_index, + const size_t thread) { + if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(dec_state->shared->BlockGroupRect(group_index), + dec_state.get()); + } + RenderPipelineInput input = + dec_state->render_pipeline->GetInputBuffers(group_index, thread); + JXL_CHECK(DecodeGroupForRoundtrip( + enc_state->coeffs, group_index, dec_state.get(), + &group_dec_caches[thread], thread, input, &decoded, nullptr)); + for (size_t c = 0; c < metadata.num_extra_channels; c++) { + std::pair ri = input.GetBuffer(3 + c); + FillPlane(0.0f, ri.first, ri.second); + } + input.Done(); + }; + JXL_CHECK(RunOnPool(pool, 0, num_groups, allocate_storage, process_group, + "AQ loop")); + + // Ensure we don't create any new special frames. + enc_state->special_frames.resize(num_special_frames); + + return decoded; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_adaptive_quantization.h b/media/libjxl/src/lib/jxl/enc_adaptive_quantization.h new file mode 100644 index 000000000..724353b47 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_adaptive_quantization.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ +#define LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" + +// Heuristics to find a good quantizer for a given image. InitialQuantField +// produces a quantization field (i.e. relative quantization amounts for each +// block) out of an opsin-space image. `InitialQuantField` uses heuristics, +// `FindBestQuantizer` (in non-fast mode) will run multiple encoding-decoding +// steps and try to improve the given quant field. + +namespace jxl { + +// Computes the decoded image for a given set of compression parameters. Mainly +// used in the FindBestQuantization loops and in some tests. +// TODO(veluca): this doesn't seem the best possible file for this function. +ImageBundle RoundtripImage(const Image3F& opsin, PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool); + +// Returns an image subsampled by kBlockDim in each direction. If the value +// at pixel (x,y) in the returned image is greater than 1.0, it means that +// more fine-grained quantization should be used in the corresponding block +// of the input image, while a value less than 1.0 indicates that less +// fine-grained quantization should be enough. Returns a mask, too, which +// can later be used to make better decisions about ac strategy. +ImageF InitialQuantField(float butteraugli_target, const Image3F& opsin, + const FrameDimensions& frame_dim, ThreadPool* pool, + float rescale, ImageF* initial_quant_mask); + +float InitialQuantDC(float butteraugli_target); + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + ImageF* quant_field); + +// Returns a quantizer that uses an adjusted version of the provided +// quant_field. Also computes the dequant_map corresponding to the given +// dequant_float_map and chosen quantization levels. +// `linear` is only used in Kitten mode or slower. +void FindBestQuantizer(const ImageBundle* linear, const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, double rescale = 1.0); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ diff --git a/media/libjxl/src/lib/jxl/enc_ans.cc b/media/libjxl/src/lib/jxl/enc_ans.cc new file mode 100644 index 000000000..81ff83675 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ans.cc @@ -0,0 +1,1686 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_ans.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_cluster.h" +#include "lib/jxl/enc_context_map.h" +#include "lib/jxl/enc_huffman.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +namespace { + +bool ans_fuzzer_friendly_ = false; + +static const int kMaxNumSymbolsForSmallCode = 4; + +void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table, + size_t alphabet_size, size_t log_alpha_size, + ANSEncSymbolInfo* info) { + size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size; + size_t entry_size_minus_1 = (1 << log_entry_size) - 1; + // create valid alias table for empty streams. + for (size_t s = 0; s < std::max(1, alphabet_size); ++s) { + const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s]; + info[s].freq_ = static_cast(freq); +#ifdef USE_MULT_BY_RECIPROCAL + if (freq != 0) { + info[s].ifreq_ = + ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_; + } else { + info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but... + } +#endif + info[s].reverse_map_.resize(freq); + } + for (int i = 0; i < ANS_TAB_SIZE; i++) { + AliasTable::Symbol s = + AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1); + info[s.value].reverse_map_[s.offset] = i; + } +} + +float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts, + size_t len) { + float sum = 0.0f; + int total_histogram = 0; + int total_counts = 0; + for (size_t i = 0; i < len; ++i) { + total_histogram += histogram[i]; + total_counts += counts[i]; + if (histogram[i] > 0) { + JXL_ASSERT(counts[i] > 0); + // += histogram[i] * -log(counts[i]/total_counts) + sum += histogram[i] * + std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i])); + } + } + if (total_histogram > 0) { + JXL_ASSERT(total_counts == ANS_TAB_SIZE); + } + return sum; +} + +float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) { + const float flat_bits = std::max(FastLog2f(len), 0.0f); + float total_histogram = 0; + for (size_t i = 0; i < len; ++i) { + total_histogram += histogram[i]; + } + return total_histogram * flat_bits; +} + +// Static Huffman code for encoding logcounts. The last symbol is used as RLE +// sequence. +static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { + 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7, +}; +static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { + 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65, +}; + +// Returns the difference between largest count that can be represented and is +// smaller than "count" and smallest representable count larger than "count". +static int SmallestIncrement(uint32_t count, uint32_t shift) { + int bits = count == 0 ? -1 : FloorLog2Nonzero(count); + int drop_bits = bits - GetPopulationCountPrecision(bits, shift); + return drop_bits < 0 ? 1 : (1 << drop_bits); +} + +template +bool RebalanceHistogram(const float* targets, int max_symbol, int table_size, + uint32_t shift, int* omit_pos, ANSHistBin* counts) { + int sum = 0; + float sum_nonrounded = 0.0; + int remainder_pos = 0; // if all of them are handled in first loop + int remainder_log = -1; + for (int n = 0; n < max_symbol; ++n) { + if (targets[n] > 0 && targets[n] < 1.0f) { + counts[n] = 1; + sum_nonrounded += targets[n]; + sum += counts[n]; + } + } + const float discount_ratio = + (table_size - sum) / (table_size - sum_nonrounded); + JXL_ASSERT(discount_ratio > 0); + JXL_ASSERT(discount_ratio <= 1.0f); + // Invariant for minimize_error_of_sum == true: + // abs(sum - sum_nonrounded) + // <= SmallestIncrement(max(targets[])) + max_symbol + for (int n = 0; n < max_symbol; ++n) { + if (targets[n] >= 1.0f) { + sum_nonrounded += targets[n]; + counts[n] = + static_cast(targets[n] * discount_ratio); // truncate + if (counts[n] == 0) counts[n] = 1; + if (counts[n] == table_size) counts[n] = table_size - 1; + // Round the count to the closest nonzero multiple of SmallestIncrement + // (when minimize_error_of_sum is false) or one of two closest so as to + // keep the sum as close as possible to sum_nonrounded. + int inc = SmallestIncrement(counts[n], shift); + counts[n] -= counts[n] & (inc - 1); + // TODO(robryk): Should we rescale targets[n]? + const float target = + minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n]; + if (counts[n] == 0 || + (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) { + counts[n] += inc; + } + sum += counts[n]; + const int count_log = FloorLog2Nonzero(static_cast(counts[n])); + if (count_log > remainder_log) { + remainder_pos = n; + remainder_log = count_log; + } + } + } + JXL_ASSERT(remainder_pos != -1); + // NOTE: This is the only place where counts could go negative. We could + // detect that, return false and make ANSHistBin uint32_t. + counts[remainder_pos] -= sum - table_size; + *omit_pos = remainder_pos; + return counts[remainder_pos] > 0; +} + +Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length, + const int precision_bits, uint32_t shift, + int* num_symbols, int* symbols) { + const int32_t table_size = 1 << precision_bits; // target sum / table size + uint64_t total = 0; + int max_symbol = 0; + int symbol_count = 0; + for (int n = 0; n < length; ++n) { + total += counts[n]; + if (counts[n] > 0) { + if (symbol_count < kMaxNumSymbolsForSmallCode) { + symbols[symbol_count] = n; + } + ++symbol_count; + max_symbol = n + 1; + } + } + *num_symbols = symbol_count; + if (symbol_count == 0) { + return true; + } + if (symbol_count == 1) { + counts[symbols[0]] = table_size; + return true; + } + if (symbol_count > table_size) + return JXL_FAILURE("Too many entries in an ANS histogram"); + + const float norm = 1.f * table_size / total; + std::vector targets(max_symbol); + for (size_t n = 0; n < targets.size(); ++n) { + targets[n] = norm * counts[n]; + } + if (!RebalanceHistogram(&targets[0], max_symbol, table_size, shift, + omit_pos, counts)) { + // Use an alternative rebalancing mechanism if the one above failed + // to create a histogram that is positive wherever the original one was. + if (!RebalanceHistogram(&targets[0], max_symbol, table_size, shift, + omit_pos, counts)) { + return JXL_FAILURE("Logic error: couldn't rebalance a histogram"); + } + } + return true; +} + +struct SizeWriter { + size_t size = 0; + void Write(size_t num, size_t bits) { size += num; } +}; + +template +void StoreVarLenUint8(size_t n, Writer* writer) { + JXL_DASSERT(n <= 255); + if (n == 0) { + writer->Write(1, 0); + } else { + writer->Write(1, 1); + size_t nbits = FloorLog2Nonzero(n); + writer->Write(3, nbits); + writer->Write(nbits, n - (1ULL << nbits)); + } +} + +template +void StoreVarLenUint16(size_t n, Writer* writer) { + JXL_DASSERT(n <= 65535); + if (n == 0) { + writer->Write(1, 0); + } else { + writer->Write(1, 1); + size_t nbits = FloorLog2Nonzero(n); + writer->Write(4, nbits); + writer->Write(nbits, n - (1ULL << nbits)); + } +} + +template +bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size, + const int omit_pos, const int num_symbols, uint32_t shift, + const int* symbols, Writer* writer) { + bool ok = true; + if (num_symbols <= 2) { + // Small tree marker to encode 1-2 symbols. + writer->Write(1, 1); + if (num_symbols == 0) { + writer->Write(1, 0); + StoreVarLenUint8(0, writer); + } else { + writer->Write(1, num_symbols - 1); + for (int i = 0; i < num_symbols; ++i) { + StoreVarLenUint8(symbols[i], writer); + } + } + if (num_symbols == 2) { + writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]); + } + } else { + // Mark non-small tree. + writer->Write(1, 0); + // Mark non-flat histogram. + writer->Write(1, 0); + + // Precompute sequences for RLE encoding. Contains the number of identical + // values starting at a given index. Only contains the value at the first + // element of the series. + std::vector same(alphabet_size, 0); + int last = 0; + for (int i = 1; i < alphabet_size; i++) { + // Store the sequence length once different symbol reached, or we're at + // the end, or the length is longer than we can encode, or we are at + // the omit_pos. We don't support including the omit_pos in an RLE + // sequence because this value may use a different amount of log2 bits + // than standard, it is too complex to handle in the decoder. + if (counts[i] != counts[last] || i + 1 == alphabet_size || + (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) { + same[last] = (i - last); + last = i + 1; + } + } + + int length = 0; + std::vector logcounts(alphabet_size); + int omit_log = 0; + for (int i = 0; i < alphabet_size; ++i) { + JXL_ASSERT(counts[i] <= ANS_TAB_SIZE); + JXL_ASSERT(counts[i] >= 0); + if (i == omit_pos) { + length = i + 1; + } else if (counts[i] > 0) { + logcounts[i] = FloorLog2Nonzero(static_cast(counts[i])) + 1; + length = i + 1; + if (i < omit_pos) { + omit_log = std::max(omit_log, logcounts[i] + 1); + } else { + omit_log = std::max(omit_log, logcounts[i]); + } + } + } + logcounts[omit_pos] = omit_log; + + // Elias gamma-like code for shift. Only difference is that if the number + // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip + // the terminating 0 in unary coding. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = FloorLog2Nonzero(shift + 1); + writer->Write(log, (1 << log) - 1); + if (log != upper_bound_log) writer->Write(1, 0); + writer->Write(log, ((1 << log) - 1) & (shift + 1)); + + // Since num_symbols >= 3, we know that length >= 3, therefore we encode + // length - 3. + if (length - 3 > 255) { + // Pretend that everything is OK, but complain about correctness later. + StoreVarLenUint8(255, writer); + ok = false; + } else { + StoreVarLenUint8(length - 3, writer); + } + + // The logcount values are encoded with a static Huffman code. + static const size_t kMinReps = 4; + size_t rep = ANS_LOG_TAB_SIZE + 1; + for (int i = 0; i < length; ++i) { + if (i > 0 && same[i - 1] > kMinReps) { + // Encode the RLE symbol and skip the repeated ones. + writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]); + StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer); + i += same[i - 1] - 2; + continue; + } + writer->Write(kLogCountBitLengths[logcounts[i]], + kLogCountSymbols[logcounts[i]]); + } + for (int i = 0; i < length; ++i) { + if (i > 0 && same[i - 1] > kMinReps) { + // Skip symbols encoded by RLE. + i += same[i - 1] - 2; + continue; + } + if (logcounts[i] > 1 && i != omit_pos) { + int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift); + int drop_bits = logcounts[i] - 1 - bitcount; + JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0); + writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount)); + } + } + } + return ok; +} + +void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) { + // Mark non-small tree. + writer->Write(1, 0); + // Mark uniform histogram. + writer->Write(1, 1); + JXL_ASSERT(alphabet_size > 0); + // Encode alphabet size. + StoreVarLenUint8(alphabet_size - 1, writer); +} + +float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size, + uint32_t method) { + if (method == 0) { // Flat code + return ANS_LOG_TAB_SIZE + 2 + + EstimateDataBitsFlat(histogram, alphabet_size); + } + // Non-flat: shift = method-1. + uint32_t shift = method - 1; + std::vector counts(histogram, histogram + alphabet_size); + int omit_pos = 0; + int num_symbols; + int symbols[kMaxNumSymbolsForSmallCode] = {}; + JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, + ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); + SizeWriter writer; + // Ignore the correctness, no real encoding happens at this stage. + (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift, + symbols, &writer); + return writer.size + + EstimateDataBits(histogram, counts.data(), alphabet_size); +} + +uint32_t ComputeBestMethod( + const ANSHistBin* histogram, size_t alphabet_size, float* cost, + HistogramParams::ANSHistogramStrategy ans_histogram_strategy) { + size_t method = 0; + float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0); + auto try_shift = [&](size_t shift) { + float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1); + if (c < fcost) { + method = shift + 1; + fcost = c; + } + }; + switch (ans_histogram_strategy) { + case HistogramParams::ANSHistogramStrategy::kPrecise: { + for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) { + try_shift(shift); + } + break; + } + case HistogramParams::ANSHistogramStrategy::kApproximate: { + for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) { + try_shift(shift); + } + break; + } + case HistogramParams::ANSHistogramStrategy::kFast: { + try_shift(0); + try_shift(ANS_LOG_TAB_SIZE / 2); + try_shift(ANS_LOG_TAB_SIZE); + break; + } + }; + *cost = fcost; + return method; +} + +} // namespace + +// Returns an estimate of the cost of encoding this histogram and the +// corresponding data. +size_t BuildAndStoreANSEncodingData( + HistogramParams::ANSHistogramStrategy ans_histogram_strategy, + const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size, + bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) { + if (use_prefix_code) { + if (alphabet_size <= 1) return 0; + std::vector histo(alphabet_size); + for (size_t i = 0; i < alphabet_size; i++) { + histo[i] = histogram[i]; + JXL_CHECK(histogram[i] >= 0); + } + size_t cost = 0; + { + std::vector depths(alphabet_size); + std::vector bits(alphabet_size); + if (writer == nullptr) { + BitWriter tmp_writer; + BitWriter::Allotment allotment( + &tmp_writer, 8 * alphabet_size + 8); // safe upper bound + BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), + bits.data(), &tmp_writer); + ReclaimAndCharge(&tmp_writer, &allotment, 0, /*aux_out=*/nullptr); + cost = tmp_writer.BitsWritten(); + } else { + size_t start = writer->BitsWritten(); + BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), + bits.data(), writer); + cost = writer->BitsWritten() - start; + } + for (size_t i = 0; i < alphabet_size; i++) { + info[i].bits = depths[i] == 0 ? 0 : bits[i]; + info[i].depth = depths[i]; + } + } + // Estimate data cost. + for (size_t i = 0; i < alphabet_size; i++) { + cost += histogram[i] * info[i].depth; + } + return cost; + } + JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE); + // Ensure we ignore trailing zeros in the histogram. + if (alphabet_size != 0) { + size_t largest_symbol = 0; + for (size_t i = 0; i < alphabet_size; i++) { + if (histogram[i] != 0) largest_symbol = i; + } + alphabet_size = largest_symbol + 1; + } + float cost; + uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost, + ans_histogram_strategy); + JXL_ASSERT(cost >= 0); + int num_symbols; + int symbols[kMaxNumSymbolsForSmallCode] = {}; + std::vector counts(histogram, histogram + alphabet_size); + if (!counts.empty()) { + size_t sum = 0; + for (size_t i = 0; i < counts.size(); i++) { + sum += counts[i]; + } + if (sum == 0) { + counts[0] = ANS_TAB_SIZE; + } + } + if (method == 0) { + counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); + AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; + InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a); + ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); + if (writer != nullptr) { + EncodeFlatHistogram(alphabet_size, writer); + } + return cost; + } + int omit_pos = 0; + uint32_t shift = method - 1; + JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, + ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); + AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; + InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a); + ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); + if (writer != nullptr) { + bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, + shift, symbols, writer); + (void)ok; + JXL_DASSERT(ok); + } + return cost; +} + +float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) { + float c; + ComputeBestMethod(data, alphabet_size, &c, + HistogramParams::ANSHistogramStrategy::kFast); + return c; +} + +template +void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer, + size_t log_alpha_size) { + writer->Write(CeilLog2Nonzero(log_alpha_size + 1), + uint_config.split_exponent); + if (uint_config.split_exponent == log_alpha_size) { + return; // msb/lsb don't matter. + } + size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1); + writer->Write(nbits, uint_config.msb_in_token); + nbits = CeilLog2Nonzero(uint_config.split_exponent - + uint_config.msb_in_token + 1); + writer->Write(nbits, uint_config.lsb_in_token); +} +template +void EncodeUintConfigs(const std::vector& uint_config, + Writer* writer, size_t log_alpha_size) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config.size(); i++) { + EncodeUintConfig(uint_config[i], writer, log_alpha_size); + } +} +template void EncodeUintConfigs(const std::vector&, + BitWriter*, size_t); + +namespace { + +void ChooseUintConfigs(const HistogramParams& params, + const std::vector>& tokens, + const std::vector& context_map, + std::vector* clustered_histograms, + EntropyEncodingData* codes, size_t* log_alpha_size) { + codes->uint_config.resize(clustered_histograms->size()); + + if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return; + if (params.uint_method == HistogramParams::HybridUintMethod::k000) { + codes->uint_config.clear(); + codes->uint_config.resize(clustered_histograms->size(), + HybridUintConfig(0, 0, 0)); + return; + } + if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { + codes->uint_config.clear(); + codes->uint_config.resize(clustered_histograms->size(), + HybridUintConfig(2, 0, 1)); + return; + } + + // Brute-force method that tries a few options. + std::vector configs; + if (params.uint_method == HistogramParams::HybridUintMethod::kBest) { + configs = { + HybridUintConfig(4, 2, 0), // default + HybridUintConfig(4, 1, 0), // less precise + HybridUintConfig(4, 2, 1), // add sign + HybridUintConfig(4, 2, 2), // add sign+parity + HybridUintConfig(4, 1, 2), // add parity but less msb + // Same as above, but more direct coding. + HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0), + HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2), + HybridUintConfig(5, 1, 2), + // Same as above, but less direct coding. + HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0), + HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2), + // For near-lossless. + HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4), + HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5), + HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0), + // Other + HybridUintConfig(0, 0, 0), // varlenuint + HybridUintConfig(2, 0, 1), // works well for ctx map + HybridUintConfig(7, 0, 0), // direct coding + HybridUintConfig(8, 0, 0), // direct coding + HybridUintConfig(9, 0, 0), // direct coding + HybridUintConfig(10, 0, 0), // direct coding + HybridUintConfig(11, 0, 0), // direct coding + HybridUintConfig(12, 0, 0), // direct coding + }; + } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) { + configs = { + HybridUintConfig(4, 2, 0), // default + HybridUintConfig(4, 1, 2), // add parity but less msb + HybridUintConfig(0, 0, 0), // smallest histograms + HybridUintConfig(2, 0, 1), // works well for ctx map + }; + } + + std::vector costs(clustered_histograms->size(), + std::numeric_limits::max()); + std::vector extra_bits(clustered_histograms->size()); + std::vector is_valid(clustered_histograms->size()); + size_t max_alpha = + codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE; + for (HybridUintConfig cfg : configs) { + std::fill(is_valid.begin(), is_valid.end(), true); + std::fill(extra_bits.begin(), extra_bits.end(), 0); + + for (size_t i = 0; i < clustered_histograms->size(); i++) { + (*clustered_histograms)[i].Clear(); + } + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + // TODO(veluca): do not ignore lz77 commands. + if (token.is_lz77_length) continue; + size_t histo = context_map[token.context]; + uint32_t tok, nbits, bits; + cfg.Encode(token.value, &tok, &nbits, &bits); + if (tok >= max_alpha || + (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) { + is_valid[histo] = false; + continue; + } + extra_bits[histo] += nbits; + (*clustered_histograms)[histo].Add(tok); + } + } + + for (size_t i = 0; i < clustered_histograms->size(); i++) { + if (!is_valid[i]) continue; + float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i]; + // add signaling cost of the hybriduintconfig itself + cost += CeilLog2Nonzero(cfg.split_exponent + 1); + cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1); + if (cost < costs[i]) { + codes->uint_config[i] = cfg; + costs[i] = cost; + } + } + } + + // Rebuild histograms. + for (size_t i = 0; i < clustered_histograms->size(); i++) { + (*clustered_histograms)[i].Clear(); + } + *log_alpha_size = 4; + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + uint32_t tok, nbits, bits; + size_t histo = context_map[token.context]; + (token.is_lz77_length ? codes->lz77.length_uint_config + : codes->uint_config[histo]) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; + (*clustered_histograms)[histo].Add(tok); + while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++; + } + } +#if JXL_ENABLE_ASSERT + size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8; + JXL_ASSERT(*log_alpha_size <= max_log_alpha_size); +#endif +} + +class HistogramBuilder { + public: + explicit HistogramBuilder(const size_t num_contexts) + : histograms_(num_contexts) {} + + void VisitSymbol(int symbol, size_t histo_idx) { + JXL_DASSERT(histo_idx < histograms_.size()); + histograms_[histo_idx].Add(symbol); + } + + // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge. + size_t BuildAndStoreEntropyCodes( + const HistogramParams& params, + const std::vector>& tokens, EntropyEncodingData* codes, + std::vector* context_map, bool use_prefix_code, + BitWriter* writer, size_t layer, AuxOut* aux_out) const { + size_t cost = 0; + codes->encoding_info.clear(); + std::vector clustered_histograms(histograms_); + context_map->resize(histograms_.size()); + if (histograms_.size() > 1) { + if (!ans_fuzzer_friendly_) { + std::vector histogram_symbols; + ClusterHistograms(params, histograms_, kClustersLimit, + &clustered_histograms, &histogram_symbols); + for (size_t c = 0; c < histograms_.size(); ++c) { + (*context_map)[c] = static_cast(histogram_symbols[c]); + } + } else { + fill(context_map->begin(), context_map->end(), 0); + size_t max_symbol = 0; + for (const Histogram& h : histograms_) { + max_symbol = std::max(h.data_.size(), max_symbol); + } + size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1); + clustered_histograms.resize(1); + clustered_histograms[0].Clear(); + for (size_t i = 0; i < num_symbols; i++) { + clustered_histograms[0].Add(i); + } + } + if (writer != nullptr) { + EncodeContextMap(*context_map, clustered_histograms.size(), writer, + layer, aux_out); + } + } + if (aux_out != nullptr) { + for (size_t i = 0; i < clustered_histograms.size(); ++i) { + aux_out->layers[layer].clustered_entropy += + clustered_histograms[i].ShannonEntropy(); + } + } + codes->use_prefix_code = use_prefix_code; + size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default. + if (ans_fuzzer_friendly_) { + codes->uint_config.clear(); + codes->uint_config.resize(1, HybridUintConfig(7, 0, 0)); + } else { + ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms, + codes, &log_alpha_size); + } + if (log_alpha_size < 5) log_alpha_size = 5; + SizeWriter size_writer; // Used if writer == nullptr to estimate costs. + cost += 1; + if (writer) writer->Write(1, use_prefix_code); + + if (use_prefix_code) { + log_alpha_size = PREFIX_MAX_BITS; + } else { + cost += 2; + } + if (writer == nullptr) { + EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size); + } else { + if (!use_prefix_code) writer->Write(2, log_alpha_size - 5); + EncodeUintConfigs(codes->uint_config, writer, log_alpha_size); + } + if (use_prefix_code) { + for (size_t c = 0; c < clustered_histograms.size(); ++c) { + size_t num_symbol = 1; + for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) { + if (clustered_histograms[c].data_[i]) num_symbol = i + 1; + } + if (writer) { + StoreVarLenUint16(num_symbol - 1, writer); + } else { + StoreVarLenUint16(num_symbol - 1, &size_writer); + } + } + } + cost += size_writer.size; + for (size_t c = 0; c < clustered_histograms.size(); ++c) { + size_t num_symbol = 1; + for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) { + if (clustered_histograms[c].data_[i]) num_symbol = i + 1; + } + codes->encoding_info.emplace_back(); + codes->encoding_info.back().resize(std::max(1, num_symbol)); + + BitWriter::Allotment allotment(writer, 256 + num_symbol * 24); + cost += BuildAndStoreANSEncodingData( + params.ans_histogram_strategy, clustered_histograms[c].data_.data(), + num_symbol, log_alpha_size, use_prefix_code, + codes->encoding_info.back().data(), writer); + allotment.FinishedHistogram(writer); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + } + return cost; + } + + const Histogram& Histo(size_t i) const { return histograms_[i]; } + + private: + std::vector histograms_; +}; + +class SymbolCostEstimator { + public: + SymbolCostEstimator(size_t num_contexts, bool force_huffman, + const std::vector>& tokens, + const LZ77Params& lz77) { + HistogramBuilder builder(num_contexts); + // Build histograms for estimating lz77 savings. + HybridUintConfig uint_config; + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? lz77.length_uint_config : uint_config) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? lz77.min_symbol : 0; + builder.VisitSymbol(tok, token.context); + } + } + max_alphabet_size_ = 0; + for (size_t i = 0; i < num_contexts; i++) { + max_alphabet_size_ = + std::max(max_alphabet_size_, builder.Histo(i).data_.size()); + } + bits_.resize(num_contexts * max_alphabet_size_); + // TODO(veluca): SIMD? + add_symbol_cost_.resize(num_contexts); + for (size_t i = 0; i < num_contexts; i++) { + float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f); + float total_cost = 0; + for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) { + size_t cnt = builder.Histo(i).data_[j]; + float cost = 0; + if (cnt != 0 && cnt != builder.Histo(i).total_count_) { + cost = -FastLog2f(cnt * inv_total); + if (force_huffman) cost = std::ceil(cost); + } else if (cnt == 0) { + cost = ANS_LOG_TAB_SIZE; // Highest possible cost. + } + bits_[i * max_alphabet_size_ + j] = cost; + total_cost += cost * builder.Histo(i).data_[j]; + } + // Penalty for adding a lz77 symbol to this contest (only used for static + // cost model). Higher penalty for contexts that have a very low + // per-symbol entropy. + add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); + } + } + float Bits(size_t ctx, size_t sym) const { + return bits_[ctx * max_alphabet_size_ + sym]; + } + float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { + uint32_t nbits, bits, tok; + lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); + tok += lz77.min_symbol; + return nbits + Bits(ctx, tok); + } + float DistCost(size_t len, const LZ77Params& lz77) const { + uint32_t nbits, bits, tok; + HybridUintConfig().Encode(len, &tok, &nbits, &bits); + return nbits + Bits(lz77.nonserialized_distance_context, tok); + } + float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } + + private: + size_t max_alphabet_size_; + std::vector bits_; + std::vector add_symbol_cost_; +}; + +void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, + LZ77Params& lz77, + std::vector>& tokens_lz77) { + // TODO(veluca): tune heuristics here. + SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); + float bit_decrease = 0; + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + std::vector sym_cost; + HybridUintConfig uint_config; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + out.reserve(in.size()); + for (size_t i = 0; i < in.size(); i++) { + size_t num_to_copy = 0; + size_t distance_symbol = 0; // 1 for RLE. + if (distance_multiplier != 0) { + distance_symbol = 1; // Special distance 1 if enabled. + JXL_DASSERT(kSpecialDistances[1][0] == 1); + JXL_DASSERT(kSpecialDistances[1][1] == 0); + } + if (i > 0) { + for (; i + num_to_copy < in.size(); num_to_copy++) { + if (in[i + num_to_copy].value != in[i - 1].value) { + break; + } + } + } + if (num_to_copy == 0) { + out.push_back(in[i]); + continue; + } + float cost = sym_cost[i + num_to_copy] - sym_cost[i]; + // This subtraction might overflow, but that's OK. + size_t lz77_len = num_to_copy - lz77.min_length; + float lz77_cost = num_to_copy >= lz77.min_length + ? CeilLog2Nonzero(lz77_len + 1) + 1 + : 0; + if (num_to_copy < lz77.min_length || cost <= lz77_cost) { + for (size_t j = 0; j < num_to_copy; j++) { + out.push_back(in[i + j]); + } + i += num_to_copy - 1; + continue; + } + // Output the LZ77 length + out.emplace_back(in[i].context, lz77_len); + out.back().is_lz77_length = true; + i += num_to_copy - 1; + bit_decrease += cost - lz77_cost; + // Output the LZ77 copy distance. + out.emplace_back(lz77.nonserialized_distance_context, distance_symbol); + } + } + + if (bit_decrease > total_symbols * 0.2 + 16) { + lz77.enabled = true; + } +} + +// Hash chain for LZ77 matching +struct HashChain { + size_t size_; + std::vector data_; + + unsigned hash_num_values_ = 32768; + unsigned hash_mask_ = hash_num_values_ - 1; + unsigned hash_shift_ = 5; + + std::vector head; + std::vector chain; + std::vector val; + + // Speed up repetitions of zero + std::vector headz; + std::vector chainz; + std::vector zeros; + uint32_t numzeros = 0; + + size_t window_size_; + size_t window_mask_; + size_t min_length_; + size_t max_length_; + + // Map of special distance codes. + std::unordered_map special_dist_table_; + size_t num_special_distances_ = 0; + + uint32_t maxchainlength = 256; // window_size_ to allow all + + HashChain(const Token* data, size_t size, size_t window_size, + size_t min_length, size_t max_length, size_t distance_multiplier) + : size_(size), + window_size_(window_size), + window_mask_(window_size - 1), + min_length_(min_length), + max_length_(max_length) { + data_.resize(size); + for (size_t i = 0; i < size; i++) { + data_[i] = data[i].value; + } + + head.resize(hash_num_values_, -1); + val.resize(window_size_, -1); + chain.resize(window_size_); + for (uint32_t i = 0; i < window_size_; ++i) { + chain[i] = i; // same value as index indicates uninitialized + } + + zeros.resize(window_size_); + headz.resize(window_size_ + 1, -1); + chainz.resize(window_size_); + for (uint32_t i = 0; i < window_size_; ++i) { + chainz[i] = i; + } + // Translate distance to special distance code. + if (distance_multiplier) { + // Count down, so if due to small distance multiplier multiple distances + // map to the same code, the smallest code will be used in the end. + for (int i = kNumSpecialDistances - 1; i >= 0; --i) { + int xi = kSpecialDistances[i][0]; + int yi = kSpecialDistances[i][1]; + int distance = yi * distance_multiplier + xi; + // Ensure that we map distance 1 to the lowest symbols. + if (distance < 1) distance = 1; + special_dist_table_[distance] = i; + } + num_special_distances_ = kNumSpecialDistances; + } + } + + uint32_t GetHash(size_t pos) const { + uint32_t result = 0; + if (pos + 2 < size_) { + // TODO(lode): take the MSB's of the uint32_t values into account as well, + // given that the hash code itself is less than 32 bits. + result ^= (uint32_t)(data_[pos + 0] << 0u); + result ^= (uint32_t)(data_[pos + 1] << hash_shift_); + result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2)); + } else { + // No need to compute hash of last 2 bytes, the length 2 is too short. + return 0; + } + return result & hash_mask_; + } + + uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { + size_t end = pos + window_size_; + if (end > size_) end = size_; + if (prevzeros > 0) { + if (prevzeros >= window_mask_ && data_[end - 1] == 0 && + end == pos + window_size_) { + return prevzeros; + } else { + return prevzeros - 1; + } + } + uint32_t num = 0; + while (pos + num < end && data_[pos + num] == 0) num++; + return num; + } + + void Update(size_t pos) { + uint32_t hashval = GetHash(pos); + uint32_t wpos = pos & window_mask_; + + val[wpos] = (int)hashval; + if (head[hashval] != -1) chain[wpos] = head[hashval]; + head[hashval] = wpos; + + if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; + numzeros = CountZeros(pos, numzeros); + + zeros[wpos] = numzeros; + if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; + headz[numzeros] = wpos; + } + + void Update(size_t pos, size_t len) { + for (size_t i = 0; i < len; i++) { + Update(pos + i); + } + } + + template + void FindMatches(size_t pos, int max_dist, const CB& found_match) const { + uint32_t wpos = pos & window_mask_; + uint32_t hashval = GetHash(pos); + uint32_t hashpos = chain[wpos]; + + int prev_dist = 0; + int end = std::min(pos + max_length_, size_); + uint32_t chainlength = 0; + uint32_t best_len = 0; + for (;;) { + int dist = (hashpos <= wpos) ? (wpos - hashpos) + : (wpos - hashpos + window_mask_ + 1); + if (dist < prev_dist) break; + prev_dist = dist; + uint32_t len = 0; + if (dist > 0) { + int i = pos; + int j = pos - dist; + if (numzeros > 3) { + int r = std::min(numzeros - 1, zeros[hashpos]); + if (i + r >= end) r = end - i - 1; + i += r; + j += r; + } + while (i < end && data_[i] == data_[j]) { + i++; + j++; + } + len = i - pos; + // This can trigger even if the new length is slightly smaller than the + // best length, because it is possible for a slightly cheaper distance + // symbol to occur. + if (len >= min_length_ && len + 2 >= best_len) { + auto it = special_dist_table_.find(dist); + int dist_symbol = (it == special_dist_table_.end()) + ? (num_special_distances_ + dist - 1) + : it->second; + found_match(len, dist_symbol); + if (len > best_len) best_len = len; + } + } + + chainlength++; + if (chainlength >= maxchainlength) break; + + if (numzeros >= 3 && len > numzeros) { + if (hashpos == chainz[hashpos]) break; + hashpos = chainz[hashpos]; + if (zeros[hashpos] != numzeros) break; + } else { + if (hashpos == chain[hashpos]) break; + hashpos = chain[hashpos]; + if (val[hashpos] != (int)hashval) break; // outdated hash value + } + } + } + void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, + size_t* result_len) const { + *result_dist_symbol = 0; + *result_len = 1; + FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { + if (len > *result_len || + (len == *result_len && *result_dist_symbol > dist_symbol)) { + *result_len = len; + *result_dist_symbol = dist_symbol; + } + }); + } +}; + +float LenCost(size_t len) { + uint32_t nbits, bits, tok; + HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); + constexpr float kCostTable[] = { + 2.797667318563126, 3.213177690381199, 2.5706009246743737, + 2.408392498667534, 2.829649191872326, 3.3923087753324577, + 4.029267451554331, 4.415576699706408, 4.509357574741465, + 9.21481543803004, 10.020590190114898, 11.858671627804766, + 12.45853300490526, 11.713105831990857, 12.561996324849314, + 13.775477692278367, 13.174027068768641, + }; + size_t table_size = sizeof kCostTable / sizeof *kCostTable; + if (tok >= table_size) tok = table_size - 1; + return kCostTable[tok] + nbits; +} + +// TODO(veluca): this does not take into account usage or non-usage of distance +// multipliers. +float DistCost(size_t dist) { + uint32_t nbits, bits, tok; + HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); + constexpr float kCostTable[] = { + 6.368282626312716, 5.680793277090298, 8.347404197105247, + 7.641619201599141, 6.914328374119438, 7.959808291537444, + 8.70023120759855, 8.71378518934703, 9.379132523982769, + 9.110472749092708, 9.159029569270908, 9.430936766731973, + 7.278284055315169, 7.8278514904267755, 10.026641158289236, + 9.976049229827066, 9.64351607048908, 9.563403863480442, + 10.171474111762747, 10.45950155077234, 9.994813912104219, + 10.322524683741156, 8.465808729388186, 8.756254166066853, + 10.160930174662234, 10.247329273413435, 10.04090403724809, + 10.129398517544082, 9.342311691539546, 9.07608009102374, + 10.104799540677513, 10.378079384990906, 10.165828974075072, + 10.337595322341553, 7.940557464567944, 10.575665823319431, + 11.023344321751955, 10.736144698831827, 11.118277044595054, + 7.468468230648442, 10.738305230932939, 10.906980780216568, + 10.163468216353817, 10.17805759656433, 11.167283670483565, + 11.147050200274544, 10.517921919244333, 10.651764778156886, + 10.17074446448919, 11.217636876224745, 11.261630721139484, + 11.403140815247259, 10.892472096873417, 11.1859607804481, + 8.017346947551262, 7.895143720278828, 11.036577113822025, + 11.170562110315794, 10.326988722591086, 10.40872184751056, + 11.213498225466386, 11.30580635516863, 10.672272515665442, + 10.768069466228063, 11.145257364153565, 11.64668307145549, + 10.593156194627339, 11.207499484844943, 10.767517766396908, + 10.826629811407042, 10.737764794499988, 10.6200448518045, + 10.191315385198092, 8.468384171390085, 11.731295299170432, + 11.824619886654398, 10.41518844301179, 10.16310536548649, + 10.539423685097576, 10.495136599328031, 10.469112847728267, + 11.72057686174922, 10.910326337834674, 11.378921834673758, + 11.847759036098536, 11.92071647623854, 10.810628276345282, + 11.008601085273893, 11.910326337834674, 11.949212023423133, + 11.298614839104337, 11.611603659010392, 10.472930394619985, + 11.835564720850282, 11.523267392285337, 12.01055816679611, + 8.413029688994023, 11.895784139536406, 11.984679534970505, + 11.220654278717394, 11.716311684833672, 10.61036646226114, + 10.89849965960364, 10.203762898863669, 10.997560826267238, + 11.484217379438984, 11.792836176993665, 12.24310468755171, + 11.464858097919262, 12.212747017409377, 11.425595666074955, + 11.572048533398757, 12.742093965163013, 11.381874288645637, + 12.191870445817015, 11.683156920035426, 11.152442115262197, + 11.90303691580457, 11.653292787169159, 11.938615382266098, + 16.970641701570223, 16.853602280380002, 17.26240782594733, + 16.644655390108507, 17.14310889757499, 16.910935455445955, + 17.505678976959697, 17.213498225466388, 2.4162310293553024, + 3.494587244462329, 3.5258600986408344, 3.4959806589517095, + 3.098390886949687, 3.343454654302911, 3.588847442290287, + 4.14614790111827, 5.152948641990529, 7.433696808092598, + 9.716311684833672, + }; + size_t table_size = sizeof kCostTable / sizeof *kCostTable; + if (tok >= table_size) tok = table_size - 1; + return kCostTable[tok] + nbits; +} + +void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, + LZ77Params& lz77, + std::vector>& tokens_lz77) { + // TODO(veluca): tune heuristics here. + SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); + float bit_decrease = 0; + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + HybridUintConfig uint_config; + std::vector sym_cost; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + + out.reserve(in.size()); + size_t max_distance = in.size(); + size_t min_length = lz77.min_length; + JXL_ASSERT(min_length >= 3); + size_t max_length = in.size(); + + // Use next power of two as window size. + size_t window_size = 1; + while (window_size < max_distance && window_size < kWindowSize) { + window_size <<= 1; + } + + HashChain chain(in.data(), in.size(), window_size, min_length, max_length, + distance_multiplier); + size_t len, dist_symbol; + + const size_t max_lazy_match_len = 256; // 0 to disable lazy matching + + // Whether the next symbol was already updated (to test lazy matching) + bool already_updated = false; + for (size_t i = 0; i < in.size(); i++) { + out.push_back(in[i]); + if (!already_updated) chain.Update(i); + already_updated = false; + chain.FindMatch(i, max_distance, &dist_symbol, &len); + if (len >= min_length) { + if (len < max_lazy_match_len && i + 1 < in.size()) { + // Try length at next symbol lazy matching + chain.Update(i + 1); + already_updated = true; + size_t len2, dist_symbol2; + chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); + if (len2 > len) { + // Use the lazy match. Add literal, and use the next length starting + // from the next byte. + ++i; + already_updated = false; + len = len2; + dist_symbol = dist_symbol2; + out.push_back(in[i]); + } + } + + float cost = sym_cost[i + len] - sym_cost[i]; + size_t lz77_len = len - lz77.min_length; + float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + + sce.AddSymbolCost(out.back().context); + + if (lz77_cost <= cost) { + out.back().value = len - min_length; + out.back().is_lz77_length = true; + out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); + bit_decrease += cost - lz77_cost; + } else { + // LZ77 match ignored, and symbol already pushed. Push all other + // symbols and skip. + for (size_t j = 1; j < len; j++) { + out.push_back(in[i + j]); + } + } + + if (already_updated) { + chain.Update(i + 2, len - 2); + already_updated = false; + } else { + chain.Update(i + 1, len - 1); + } + i += len - 1; + } else { + // Literal, already pushed + } + } + } + + if (bit_decrease > total_symbols * 0.2 + 16) { + lz77.enabled = true; + } +} + +void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, + LZ77Params& lz77, + std::vector>& tokens_lz77) { + std::vector> tokens_for_cost_estimate; + ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate); + // If greedy-LZ77 does not give better compression than no-lz77, no reason to + // run the optimal matching. + if (!lz77.enabled) return; + SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, + tokens_for_cost_estimate, lz77); + tokens_lz77.resize(tokens.size()); + HybridUintConfig uint_config; + std::vector sym_cost; + std::vector dist_symbols; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + + out.reserve(in.size()); + size_t max_distance = in.size(); + size_t min_length = lz77.min_length; + JXL_ASSERT(min_length >= 3); + size_t max_length = in.size(); + + // Use next power of two as window size. + size_t window_size = 1; + while (window_size < max_distance && window_size < kWindowSize) { + window_size <<= 1; + } + + HashChain chain(in.data(), in.size(), window_size, min_length, max_length, + distance_multiplier); + + struct MatchInfo { + uint32_t len; + uint32_t dist_symbol; + uint32_t ctx; + float total_cost = std::numeric_limits::max(); + }; + // Total cost to encode the first N symbols. + std::vector prefix_costs(in.size() + 1); + prefix_costs[0].total_cost = 0; + + size_t rle_length = 0; + size_t skip_lz77 = 0; + for (size_t i = 0; i < in.size(); i++) { + chain.Update(i); + float lit_cost = + prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; + if (prefix_costs[i + 1].total_cost > lit_cost) { + prefix_costs[i + 1].dist_symbol = 0; + prefix_costs[i + 1].len = 1; + prefix_costs[i + 1].ctx = in[i].context; + prefix_costs[i + 1].total_cost = lit_cost; + } + if (skip_lz77 > 0) { + skip_lz77--; + continue; + } + dist_symbols.clear(); + chain.FindMatches(i, max_distance, + [&dist_symbols](size_t len, size_t dist_symbol) { + if (dist_symbols.size() <= len) { + dist_symbols.resize(len + 1, dist_symbol); + } + if (dist_symbol < dist_symbols[len]) { + dist_symbols[len] = dist_symbol; + } + }); + if (dist_symbols.size() <= min_length) continue; + { + size_t best_cost = dist_symbols.back(); + for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { + if (dist_symbols[j] < best_cost) { + best_cost = dist_symbols[j]; + } + dist_symbols[j] = best_cost; + } + } + for (size_t j = min_length; j < dist_symbols.size(); j++) { + // Cost model that uses results from lazy LZ77. + float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + + sce.DistCost(dist_symbols[j], lz77); + float cost = prefix_costs[i].total_cost + lz77_cost; + if (prefix_costs[i + j].total_cost > cost) { + prefix_costs[i + j].len = j; + prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; + prefix_costs[i + j].ctx = in[i].context; + prefix_costs[i + j].total_cost = cost; + } + } + // We are in a RLE sequence: skip all the symbols except the first 8 and + // the last 8. This avoid quadratic costs for sequences with long runs of + // the same symbol. + if ((dist_symbols.back() == 0 && distance_multiplier == 0) || + (dist_symbols.back() == 1 && distance_multiplier != 0)) { + rle_length++; + } else { + rle_length = 0; + } + if (rle_length >= 8 && dist_symbols.size() > 9) { + skip_lz77 = dist_symbols.size() - 10; + rle_length = 0; + } + } + size_t pos = in.size(); + while (pos > 0) { + bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; + if (is_lz77_length) { + size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; + out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); + } + size_t val = is_lz77_length ? prefix_costs[pos].len - min_length + : in[pos - 1].value; + out.emplace_back(prefix_costs[pos].ctx, val); + out.back().is_lz77_length = is_lz77_length; + pos -= prefix_costs[pos].len; + } + std::reverse(out.begin(), out.end()); + } +} + +void ApplyLZ77(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, LZ77Params& lz77, + std::vector>& tokens_lz77) { + lz77.enabled = false; + if (params.force_huffman) { + lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512); + } else { + lz77.min_symbol = 224; + } + if (params.lz77_method == HistogramParams::LZ77Method::kNone) { + return; + } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) { + ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77); + } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) { + ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77); + } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) { + ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77); + } else { + JXL_ABORT("Not implemented"); + } +} +} // namespace + +size_t BuildAndEncodeHistograms(const HistogramParams& params, + size_t num_contexts, + std::vector>& tokens, + EntropyEncodingData* codes, + std::vector* context_map, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + size_t total_bits = 0; + codes->lz77.nonserialized_distance_context = num_contexts; + std::vector> tokens_lz77; + ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77); + if (ans_fuzzer_friendly_) { + codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0); + codes->lz77.min_symbol = 2048; + } + + const size_t max_contexts = std::min(num_contexts, kClustersLimit); + BitWriter::Allotment allotment(writer, + 128 + num_contexts * 40 + max_contexts * 96); + if (writer) { + JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out)); + } else { + size_t ebits, bits; + JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits)); + total_bits += bits; + } + if (codes->lz77.enabled) { + if (writer) { + size_t b = writer->BitsWritten(); + EncodeUintConfig(codes->lz77.length_uint_config, writer, + /*log_alpha_size=*/8); + total_bits += writer->BitsWritten() - b; + } else { + SizeWriter size_writer; + EncodeUintConfig(codes->lz77.length_uint_config, &size_writer, + /*log_alpha_size=*/8); + total_bits += size_writer.size; + } + num_contexts += 1; + tokens = std::move(tokens_lz77); + } + size_t total_tokens = 0; + // Build histograms. + HistogramBuilder builder(num_contexts); + HybridUintConfig uint_config; // Default config for clustering. + // Unless we are using the kContextMap histogram option. + if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { + uint_config = HybridUintConfig(2, 0, 1); + } + if (params.uint_method == HistogramParams::HybridUintMethod::k000) { + uint_config = HybridUintConfig(0, 0, 0); + } + if (ans_fuzzer_friendly_) { + uint_config = HybridUintConfig(10, 0, 0); + } + for (size_t i = 0; i < tokens.size(); ++i) { + if (codes->lz77.enabled) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token& token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; + builder.VisitSymbol(tok, token.context); + } + } else if (num_contexts == 1) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token& token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + uint_config.Encode(token.value, &tok, &nbits, &bits); + builder.VisitSymbol(tok, /*token.context=*/0); + } + } else { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token& token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + uint_config.Encode(token.value, &tok, &nbits, &bits); + builder.VisitSymbol(tok, token.context); + } + } + } + + bool use_prefix_code = + params.force_huffman || total_tokens < 100 || + params.clustering == HistogramParams::ClusteringType::kFastest || + ans_fuzzer_friendly_; + if (!use_prefix_code) { + bool all_singleton = true; + for (size_t i = 0; i < num_contexts; i++) { + if (builder.Histo(i).ShannonEntropy() >= 1e-5) { + all_singleton = false; + } + } + if (all_singleton) { + use_prefix_code = true; + } + } + + // Encode histograms. + total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes, + context_map, use_prefix_code, + writer, layer, aux_out); + allotment.FinishedHistogram(writer); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + + if (aux_out != nullptr) { + aux_out->layers[layer].num_clustered_histograms += + codes->encoding_info.size(); + } + return total_bits; +} + +size_t WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer) { + size_t num_extra_bits = 0; + if (codes.use_prefix_code) { + for (size_t i = 0; i < tokens.size(); i++) { + uint32_t tok, nbits, bits; + const Token& token = tokens[i]; + size_t histo = context_map[token.context]; + (token.is_lz77_length ? codes.lz77.length_uint_config + : codes.uint_config[histo]) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; + // Combine two calls to the BitWriter. Equivalent to: + // writer->Write(codes.encoding_info[histo][tok].depth, + // codes.encoding_info[histo][tok].bits); + // writer->Write(nbits, bits); + uint64_t data = codes.encoding_info[histo][tok].bits; + data |= bits << codes.encoding_info[histo][tok].depth; + writer->Write(codes.encoding_info[histo][tok].depth + nbits, data); + num_extra_bits += nbits; + } + return num_extra_bits; + } + std::vector out; + std::vector out_nbits; + out.reserve(tokens.size()); + out_nbits.reserve(tokens.size()); + uint64_t allbits = 0; + size_t numallbits = 0; + // Writes in *reversed* order. + auto addbits = [&](size_t bits, size_t nbits) { + if (JXL_UNLIKELY(nbits)) { + JXL_DASSERT(bits >> nbits == 0); + if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) { + out.push_back(allbits); + out_nbits.push_back(numallbits); + numallbits = allbits = 0; + } + allbits <<= nbits; + allbits |= bits; + numallbits += nbits; + } + }; + const int end = tokens.size(); + ANSCoder ans; + if (codes.lz77.enabled || context_map.size() > 1) { + for (int i = end - 1; i >= 0; --i) { + const Token token = tokens[i]; + const uint8_t histo = context_map[token.context]; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? codes.lz77.length_uint_config + : codes.uint_config[histo]) + .Encode(tokens[i].value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; + const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok]; + // Extra bits first as this is reversed. + addbits(bits, nbits); + num_extra_bits += nbits; + uint8_t ans_nbits = 0; + uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); + addbits(ans_bits, ans_nbits); + } + } else { + for (int i = end - 1; i >= 0; --i) { + uint32_t tok, nbits, bits; + codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits); + const ANSEncSymbolInfo& info = codes.encoding_info[0][tok]; + // Extra bits first as this is reversed. + addbits(bits, nbits); + num_extra_bits += nbits; + uint8_t ans_nbits = 0; + uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); + addbits(ans_bits, ans_nbits); + } + } + const uint32_t state = ans.GetState(); + writer->Write(32, state); + writer->Write(numallbits, allbits); + for (int i = out.size(); i > 0; --i) { + writer->Write(out_nbits[i - 1], out[i - 1]); + } + return num_extra_bits; +} + +void WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4); + size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + if (aux_out != nullptr) { + aux_out->layers[layer].extra_bits += num_extra_bits; + } +} + +void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) { +#if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes. + ans_fuzzer_friendly_ = ans_fuzzer_friendly; +#endif +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_ans.h b/media/libjxl/src/lib/jxl/enc_ans.h new file mode 100644 index 000000000..2f720f560 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ans.h @@ -0,0 +1,143 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ANS_H_ +#define LIB_JXL_ENC_ANS_H_ + +// Library to encode the ANS population counts to the bit-stream and encode +// symbols based on the respective distributions. + +#include +#include +#include +#include +#include + +#include +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans_params.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +#define USE_MULT_BY_RECIPROCAL + +// precision must be equal to: #bits(state_) + #bits(freq) +#define RECIPROCAL_PRECISION (32 + ANS_LOG_TAB_SIZE) + +// Data structure representing one element of the encoding table built +// from a distribution. +// TODO(veluca): split this up, or use an union. +struct ANSEncSymbolInfo { + // ANS + uint16_t freq_; + std::vector reverse_map_; +#ifdef USE_MULT_BY_RECIPROCAL + uint64_t ifreq_; +#endif + // Prefix coding. + uint8_t depth; + uint16_t bits; +}; + +class ANSCoder { + public: + ANSCoder() : state_(ANS_SIGNATURE << 16) {} + + uint32_t PutSymbol(const ANSEncSymbolInfo& t, uint8_t* nbits) { + uint32_t bits = 0; + *nbits = 0; + if ((state_ >> (32 - ANS_LOG_TAB_SIZE)) >= t.freq_) { + bits = state_ & 0xffff; + state_ >>= 16; + *nbits = 16; + } +#ifdef USE_MULT_BY_RECIPROCAL + // We use mult-by-reciprocal trick, but that requires 64b calc. + const uint32_t v = (state_ * t.ifreq_) >> RECIPROCAL_PRECISION; + const uint32_t offset = t.reverse_map_[state_ - v * t.freq_]; + state_ = (v << ANS_LOG_TAB_SIZE) + offset; +#else + state_ = ((state_ / t.freq_) << ANS_LOG_TAB_SIZE) + + t.reverse_map_[state_ % t.freq_]; +#endif + return bits; + } + + uint32_t GetState() const { return state_; } + + private: + uint32_t state_; +}; + +// RebalanceHistogram requires a signed type. +using ANSHistBin = int32_t; + +struct EntropyEncodingData { + std::vector> encoding_info; + bool use_prefix_code; + std::vector uint_config; + LZ77Params lz77; +}; + +// Integer to be encoded by an entropy coder, either ANS or Huffman. +struct Token { + Token() {} + Token(uint32_t c, uint32_t value) + : is_lz77_length(false), context(c), value(value) {} + uint32_t is_lz77_length : 1; + uint32_t context : 31; + uint32_t value; +}; + +// Returns an estimate of the number of bits required to encode the given +// histogram (header bits plus data bits). +float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size); + +// Apply context clustering, compute histograms and encode them. Returns an +// estimate of the total bits used for encoding the stream. If `writer` == +// nullptr, the bit estimate will not take into account the context map (which +// does not get written if `num_contexts` == 1). +size_t BuildAndEncodeHistograms(const HistogramParams& params, + size_t num_contexts, + std::vector>& tokens, + EntropyEncodingData* codes, + std::vector* context_map, + BitWriter* writer, size_t layer, + AuxOut* aux_out); + +// Write the tokens to a string. +void WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +// Same as above, but assumes allotment created by caller. +size_t WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer); + +// Exposed for tests; to be used with Writer=BitWriter only. +template +void EncodeUintConfigs(const std::vector& uint_config, + Writer* writer, size_t log_alpha_size); +extern template void EncodeUintConfigs(const std::vector&, + BitWriter*, size_t); + +// Globally set the option to create fuzzer-friendly ANS streams. Negatively +// impacts compression. Not thread-safe. +void SetANSFuzzerFriendly(bool ans_fuzzer_friendly); +} // namespace jxl + +#endif // LIB_JXL_ENC_ANS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_ans_params.h b/media/libjxl/src/lib/jxl/enc_ans_params.h new file mode 100644 index 000000000..50ca31dc0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ans_params.h @@ -0,0 +1,76 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ANS_PARAMS_H_ +#define LIB_JXL_ENC_ANS_PARAMS_H_ + +// Encoder-only parameter needed for ANS entropy encoding methods. + +#include +#include + +#include "lib/jxl/enc_params.h" + +namespace jxl { + +struct HistogramParams { + enum class ClusteringType { + kFastest, // Only 4 clusters. + kFast, + kBest, + }; + + enum class HybridUintMethod { + kNone, // just use kHybridUint420Config. + k000, // force the fastest option. + kFast, // just try a couple of options. + kContextMap, // fast choice for ctx map. + kBest, + }; + + enum class LZ77Method { + kNone, // do not try lz77. + kRLE, // only try doing RLE. + kLZ77, // try lz77 with backward references. + kOptimal, // optimal-matching LZ77 parsing. + }; + + enum class ANSHistogramStrategy { + kFast, // Only try some methods, early exit. + kApproximate, // Only try some methods. + kPrecise, // Try all methods. + }; + + HistogramParams() = default; + + HistogramParams(SpeedTier tier, size_t num_ctx) { + if (tier > SpeedTier::kFalcon) { + clustering = ClusteringType::kFastest; + lz77_method = LZ77Method::kNone; + } else if (tier > SpeedTier::kTortoise) { + clustering = ClusteringType::kFast; + } else { + clustering = ClusteringType::kBest; + } + if (tier > SpeedTier::kTortoise) { + uint_method = HybridUintMethod::kNone; + } + if (tier >= SpeedTier::kSquirrel) { + ans_histogram_strategy = ANSHistogramStrategy::kApproximate; + } + } + + ClusteringType clustering = ClusteringType::kBest; + HybridUintMethod uint_method = HybridUintMethod::kBest; + LZ77Method lz77_method = LZ77Method::kRLE; + ANSHistogramStrategy ans_histogram_strategy = ANSHistogramStrategy::kPrecise; + std::vector image_widths; + size_t max_histograms = ~0; + bool force_huffman = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_ANS_PARAMS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_ar_control_field.cc b/media/libjxl/src/lib/jxl/enc_ar_control_field.cc new file mode 100644 index 000000000..9030430e2 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ar_control_field.cc @@ -0,0 +1,325 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_ar_control_field.h" + +#include +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_ar_control_field.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sqrt; + +void ProcessTile(const Image3F& opsin, PassesEncoderState* enc_state, + const Rect& rect, + ArControlFieldHeuristics::TempImages* temp_image) { + constexpr size_t N = kBlockDim; + ImageB* JXL_RESTRICT epf_sharpness = &enc_state->shared.epf_sharpness; + ImageF* JXL_RESTRICT quant = &enc_state->initial_quant_field; + JXL_ASSERT( + epf_sharpness->xsize() == enc_state->shared.frame_dim.xsize_blocks && + epf_sharpness->ysize() == enc_state->shared.frame_dim.ysize_blocks); + + if (enc_state->cparams.butteraugli_distance < kMinButteraugliForDynamicAR || + enc_state->cparams.speed_tier > SpeedTier::kWombat || + enc_state->shared.frame_header.loop_filter.epf_iters == 0) { + FillPlane(static_cast(4), epf_sharpness, rect); + return; + } + + // Likely better to have a higher X weight, like: + // const float kChannelWeights[3] = {47.0f, 4.35f, 0.287f}; + const float kChannelWeights[3] = {4.35f, 4.35f, 0.287f}; + const float kChannelWeightsLapNeg[3] = {-0.125f * kChannelWeights[0], + -0.125f * kChannelWeights[1], + -0.125f * kChannelWeights[2]}; + const size_t sharpness_stride = + static_cast(epf_sharpness->PixelsPerRow()); + + size_t by0 = rect.y0(); + size_t by1 = rect.y0() + rect.ysize(); + size_t bx0 = rect.x0(); + size_t bx1 = rect.x0() + rect.xsize(); + temp_image->InitOnce(); + ImageF& laplacian_sqrsum = temp_image->laplacian_sqrsum; + // Calculate the L2 of the 3x3 Laplacian in an integral transform + // (for example 32x32 dct). This relates to transforms ability + // to propagate artefacts. + size_t y0 = by0 == 0 ? 2 : 0; + size_t y1 = by1 * N + 4 <= opsin.ysize() + 2 ? (by1 - by0) * N + 4 + : opsin.ysize() + 2 - by0 * N; + size_t x0 = bx0 == 0 ? 2 : 0; + size_t x1 = bx1 * N + 4 <= opsin.xsize() + 2 ? (bx1 - bx0) * N + 4 + : opsin.xsize() + 2 - bx0 * N; + HWY_FULL(float) df; + for (size_t y = y0; y < y1; y++) { + float* JXL_RESTRICT laplacian_sqrsum_row = laplacian_sqrsum.Row(y); + size_t cy = y + by0 * N - 2; + const float* JXL_RESTRICT in_row_t[3]; + const float* JXL_RESTRICT in_row[3]; + const float* JXL_RESTRICT in_row_b[3]; + for (size_t c = 0; c < 3; c++) { + in_row_t[c] = opsin.PlaneRow(c, cy > 0 ? cy - 1 : cy); + in_row[c] = opsin.PlaneRow(c, cy); + in_row_b[c] = opsin.PlaneRow(c, cy + 1 < opsin.ysize() ? cy + 1 : cy); + } + auto compute_laplacian_scalar = [&](size_t x) { + size_t cx = x + bx0 * N - 2; + const size_t prevX = cx >= 1 ? cx - 1 : cx; + const size_t nextX = cx + 1 < opsin.xsize() ? cx + 1 : cx; + float sumsqr = 0; + for (size_t c = 0; c < 3; c++) { + float laplacian = + kChannelWeights[c] * in_row[c][cx] + + kChannelWeightsLapNeg[c] * + (in_row[c][prevX] + in_row[c][nextX] + in_row_b[c][prevX] + + in_row_b[c][cx] + in_row_b[c][nextX] + in_row_t[c][prevX] + + in_row_t[c][cx] + in_row_t[c][nextX]); + sumsqr += laplacian * laplacian; + } + laplacian_sqrsum_row[x] = sumsqr; + }; + size_t x = x0; + for (; x + bx0 * N < 3; x++) { + compute_laplacian_scalar(x); + } + // Interior. One extra pixel of border as the last pixel is special. + for (; x + Lanes(df) <= x1 && x + Lanes(df) + bx0 * N - 1 <= opsin.xsize(); + x += Lanes(df)) { + size_t cx = x + bx0 * N - 2; + auto sumsqr = Zero(df); + for (size_t c = 0; c < 3; c++) { + auto laplacian = + Mul(LoadU(df, in_row[c] + cx), Set(df, kChannelWeights[c])); + auto sum_oth0 = LoadU(df, in_row[c] + cx - 1); + auto sum_oth1 = LoadU(df, in_row[c] + cx + 1); + auto sum_oth2 = LoadU(df, in_row_t[c] + cx - 1); + auto sum_oth3 = LoadU(df, in_row_t[c] + cx); + sum_oth0 = Add(sum_oth0, LoadU(df, in_row_t[c] + cx + 1)); + sum_oth1 = Add(sum_oth1, LoadU(df, in_row_b[c] + cx - 1)); + sum_oth2 = Add(sum_oth2, LoadU(df, in_row_b[c] + cx)); + sum_oth3 = Add(sum_oth3, LoadU(df, in_row_b[c] + cx + 1)); + sum_oth0 = Add(sum_oth0, sum_oth1); + sum_oth2 = Add(sum_oth2, sum_oth3); + sum_oth0 = Add(sum_oth0, sum_oth2); + laplacian = + MulAdd(Set(df, kChannelWeightsLapNeg[c]), sum_oth0, laplacian); + sumsqr = MulAdd(laplacian, laplacian, sumsqr); + } + StoreU(sumsqr, df, laplacian_sqrsum_row + x); + } + for (; x < x1; x++) { + compute_laplacian_scalar(x); + } + } + HWY_CAPPED(float, 4) df4; + // Calculate the L2 of the 3x3 Laplacian in 4x4 blocks within the area + // of the integral transform. Sample them within the integral transform + // with two offsets (0,0) and (-2, -2) pixels (sqrsum_00 and sqrsum_22, + // respectively). + ImageF& sqrsum_00 = temp_image->sqrsum_00; + size_t sqrsum_00_stride = sqrsum_00.PixelsPerRow(); + float* JXL_RESTRICT sqrsum_00_row = sqrsum_00.Row(0); + for (size_t y = 0; y < (by1 - by0) * 2; y++) { + const float* JXL_RESTRICT rows_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy + 2); + } + float* JXL_RESTRICT row_out = sqrsum_00_row + y * sqrsum_00_stride; + for (size_t x = 0; x < (bx1 - bx0) * 2; x++) { + auto sum = Zero(df4); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix += Lanes(df4)) { + sum = Add(sum, LoadU(df4, rows_in[iy] + x * 4 + ix + 2)); + } + } + row_out[x] = GetLane(Sqrt(SumOfLanes(df4, sum))) * (1.0f / 4.0f); + } + } + // Indexing iy and ix is a bit tricky as we include a 2 pixel border + // around the block for evenness calculations. This is similar to what + // we did in guetzli for the observability of artefacts, except there + // the element is a sliding 5x5, not sparsely sampled 4x4 box like here. + ImageF& sqrsum_22 = temp_image->sqrsum_22; + size_t sqrsum_22_stride = sqrsum_22.PixelsPerRow(); + float* JXL_RESTRICT sqrsum_22_row = sqrsum_22.Row(0); + for (size_t y = 0; y < (by1 - by0) * 2 + 1; y++) { + const float* JXL_RESTRICT rows_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy); + } + float* JXL_RESTRICT row_out = sqrsum_22_row + y * sqrsum_22_stride; + // ignore pixels outside the image. + // Y coordinates are relative to by0*8+y*4. + size_t sy = y * 4 + by0 * 8 > 0 ? 0 : 2; + size_t ey = y * 4 + by0 * 8 + 4 <= opsin.ysize() + 2 + ? 4 + : opsin.ysize() - y * 4 - by0 * 8 + 2; + for (size_t x = 0; x < (bx1 - bx0) * 2 + 1; x++) { + // ignore pixels outside the image. + // X coordinates are relative to bx0*8. + size_t sx = x * 4 + bx0 * 8 > 0 ? x * 4 : x * 4 + 2; + size_t ex = x * 4 + bx0 * 8 + 4 <= opsin.xsize() + 2 + ? x * 4 + 4 + : opsin.xsize() - bx0 * 8 + 2; + if (ex - sx == 4 && ey - sy == 4) { + auto sum = Zero(df4); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix += Lanes(df4)) { + sum = Add(sum, Load(df4, rows_in[iy] + sx + ix)); + } + } + row_out[x] = GetLane(Sqrt(SumOfLanes(df4, sum))) * (1.0f / 4.0f); + } else { + float sum = 0; + for (size_t iy = sy; iy < ey; iy++) { + for (size_t ix = sx; ix < ex; ix++) { + sum += rows_in[iy][ix]; + } + } + row_out[x] = std::sqrt(sum / ((ex - sx) * (ey - sy))); + } + } + } + for (size_t by = by0; by < by1; by++) { + AcStrategyRow acs_row = enc_state->shared.ac_strategy.ConstRow(by); + uint8_t* JXL_RESTRICT out_row = epf_sharpness->Row(by); + float* JXL_RESTRICT quant_row = quant->Row(by); + for (size_t bx = bx0; bx < bx1; bx++) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + // The errors are going to be linear to the quantization value in this + // locality. We only have access to the initial quant field here. + float quant_val = 1.0f / quant_row[bx]; + + const auto sq00 = [&](size_t y, size_t x) { + return sqrsum_00_row[((by - by0) * 2 + y) * sqrsum_00_stride + + (bx - bx0) * 2 + x]; + }; + const auto sq22 = [&](size_t y, size_t x) { + return sqrsum_22_row[((by - by0) * 2 + y) * sqrsum_22_stride + + (bx - bx0) * 2 + x]; + }; + float sqrsum_integral_transform = 0; + for (size_t iy = 0; iy < acs.covered_blocks_y() * 2; iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x() * 2; ix++) { + sqrsum_integral_transform += sq00(iy, ix) * sq00(iy, ix); + } + } + sqrsum_integral_transform /= + 4 * acs.covered_blocks_x() * acs.covered_blocks_y(); + sqrsum_integral_transform = std::sqrt(sqrsum_integral_transform); + // If masking is high or amplitude of the artefacts is low, then no + // smoothing is needed. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + // Five 4x4 blocks for masking estimation, all within the + // 8x8 area. + float minval_1 = std::min(sq00(2 * iy + 0, 2 * ix + 0), + sq00(2 * iy + 0, 2 * ix + 1)); + float minval_2 = std::min(sq00(2 * iy + 1, 2 * ix + 0), + sq00(2 * iy + 1, 2 * ix + 1)); + float minval = std::min(minval_1, minval_2); + minval = std::min(minval, sq22(2 * iy + 1, 2 * ix + 1)); + // Nine more 4x4 blocks for masking estimation, includes + // the 2 pixel area around the 8x8 block being controlled. + float minval2_1 = std::min(sq22(2 * iy + 0, 2 * ix + 0), + sq22(2 * iy + 0, 2 * ix + 1)); + float minval2_2 = std::min(sq22(2 * iy + 0, 2 * ix + 2), + sq22(2 * iy + 1, 2 * ix + 0)); + float minval2_3 = std::min(sq22(2 * iy + 1, 2 * ix + 1), + sq22(2 * iy + 1, 2 * ix + 2)); + float minval2_4 = std::min(sq22(2 * iy + 2, 2 * ix + 0), + sq22(2 * iy + 2, 2 * ix + 1)); + float minval2_5 = std::min(minval2_1, minval2_2); + float minval2_6 = std::min(minval2_3, minval2_4); + float minval2 = std::min(minval2_5, minval2_6); + minval2 = std::min(minval2, sq22(2 * iy + 2, 2 * ix + 2)); + float minval3 = std::min(minval, minval2); + minval *= 0.125f; + minval += 0.625f * minval3; + minval += + 0.125f * std::min(1.5f * minval3, sq22(2 * iy + 1, 2 * ix + 1)); + minval += 0.125f * minval2; + // Larger kBias, less smoothing for low intensity changes. + float kDeltaLimit = 3.2; + float bias = 0.0625f * quant_val; + float delta = + (sqrsum_integral_transform + (kDeltaLimit + 0.05) * bias) / + (minval + bias); + int out = 4; + if (delta > kDeltaLimit) { + out = 4; // smooth + } else { + out = 0; + } + // 'threshold' is separate from 'bias' for easier tuning of these + // heuristics. + float threshold = 0.0625f * quant_val; + const float kSmoothLimit = 0.085f; + float smooth = 0.20f * (sq00(2 * iy + 0, 2 * ix + 0) + + sq00(2 * iy + 0, 2 * ix + 1) + + sq00(2 * iy + 1, 2 * ix + 0) + + sq00(2 * iy + 1, 2 * ix + 1) + minval); + if (smooth < kSmoothLimit * threshold) { + out = 4; + } + out_row[bx + sharpness_stride * iy + ix] = out; + } + } + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ProcessTile); + +void ArControlFieldHeuristics::RunRect(const Rect& block_rect, + const Image3F& opsin, + PassesEncoderState* enc_state, + size_t thread) { + HWY_DYNAMIC_DISPATCH(ProcessTile) + (opsin, enc_state, block_rect, &temp_images[thread]); +} + +} // namespace jxl + +#endif diff --git a/media/libjxl/src/lib/jxl/enc_ar_control_field.h b/media/libjxl/src/lib/jxl/enc_ar_control_field.h new file mode 100644 index 000000000..ae9d399b9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_ar_control_field.h @@ -0,0 +1,49 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_AR_CONTROL_FIELD_H_ +#define LIB_JXL_ENC_AR_CONTROL_FIELD_H_ + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +struct ArControlFieldHeuristics { + struct TempImages { + void InitOnce() { + if (laplacian_sqrsum.xsize() != 0) return; + laplacian_sqrsum = ImageF(kEncTileDim + 4, kEncTileDim + 4); + sqrsum_00 = ImageF(kEncTileDim / 4, kEncTileDim / 4); + sqrsum_22 = ImageF(kEncTileDim / 4 + 1, kEncTileDim / 4 + 1); + } + + ImageF laplacian_sqrsum; + ImageF sqrsum_00; + ImageF sqrsum_22; + }; + + void PrepareForThreads(size_t num_threads) { + temp_images.resize(num_threads); + } + + void RunRect(const Rect& block_rect, const Image3F& opsin, + PassesEncoderState* enc_state, size_t thread); + + std::vector temp_images; + ImageB* epf_sharpness; + ImageF* quant; + bool all_default; +}; + +} // namespace jxl + +#endif // LIB_JXL_AR_ENC_CONTROL_FIELD_H_ diff --git a/media/libjxl/src/lib/jxl/enc_bit_writer.cc b/media/libjxl/src/lib/jxl/enc_bit_writer.cc new file mode 100644 index 000000000..7bac7b9ba --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_bit_writer.cc @@ -0,0 +1,180 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_bit_writer.h" + +#include // memcpy + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +BitWriter::Allotment::Allotment(BitWriter* JXL_RESTRICT writer, size_t max_bits) + : max_bits_(max_bits) { + if (writer == nullptr) return; + prev_bits_written_ = writer->BitsWritten(); + const size_t prev_bytes = writer->storage_.size(); + const size_t next_bytes = DivCeil(max_bits, kBitsPerByte); + writer->storage_.resize(prev_bytes + next_bytes); + parent_ = writer->current_allotment_; + writer->current_allotment_ = this; +} + +BitWriter::Allotment::~Allotment() { + if (!called_) { + // Not calling is a bug - unused storage will not be reclaimed. + JXL_ABORT("Did not call Allotment::ReclaimUnused"); + } +} + +void BitWriter::Allotment::FinishedHistogram(BitWriter* JXL_RESTRICT writer) { + if (writer == nullptr) return; + JXL_ASSERT(!called_); // Call before ReclaimUnused + JXL_ASSERT(histogram_bits_ == 0); // Do not call twice + JXL_ASSERT(writer->BitsWritten() >= prev_bits_written_); + histogram_bits_ = writer->BitsWritten() - prev_bits_written_; +} + +void BitWriter::Allotment::PrivateReclaim(BitWriter* JXL_RESTRICT writer, + size_t* JXL_RESTRICT used_bits, + size_t* JXL_RESTRICT unused_bits) { + JXL_ASSERT(!called_); // Do not call twice + called_ = true; + if (writer == nullptr) return; + + JXL_ASSERT(writer->BitsWritten() >= prev_bits_written_); + *used_bits = writer->BitsWritten() - prev_bits_written_; + JXL_ASSERT(*used_bits <= max_bits_); + *unused_bits = max_bits_ - *used_bits; + + // Reclaim unused bytes whole bytes from writer's allotment. + const size_t unused_bytes = *unused_bits / kBitsPerByte; // truncate + JXL_ASSERT(writer->storage_.size() >= unused_bytes); + writer->storage_.resize(writer->storage_.size() - unused_bytes); + writer->current_allotment_ = parent_; + // Ensure we don't also charge the parent for these bits. + auto parent = parent_; + while (parent != nullptr) { + parent->prev_bits_written_ += *used_bits; + parent = parent->parent_; + } +} + +void BitWriter::AppendByteAligned(const Span& span) { + if (!span.size()) return; + storage_.resize(storage_.size() + span.size() + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += span.size() * kBitsPerByte; +} + +void BitWriter::AppendByteAligned(const BitWriter& other) { + JXL_ASSERT(other.BitsWritten() % kBitsPerByte == 0); + JXL_ASSERT(other.BitsWritten() / kBitsPerByte != 0); + + AppendByteAligned(other.GetSpan()); +} + +void BitWriter::AppendByteAligned(const std::vector& others) { + // Total size to add so we can preallocate + size_t other_bytes = 0; + for (const BitWriter& writer : others) { + JXL_ASSERT(writer.BitsWritten() % kBitsPerByte == 0); + other_bytes += writer.BitsWritten() / kBitsPerByte; + } + if (other_bytes == 0) { + // No bytes to append: this happens for example when creating per-group + // storage for groups, but not writing anything in them for e.g. lossless + // images with no alpha. Do nothing. + return; + } + storage_.resize(storage_.size() + other_bytes + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + for (const BitWriter& writer : others) { + const Span span = writer.GetSpan(); + if (!span.empty()) { + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + } + } + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += other_bytes * kBitsPerByte; +} + +// TODO(lode): avoid code duplication +void BitWriter::AppendByteAligned( + const std::vector>& others) { + // Total size to add so we can preallocate + size_t other_bytes = 0; + for (const auto& writer : others) { + JXL_ASSERT(writer->BitsWritten() % kBitsPerByte == 0); + other_bytes += writer->BitsWritten() / kBitsPerByte; + } + if (other_bytes == 0) { + // No bytes to append: this happens for example when creating per-group + // storage for groups, but not writing anything in them for e.g. lossless + // images with no alpha. Do nothing. + return; + } + storage_.resize(storage_.size() + other_bytes + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + for (const auto& writer : others) { + const Span span = writer->GetSpan(); + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + } + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += other_bytes * kBitsPerByte; +} + +// Example: let's assume that 3 bits (Rs below) have been written already: +// BYTE+0 BYTE+1 BYTE+2 +// 0000 0RRR ???? ???? ???? ???? +// +// Now, we could write up to 5 bits by just shifting them left by 3 bits and +// OR'ing to BYTE-0. +// +// For n > 5 bits, we write the lowest 5 bits as above, then write the next +// lowest bits into BYTE+1 starting from its lower bits and so on. +void BitWriter::Write(size_t n_bits, uint64_t bits) { + JXL_DASSERT((bits >> n_bits) == 0); + JXL_DASSERT(n_bits <= kMaxBitsPerCall); + uint8_t* p = &storage_[bits_written_ / kBitsPerByte]; + const size_t bits_in_first_byte = bits_written_ % kBitsPerByte; + bits <<= bits_in_first_byte; +#if JXL_BYTE_ORDER_LITTLE + uint64_t v = *p; + // Last (partial) or next byte to write must be zero-initialized! + // PaddedBytes initializes the first, and Write/Append maintain this. + JXL_DASSERT(v >> bits_in_first_byte == 0); + v |= bits; + memcpy(p, &v, sizeof(v)); // Write bytes: possibly more than n_bits/8 +#else + *p++ |= static_cast(bits & 0xFF); + for (size_t bits_left_to_write = n_bits + bits_in_first_byte; + bits_left_to_write >= 9; bits_left_to_write -= 8) { + bits >>= 8; + *p++ = static_cast(bits & 0xFF); + } + *p = 0; +#endif + bits_written_ += n_bits; +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_bit_writer.h b/media/libjxl/src/lib/jxl/enc_bit_writer.h new file mode 100644 index 000000000..4cac8dfbe --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_bit_writer.h @@ -0,0 +1,126 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_BIT_WRITER_H_ +#define LIB_JXL_ENC_BIT_WRITER_H_ + +// BitWriter class: unbuffered writes using unaligned 64-bit stores. + +#include +#include + +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" + +namespace jxl { + +struct BitWriter { + // Upper bound on `n_bits` in each call to Write. We shift a 64-bit word by + // 7 bits (max already valid bits in the last byte) and at least 1 bit is + // needed to zero-initialize the bit-stream ahead (i.e. if 7 bits are valid + // and we write 57 bits, then the next write will access a byte that was not + // yet zero-initialized). + static constexpr size_t kMaxBitsPerCall = 56; + + BitWriter() : bits_written_(0) {} + + // Disallow copying - may lead to bugs. + BitWriter(const BitWriter&) = delete; + BitWriter& operator=(const BitWriter&) = delete; + BitWriter(BitWriter&&) = default; + BitWriter& operator=(BitWriter&&) = default; + + size_t BitsWritten() const { return bits_written_; } + + Span GetSpan() const { + // Callers must ensure byte alignment to avoid uninitialized bits. + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + return Span(storage_.data(), bits_written_ / kBitsPerByte); + } + + // Example usage: bytes = std::move(writer).TakeBytes(); Useful for the + // top-level encoder which returns PaddedBytes, not a BitWriter. + // *this must be an rvalue reference and is invalid afterwards. + PaddedBytes&& TakeBytes() && { + // Callers must ensure byte alignment to avoid uninitialized bits. + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + storage_.resize(bits_written_ / kBitsPerByte); + return std::move(storage_); + } + + private: + // Must be byte-aligned before calling. + void AppendByteAligned(const Span& span); + + public: + // NOTE: no allotment needed, the other BitWriters have already been charged. + void AppendByteAligned(const BitWriter& other); + void AppendByteAligned(const std::vector>& others); + void AppendByteAligned(const std::vector& others); + + class Allotment { + public: + // Expands a BitWriter's storage. Must happen before calling Write or + // ZeroPadToByte. Must call ReclaimUnused after writing to reclaim the + // unused storage so that BitWriter memory use remains tightly bounded. + Allotment(BitWriter* JXL_RESTRICT writer, size_t max_bits); + ~Allotment(); + + size_t MaxBits() const { return max_bits_; } + + // Call after writing a histogram, but before ReclaimUnused. + void FinishedHistogram(BitWriter* JXL_RESTRICT writer); + + size_t HistogramBits() const { + JXL_ASSERT(called_); + return histogram_bits_; + } + + // Do not call directly - use ::ReclaimAndCharge instead, which ensures + // the bits are charged to a layer. + void PrivateReclaim(BitWriter* JXL_RESTRICT writer, + size_t* JXL_RESTRICT used_bits, + size_t* JXL_RESTRICT unused_bits); + + private: + size_t prev_bits_written_; + const size_t max_bits_; + size_t histogram_bits_ = 0; + bool called_ = false; + Allotment* parent_; + }; + + // Writes bits into bytes in increasing addresses, and within a byte + // least-significant-bit first. + // + // The function can write up to 56 bits in one go. + void Write(size_t n_bits, uint64_t bits); + + // This should only rarely be used - e.g. when the current location will be + // referenced via byte offset (TOCs point to groups), or byte-aligned reading + // is required for speed. + void ZeroPadToByte() { + const size_t remainder_bits = + RoundUpBitsToByteMultiple(bits_written_) - bits_written_; + if (remainder_bits == 0) return; + Write(remainder_bits, 0); + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + } + + private: + size_t bits_written_; + PaddedBytes storage_; + Allotment* current_allotment_ = nullptr; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_BIT_WRITER_H_ diff --git a/media/libjxl/src/lib/jxl/enc_butteraugli_comparator.cc b/media/libjxl/src/lib/jxl/enc_butteraugli_comparator.cc new file mode 100644 index 000000000..e79c4b5c0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_butteraugli_comparator.cc @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_butteraugli_comparator.h" + +#include +#include + +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_image_bundle.h" + +namespace jxl { + +JxlButteraugliComparator::JxlButteraugliComparator( + const ButteraugliParams& params, const JxlCmsInterface& cms) + : params_(params), cms_(cms) {} + +Status JxlButteraugliComparator::SetReferenceImage(const ImageBundle& ref) { + const ImageBundle* ref_linear_srgb; + ImageMetadata metadata = *ref.metadata(); + ImageBundle store(&metadata); + if (!TransformIfNeeded(ref, ColorEncoding::LinearSRGB(ref.IsGray()), cms_, + /*pool=*/nullptr, &store, &ref_linear_srgb)) { + return false; + } + + comparator_.reset( + new ButteraugliComparator(ref_linear_srgb->color(), params_)); + xsize_ = ref.xsize(); + ysize_ = ref.ysize(); + return true; +} + +Status JxlButteraugliComparator::CompareWith(const ImageBundle& actual, + ImageF* diffmap, float* score) { + if (!comparator_) { + return JXL_FAILURE("Must set reference image first"); + } + if (xsize_ != actual.xsize() || ysize_ != actual.ysize()) { + return JXL_FAILURE("Images must have same size"); + } + + const ImageBundle* actual_linear_srgb; + ImageMetadata metadata = *actual.metadata(); + ImageBundle store(&metadata); + if (!TransformIfNeeded(actual, ColorEncoding::LinearSRGB(actual.IsGray()), + cms_, + /*pool=*/nullptr, &store, &actual_linear_srgb)) { + return false; + } + + ImageF temp_diffmap(xsize_, ysize_); + comparator_->Diffmap(actual_linear_srgb->color(), temp_diffmap); + + if (score != nullptr) { + *score = ButteraugliScoreFromDiffmap(temp_diffmap, ¶ms_); + } + if (diffmap != nullptr) { + diffmap->Swap(temp_diffmap); + } + + return true; +} + +float JxlButteraugliComparator::GoodQualityScore() const { + return ButteraugliFuzzyInverse(1.5); +} + +float JxlButteraugliComparator::BadQualityScore() const { + return ButteraugliFuzzyInverse(0.5); +} + +float ButteraugliDistance(const ImageBundle& rgb0, const ImageBundle& rgb1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap, + ThreadPool* pool) { + JxlButteraugliComparator comparator(params, cms); + return ComputeScore(rgb0, rgb1, &comparator, cms, distmap, pool); +} + +float ButteraugliDistance(const CodecInOut& rgb0, const CodecInOut& rgb1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap, + ThreadPool* pool) { + JxlButteraugliComparator comparator(params, cms); + JXL_ASSERT(rgb0.frames.size() == rgb1.frames.size()); + float max_dist = 0.0f; + for (size_t i = 0; i < rgb0.frames.size(); ++i) { + max_dist = + std::max(max_dist, ComputeScore(rgb0.frames[i], rgb1.frames[i], + &comparator, cms, distmap, pool)); + } + return max_dist; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_butteraugli_comparator.h b/media/libjxl/src/lib/jxl/enc_butteraugli_comparator.h new file mode 100644 index 000000000..6d0751c41 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_butteraugli_comparator.h @@ -0,0 +1,58 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ +#define LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ + +#include + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_comparator.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class JxlButteraugliComparator : public Comparator { + public: + explicit JxlButteraugliComparator(const ButteraugliParams& params, + const JxlCmsInterface& cms); + + Status SetReferenceImage(const ImageBundle& ref) override; + + Status CompareWith(const ImageBundle& actual, ImageF* diffmap, + float* score) override; + + float GoodQualityScore() const override; + float BadQualityScore() const override; + + private: + ButteraugliParams params_; + JxlCmsInterface cms_; + std::unique_ptr comparator_; + size_t xsize_ = 0; + size_t ysize_ = 0; +}; + +// Returns the butteraugli distance between rgb0 and rgb1. +// If distmap is not null, it must be the same size as rgb0 and rgb1. +float ButteraugliDistance(const ImageBundle& rgb0, const ImageBundle& rgb1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap = nullptr, + ThreadPool* pool = nullptr); + +float ButteraugliDistance(const CodecInOut& rgb0, const CodecInOut& rgb1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap = nullptr, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ diff --git a/media/libjxl/src/lib/jxl/enc_butteraugli_pnorm.cc b/media/libjxl/src/lib/jxl/enc_butteraugli_pnorm.cc new file mode 100644 index 000000000..fe5629dcd --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_butteraugli_pnorm.cc @@ -0,0 +1,211 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_butteraugli_pnorm.h" + +#include +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_butteraugli_pnorm.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Rebind; + +double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params, + double p) { + PROFILER_FUNC; + + const double onePerPixels = 1.0 / (distmap.ysize() * distmap.xsize()); + if (std::abs(p - 3.0) < 1E-6) { + double sum1[3] = {0.0}; + +// Prefer double if possible, but otherwise use float rather than scalar. +#if HWY_CAP_FLOAT64 + using T = double; + const Rebind df; +#else + using T = float; +#endif + const HWY_FULL(T) d; + constexpr size_t N = MaxLanes(HWY_FULL(T)()); + // Manually aligned storage to avoid asan crash on clang-7 due to + // unaligned spill. + HWY_ALIGN T sum_totals0[N] = {0}; + HWY_ALIGN T sum_totals1[N] = {0}; + HWY_ALIGN T sum_totals2[N] = {0}; + + for (size_t y = 0; y < distmap.ysize(); ++y) { + const float* JXL_RESTRICT row = distmap.ConstRow(y); + + auto sums0 = Zero(d); + auto sums1 = Zero(d); + auto sums2 = Zero(d); + + size_t x = 0; + for (; x + Lanes(d) <= distmap.xsize(); x += Lanes(d)) { +#if HWY_CAP_FLOAT64 + const auto d1 = PromoteTo(d, Load(df, row + x)); +#else + const auto d1 = Load(d, row + x); +#endif + const auto d2 = Mul(d1, Mul(d1, d1)); + sums0 = Add(sums0, d2); + const auto d3 = Mul(d2, d2); + sums1 = Add(sums1, d3); + const auto d4 = Mul(d3, d3); + sums2 = Add(sums2, d4); + } + + Store(Add(sums0, Load(d, sum_totals0)), d, sum_totals0); + Store(Add(sums1, Load(d, sum_totals1)), d, sum_totals1); + Store(Add(sums2, Load(d, sum_totals2)), d, sum_totals2); + + for (; x < distmap.xsize(); ++x) { + const double d1 = row[x]; + double d2 = d1 * d1 * d1; + sum1[0] += d2; + d2 *= d2; + sum1[1] += d2; + d2 *= d2; + sum1[2] += d2; + } + } + double v = 0; + v += pow( + onePerPixels * (sum1[0] + GetLane(SumOfLanes(d, Load(d, sum_totals0)))), + 1.0 / (p * 1.0)); + v += pow( + onePerPixels * (sum1[1] + GetLane(SumOfLanes(d, Load(d, sum_totals1)))), + 1.0 / (p * 2.0)); + v += pow( + onePerPixels * (sum1[2] + GetLane(SumOfLanes(d, Load(d, sum_totals2)))), + 1.0 / (p * 4.0)); + v /= 3.0; + return v; + } else { + static std::atomic once{0}; + if (once.fetch_add(1, std::memory_order_relaxed) == 0) { + JXL_WARNING("WARNING: using slow ComputeDistanceP"); + } + double sum1[3] = {0.0}; + for (size_t y = 0; y < distmap.ysize(); ++y) { + const float* JXL_RESTRICT row = distmap.ConstRow(y); + for (size_t x = 0; x < distmap.xsize(); ++x) { + double d2 = std::pow(row[x], p); + sum1[0] += d2; + d2 *= d2; + sum1[1] += d2; + d2 *= d2; + sum1[2] += d2; + } + } + double v = 0; + for (int i = 0; i < 3; ++i) { + v += pow(onePerPixels * (sum1[i]), 1.0 / (p * (1 << i))); + } + v /= 3.0; + return v; + } +} + +// TODO(lode): take alpha into account when needed +double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2, + const JxlCmsInterface& cms) { + PROFILER_FUNC; + // Convert to sRGB - closer to perception than linear. + const Image3F* srgb1 = &ib1.color(); + Image3F copy1; + if (!ib1.IsSRGB()) { + JXL_CHECK( + ib1.CopyTo(Rect(ib1), ColorEncoding::SRGB(ib1.IsGray()), cms, ©1)); + srgb1 = ©1; + } + const Image3F* srgb2 = &ib2.color(); + Image3F copy2; + if (!ib2.IsSRGB()) { + JXL_CHECK( + ib2.CopyTo(Rect(ib2), ColorEncoding::SRGB(ib2.IsGray()), cms, ©2)); + srgb2 = ©2; + } + + JXL_CHECK(SameSize(*srgb1, *srgb2)); + + // TODO(veluca): SIMD. + float yuvmatrix[3][3] = {{0.299, 0.587, 0.114}, + {-0.14713, -0.28886, 0.436}, + {0.615, -0.51499, -0.10001}}; + double sum_of_squares[3] = {}; + for (size_t y = 0; y < srgb1->ysize(); ++y) { + const float* JXL_RESTRICT row1[3]; + const float* JXL_RESTRICT row2[3]; + for (size_t j = 0; j < 3; j++) { + row1[j] = srgb1->ConstPlaneRow(j, y); + row2[j] = srgb2->ConstPlaneRow(j, y); + } + for (size_t x = 0; x < srgb1->xsize(); ++x) { + float cdiff[3] = {}; + // YUV conversion is linear, so we can run it on the difference. + for (size_t j = 0; j < 3; j++) { + cdiff[j] = row1[j][x] - row2[j][x]; + } + float yuvdiff[3] = {}; + for (size_t j = 0; j < 3; j++) { + for (size_t k = 0; k < 3; k++) { + yuvdiff[j] += yuvmatrix[j][k] * cdiff[k]; + } + } + for (size_t j = 0; j < 3; j++) { + sum_of_squares[j] += yuvdiff[j] * yuvdiff[j]; + } + } + } + // Weighted PSNR as in JPEG-XL: chroma counts 1/8. + const float weights[3] = {6.0f / 8, 1.0f / 8, 1.0f / 8}; + // Avoid squaring the weight - 1/64 is too extreme. + double norm = 0; + for (size_t i = 0; i < 3; i++) { + norm += std::sqrt(sum_of_squares[i]) * weights[i]; + } + // This function returns distance *squared*. + return norm * norm; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ComputeDistanceP); +double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params, + double p) { + return HWY_DYNAMIC_DISPATCH(ComputeDistanceP)(distmap, params, p); +} + +HWY_EXPORT(ComputeDistance2); +double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2, + const JxlCmsInterface& cms) { + return HWY_DYNAMIC_DISPATCH(ComputeDistance2)(ib1, ib2, cms); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/enc_butteraugli_pnorm.h b/media/libjxl/src/lib/jxl/enc_butteraugli_pnorm.h new file mode 100644 index 000000000..cf6872e5d --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_butteraugli_pnorm.h @@ -0,0 +1,25 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_BUTTERAUGLI_PNORM_H_ +#define LIB_JXL_ENC_BUTTERAUGLI_PNORM_H_ + +#include + +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Computes p-norm given the butteraugli distmap. +double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params, + double p); + +double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2, + const JxlCmsInterface& cms); + +} // namespace jxl + +#endif // LIB_JXL_ENC_BUTTERAUGLI_PNORM_H_ diff --git a/media/libjxl/src/lib/jxl/enc_cache.cc b/media/libjxl/src/lib/jxl/enc_cache.cc new file mode 100644 index 000000000..a1f2a0887 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_cache.cc @@ -0,0 +1,218 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_cache.h" + +#include +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +Status InitializePassesEncoder(const Image3F& opsin, const JxlCmsInterface& cms, + ThreadPool* pool, PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + AuxOut* aux_out) { + PROFILER_FUNC; + + PassesSharedState& JXL_RESTRICT shared = enc_state->shared; + + enc_state->histogram_idx.resize(shared.frame_dim.num_groups); + + enc_state->x_qm_multiplier = + std::pow(1.25f, shared.frame_header.x_qm_scale - 2.0f); + enc_state->b_qm_multiplier = + std::pow(1.25f, shared.frame_header.b_qm_scale - 2.0f); + + if (enc_state->coeffs.size() < shared.frame_header.passes.num_passes) { + enc_state->coeffs.reserve(shared.frame_header.passes.num_passes); + for (size_t i = enc_state->coeffs.size(); + i < shared.frame_header.passes.num_passes; i++) { + // Allocate enough coefficients for each group on every row. + enc_state->coeffs.emplace_back(make_unique>( + kGroupDim * kGroupDim, shared.frame_dim.num_groups)); + } + } + while (enc_state->coeffs.size() > shared.frame_header.passes.num_passes) { + enc_state->coeffs.pop_back(); + } + + float scale = + shared.quantizer.ScaleGlobalScale(enc_state->cparams.quant_ac_rescale); + DequantMatricesScaleDC(&shared.matrices, scale); + shared.quantizer.RecomputeFromGlobalScale(); + + Image3F dc(shared.frame_dim.xsize_blocks, shared.frame_dim.ysize_blocks); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, shared.frame_dim.num_groups, ThreadPool::NoInit, + [&](size_t group_idx, size_t _) { + ComputeCoefficients(group_idx, enc_state, opsin, &dc); + }, + "Compute coeffs")); + + if (shared.frame_header.flags & FrameHeader::kUseDcFrame) { + CompressParams cparams = enc_state->cparams; + cparams.dots = Override::kOff; + cparams.noise = Override::kOff; + cparams.patches = Override::kOff; + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.resampling = 1; + cparams.ec_resampling = 1; + // The DC frame will have alpha=0. Don't erase its contents. + cparams.keep_invisible = Override::kOn; + JXL_ASSERT(cparams.progressive_dc > 0); + cparams.progressive_dc--; + // Use kVarDCT in max_error_mode for intermediate progressive DC, + // and kModular for the smallest DC (first in the bitstream) + if (cparams.progressive_dc == 0) { + cparams.modular_mode = true; + // TODO(jon): tweak mapping from image dist to dist for modular DC + cparams.butteraugli_distance = + std::max(kMinButteraugliDistance, + enc_state->cparams.butteraugli_distance * 0.03f); + } else { + cparams.max_error_mode = true; + for (size_t c = 0; c < 3; c++) { + cparams.max_error[c] = shared.quantizer.MulDC()[c]; + } + // Guess a distance that produces good initial results. + cparams.butteraugli_distance = + std::max(kMinButteraugliDistance, + enc_state->cparams.butteraugli_distance * 0.1f); + } + ImageBundle ib(&shared.metadata->m); + // This is a lie - dc is in XYB + // (but EncodeFrame will skip RGB->XYB conversion anyway) + ib.SetFromImage( + std::move(dc), + ColorEncoding::LinearSRGB(shared.metadata->m.color_encoding.IsGray())); + if (!ib.metadata()->extra_channel_info.empty()) { + // Add dummy extra channels to the patch image: dc_level frames do not yet + // support extra channels, but the codec expects that the amount of extra + // channels in frames matches that in the metadata of the codestream. + std::vector extra_channels; + extra_channels.reserve(ib.metadata()->extra_channel_info.size()); + for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) { + extra_channels.emplace_back(ib.xsize(), ib.ysize()); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + // TODO(lode): dc_level must copy and use the real extra channels + // instead. + ZeroFillImage(&extra_channels.back()); + } + ib.SetExtraChannels(std::move(extra_channels)); + } + std::unique_ptr state = + jxl::make_unique(); + + auto special_frame = std::unique_ptr(new BitWriter()); + FrameInfo dc_frame_info; + dc_frame_info.frame_type = FrameType::kDCFrame; + dc_frame_info.dc_level = shared.frame_header.dc_level + 1; + dc_frame_info.ib_needs_color_transform = false; + dc_frame_info.save_before_color_transform = true; // Implicitly true + AuxOut dc_aux_out; + if (aux_out) { + dc_aux_out.debug_prefix = aux_out->debug_prefix; + } + JXL_CHECK(EncodeFrame(cparams, dc_frame_info, shared.metadata, ib, + state.get(), cms, pool, special_frame.get(), + aux_out ? &dc_aux_out : nullptr)); + if (aux_out) { + for (const auto& l : dc_aux_out.layers) { + aux_out->layers[kLayerDC].Assimilate(l); + } + } + const Span encoded = special_frame->GetSpan(); + enc_state->special_frames.emplace_back(std::move(special_frame)); + + ImageBundle decoded(&shared.metadata->m); + std::unique_ptr dec_state = + jxl::make_unique(); + JXL_CHECK( + dec_state->output_encoding_info.SetFromMetadata(*shared.metadata)); + const uint8_t* frame_start = encoded.data(); + size_t encoded_size = encoded.size(); + for (int i = 0; i <= cparams.progressive_dc; ++i) { + JXL_CHECK(DecodeFrame(dec_state.get(), pool, frame_start, encoded_size, + &decoded, *shared.metadata)); + frame_start += decoded.decoded_bytes(); + encoded_size -= decoded.decoded_bytes(); + } + // TODO(lode): shared.frame_header.dc_level should be equal to + // dec_state.shared->frame_header.dc_level - 1 here, since above we set + // dc_frame_info.dc_level = shared.frame_header.dc_level + 1, and + // dc_frame_info.dc_level is used by EncodeFrame. However, if EncodeFrame + // outputs multiple frames, this assumption could be wrong. + shared.dc_storage = + CopyImage(dec_state->shared->dc_frames[shared.frame_header.dc_level]); + ZeroFillImage(&shared.quant_dc); + shared.dc = &shared.dc_storage; + JXL_CHECK(encoded_size == 0); + } else { + auto compute_dc_coeffs = [&](int group_index, int /* thread */) { + modular_frame_encoder->AddVarDCTDC( + dc, group_index, + enc_state->cparams.butteraugli_distance >= 2.0f && + enc_state->cparams.speed_tier < SpeedTier::kFalcon, + enc_state, /*jpeg_transcode=*/false); + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, + ThreadPool::NoInit, compute_dc_coeffs, + "Compute DC coeffs")); + // TODO(veluca): this is only useful in tests and if inspection is enabled. + if (!(shared.frame_header.flags & FrameHeader::kSkipAdaptiveDCSmoothing)) { + AdaptiveDCSmoothing(shared.quantizer.MulDC(), &shared.dc_storage, pool); + } + } + auto compute_ac_meta = [&](int group_index, int /* thread */) { + modular_frame_encoder->AddACMetadata(group_index, /*jpeg_transcode=*/false, + enc_state); + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, + ThreadPool::NoInit, compute_ac_meta, + "Compute AC Metadata")); + + if (aux_out != nullptr) { + aux_out->InspectImage3F("compressed_image:InitializeFrameEncCache:dc_dec", + shared.dc_storage); + } + return true; +} + +void EncCache::InitOnce() { + PROFILER_FUNC; + + if (num_nzeroes.xsize() == 0) { + num_nzeroes = Image3I(kGroupDimInBlocks, kGroupDimInBlocks); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_cache.h b/media/libjxl/src/lib/jxl/enc_cache.h new file mode 100644 index 000000000..04dff0bed --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_cache.h @@ -0,0 +1,93 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_CACHE_H_ +#define LIB_JXL_ENC_CACHE_H_ + +#include +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_heuristics.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/progressive_split.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +// Contains encoder state. +struct PassesEncoderState { + PassesSharedState shared; + + ImageF initial_quant_field; // Invalid in Falcon mode. + ImageF initial_quant_masking; // Invalid in Falcon mode. + + // Per-pass DCT coefficients for the image. One row per group. + std::vector> coeffs; + + // Raw data for special (reference+DC) frames. + std::vector> special_frames; + + // For splitting into passes. + ProgressiveSplitter progressive_splitter; + + CompressParams cparams; + + struct PassData { + std::vector> ac_tokens; + std::vector context_map; + EntropyEncodingData codes; + }; + + std::vector passes; + std::vector histogram_idx; + + // Coefficient orders that are non-default. + std::vector used_orders; + + // Multiplier to be applied to the quant matrices of the x channel. + float x_qm_multiplier = 1.0f; + float b_qm_multiplier = 1.0f; + + // Heuristics to be used by the encoder. + std::unique_ptr heuristics = + make_unique(); +}; + +// Initialize per-frame information. +class ModularFrameEncoder; +Status InitializePassesEncoder(const Image3F& opsin, const JxlCmsInterface& cms, + ThreadPool* pool, + PassesEncoderState* passes_enc_state, + ModularFrameEncoder* modular_frame_encoder, + AuxOut* aux_out); + +// Working area for ComputeCoefficients (per-group!) +struct EncCache { + // Allocates memory when first called, shrinks images to current group size. + void InitOnce(); + + // TokenizeCoefficients + Image3I num_nzeroes; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_CACHE_H_ diff --git a/media/libjxl/src/lib/jxl/enc_chroma_from_luma.cc b/media/libjxl/src/lib/jxl/enc_chroma_from_luma.cc new file mode 100644 index 000000000..4f0798e99 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_chroma_from_luma.cc @@ -0,0 +1,388 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_chroma_from_luma.h" + +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc" +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/quantizer.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; + +static HWY_FULL(float) df; + +struct CFLFunction { + static constexpr float kCoeff = 1.f / 3; + static constexpr float kThres = 100.0f; + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + CFLFunction(const float* values_m, const float* values_s, size_t num, + float base, float distance_mul) + : values_m(values_m), + values_s(values_s), + num(num), + base(base), + distance_mul(distance_mul) {} + + // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) + + // distance_mul * x^2 * num. + float Compute(float x, float eps, float* fpeps, float* fmeps) const { + float first_derivative = 2 * distance_mul * num * x; + float first_derivative_peps = 2 * distance_mul * num * (x + eps); + float first_derivative_meps = 2 * distance_mul * num * (x - eps); + + const auto inv_color_factor = Set(df, kInvColorFactor); + const auto thres = Set(df, kThres); + const auto coeffx2 = Set(df, kCoeff * 2.0f); + const auto one = Set(df, 1.0f); + const auto zero = Set(df, 0.0f); + const auto base_v = Set(df, base); + const auto x_v = Set(df, x); + const auto xpe_v = Set(df, x + eps); + const auto xme_v = Set(df, x - eps); + auto fd_v = Zero(df); + auto fdpe_v = Zero(df); + auto fdme_v = Zero(df); + JXL_ASSERT(num % Lanes(df) == 0); + + for (size_t i = 0; i < num; i += Lanes(df)) { + // color residual = ax + b + const auto a = Mul(inv_color_factor, Load(df, values_m + i)); + const auto b = + Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); + const auto v = MulAdd(a, x_v, b); + const auto vpe = MulAdd(a, xpe_v, b); + const auto vme = MulAdd(a, xme_v, b); + const auto av = Abs(v); + const auto avpe = Abs(vpe); + const auto avme = Abs(vme); + const auto acoeffx2 = Mul(coeffx2, a); + auto d = Mul(acoeffx2, Add(av, one)); + auto dpe = Mul(acoeffx2, Add(avpe, one)); + auto dme = Mul(acoeffx2, Add(avme, one)); + d = IfThenElse(Lt(v, zero), Sub(zero, d), d); + dpe = IfThenElse(Lt(vpe, zero), Sub(zero, dpe), dpe); + dme = IfThenElse(Lt(vme, zero), Sub(zero, dme), dme); + const auto above = Ge(av, thres); + // TODO(eustas): use IfThenElseZero + fd_v = Add(fd_v, IfThenElse(above, zero, d)); + fdpe_v = Add(fdpe_v, IfThenElse(above, zero, dpe)); + fdme_v = Add(fdme_v, IfThenElse(above, zero, dme)); + } + + *fpeps = first_derivative_peps + GetLane(SumOfLanes(df, fdpe_v)); + *fmeps = first_derivative_meps + GetLane(SumOfLanes(df, fdme_v)); + return first_derivative + GetLane(SumOfLanes(df, fd_v)); + } + + const float* JXL_RESTRICT values_m; + const float* JXL_RESTRICT values_s; + size_t num; + float base; + float distance_mul; +}; + +int32_t FindBestMultiplier(const float* values_m, const float* values_s, + size_t num, float base, float distance_mul, + bool fast) { + if (num == 0) { + return 0; + } + float x; + if (fast) { + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + auto ca = Zero(df); + auto cb = Zero(df); + const auto inv_color_factor = Set(df, kInvColorFactor); + const auto base_v = Set(df, base); + for (size_t i = 0; i < num; i += Lanes(df)) { + // color residual = ax + b + const auto a = Mul(inv_color_factor, Load(df, values_m + i)); + const auto b = + Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); + ca = MulAdd(a, a, ca); + cb = MulAdd(a, b, cb); + } + // + distance_mul * x^2 * num + x = -GetLane(SumOfLanes(df, cb)) / + (GetLane(SumOfLanes(df, ca)) + num * distance_mul * 0.5f); + } else { + constexpr float eps = 1; + constexpr float kClamp = 20.0f; + CFLFunction fn(values_m, values_s, num, base, distance_mul); + x = 0; + // Up to 20 Newton iterations, with approximate derivatives. + // Derivatives are approximate due to the high amount of noise in the exact + // derivatives. + for (size_t i = 0; i < 20; i++) { + float dfpeps, dfmeps; + float df = fn.Compute(x, eps, &dfpeps, &dfmeps); + float ddf = (dfpeps - dfmeps) / (2 * eps); + float step = df / ddf; + x -= std::min(kClamp, std::max(-kClamp, step)); + if (std::abs(step) < 3e-3) break; + } + } + return std::max(-128.0f, std::min(127.0f, roundf(x))); +} + +void InitDCStorage(size_t num_blocks, ImageF* dc_values) { + // First row: Y channel + // Second row: X channel + // Third row: Y channel + // Fourth row: B channel + *dc_values = ImageF(RoundUpTo(num_blocks, Lanes(df)), 4); + + JXL_ASSERT(dc_values->xsize() != 0); + // Zero-fill the last lanes + for (size_t y = 0; y < 4; y++) { + for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize(); + x++) { + dc_values->Row(y)[x] = 0; + } + } +} + +void ComputeDC(const ImageF& dc_values, bool fast, int32_t* dc_x, + int32_t* dc_b) { + constexpr float kDistanceMultiplierDC = 1e-5f; + const float* JXL_RESTRICT dc_values_yx = dc_values.Row(0); + const float* JXL_RESTRICT dc_values_x = dc_values.Row(1); + const float* JXL_RESTRICT dc_values_yb = dc_values.Row(2); + const float* JXL_RESTRICT dc_values_b = dc_values.Row(3); + *dc_x = FindBestMultiplier(dc_values_yx, dc_values_x, dc_values.xsize(), 0.0f, + kDistanceMultiplierDC, fast); + *dc_b = FindBestMultiplier(dc_values_yb, dc_values_b, dc_values.xsize(), + kYToBRatio, kDistanceMultiplierDC, fast); +} + +void ComputeTile(const Image3F& opsin, const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, const Quantizer* quantizer, + const Rect& r, bool fast, bool use_dct8, ImageSB* map_x, + ImageSB* map_b, ImageF* dc_values, float* mem) { + static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks, + "Invalid color tile dim"); + size_t xsize_blocks = opsin.xsize() / kBlockDim; + constexpr float kDistanceMultiplierAC = 1e-3f; + + const size_t y0 = r.y0(); + const size_t x0 = r.x0(); + const size_t x1 = r.x0() + r.xsize(); + const size_t y1 = r.y0() + r.ysize(); + + int ty = y0 / kColorTileDimInBlocks; + int tx = x0 / kColorTileDimInBlocks; + + int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty); + int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty); + + float* JXL_RESTRICT dc_values_yx = dc_values->Row(0); + float* JXL_RESTRICT dc_values_x = dc_values->Row(1); + float* JXL_RESTRICT dc_values_yb = dc_values->Row(2); + float* JXL_RESTRICT dc_values_b = dc_values->Row(3); + + // All are aligned. + float* HWY_RESTRICT block_y = mem; + float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim; + JXL_DASSERT(scratch_space + 2 * AcStrategy::kMaxCoeffArea == + block_y + CfLHeuristics::kItemsPerThread); + + // Small (~256 bytes each) + HWY_ALIGN_MAX float + dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + HWY_ALIGN_MAX float + dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + HWY_ALIGN_MAX float + dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + size_t num_ac = 0; + + for (size_t y = y0; y < y1; ++y) { + const float* JXL_RESTRICT row_y = opsin.ConstPlaneRow(1, y * kBlockDim); + const float* JXL_RESTRICT row_x = opsin.ConstPlaneRow(0, y * kBlockDim); + const float* JXL_RESTRICT row_b = opsin.ConstPlaneRow(2, y * kBlockDim); + size_t stride = opsin.PixelsPerRow(); + + for (size_t x = x0; x < x1; x++) { + AcStrategy acs = use_dct8 + ? AcStrategy::FromRawStrategy(AcStrategy::Type::DCT) + : ac_strategy->ConstRow(y)[x]; + if (!acs.IsFirstBlock()) continue; + size_t xs = acs.covered_blocks_x(); + TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride, + block_y, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs); + TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride, + block_x, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs); + TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride, + block_b, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs); + const float* const JXL_RESTRICT qm_x = + dequant.InvMatrix(acs.Strategy(), 0); + const float* const JXL_RESTRICT qm_b = + dequant.InvMatrix(acs.Strategy(), 2); + // Why does a constant seem to work better than + // raw_quant_field->Row(y)[x] ? + float q = use_dct8 ? 1 : quantizer->Scale() * 400.0f; + float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0); + float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2); + + // Copy DCs in dc_values. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < xs; ix++) { + dc_values_yx[(iy + y) * xsize_blocks + ix + x] = + dc_y[iy * xs + ix] * q_dc_x; + dc_values_x[(iy + y) * xsize_blocks + ix + x] = + dc_x[iy * xs + ix] * q_dc_x; + dc_values_yb[(iy + y) * xsize_blocks + ix + x] = + dc_y[iy * xs + ix] * q_dc_b; + dc_values_b[(iy + y) * xsize_blocks + ix + x] = + dc_b[iy * xs + ix] * q_dc_b; + } + } + + // Do not use this block for computing AC CfL. + if (acs.covered_blocks_x() + x0 > x1 || + acs.covered_blocks_y() + y0 > y1) { + continue; + } + + // Copy AC coefficients in the local block. The order in which + // coefficients get stored does not matter. + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + // Zero out LFs. This introduces terms in the optimization loop that + // don't affect the result, as they are all 0, but allow for simpler + // SIMDfication. + for (size_t iy = 0; iy < cy; iy++) { + for (size_t ix = 0; ix < cx; ix++) { + block_y[cx * kBlockDim * iy + ix] = 0; + block_x[cx * kBlockDim * iy + ix] = 0; + block_b[cx * kBlockDim * iy + ix] = 0; + } + } + const auto qv = Set(df, q); + for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) { + const auto b_y = Load(df, block_y + i); + const auto b_x = Load(df, block_x + i); + const auto b_b = Load(df, block_b + i); + const auto qqm_x = Mul(qv, Load(df, qm_x + i)); + const auto qqm_b = Mul(qv, Load(df, qm_b + i)); + Store(Mul(b_y, qqm_x), df, coeffs_yx + num_ac); + Store(Mul(b_x, qqm_x), df, coeffs_x + num_ac); + Store(Mul(b_y, qqm_b), df, coeffs_yb + num_ac); + Store(Mul(b_b, qqm_b), df, coeffs_b + num_ac); + num_ac += Lanes(df); + } + } + } + JXL_CHECK(num_ac % Lanes(df) == 0); + row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f, + kDistanceMultiplierAC, fast); + row_out_b[tx] = FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, kYToBRatio, + kDistanceMultiplierAC, fast); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(InitDCStorage); +HWY_EXPORT(ComputeDC); +HWY_EXPORT(ComputeTile); + +void CfLHeuristics::Init(const Image3F& opsin) { + size_t xsize_blocks = opsin.xsize() / kBlockDim; + size_t ysize_blocks = opsin.ysize() / kBlockDim; + HWY_DYNAMIC_DISPATCH(InitDCStorage) + (xsize_blocks * ysize_blocks, &dc_values); +} + +void CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const Quantizer* quantizer, bool fast, + size_t thread, ColorCorrelationMap* cmap) { + bool use_dct8 = ac_strategy == nullptr; + HWY_DYNAMIC_DISPATCH(ComputeTile) + (opsin, dequant, ac_strategy, quantizer, r, fast, use_dct8, &cmap->ytox_map, + &cmap->ytob_map, &dc_values, mem.get() + thread * kItemsPerThread); +} + +void CfLHeuristics::ComputeDC(bool fast, ColorCorrelationMap* cmap) { + int32_t ytob_dc = 0; + int32_t ytox_dc = 0; + HWY_DYNAMIC_DISPATCH(ComputeDC)(dc_values, fast, &ytox_dc, &ytob_dc); + cmap->SetYToBDC(ytob_dc); + cmap->SetYToXDC(ytox_dc); +} + +void ColorCorrelationMapEncodeDC(ColorCorrelationMap* map, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + float color_factor = map->GetColorFactor(); + float base_correlation_x = map->GetBaseCorrelationX(); + float base_correlation_b = map->GetBaseCorrelationB(); + int32_t ytox_dc = map->GetYToXDC(); + int32_t ytob_dc = map->GetYToBDC(); + + BitWriter::Allotment allotment(writer, 1 + 2 * kBitsPerByte + 12 + 32); + if (ytox_dc == 0 && ytob_dc == 0 && color_factor == kDefaultColorFactor && + base_correlation_x == 0.0f && base_correlation_b == kYToBRatio) { + writer->Write(1, 1); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return; + } + writer->Write(1, 0); + JXL_CHECK(U32Coder::Write(kColorFactorDist, color_factor, writer)); + JXL_CHECK(F16Coder::Write(base_correlation_x, writer)); + JXL_CHECK(F16Coder::Write(base_correlation_b, writer)); + writer->Write(kBitsPerByte, ytox_dc - std::numeric_limits::min()); + writer->Write(kBitsPerByte, ytob_dc - std::numeric_limits::min()); + ReclaimAndCharge(writer, &allotment, layer, aux_out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_chroma_from_luma.h b/media/libjxl/src/lib/jxl/enc_chroma_from_luma.h new file mode 100644 index 000000000..a09777403 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_chroma_from_luma.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ +#define LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ + +// Chroma-from-luma, computed using heuristics to determine the best linear +// model for the X and B channels from the Y channel. + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +void ColorCorrelationMapEncodeDC(ColorCorrelationMap* map, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +struct CfLHeuristics { + void Init(const Image3F& opsin); + + void PrepareForThreads(size_t num_threads) { + mem = hwy::AllocateAligned(num_threads * kItemsPerThread); + } + + void ComputeTile(const Rect& r, const Image3F& opsin, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const Quantizer* quantizer, bool fast, size_t thread, + ColorCorrelationMap* cmap); + + void ComputeDC(bool fast, ColorCorrelationMap* cmap); + + ImageF dc_values; + hwy::AlignedFreeUniquePtr mem; + + // Working set is too large for stack; allocate dynamically. + constexpr static size_t kItemsPerThread = + AcStrategy::kMaxCoeffArea * 3 // Blocks + + kColorTileDim * kColorTileDim * 4 // AC coeff storage + + AcStrategy::kMaxCoeffArea * 2; // Scratch space +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ diff --git a/media/libjxl/src/lib/jxl/enc_cluster.cc b/media/libjxl/src/lib/jxl/enc_cluster.cc new file mode 100644 index 000000000..c79b3ac83 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_cluster.cc @@ -0,0 +1,295 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_cluster.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/fast_math-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenZeroElse; + +template +V Entropy(V count, V inv_total, V total) { + const HWY_CAPPED(float, Histogram::kRounding) d; + const auto zero = Set(d, 0.0f); + // TODO(eustas): why (0 - x) instead of Neg(x)? + return IfThenZeroElse( + Eq(count, total), + Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); +} + +void HistogramEntropy(const Histogram& a) { + a.entropy_ = 0.0f; + if (a.total_count_ == 0) return; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto inv_tot = Set(df, 1.0f / a.total_count_); + auto entropy_lanes = Zero(df); + auto total = Set(df, a.total_count_); + + for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) { + const auto counts = LoadU(di, &a.data_[i]); + entropy_lanes = + Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); + } + a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes)); +} + +float HistogramDistance(const Histogram& a, const Histogram& b) { + if (a.total_count_ == 0 || b.total_count_ == 0) return 0; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_)); + auto distance_lanes = Zero(df); + auto total = Set(df, a.total_count_ + b.total_count_); + + for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size()); + i += Lanes(di)) { + const auto a_counts = + a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di); + const auto b_counts = + b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di); + const auto counts = ConvertTo(df, Add(a_counts, b_counts)); + distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); + } + const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); + return total_distance - a.entropy_ - b.entropy_; +} + +// First step of a k-means clustering with a fancy distance metric. +void FastClusterHistograms(const std::vector& in, + size_t max_histograms, std::vector* out, + std::vector* histogram_symbols) { + PROFILER_FUNC; + out->clear(); + out->reserve(max_histograms); + histogram_symbols->clear(); + histogram_symbols->resize(in.size(), max_histograms); + + std::vector dists(in.size(), std::numeric_limits::max()); + size_t largest_idx = 0; + for (size_t i = 0; i < in.size(); i++) { + if (in[i].total_count_ == 0) { + (*histogram_symbols)[i] = 0; + dists[i] = 0.0f; + continue; + } + HistogramEntropy(in[i]); + if (in[i].total_count_ > in[largest_idx].total_count_) { + largest_idx = i; + } + } + + constexpr float kMinDistanceForDistinct = 48.0f; + while (out->size() < max_histograms) { + (*histogram_symbols)[largest_idx] = out->size(); + out->push_back(in[largest_idx]); + dists[largest_idx] = 0.0f; + largest_idx = 0; + for (size_t i = 0; i < in.size(); i++) { + if (dists[i] == 0.0f) continue; + dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); + if (dists[i] > dists[largest_idx]) largest_idx = i; + } + if (dists[largest_idx] < kMinDistanceForDistinct) break; + } + + for (size_t i = 0; i < in.size(); i++) { + if ((*histogram_symbols)[i] != max_histograms) continue; + size_t best = 0; + float best_dist = HistogramDistance(in[i], (*out)[best]); + for (size_t j = 1; j < out->size(); j++) { + float dist = HistogramDistance(in[i], (*out)[j]); + if (dist < best_dist) { + best = j; + best_dist = dist; + } + } + (*out)[best].AddHistogram(in[i]); + HistogramEntropy((*out)[best]); + (*histogram_symbols)[i] = best; + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(FastClusterHistograms); // Local function +HWY_EXPORT(HistogramEntropy); // Local function + +float Histogram::ShannonEntropy() const { + HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); + return entropy_; +} + +namespace { +// ----------------------------------------------------------------------------- +// Histogram refinement + +// Reorder histograms in *out so that the new symbols in *symbols come in +// increasing order. +void HistogramReindex(std::vector* out, + std::vector* symbols) { + std::vector tmp(*out); + std::map new_index; + int next_index = 0; + for (uint32_t symbol : *symbols) { + if (new_index.find(symbol) == new_index.end()) { + new_index[symbol] = next_index; + (*out)[next_index] = tmp[symbol]; + ++next_index; + } + } + out->resize(next_index); + for (uint32_t& symbol : *symbols) { + symbol = new_index[symbol]; + } +} + +} // namespace + +// Clusters similar histograms in 'in' together, the selected histograms are +// placed in 'out', and for each index in 'in', *histogram_symbols will +// indicate which of the 'out' histograms is the best approximation. +void ClusterHistograms(const HistogramParams params, + const std::vector& in, size_t max_histograms, + std::vector* out, + std::vector* histogram_symbols) { + max_histograms = std::min(max_histograms, params.max_histograms); + max_histograms = std::min(max_histograms, in.size()); + if (params.clustering == HistogramParams::ClusteringType::kFastest) { + max_histograms = std::min(max_histograms, static_cast(4)); + } + + HWY_DYNAMIC_DISPATCH(FastClusterHistograms) + (in, max_histograms, out, histogram_symbols); + + if (params.clustering == HistogramParams::ClusteringType::kBest) { + for (size_t i = 0; i < out->size(); i++) { + (*out)[i].entropy_ = + ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size()); + } + uint32_t next_version = 2; + std::vector version(out->size(), 1); + std::vector renumbering(out->size()); + std::iota(renumbering.begin(), renumbering.end(), 0); + + // Try to pair up clusters if doing so reduces the total cost. + + struct HistogramPair { + // validity of a pair: p.version == max(version[i], version[j]) + float cost; + uint32_t first; + uint32_t second; + uint32_t version; + // We use > because priority queues sort in *decreasing* order, but we + // want lower cost elements to appear first. + bool operator<(const HistogramPair& other) const { + return std::make_tuple(cost, first, second, version) > + std::make_tuple(other.cost, other.first, other.second, + other.version); + } + }; + + // Create list of all pairs by increasing merging cost. + std::priority_queue pairs_to_merge; + for (uint32_t i = 0; i < out->size(); i++) { + for (uint32_t j = i + 1; j < out->size(); j++) { + Histogram histo; + histo.AddHistogram((*out)[i]); + histo.AddHistogram((*out)[j]); + float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - + (*out)[i].entropy_ - (*out)[j].entropy_; + // Avoid enqueueing pairs that are not advantageous to merge. + if (cost >= 0) continue; + pairs_to_merge.push( + HistogramPair{cost, i, j, std::max(version[i], version[j])}); + } + } + + // Merge the best pair to merge, add new pairs that get formed as a + // consequence. + while (!pairs_to_merge.empty()) { + uint32_t first = pairs_to_merge.top().first; + uint32_t second = pairs_to_merge.top().second; + uint32_t ver = pairs_to_merge.top().version; + pairs_to_merge.pop(); + if (ver != std::max(version[first], version[second]) || + version[first] == 0 || version[second] == 0) { + continue; + } + (*out)[first].AddHistogram((*out)[second]); + (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(), + (*out)[first].data_.size()); + for (size_t i = 0; i < renumbering.size(); i++) { + if (renumbering[i] == second) { + renumbering[i] = first; + } + } + version[second] = 0; + version[first] = next_version++; + for (uint32_t j = 0; j < out->size(); j++) { + if (j == first) continue; + if (version[j] == 0) continue; + Histogram histo; + histo.AddHistogram((*out)[first]); + histo.AddHistogram((*out)[j]); + float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - + (*out)[first].entropy_ - (*out)[j].entropy_; + // Avoid enqueueing pairs that are not advantageous to merge. + if (cost >= 0) continue; + pairs_to_merge.push( + HistogramPair{cost, std::min(first, j), std::max(first, j), + std::max(version[first], version[j])}); + } + } + std::vector reverse_renumbering(out->size(), -1); + size_t num_alive = 0; + for (size_t i = 0; i < out->size(); i++) { + if (version[i] == 0) continue; + (*out)[num_alive++] = (*out)[i]; + reverse_renumbering[i] = num_alive - 1; + } + out->resize(num_alive); + for (size_t i = 0; i < histogram_symbols->size(); i++) { + (*histogram_symbols)[i] = + reverse_renumbering[renumbering[(*histogram_symbols)[i]]]; + } + } + + // Convert the context map to a canonical form. + HistogramReindex(out, histogram_symbols); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_cluster.h b/media/libjxl/src/lib/jxl/enc_cluster.h new file mode 100644 index 000000000..a06783fcc --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_cluster.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Functions for clustering similar histograms together. + +#ifndef LIB_JXL_ENC_CLUSTER_H_ +#define LIB_JXL_ENC_CLUSTER_H_ + +#include +#include +#include + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/enc_ans.h" + +namespace jxl { + +struct Histogram { + Histogram() { total_count_ = 0; } + void Clear() { + data_.clear(); + total_count_ = 0; + } + void Add(size_t symbol) { + if (data_.size() <= symbol) { + data_.resize(DivCeil(symbol + 1, kRounding) * kRounding); + } + ++data_[symbol]; + ++total_count_; + } + void AddHistogram(const Histogram& other) { + if (other.data_.size() > data_.size()) { + data_.resize(other.data_.size()); + } + for (size_t i = 0; i < other.data_.size(); ++i) { + data_[i] += other.data_[i]; + } + total_count_ += other.total_count_; + } + float PopulationCost() const { + return ANSPopulationCost(data_.data(), data_.size()); + } + float ShannonEntropy() const; + + std::vector data_; + size_t total_count_; + mutable float entropy_; // WARNING: not kept up-to-date. + static constexpr size_t kRounding = 8; +}; + +void ClusterHistograms(HistogramParams params, const std::vector& in, + size_t max_histograms, std::vector* out, + std::vector* histogram_symbols); +} // namespace jxl + +#endif // LIB_JXL_ENC_CLUSTER_H_ diff --git a/media/libjxl/src/lib/jxl/enc_coeff_order.cc b/media/libjxl/src/lib/jxl/enc_coeff_order.cc new file mode 100644 index 000000000..8d75cc038 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_coeff_order.cc @@ -0,0 +1,290 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/lehmer_code.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +std::pair ComputeUsedOrders( + const SpeedTier speed, const AcStrategyImage& ac_strategy, + const Rect& rect) { + // Only uses DCT8 = 0, so bitfield = 1. + if (speed >= SpeedTier::kFalcon) return {1, 1}; + + uint32_t ret = 0; + uint32_t ret_customize = 0; + size_t xsize_blocks = rect.xsize(); + size_t ysize_blocks = rect.ysize(); + // TODO(veluca): precompute when doing DCT. + for (size_t by = 0; by < ysize_blocks; ++by) { + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < xsize_blocks; ++bx) { + int ord = kStrategyOrder[acs_row[bx].RawStrategy()]; + // Do not customize coefficient orders for blocks bigger than 32x32. + ret |= 1u << ord; + if (ord > 6) { + continue; + } + ret_customize |= 1u << ord; + } + } + // Use default orders for small images. + if (ac_strategy.xsize() < 5 && ac_strategy.ysize() < 5) return {ret, 0}; + return {ret, ret_customize}; +} + +void ComputeCoeffOrder(SpeedTier speed, const ACImage& acs, + const AcStrategyImage& ac_strategy, + const FrameDimensions& frame_dim, uint32_t& used_orders, + uint16_t used_acs, coeff_order_t* JXL_RESTRICT order) { + std::vector num_zeros(kCoeffOrderMaxSize); + // If compressing at high speed and only using 8x8 DCTs, only consider a + // subset of blocks. + double block_fraction = 1.0f; + // TODO(veluca): figure out why sampling blocks if non-8x8s are used makes + // encoding significantly less dense. + if (speed >= SpeedTier::kSquirrel && used_orders == 1) { + block_fraction = 0.5f; + } + // No need to compute number of zero coefficients if all orders are the + // default. + if (used_orders != 0) { + uint64_t threshold = + (std::numeric_limits::max() >> 32) * block_fraction; + uint64_t s[2] = {static_cast(0x94D049BB133111EBull), + static_cast(0xBF58476D1CE4E5B9ull)}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + // Count number of zero coefficients, separately for each DCT band. + // TODO(veluca): precompute when doing DCT. + for (size_t group_index = 0; group_index < frame_dim.num_groups; + group_index++) { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * kGroupDimInBlocks, gy * kGroupDimInBlocks, + kGroupDimInBlocks, kGroupDimInBlocks, + frame_dim.xsize_blocks, frame_dim.ysize_blocks); + ConstACPtr rows[3]; + ACType type = acs.Type(); + for (size_t c = 0; c < 3; c++) { + rows[c] = acs.PlaneRow(c, group_index, 0); + } + size_t ac_offset = 0; + + // TODO(veluca): SIMDfy. + for (size_t by = 0; by < rect.ysize(); ++by) { + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < rect.xsize(); ++bx) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + if (!use_sample()) continue; + size_t size = kDCTBlockSize << acs.log2_covered_blocks(); + for (size_t c = 0; c < 3; ++c) { + const size_t order_offset = + CoeffOrderOffset(kStrategyOrder[acs.RawStrategy()], c); + if (type == ACType::k16) { + for (size_t k = 0; k < size; k++) { + bool is_zero = rows[c].ptr16[ac_offset + k] == 0; + num_zeros[order_offset + k] += is_zero ? 1 : 0; + } + } else { + for (size_t k = 0; k < size; k++) { + bool is_zero = rows[c].ptr32[ac_offset + k] == 0; + num_zeros[order_offset + k] += is_zero ? 1 : 0; + } + } + // Ensure LLFs are first in the order. + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + for (size_t iy = 0; iy < cy; iy++) { + for (size_t ix = 0; ix < cx; ix++) { + num_zeros[order_offset + iy * kBlockDim * cx + ix] = -1; + } + } + } + ac_offset += size; + } + } + } + } + struct PosAndCount { + uint32_t pos; + uint32_t count; + }; + auto mem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); + + std::vector natural_order_buffer; + + uint16_t computed = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + size_t sz = kDCTBlockSize * acs.covered_blocks_x() * acs.covered_blocks_y(); + + // Do nothing for transforms that don't appear. + if ((1 << ord) & ~used_acs) continue; + + if (natural_order_buffer.size() < sz) natural_order_buffer.resize(sz); + acs.ComputeNaturalCoeffOrder(natural_order_buffer.data()); + + // Ensure natural coefficient order is not permuted if the order is + // not transmitted. + if ((1 << ord) & ~used_orders) { + for (size_t c = 0; c < 3; c++) { + size_t offset = CoeffOrderOffset(ord, c); + JXL_DASSERT(CoeffOrderOffset(ord, c + 1) - offset == sz); + memcpy(&order[offset], natural_order_buffer.data(), + sz * sizeof(*order)); + } + continue; + } + + bool is_nondefault = false; + for (uint8_t c = 0; c < 3; c++) { + // Apply zig-zag order. + PosAndCount* pos_and_val = mem.get(); + size_t offset = CoeffOrderOffset(ord, c); + JXL_DASSERT(CoeffOrderOffset(ord, c + 1) - offset == sz); + float inv_sqrt_sz = 1.0f / std::sqrt(sz); + for (size_t i = 0; i < sz; ++i) { + size_t pos = natural_order_buffer[i]; + pos_and_val[i].pos = pos; + // We don't care for the exact number -> quantize number of zeros, + // to get less permuted order. + pos_and_val[i].count = num_zeros[offset + pos] * inv_sqrt_sz + 0.1f; + } + + // Stable-sort -> elements with same number of zeros will preserve their + // order. + auto comparator = [](const PosAndCount& a, const PosAndCount& b) -> bool { + return a.count < b.count; + }; + std::stable_sort(pos_and_val, pos_and_val + sz, comparator); + + // Grab indices. + for (size_t i = 0; i < sz; ++i) { + order[offset + i] = pos_and_val[i].pos; + is_nondefault |= natural_order_buffer[i] != pos_and_val[i].pos; + } + } + if (!is_nondefault) { + used_orders &= ~(1 << ord); + } + } +} + +namespace { + +void TokenizePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, std::vector* tokens) { + std::vector lehmer(size); + std::vector temp(size + 1); + ComputeLehmerCode(order, temp.data(), size, lehmer.data()); + size_t end = size; + while (end > skip && lehmer[end - 1] == 0) { + --end; + } + tokens->emplace_back(CoeffOrderContext(size), end - skip); + uint32_t last = 0; + for (size_t i = skip; i < end; ++i) { + tokens->emplace_back(CoeffOrderContext(last), lehmer[i]); + last = lehmer[i]; + } +} + +} // namespace + +void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, BitWriter* writer, int layer, + AuxOut* aux_out) { + std::vector> tokens(1); + TokenizePermutation(order, skip, size, &tokens[0]); + std::vector context_map; + EntropyEncodingData codes; + BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, + &codes, &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); +} + +namespace { +void EncodeCoeffOrder(const coeff_order_t* JXL_RESTRICT order, AcStrategy acs, + std::vector* tokens, coeff_order_t* order_zigzag, + std::vector& natural_order_lut) { + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + for (size_t i = 0; i < size; ++i) { + order_zigzag[i] = natural_order_lut[order[i]]; + } + TokenizePermutation(order_zigzag, llf, size, tokens); +} +} // namespace + +void EncodeCoeffOrders(uint16_t used_orders, + const coeff_order_t* JXL_RESTRICT order, + BitWriter* writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out) { + auto mem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); + uint16_t computed = 0; + std::vector> tokens(1); + std::vector natural_order_lut; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + if ((used_orders & (1 << ord)) == 0) continue; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + if (natural_order_lut.size() < size) natural_order_lut.resize(size); + acs.ComputeNaturalCoeffOrderLut(natural_order_lut.data()); + for (size_t c = 0; c < 3; c++) { + EncodeCoeffOrder(&order[CoeffOrderOffset(ord, c)], acs, &tokens[0], + mem.get(), natural_order_lut); + } + } + // Do not write anything if no order is used. + if (used_orders != 0) { + std::vector context_map; + EntropyEncodingData codes; + BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, + &codes, &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_coeff_order.h b/media/libjxl/src/lib/jxl/enc_coeff_order.h new file mode 100644 index 000000000..7a237f2b4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_coeff_order.h @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_COEFF_ORDER_H_ +#define LIB_JXL_ENC_COEFF_ORDER_H_ + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" + +namespace jxl { + +// Orders that are actually used in part of image. `rect` is in block units. +// Returns {orders that are used, orders that might be made non-default}. +std::pair ComputeUsedOrders( + SpeedTier speed, const AcStrategyImage& ac_strategy, const Rect& rect); + +// Modify zig-zag order, so that DCT bands with more zeros go later. +// Order of DCT bands with same number of zeros is untouched, so +// permutation will be cheaper to encode. +void ComputeCoeffOrder(SpeedTier speed, const ACImage& acs, + const AcStrategyImage& ac_strategy, + const FrameDimensions& frame_dim, uint32_t& used_orders, + uint16_t used_acs, coeff_order_t* JXL_RESTRICT order); + +void EncodeCoeffOrders(uint16_t used_orders, + const coeff_order_t* JXL_RESTRICT order, + BitWriter* writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out); + +// Encoding/decoding of a single permutation. `size`: number of elements in the +// permutation. `skip`: number of elements to skip from the *beginning* of the +// permutation. +void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, BitWriter* writer, int layer, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COEFF_ORDER_H_ diff --git a/media/libjxl/src/lib/jxl/enc_color_management.cc b/media/libjxl/src/lib/jxl/enc_color_management.cc new file mode 100644 index 000000000..0b031d2dc --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_color_management.cc @@ -0,0 +1,1191 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_color_management.h" + +#ifndef JPEGXL_ENABLE_SKCMS +#define JPEGXL_ENABLE_SKCMS 0 +#endif + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_color_management.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/transfer_functions-inl.h" +#if JPEGXL_ENABLE_SKCMS +#include "lib/jxl/enc_jxl_skcms.h" +#else // JPEGXL_ENABLE_SKCMS +#include "lcms2.h" +#include "lcms2_plugin.h" +#endif // JPEGXL_ENABLE_SKCMS + +#define JXL_CMS_VERBOSE 0 + +// Define these only once. We can't use HWY_ONCE here because it is defined as +// 1 only on the last pass. +#ifndef LIB_JXL_ENC_COLOR_MANAGEMENT_CC_ +#define LIB_JXL_ENC_COLOR_MANAGEMENT_CC_ + +namespace jxl { +namespace { +struct JxlCms { +#if JPEGXL_ENABLE_SKCMS + PaddedBytes icc_src, icc_dst; + skcms_ICCProfile profile_src, profile_dst; +#else + void* lcms_transform; +#endif + + // These fields are used when the HLG OOTF or inverse OOTF must be applied. + bool apply_hlg_ootf; + size_t hlg_ootf_num_channels; + // Y component of the primaries. + std::array hlg_ootf_luminances; + + size_t channels_src; + size_t channels_dst; + ImageF buf_src; + ImageF buf_dst; + float intensity_target; + bool skip_lcms = false; + ExtraTF preprocess = ExtraTF::kNone; + ExtraTF postprocess = ExtraTF::kNone; +}; + +Status ApplyHlgOotf(JxlCms* t, float* JXL_RESTRICT buf, size_t xsize, + bool forward); +} // namespace +} // namespace jxl + +#endif // LIB_JXL_ENC_COLOR_MANAGEMENT_CC_ + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +#if JXL_CMS_VERBOSE >= 2 +const size_t kX = 0; // pixel index, multiplied by 3 for RGB +#endif + +// xform_src = UndoGammaCompression(buf_src). +Status BeforeTransform(JxlCms* t, const float* buf_src, float* xform_src, + size_t buf_size) { + switch (t->preprocess) { + case ExtraTF::kNone: + JXL_DASSERT(false); // unreachable + break; + + case ExtraTF::kPQ: { + // By default, PQ content has an intensity target of 10000, stored + // exactly. + HWY_FULL(float) df; + const auto multiplier = Set(df, t->intensity_target == 10000.f + ? 1.0f + : 10000.f / t->intensity_target); + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_src + i); + const auto result = + Mul(multiplier, TF_PQ().DisplayFromEncoded(df, val)); + Store(result, df, xform_src + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoPQ %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + } + + case ExtraTF::kHLG: + for (size_t i = 0; i < buf_size; ++i) { + xform_src[i] = static_cast( + TF_HLG().DisplayFromEncoded(static_cast(buf_src[i]))); + } + if (t->apply_hlg_ootf) { + JXL_RETURN_IF_ERROR( + ApplyHlgOotf(t, xform_src, buf_size, /*forward=*/true)); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoHLG %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + + case ExtraTF::kSRGB: + HWY_FULL(float) df; + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_src + i); + const auto result = TF_SRGB().DisplayFromEncoded(val); + Store(result, df, xform_src + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoSRGB %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + } + return true; +} + +// Applies gamma compression in-place. +Status AfterTransform(JxlCms* t, float* JXL_RESTRICT buf_dst, size_t buf_size) { + switch (t->postprocess) { + case ExtraTF::kNone: + JXL_DASSERT(false); // unreachable + break; + case ExtraTF::kPQ: { + HWY_FULL(float) df; + const auto multiplier = + Set(df, t->intensity_target == 10000.f ? 1.0f + : t->intensity_target * 1e-4f); + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_dst + i); + const auto result = + TF_PQ().EncodedFromDisplay(df, Mul(multiplier, val)); + Store(result, df, buf_dst + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after PQ enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + } + case ExtraTF::kHLG: + if (t->apply_hlg_ootf) { + JXL_RETURN_IF_ERROR( + ApplyHlgOotf(t, buf_dst, buf_size, /*forward=*/false)); + } + for (size_t i = 0; i < buf_size; ++i) { + buf_dst[i] = static_cast( + TF_HLG().EncodedFromDisplay(static_cast(buf_dst[i]))); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after HLG enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + case ExtraTF::kSRGB: + HWY_FULL(float) df; + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_dst + i); + const auto result = + TF_SRGB().EncodedFromDisplay(HWY_FULL(float)(), val); + Store(result, df, buf_dst + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after SRGB enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + } + return true; +} + +Status DoColorSpaceTransform(void* cms_data, const size_t thread, + const float* buf_src, float* buf_dst, + size_t xsize) { + // No lock needed. + JxlCms* t = reinterpret_cast(cms_data); + + const float* xform_src = buf_src; // Read-only. + if (t->preprocess != ExtraTF::kNone) { + float* mutable_xform_src = t->buf_src.Row(thread); // Writable buffer. + JXL_RETURN_IF_ERROR(BeforeTransform(t, buf_src, mutable_xform_src, + xsize * t->channels_src)); + xform_src = mutable_xform_src; + } + +#if JPEGXL_ENABLE_SKCMS + if (t->channels_src == 1 && !t->skip_lcms) { + // Expand from 1 to 3 channels, starting from the end in case + // xform_src == t->buf_src.Row(thread). + float* mutable_xform_src = t->buf_src.Row(thread); + for (size_t i = 0; i < xsize; ++i) { + const size_t x = xsize - i - 1; + mutable_xform_src[x * 3] = mutable_xform_src[x * 3 + 1] = + mutable_xform_src[x * 3 + 2] = xform_src[x]; + } + xform_src = mutable_xform_src; + } +#else + if (t->channels_src == 4 && !t->skip_lcms) { + // LCMS does CMYK in a weird way: 0 = white, 100 = max ink + float* mutable_xform_src = t->buf_src.Row(thread); + for (size_t x = 0; x < xsize * 4; ++x) { + mutable_xform_src[x] = 100.f - 100.f * mutable_xform_src[x]; + } + xform_src = mutable_xform_src; + } +#endif + +#if JXL_CMS_VERBOSE >= 2 + // Save inputs for printing before in-place transforms overwrite them. + const float in0 = xform_src[3 * kX + 0]; + const float in1 = xform_src[3 * kX + 1]; + const float in2 = xform_src[3 * kX + 2]; +#endif + + if (t->skip_lcms) { + if (buf_dst != xform_src) { + memcpy(buf_dst, xform_src, xsize * t->channels_src * sizeof(*buf_dst)); + } // else: in-place, no need to copy + } else { +#if JPEGXL_ENABLE_SKCMS + JXL_CHECK( + skcms_Transform(xform_src, + (t->channels_src == 4 ? skcms_PixelFormat_RGBA_ffff + : skcms_PixelFormat_RGB_fff), + skcms_AlphaFormat_Opaque, &t->profile_src, buf_dst, + skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &t->profile_dst, xsize)); +#else // JPEGXL_ENABLE_SKCMS + cmsDoTransform(t->lcms_transform, xform_src, buf_dst, + static_cast(xsize)); +#endif // JPEGXL_ENABLE_SKCMS + } +#if JXL_CMS_VERBOSE >= 2 + printf("xform skip%d: %.4f %.4f %.4f (%p) -> (%p) %.4f %.4f %.4f\n", + t->skip_lcms, in0, in1, in2, xform_src, buf_dst, buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + +#if JPEGXL_ENABLE_SKCMS + if (t->channels_dst == 1 && !t->skip_lcms) { + // Contract back from 3 to 1 channel, this time forward. + float* grayscale_buf_dst = t->buf_dst.Row(thread); + for (size_t x = 0; x < xsize; ++x) { + grayscale_buf_dst[x] = buf_dst[x * 3]; + } + buf_dst = grayscale_buf_dst; + } +#endif + + if (t->postprocess != ExtraTF::kNone) { + JXL_RETURN_IF_ERROR(AfterTransform(t, buf_dst, xsize * t->channels_dst)); + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(DoColorSpaceTransform); +int DoColorSpaceTransform(void* t, size_t thread, const float* buf_src, + float* buf_dst, size_t xsize) { + return HWY_DYNAMIC_DISPATCH(DoColorSpaceTransform)(t, thread, buf_src, + buf_dst, xsize); +} + +// Define to 1 on OS X as a workaround for older LCMS lacking MD5. +#define JXL_CMS_OLD_VERSION 0 + +#if JPEGXL_ENABLE_SKCMS + +JXL_MUST_USE_RESULT CIExy CIExyFromXYZ(const float XYZ[3]) { + const float factor = 1.f / (XYZ[0] + XYZ[1] + XYZ[2]); + CIExy xy; + xy.x = XYZ[0] * factor; + xy.y = XYZ[1] * factor; + return xy; +} + +#else // JPEGXL_ENABLE_SKCMS +// (LCMS interface requires xyY but we omit the Y for white points/primaries.) + +JXL_MUST_USE_RESULT CIExy CIExyFromxyY(const cmsCIExyY& xyY) { + CIExy xy; + xy.x = xyY.x; + xy.y = xyY.y; + return xy; +} + +JXL_MUST_USE_RESULT CIExy CIExyFromXYZ(const cmsCIEXYZ& XYZ) { + cmsCIExyY xyY; + cmsXYZ2xyY(/*Dest=*/&xyY, /*Source=*/&XYZ); + return CIExyFromxyY(xyY); +} + +JXL_MUST_USE_RESULT cmsCIEXYZ D50_XYZ() { + // Quantized D50 as stored in ICC profiles. + return {0.96420288, 1.0, 0.82490540}; +} + +JXL_MUST_USE_RESULT cmsCIExyY xyYFromCIExy(const CIExy& xy) { + const cmsCIExyY xyY = {xy.x, xy.y, 1.0}; + return xyY; +} + +// RAII + +struct ProfileDeleter { + void operator()(void* p) { cmsCloseProfile(p); } +}; +using Profile = std::unique_ptr; + +struct TransformDeleter { + void operator()(void* p) { cmsDeleteTransform(p); } +}; +using Transform = std::unique_ptr; + +struct CurveDeleter { + void operator()(cmsToneCurve* p) { cmsFreeToneCurve(p); } +}; +using Curve = std::unique_ptr; + +Status CreateProfileXYZ(const cmsContext context, + Profile* JXL_RESTRICT profile) { + profile->reset(cmsCreateXYZProfileTHR(context)); + if (profile->get() == nullptr) return JXL_FAILURE("Failed to create XYZ"); + return true; +} + +#endif // !JPEGXL_ENABLE_SKCMS + +#if JPEGXL_ENABLE_SKCMS +// IMPORTANT: icc must outlive profile. +Status DecodeProfile(const uint8_t* icc, size_t size, + skcms_ICCProfile* const profile) { + if (!skcms_Parse(icc, size, profile)) { + return JXL_FAILURE("Failed to parse ICC profile with %" PRIuS " bytes", + size); + } + return true; +} +#else // JPEGXL_ENABLE_SKCMS +Status DecodeProfile(const cmsContext context, const PaddedBytes& icc, + Profile* profile) { + profile->reset(cmsOpenProfileFromMemTHR(context, icc.data(), icc.size())); + if (profile->get() == nullptr) { + return JXL_FAILURE("Failed to decode profile"); + } + + // WARNING: due to the LCMS MD5 issue mentioned above, many existing + // profiles have incorrect MD5, so do not even bother checking them nor + // generating warning clutter. + + return true; +} +#endif // JPEGXL_ENABLE_SKCMS + +#if JPEGXL_ENABLE_SKCMS + +ColorSpace ColorSpaceFromProfile(const skcms_ICCProfile& profile) { + switch (profile.data_color_space) { + case skcms_Signature_RGB: + case skcms_Signature_CMYK: + // spec says CMYK is encoded as RGB (the kBlack extra channel signals that + // it is actually CMYK) + return ColorSpace::kRGB; + case skcms_Signature_Gray: + return ColorSpace::kGray; + default: + return ColorSpace::kUnknown; + } +} + +// "profile1" is pre-decoded to save time in DetectTransferFunction. +Status ProfileEquivalentToICC(const skcms_ICCProfile& profile1, + const PaddedBytes& icc) { + skcms_ICCProfile profile2; + JXL_RETURN_IF_ERROR(skcms_Parse(icc.data(), icc.size(), &profile2)); + return skcms_ApproximatelyEqualProfiles(&profile1, &profile2); +} + +// vector_out := matmul(matrix, vector_in) +void MatrixProduct(const skcms_Matrix3x3& matrix, const float vector_in[3], + float vector_out[3]) { + for (int i = 0; i < 3; ++i) { + vector_out[i] = 0; + for (int j = 0; j < 3; ++j) { + vector_out[i] += matrix.vals[i][j] * vector_in[j]; + } + } +} + +// Returns white point that was specified when creating the profile. +JXL_MUST_USE_RESULT Status UnadaptedWhitePoint(const skcms_ICCProfile& profile, + CIExy* out) { + float media_white_point_XYZ[3]; + if (!skcms_GetWTPT(&profile, media_white_point_XYZ)) { + return JXL_FAILURE("ICC profile does not contain WhitePoint tag"); + } + skcms_Matrix3x3 CHAD; + if (!skcms_GetCHAD(&profile, &CHAD)) { + // If there is no chromatic adaptation matrix, it means that the white point + // is already unadapted. + *out = CIExyFromXYZ(media_white_point_XYZ); + return true; + } + // Otherwise, it has been adapted to the PCS white point using said matrix, + // and the adaptation needs to be undone. + skcms_Matrix3x3 inverse_CHAD; + if (!skcms_Matrix3x3_invert(&CHAD, &inverse_CHAD)) { + return JXL_FAILURE("Non-invertible ChromaticAdaptation matrix"); + } + float unadapted_white_point_XYZ[3]; + MatrixProduct(inverse_CHAD, media_white_point_XYZ, unadapted_white_point_XYZ); + *out = CIExyFromXYZ(unadapted_white_point_XYZ); + return true; +} + +Status IdentifyPrimaries(const skcms_ICCProfile& profile, + const CIExy& wp_unadapted, ColorEncoding* c) { + if (!c->HasPrimaries()) return true; + + skcms_Matrix3x3 CHAD, inverse_CHAD; + if (skcms_GetCHAD(&profile, &CHAD)) { + JXL_RETURN_IF_ERROR(skcms_Matrix3x3_invert(&CHAD, &inverse_CHAD)); + } else { + static constexpr skcms_Matrix3x3 kLMSFromXYZ = { + {{0.8951, 0.2664, -0.1614}, + {-0.7502, 1.7135, 0.0367}, + {0.0389, -0.0685, 1.0296}}}; + static constexpr skcms_Matrix3x3 kXYZFromLMS = { + {{0.9869929, -0.1470543, 0.1599627}, + {0.4323053, 0.5183603, 0.0492912}, + {-0.0085287, 0.0400428, 0.9684867}}}; + static constexpr float kWpD50XYZ[3] = {0.96420288, 1.0, 0.82490540}; + float wp_unadapted_XYZ[3]; + JXL_RETURN_IF_ERROR(CIEXYZFromWhiteCIExy(wp_unadapted, wp_unadapted_XYZ)); + float wp_D50_LMS[3], wp_unadapted_LMS[3]; + MatrixProduct(kLMSFromXYZ, kWpD50XYZ, wp_D50_LMS); + MatrixProduct(kLMSFromXYZ, wp_unadapted_XYZ, wp_unadapted_LMS); + inverse_CHAD = {{{wp_unadapted_LMS[0] / wp_D50_LMS[0], 0, 0}, + {0, wp_unadapted_LMS[1] / wp_D50_LMS[1], 0}, + {0, 0, wp_unadapted_LMS[2] / wp_D50_LMS[2]}}}; + inverse_CHAD = skcms_Matrix3x3_concat(&kXYZFromLMS, &inverse_CHAD); + inverse_CHAD = skcms_Matrix3x3_concat(&inverse_CHAD, &kLMSFromXYZ); + } + + float XYZ[3]; + PrimariesCIExy primaries; + CIExy* const chromaticities[] = {&primaries.r, &primaries.g, &primaries.b}; + for (int i = 0; i < 3; ++i) { + float RGB[3] = {}; + RGB[i] = 1; + skcms_Transform(RGB, skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &profile, XYZ, skcms_PixelFormat_RGB_fff, + skcms_AlphaFormat_Opaque, skcms_XYZD50_profile(), 1); + float unadapted_XYZ[3]; + MatrixProduct(inverse_CHAD, XYZ, unadapted_XYZ); + *chromaticities[i] = CIExyFromXYZ(unadapted_XYZ); + } + return c->SetPrimaries(primaries); +} + +void DetectTransferFunction(const skcms_ICCProfile& profile, + ColorEncoding* JXL_RESTRICT c) { + if (c->tf.SetImplicit()) return; + + for (TransferFunction tf : Values()) { + // Can only create profile from known transfer function. + if (tf == TransferFunction::kUnknown) continue; + + c->tf.SetTransferFunction(tf); + + skcms_ICCProfile profile_test; + PaddedBytes bytes; + if (MaybeCreateProfile(*c, &bytes) && + DecodeProfile(bytes.data(), bytes.size(), &profile_test) && + skcms_ApproximatelyEqualProfiles(&profile, &profile_test)) { + return; + } + } + + c->tf.SetTransferFunction(TransferFunction::kUnknown); +} + +#else // JPEGXL_ENABLE_SKCMS + +uint32_t Type32(const ColorEncoding& c, bool cmyk) { + if (cmyk) return TYPE_CMYK_FLT; + if (c.IsGray()) return TYPE_GRAY_FLT; + return TYPE_RGB_FLT; +} + +uint32_t Type64(const ColorEncoding& c) { + if (c.IsGray()) return TYPE_GRAY_DBL; + return TYPE_RGB_DBL; +} + +ColorSpace ColorSpaceFromProfile(const Profile& profile) { + switch (cmsGetColorSpace(profile.get())) { + case cmsSigRgbData: + case cmsSigCmykData: + return ColorSpace::kRGB; + case cmsSigGrayData: + return ColorSpace::kGray; + default: + return ColorSpace::kUnknown; + } +} + +// "profile1" is pre-decoded to save time in DetectTransferFunction. +Status ProfileEquivalentToICC(const cmsContext context, const Profile& profile1, + const PaddedBytes& icc, const ColorEncoding& c) { + const uint32_t type_src = Type64(c); + + Profile profile2; + JXL_RETURN_IF_ERROR(DecodeProfile(context, icc, &profile2)); + + Profile profile_xyz; + JXL_RETURN_IF_ERROR(CreateProfileXYZ(context, &profile_xyz)); + + const uint32_t intent = INTENT_RELATIVE_COLORIMETRIC; + const uint32_t flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_BLACKPOINTCOMPENSATION | + cmsFLAGS_HIGHRESPRECALC; + Transform xform1(cmsCreateTransformTHR(context, profile1.get(), type_src, + profile_xyz.get(), TYPE_XYZ_DBL, + intent, flags)); + Transform xform2(cmsCreateTransformTHR(context, profile2.get(), type_src, + profile_xyz.get(), TYPE_XYZ_DBL, + intent, flags)); + if (xform1 == nullptr || xform2 == nullptr) { + return JXL_FAILURE("Failed to create transform"); + } + + double in[3]; + double out1[3]; + double out2[3]; + + // Uniformly spaced samples from very dark to almost fully bright. + const double init = 1E-3; + const double step = 0.2; + + if (c.IsGray()) { + // Finer sampling and replicate each component. + for (in[0] = init; in[0] < 1.0; in[0] += step / 8) { + cmsDoTransform(xform1.get(), in, out1, 1); + cmsDoTransform(xform2.get(), in, out2, 1); + if (!ApproxEq(out1[0], out2[0], 2E-4)) { + return false; + } + } + } else { + for (in[0] = init; in[0] < 1.0; in[0] += step) { + for (in[1] = init; in[1] < 1.0; in[1] += step) { + for (in[2] = init; in[2] < 1.0; in[2] += step) { + cmsDoTransform(xform1.get(), in, out1, 1); + cmsDoTransform(xform2.get(), in, out2, 1); + for (size_t i = 0; i < 3; ++i) { + if (!ApproxEq(out1[i], out2[i], 2E-4)) { + return false; + } + } + } + } + } + } + + return true; +} + +// Returns white point that was specified when creating the profile. +// NOTE: we can't just use cmsSigMediaWhitePointTag because its interpretation +// differs between ICC versions. +JXL_MUST_USE_RESULT cmsCIEXYZ UnadaptedWhitePoint(const cmsContext context, + const Profile& profile, + const ColorEncoding& c) { + const cmsCIEXYZ* white_point = static_cast( + cmsReadTag(profile.get(), cmsSigMediaWhitePointTag)); + if (white_point != nullptr && + cmsReadTag(profile.get(), cmsSigChromaticAdaptationTag) == nullptr) { + // No chromatic adaptation matrix: the white point is already unadapted. + return *white_point; + } + + cmsCIEXYZ XYZ = {1.0, 1.0, 1.0}; + Profile profile_xyz; + if (!CreateProfileXYZ(context, &profile_xyz)) return XYZ; + // Array arguments are one per profile. + cmsHPROFILE profiles[2] = {profile.get(), profile_xyz.get()}; + // Leave white point unchanged - that is what we're trying to extract. + cmsUInt32Number intents[2] = {INTENT_ABSOLUTE_COLORIMETRIC, + INTENT_ABSOLUTE_COLORIMETRIC}; + cmsBool black_compensation[2] = {0, 0}; + cmsFloat64Number adaption[2] = {0.0, 0.0}; + // Only transforming a single pixel, so skip expensive optimizations. + cmsUInt32Number flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_HIGHRESPRECALC; + Transform xform(cmsCreateExtendedTransform( + context, 2, profiles, black_compensation, intents, adaption, nullptr, 0, + Type64(c), TYPE_XYZ_DBL, flags)); + if (!xform) return XYZ; // TODO(lode): return error + + // xy are relative, so magnitude does not matter if we ignore output Y. + const cmsFloat64Number in[3] = {1.0, 1.0, 1.0}; + cmsDoTransform(xform.get(), in, &XYZ.X, 1); + return XYZ; +} + +Status IdentifyPrimaries(const cmsContext context, const Profile& profile, + const cmsCIEXYZ& wp_unadapted, ColorEncoding* c) { + if (!c->HasPrimaries()) return true; + if (ColorSpaceFromProfile(profile) == ColorSpace::kUnknown) return true; + + // These were adapted to the profile illuminant before storing in the profile. + const cmsCIEXYZ* adapted_r = static_cast( + cmsReadTag(profile.get(), cmsSigRedColorantTag)); + const cmsCIEXYZ* adapted_g = static_cast( + cmsReadTag(profile.get(), cmsSigGreenColorantTag)); + const cmsCIEXYZ* adapted_b = static_cast( + cmsReadTag(profile.get(), cmsSigBlueColorantTag)); + + cmsCIEXYZ converted_rgb[3]; + if (adapted_r == nullptr || adapted_g == nullptr || adapted_b == nullptr) { + // No colorant tag, determine the XYZ coordinates of the primaries by + // converting from the colorspace. + Profile profile_xyz; + if (!CreateProfileXYZ(context, &profile_xyz)) { + return JXL_FAILURE("Failed to retrieve colorants"); + } + // Array arguments are one per profile. + cmsHPROFILE profiles[2] = {profile.get(), profile_xyz.get()}; + cmsUInt32Number intents[2] = {INTENT_RELATIVE_COLORIMETRIC, + INTENT_RELATIVE_COLORIMETRIC}; + cmsBool black_compensation[2] = {0, 0}; + cmsFloat64Number adaption[2] = {0.0, 0.0}; + // Only transforming three pixels, so skip expensive optimizations. + cmsUInt32Number flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_HIGHRESPRECALC; + Transform xform(cmsCreateExtendedTransform( + context, 2, profiles, black_compensation, intents, adaption, nullptr, 0, + Type64(*c), TYPE_XYZ_DBL, flags)); + if (!xform) return JXL_FAILURE("Failed to retrieve colorants"); + + const cmsFloat64Number in[9] = {1.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0}; + cmsDoTransform(xform.get(), in, &converted_rgb->X, 3); + adapted_r = &converted_rgb[0]; + adapted_g = &converted_rgb[1]; + adapted_b = &converted_rgb[2]; + } + + // TODO(janwas): no longer assume Bradford and D50. + // Undo the chromatic adaptation. + const cmsCIEXYZ d50 = D50_XYZ(); + + cmsCIEXYZ r, g, b; + cmsAdaptToIlluminant(&r, &d50, &wp_unadapted, adapted_r); + cmsAdaptToIlluminant(&g, &d50, &wp_unadapted, adapted_g); + cmsAdaptToIlluminant(&b, &d50, &wp_unadapted, adapted_b); + + const PrimariesCIExy rgb = {CIExyFromXYZ(r), CIExyFromXYZ(g), + CIExyFromXYZ(b)}; + return c->SetPrimaries(rgb); +} + +void DetectTransferFunction(const cmsContext context, const Profile& profile, + ColorEncoding* JXL_RESTRICT c) { + if (c->tf.SetImplicit()) return; + + for (TransferFunction tf : Values()) { + // Can only create profile from known transfer function. + if (tf == TransferFunction::kUnknown) continue; + + c->tf.SetTransferFunction(tf); + + PaddedBytes icc_test; + if (MaybeCreateProfile(*c, &icc_test) && + ProfileEquivalentToICC(context, profile, icc_test, *c)) { + return; + } + } + + c->tf.SetTransferFunction(TransferFunction::kUnknown); +} + +void ErrorHandler(cmsContext context, cmsUInt32Number code, const char* text) { + JXL_WARNING("LCMS error %u: %s", code, text); +} + +// Returns a context for the current thread, creating it if necessary. +cmsContext GetContext() { + static thread_local void* context_; + if (context_ == nullptr) { + context_ = cmsCreateContext(nullptr, nullptr); + JXL_ASSERT(context_ != nullptr); + + cmsSetLogErrorHandlerTHR(static_cast(context_), &ErrorHandler); + } + return static_cast(context_); +} + +#endif // JPEGXL_ENABLE_SKCMS + +Status GetPrimariesLuminances(const ColorEncoding& encoding, + float luminances[3]) { + // Explanation: + // We know that the three primaries must sum to white: + // + // [Xr, Xg, Xb; [1; [Xw; + // Yr, Yg, Yb; × 1; = Yw; + // Zr, Zg, Zb] 1] Zw] + // + // By noting that X = x·(X+Y+Z), Y = y·(X+Y+Z) and Z = z·(X+Y+Z) (note the + // lower case indicating chromaticity), and factoring the totals (X+Y+Z) out + // of the left matrix and into the all-ones vector, we get: + // + // [xr, xg, xb; [Xr + Yr + Zr; [Xw; + // yr, yg, yb; × Xg + Yg + Zg; = Yw; + // zr, zg, zb] Xb + Yb + Zb] Zw] + // + // Which makes it apparent that we can compute those totals as: + // + // [Xr + Yr + Zr; inv([xr, xg, xb; [Xw; + // Xg + Yg + Zg; = yr, yg, yb; × Yw; + // Xb + Yb + Zb] zr, zg, zb]) Zw] + // + // From there, by multiplying each total by its corresponding y, we get Y for + // that primary. + + float white_XYZ[3]; + JXL_RETURN_IF_ERROR( + CIEXYZFromWhiteCIExy(encoding.GetWhitePoint(), white_XYZ)); + + const PrimariesCIExy primaries = encoding.GetPrimaries(); + double chromaticities[3][3] = { + {primaries.r.x, primaries.g.x, primaries.b.x}, + {primaries.r.y, primaries.g.y, primaries.b.y}, + {1 - primaries.r.x - primaries.r.y, 1 - primaries.g.x - primaries.g.y, + 1 - primaries.b.x - primaries.b.y}}; + JXL_RETURN_IF_ERROR(Inv3x3Matrix(&chromaticities[0][0])); + const double ys[3] = {primaries.r.y, primaries.g.y, primaries.b.y}; + for (size_t i = 0; i < 3; ++i) { + luminances[i] = ys[i] * (chromaticities[i][0] * white_XYZ[0] + + chromaticities[i][1] * white_XYZ[1] + + chromaticities[i][2] * white_XYZ[2]); + } + return true; +} + +Status ApplyHlgOotf(JxlCms* t, float* JXL_RESTRICT buf, size_t xsize, + bool forward) { + if (295 <= t->intensity_target && t->intensity_target <= 305) { + // The gamma is approximately 1 so this can essentially be skipped. + return true; + } + float gamma = 1.2f * std::pow(1.111f, std::log2(t->intensity_target * 1e-3f)); + if (!forward) gamma = 1.f / gamma; + + switch (t->hlg_ootf_num_channels) { + case 1: + for (size_t x = 0; x < xsize; ++x) { + buf[x] = std::pow(buf[x], gamma); + } + break; + + case 3: + for (size_t x = 0; x < xsize; x += 3) { + const float luminance = buf[x] * t->hlg_ootf_luminances[0] + + buf[x + 1] * t->hlg_ootf_luminances[1] + + buf[x + 2] * t->hlg_ootf_luminances[2]; + const float ratio = std::pow(luminance, gamma - 1); + if (std::isfinite(ratio)) { + buf[x] *= ratio; + buf[x + 1] *= ratio; + buf[x + 2] *= ratio; + if (forward && gamma < 1) { + // If gamma < 1, the ratio above will be > 1 which can push bright + // saturated highlights out of gamut. There are several possible + // ways to bring them back in-gamut; this one preserves hue and + // saturation at the slight expense of luminance. If !forward, the + // previously-applied forward OOTF with gamma > 1 already pushed + // those highlights down and we are simply putting them back where + // they were so this is not necessary. + const float maximum = + std::max(buf[x], std::max(buf[x + 1], buf[x + 2])); + if (maximum > 1) { + const float normalizer = 1.f / maximum; + buf[x] *= normalizer; + buf[x + 1] *= normalizer; + buf[x + 2] *= normalizer; + } + } + } + } + break; + + default: + return JXL_FAILURE("HLG OOTF not implemented for %" PRIuS " channels", + t->hlg_ootf_num_channels); + } + return true; +} + +} // namespace + +Status ColorEncoding::SetFieldsFromICC() { + // In case parsing fails, mark the ColorEncoding as invalid. + SetColorSpace(ColorSpace::kUnknown); + tf.SetTransferFunction(TransferFunction::kUnknown); + + if (icc_.empty()) return JXL_FAILURE("Empty ICC profile"); + +#if JPEGXL_ENABLE_SKCMS + if (icc_.size() < 128) { + return JXL_FAILURE("ICC file too small"); + } + + skcms_ICCProfile profile; + JXL_RETURN_IF_ERROR(skcms_Parse(icc_.data(), icc_.size(), &profile)); + + // skcms does not return the rendering intent, so get it from the file. It + // is encoded as big-endian 32-bit integer in bytes 60..63. + uint32_t rendering_intent32 = icc_[67]; + if (rendering_intent32 > 3 || icc_[64] != 0 || icc_[65] != 0 || + icc_[66] != 0) { + return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); + } + + SetColorSpace(ColorSpaceFromProfile(profile)); + cmyk_ = (profile.data_color_space == skcms_Signature_CMYK); + + CIExy wp_unadapted; + JXL_RETURN_IF_ERROR(UnadaptedWhitePoint(profile, &wp_unadapted)); + JXL_RETURN_IF_ERROR(SetWhitePoint(wp_unadapted)); + + // Relies on color_space. + JXL_RETURN_IF_ERROR(IdentifyPrimaries(profile, wp_unadapted, this)); + + // Relies on color_space/white point/primaries being set already. + DetectTransferFunction(profile, this); + // ICC and RenderingIntent have the same values (0..3). + rendering_intent = static_cast(rendering_intent32); +#else // JPEGXL_ENABLE_SKCMS + + const cmsContext context = GetContext(); + + Profile profile; + JXL_RETURN_IF_ERROR(DecodeProfile(context, icc_, &profile)); + + const cmsUInt32Number rendering_intent32 = + cmsGetHeaderRenderingIntent(profile.get()); + if (rendering_intent32 > 3) { + return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); + } + // ICC and RenderingIntent have the same values (0..3). + rendering_intent = static_cast(rendering_intent32); + + SetColorSpace(ColorSpaceFromProfile(profile)); + if (cmsGetColorSpace(profile.get()) == cmsSigCmykData) { + cmyk_ = true; + return true; + } + + const cmsCIEXYZ wp_unadapted = UnadaptedWhitePoint(context, profile, *this); + JXL_RETURN_IF_ERROR(SetWhitePoint(CIExyFromXYZ(wp_unadapted))); + + // Relies on color_space. + JXL_RETURN_IF_ERROR(IdentifyPrimaries(context, profile, wp_unadapted, this)); + + // Relies on color_space/white point/primaries being set already. + DetectTransferFunction(context, profile, this); + +#endif // JPEGXL_ENABLE_SKCMS + + return true; +} + +void ColorEncoding::DecideIfWantICC() { + PaddedBytes icc_new; + bool equivalent; +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile profile; + if (!DecodeProfile(ICC().data(), ICC().size(), &profile)) return; + if (!MaybeCreateProfile(*this, &icc_new)) return; + equivalent = ProfileEquivalentToICC(profile, icc_new); +#else // JPEGXL_ENABLE_SKCMS + const cmsContext context = GetContext(); + Profile profile; + if (!DecodeProfile(context, ICC(), &profile)) return; + if (cmsGetColorSpace(profile.get()) == cmsSigCmykData) return; + if (!MaybeCreateProfile(*this, &icc_new)) return; + equivalent = ProfileEquivalentToICC(context, profile, icc_new, *this); +#endif // JPEGXL_ENABLE_SKCMS + + // Successfully created a profile => reconstruction should be equivalent. + JXL_ASSERT(equivalent); + want_icc_ = false; +} + +namespace { + +void JxlCmsDestroy(void* cms_data) { + if (cms_data == nullptr) return; + JxlCms* t = reinterpret_cast(cms_data); +#if !JPEGXL_ENABLE_SKCMS + TransformDeleter()(t->lcms_transform); +#endif + delete t; +} + +void* JxlCmsInit(void* init_data, size_t num_threads, size_t xsize, + const JxlColorProfile* input, const JxlColorProfile* output, + float intensity_target) { + auto t = jxl::make_unique(); + PaddedBytes icc_src, icc_dst; + icc_src.assign(input->icc.data, input->icc.data + input->icc.size); + ColorEncoding c_src; + if (!c_src.SetICC(std::move(icc_src))) { + JXL_NOTIFY_ERROR("JxlCmsInit: failed to parse input ICC"); + return nullptr; + } + icc_dst.assign(output->icc.data, output->icc.data + output->icc.size); + ColorEncoding c_dst; + if (!c_dst.SetICC(std::move(icc_dst))) { + JXL_NOTIFY_ERROR("JxlCmsInit: failed to parse output ICC"); + return nullptr; + } +#if JXL_CMS_VERBOSE + printf("%s -> %s\n", Description(c_src).c_str(), Description(c_dst).c_str()); +#endif + +#if JPEGXL_ENABLE_SKCMS + if (!DecodeProfile(input->icc.data, input->icc.size, &t->profile_src)) { + JXL_NOTIFY_ERROR("JxlCmsInit: skcms failed to parse input ICC"); + return nullptr; + } + if (!DecodeProfile(output->icc.data, output->icc.size, &t->profile_dst)) { + JXL_NOTIFY_ERROR("JxlCmsInit: skcms failed to parse output ICC"); + return nullptr; + } +#else // JPEGXL_ENABLE_SKCMS + const cmsContext context = GetContext(); + Profile profile_src, profile_dst; + if (!DecodeProfile(context, c_src.ICC(), &profile_src)) { + JXL_NOTIFY_ERROR("JxlCmsInit: lcms failed to parse input ICC"); + return nullptr; + } + if (!DecodeProfile(context, c_dst.ICC(), &profile_dst)) { + JXL_NOTIFY_ERROR("JxlCmsInit: lcms failed to parse output ICC"); + return nullptr; + } +#endif // JPEGXL_ENABLE_SKCMS + + t->skip_lcms = false; + if (c_src.SameColorEncoding(c_dst)) { + t->skip_lcms = true; +#if JXL_CMS_VERBOSE + printf("Skip CMS\n"); +#endif + } + + t->apply_hlg_ootf = c_src.tf.IsHLG() != c_dst.tf.IsHLG(); + if (t->apply_hlg_ootf) { + const ColorEncoding* c_hlg = c_src.tf.IsHLG() ? &c_src : &c_dst; + t->hlg_ootf_num_channels = c_hlg->Channels(); + if (t->hlg_ootf_num_channels == 3 && + !GetPrimariesLuminances(*c_hlg, t->hlg_ootf_luminances.data())) { + JXL_NOTIFY_ERROR( + "JxlCmsInit: failed to compute the luminances of primaries"); + return nullptr; + } + } + + // Special-case SRGB <=> linear if the primaries / white point are the same, + // or any conversion where PQ or HLG is involved: + bool src_linear = c_src.tf.IsLinear(); + const bool dst_linear = c_dst.tf.IsLinear(); + + if (c_src.tf.IsPQ() || c_src.tf.IsHLG() || + (c_src.tf.IsSRGB() && dst_linear && c_src.SameColorSpace(c_dst))) { + // Construct new profile as if the data were already/still linear. + ColorEncoding c_linear_src = c_src; + c_linear_src.tf.SetTransferFunction(TransferFunction::kLinear); +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile new_src; +#else // JPEGXL_ENABLE_SKCMS + Profile new_src; +#endif // JPEGXL_ENABLE_SKCMS + // Only enable ExtraTF if profile creation succeeded. + if (MaybeCreateProfile(c_linear_src, &icc_src) && +#if JPEGXL_ENABLE_SKCMS + DecodeProfile(icc_src.data(), icc_src.size(), &new_src)) { +#else // JPEGXL_ENABLE_SKCMS + DecodeProfile(context, icc_src, &new_src)) { +#endif // JPEGXL_ENABLE_SKCMS +#if JXL_CMS_VERBOSE + printf("Special HLG/PQ/sRGB -> linear\n"); +#endif +#if JPEGXL_ENABLE_SKCMS + t->icc_src = std::move(icc_src); + t->profile_src = new_src; +#else // JPEGXL_ENABLE_SKCMS + profile_src.swap(new_src); +#endif // JPEGXL_ENABLE_SKCMS + t->preprocess = c_src.tf.IsSRGB() + ? ExtraTF::kSRGB + : (c_src.tf.IsPQ() ? ExtraTF::kPQ : ExtraTF::kHLG); + c_src = c_linear_src; + src_linear = true; + } else { + if (t->apply_hlg_ootf) { + JXL_NOTIFY_ERROR( + "Failed to create extra linear source profile, and HLG OOTF " + "required"); + return nullptr; + } + JXL_WARNING("Failed to create extra linear destination profile"); + } + } + + if (c_dst.tf.IsPQ() || c_dst.tf.IsHLG() || + (c_dst.tf.IsSRGB() && src_linear && c_src.SameColorSpace(c_dst))) { + ColorEncoding c_linear_dst = c_dst; + c_linear_dst.tf.SetTransferFunction(TransferFunction::kLinear); +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile new_dst; +#else // JPEGXL_ENABLE_SKCMS + Profile new_dst; +#endif // JPEGXL_ENABLE_SKCMS + // Only enable ExtraTF if profile creation succeeded. + if (MaybeCreateProfile(c_linear_dst, &icc_dst) && +#if JPEGXL_ENABLE_SKCMS + DecodeProfile(icc_dst.data(), icc_dst.size(), &new_dst)) { +#else // JPEGXL_ENABLE_SKCMS + DecodeProfile(context, icc_dst, &new_dst)) { +#endif // JPEGXL_ENABLE_SKCMS +#if JXL_CMS_VERBOSE + printf("Special linear -> HLG/PQ/sRGB\n"); +#endif +#if JPEGXL_ENABLE_SKCMS + t->icc_dst = std::move(icc_dst); + t->profile_dst = new_dst; +#else // JPEGXL_ENABLE_SKCMS + profile_dst.swap(new_dst); +#endif // JPEGXL_ENABLE_SKCMS + t->postprocess = c_dst.tf.IsSRGB() + ? ExtraTF::kSRGB + : (c_dst.tf.IsPQ() ? ExtraTF::kPQ : ExtraTF::kHLG); + c_dst = c_linear_dst; + } else { + if (t->apply_hlg_ootf) { + JXL_NOTIFY_ERROR( + "Failed to create extra linear destination profile, and inverse " + "HLG OOTF required"); + return nullptr; + } + JXL_WARNING("Failed to create extra linear destination profile"); + } + } + + if (c_src.SameColorEncoding(c_dst)) { +#if JXL_CMS_VERBOSE + printf("Same intermediary linear profiles, skipping CMS\n"); +#endif + t->skip_lcms = true; + } + +#if JPEGXL_ENABLE_SKCMS + if (!skcms_MakeUsableAsDestination(&t->profile_dst)) { + JXL_NOTIFY_ERROR( + "Failed to make %s usable as a color transform destination", + Description(c_dst).c_str()); + return nullptr; + } +#endif // JPEGXL_ENABLE_SKCMS + + // Not including alpha channel (copied separately). + const size_t channels_src = (c_src.IsCMYK() ? 4 : c_src.Channels()); + const size_t channels_dst = c_dst.Channels(); + JXL_CHECK(channels_src == channels_dst || + (channels_src == 4 && channels_dst == 3)); +#if JXL_CMS_VERBOSE + printf("Channels: %" PRIuS "; Threads: %" PRIuS "\n", channels_src, + num_threads); +#endif + +#if !JPEGXL_ENABLE_SKCMS + // Type includes color space (XYZ vs RGB), so can be different. + const uint32_t type_src = Type32(c_src, channels_src == 4); + const uint32_t type_dst = Type32(c_dst, false); + const uint32_t intent = static_cast(c_dst.rendering_intent); + // Use cmsFLAGS_NOCACHE to disable the 1-pixel cache and make calling + // cmsDoTransform() thread-safe. + const uint32_t flags = cmsFLAGS_NOCACHE | cmsFLAGS_BLACKPOINTCOMPENSATION | + cmsFLAGS_HIGHRESPRECALC; + t->lcms_transform = + cmsCreateTransformTHR(context, profile_src.get(), type_src, + profile_dst.get(), type_dst, intent, flags); + if (t->lcms_transform == nullptr) { + JXL_NOTIFY_ERROR("Failed to create transform"); + return nullptr; + } +#endif // !JPEGXL_ENABLE_SKCMS + + // Ideally LCMS would convert directly from External to Image3. However, + // cmsDoTransformLineStride only accepts 32-bit BytesPerPlaneIn, whereas our + // planes can be more than 4 GiB apart. Hence, transform inputs/outputs must + // be interleaved. Calling cmsDoTransform for each pixel is expensive + // (indirect call). We therefore transform rows, which requires per-thread + // buffers. To avoid separate allocations, we use the rows of an image. + // Because LCMS apparently also cannot handle <= 16 bit inputs and 32-bit + // outputs (or vice versa), we use floating point input/output. + t->channels_src = channels_src; + t->channels_dst = channels_dst; +#if JPEGXL_ENABLE_SKCMS + // SkiaCMS doesn't support grayscale float buffers, so we create space for RGB + // float buffers anyway. + t->buf_src = ImageF(xsize * (channels_src == 4 ? 4 : 3), num_threads); + t->buf_dst = ImageF(xsize * 3, num_threads); +#else + t->buf_src = ImageF(xsize * channels_src, num_threads); + t->buf_dst = ImageF(xsize * channels_dst, num_threads); +#endif + t->intensity_target = intensity_target; + return t.release(); +} + +float* JxlCmsGetSrcBuf(void* cms_data, size_t thread) { + JxlCms* t = reinterpret_cast(cms_data); + return t->buf_src.Row(thread); +} + +float* JxlCmsGetDstBuf(void* cms_data, size_t thread) { + JxlCms* t = reinterpret_cast(cms_data); + return t->buf_dst.Row(thread); +} + +} // namespace + +const JxlCmsInterface& GetJxlCms() { + static constexpr JxlCmsInterface kInterface = { + /*init_data=*/nullptr, + /*init=*/&JxlCmsInit, + /*get_src_buf=*/&JxlCmsGetSrcBuf, + /*get_dst_buf=*/&JxlCmsGetDstBuf, + /*run=*/&DoColorSpaceTransform, + /*destroy=*/&JxlCmsDestroy}; + return kInterface; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_color_management.h b/media/libjxl/src/lib/jxl/enc_color_management.h new file mode 100644 index 000000000..0d701d74f --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_color_management.h @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_COLOR_MANAGEMENT_H_ +#define LIB_JXL_ENC_COLOR_MANAGEMENT_H_ + +// ICC profiles and color space conversions. + +#include +#include + +#include + +#include "jxl/cms_interface.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Internal C++ wrapper for a JxlCmsInterface. +class ColorSpaceTransform { + public: + explicit ColorSpaceTransform(const JxlCmsInterface& cms) : cms_(cms) {} + ~ColorSpaceTransform() { + if (cms_data_ != nullptr) { + cms_.destroy(cms_data_); + } + } + + // Cannot copy. + ColorSpaceTransform(const ColorSpaceTransform&) = delete; + ColorSpaceTransform& operator=(const ColorSpaceTransform&) = delete; + + Status Init(const ColorEncoding& c_src, const ColorEncoding& c_dst, + float intensity_target, size_t xsize, size_t num_threads) { + xsize_ = xsize; + JxlColorProfile input_profile; + icc_src_ = c_src.ICC(); + input_profile.icc.data = icc_src_.data(); + input_profile.icc.size = icc_src_.size(); + ConvertInternalToExternalColorEncoding(c_src, + &input_profile.color_encoding); + input_profile.num_channels = c_src.IsCMYK() ? 4 : c_src.Channels(); + JxlColorProfile output_profile; + icc_dst_ = c_dst.ICC(); + output_profile.icc.data = icc_dst_.data(); + output_profile.icc.size = icc_dst_.size(); + ConvertInternalToExternalColorEncoding(c_dst, + &output_profile.color_encoding); + if (c_dst.IsCMYK()) + return JXL_FAILURE("Conversion to CMYK is not supported"); + output_profile.num_channels = c_dst.Channels(); + cms_data_ = cms_.init(cms_.init_data, num_threads, xsize, &input_profile, + &output_profile, intensity_target); + JXL_RETURN_IF_ERROR(cms_data_ != nullptr); + return true; + } + + float* BufSrc(const size_t thread) const { + return cms_.get_src_buf(cms_data_, thread); + } + + float* BufDst(const size_t thread) const { + return cms_.get_dst_buf(cms_data_, thread); + } + + Status Run(const size_t thread, const float* buf_src, float* buf_dst) { + return cms_.run(cms_data_, thread, buf_src, buf_dst, xsize_); + } + + private: + JxlCmsInterface cms_; + void* cms_data_ = nullptr; + // The interface may retain pointers into these. + PaddedBytes icc_src_; + PaddedBytes icc_dst_; + size_t xsize_; +}; + +const JxlCmsInterface& GetJxlCms(); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COLOR_MANAGEMENT_H_ diff --git a/media/libjxl/src/lib/jxl/enc_comparator.cc b/media/libjxl/src/lib/jxl/enc_comparator.cc new file mode 100644 index 000000000..a2d170d87 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_comparator.cc @@ -0,0 +1,142 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_comparator.h" + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_gamma_correct.h" +#include "lib/jxl/enc_image_bundle.h" + +namespace jxl { +namespace { + +// color is linear, but blending happens in gamma-compressed space using +// (gamma-compressed) grayscale background color, alpha image represents +// weights of the sRGB colors in the [0 .. (1 << bit_depth) - 1] interval, +// output image is in linear space. +void AlphaBlend(const Image3F& in, const size_t c, float background_linear, + const ImageF& alpha, Image3F* out) { + const float background = LinearToSrgb8Direct(background_linear); + + for (size_t y = 0; y < out->ysize(); ++y) { + const float* JXL_RESTRICT row_a = alpha.ConstRow(y); + const float* JXL_RESTRICT row_i = in.ConstPlaneRow(c, y); + float* JXL_RESTRICT row_o = out->PlaneRow(c, y); + for (size_t x = 0; x < out->xsize(); ++x) { + const float a = row_a[x]; + if (a <= 0.f) { + row_o[x] = background_linear; + } else if (a >= 1.f) { + row_o[x] = row_i[x]; + } else { + const float w_fg = a; + const float w_bg = 1.0f - w_fg; + const float fg = w_fg * LinearToSrgb8Direct(row_i[x]); + const float bg = w_bg * background; + row_o[x] = Srgb8ToLinearDirect(fg + bg); + } + } + } +} + +const Image3F* AlphaBlend(const ImageBundle& ib, const Image3F& linear, + float background_linear, Image3F* copy) { + // No alpha => all opaque. + if (!ib.HasAlpha()) return &linear; + + *copy = Image3F(linear.xsize(), linear.ysize()); + for (size_t c = 0; c < 3; ++c) { + AlphaBlend(linear, c, background_linear, ib.alpha(), copy); + } + return copy; +} + +void AlphaBlend(float background_linear, ImageBundle* io_linear_srgb) { + // No alpha => all opaque. + if (!io_linear_srgb->HasAlpha()) return; + + for (size_t c = 0; c < 3; ++c) { + AlphaBlend(*io_linear_srgb->color(), c, background_linear, + *io_linear_srgb->alpha(), io_linear_srgb->color()); + } +} + +float ComputeScoreImpl(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, ImageF* distmap) { + JXL_CHECK(comparator->SetReferenceImage(rgb0)); + float score; + JXL_CHECK(comparator->CompareWith(rgb1, distmap, &score)); + return score; +} + +} // namespace + +float ComputeScore(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, const JxlCmsInterface& cms, + ImageF* diffmap, ThreadPool* pool) { + PROFILER_FUNC; + // Convert to linear sRGB (unless already in that space) + ImageMetadata metadata0 = *rgb0.metadata(); + ImageBundle store0(&metadata0); + const ImageBundle* linear_srgb0; + JXL_CHECK(TransformIfNeeded(rgb0, ColorEncoding::LinearSRGB(rgb0.IsGray()), + cms, pool, &store0, &linear_srgb0)); + ImageMetadata metadata1 = *rgb1.metadata(); + ImageBundle store1(&metadata1); + const ImageBundle* linear_srgb1; + JXL_CHECK(TransformIfNeeded(rgb1, ColorEncoding::LinearSRGB(rgb1.IsGray()), + cms, pool, &store1, &linear_srgb1)); + + // No alpha: skip blending, only need a single call to Butteraugli. + if (!rgb0.HasAlpha() && !rgb1.HasAlpha()) { + return ComputeScoreImpl(*linear_srgb0, *linear_srgb1, comparator, diffmap); + } + + // Blend on black and white backgrounds + + const float black = 0.0f; + ImageBundle blended_black0 = linear_srgb0->Copy(); + ImageBundle blended_black1 = linear_srgb1->Copy(); + AlphaBlend(black, &blended_black0); + AlphaBlend(black, &blended_black1); + + const float white = 1.0f; + ImageBundle blended_white0 = linear_srgb0->Copy(); + ImageBundle blended_white1 = linear_srgb1->Copy(); + + AlphaBlend(white, &blended_white0); + AlphaBlend(white, &blended_white1); + + ImageF diffmap_black, diffmap_white; + const float dist_black = ComputeScoreImpl(blended_black0, blended_black1, + comparator, &diffmap_black); + const float dist_white = ComputeScoreImpl(blended_white0, blended_white1, + comparator, &diffmap_white); + + // diffmap and return values are the max of diffmap_black/white. + if (diffmap != nullptr) { + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + *diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const float* JXL_RESTRICT row_black = diffmap_black.ConstRow(y); + const float* JXL_RESTRICT row_white = diffmap_white.ConstRow(y); + float* JXL_RESTRICT row_out = diffmap->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = std::max(row_black[x], row_white[x]); + } + } + } + return std::max(dist_black, dist_white); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_comparator.h b/media/libjxl/src/lib/jxl/enc_comparator.h new file mode 100644 index 000000000..0ac4df829 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_comparator.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_COMPARATOR_H_ +#define LIB_JXL_ENC_COMPARATOR_H_ + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class Comparator { + public: + virtual ~Comparator() = default; + + // Sets the reference image, the first to compare + // Image must be in linear sRGB (gamma expanded) in range 0.0f-1.0f as + // the range from standard black point to standard white point, but values + // outside permitted. + virtual Status SetReferenceImage(const ImageBundle& ref) = 0; + + // Sets the actual image (with loss), the second to compare + // Image must be in linear sRGB (gamma expanded) in range 0.0f-1.0f as + // the range from standard black point to standard white point, but values + // outside permitted. + // In diffmap it outputs the local score per pixel, while in score it outputs + // a single score. Any one may be set to nullptr to not compute it. + virtual Status CompareWith(const ImageBundle& actual, ImageF* diffmap, + float* score) = 0; + + // Quality thresholds for diffmap and score values. + // The good score must represent a value where the images are considered to + // be perceptually indistinguishable (but not identical) + // The bad value must be larger than good to indicate "lower means better" + // and smaller than good to indicate "higher means better" + virtual float GoodQualityScore() const = 0; + virtual float BadQualityScore() const = 0; +}; + +// Computes the score given images in any RGB color model, optionally with +// alpha channel. +float ComputeScore(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, const JxlCmsInterface& cms, + ImageF* diffmap = nullptr, ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COMPARATOR_H_ diff --git a/media/libjxl/src/lib/jxl/enc_context_map.cc b/media/libjxl/src/lib/jxl/enc_context_map.cc new file mode 100644 index 000000000..82e5e6181 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_context_map.cc @@ -0,0 +1,140 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Library to encode the context map. + +#include "lib/jxl/enc_context_map.h" + +#include + +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" + +namespace jxl { + +namespace { + +size_t IndexOf(const std::vector& v, uint8_t value) { + size_t i = 0; + for (; i < v.size(); ++i) { + if (v[i] == value) return i; + } + return i; +} + +void MoveToFront(std::vector* v, size_t index) { + uint8_t value = (*v)[index]; + for (size_t i = index; i != 0; --i) { + (*v)[i] = (*v)[i - 1]; + } + (*v)[0] = value; +} + +std::vector MoveToFrontTransform(const std::vector& v) { + if (v.empty()) return v; + uint8_t max_value = *std::max_element(v.begin(), v.end()); + std::vector mtf(max_value + 1); + for (size_t i = 0; i <= max_value; ++i) mtf[i] = i; + std::vector result(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + size_t index = IndexOf(mtf, v[i]); + JXL_ASSERT(index < mtf.size()); + result[i] = static_cast(index); + MoveToFront(&mtf, index); + } + return result; +} + +} // namespace + +void EncodeContextMap(const std::vector& context_map, + size_t num_histograms, BitWriter* writer, size_t layer, + AuxOut* aux_out) { + if (num_histograms == 1) { + // Simple code + writer->Write(1, 1); + // 0 bits per entry. + writer->Write(2, 0); + return; + } + + std::vector transformed_symbols = MoveToFrontTransform(context_map); + std::vector> tokens(1), mtf_tokens(1); + EntropyEncodingData codes; + std::vector dummy_context_map; + for (size_t i = 0; i < context_map.size(); i++) { + tokens[0].emplace_back(0, context_map[i]); + } + for (size_t i = 0; i < transformed_symbols.size(); i++) { + mtf_tokens[0].emplace_back(0, transformed_symbols[i]); + } + HistogramParams params; + params.uint_method = HistogramParams::HybridUintMethod::kContextMap; + size_t ans_cost = BuildAndEncodeHistograms( + params, 1, tokens, &codes, &dummy_context_map, nullptr, 0, nullptr); + size_t mtf_cost = BuildAndEncodeHistograms( + params, 1, mtf_tokens, &codes, &dummy_context_map, nullptr, 0, nullptr); + bool use_mtf = mtf_cost < ans_cost; + // Rebuild token list. + tokens[0].clear(); + for (size_t i = 0; i < transformed_symbols.size(); i++) { + tokens[0].emplace_back(0, + use_mtf ? transformed_symbols[i] : context_map[i]); + } + size_t entry_bits = CeilLog2Nonzero(num_histograms); + size_t simple_cost = entry_bits * context_map.size(); + if (entry_bits < 4 && simple_cost < ans_cost && simple_cost < mtf_cost) { + writer->Write(1, 1); + writer->Write(2, entry_bits); + for (size_t i = 0; i < context_map.size(); i++) { + writer->Write(entry_bits, context_map[i]); + } + } else { + writer->Write(1, 0); + writer->Write(1, use_mtf); // Use/don't use MTF. + BuildAndEncodeHistograms(params, 1, tokens, &codes, &dummy_context_map, + writer, layer, aux_out); + WriteTokens(tokens[0], codes, dummy_context_map, writer); + } +} + +void EncodeBlockCtxMap(const BlockCtxMap& block_ctx_map, BitWriter* writer, + AuxOut* aux_out) { + auto& dct = block_ctx_map.dc_thresholds; + auto& qft = block_ctx_map.qf_thresholds; + auto& ctx_map = block_ctx_map.ctx_map; + BitWriter::Allotment allotment( + writer, + (dct[0].size() + dct[1].size() + dct[2].size() + qft.size()) * 34 + 1 + + 4 + 4 + ctx_map.size() * 10 + 1024); + if (dct[0].empty() && dct[1].empty() && dct[2].empty() && qft.empty() && + ctx_map.size() == 21 && + std::equal(ctx_map.begin(), ctx_map.end(), BlockCtxMap::kDefaultCtxMap)) { + writer->Write(1, 1); // default + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out); + return; + } + writer->Write(1, 0); + for (int j : {0, 1, 2}) { + writer->Write(4, dct[j].size()); + for (int i : dct[j]) { + JXL_CHECK(U32Coder::Write(kDCThresholdDist, PackSigned(i), writer)); + } + } + writer->Write(4, qft.size()); + for (uint32_t i : qft) { + JXL_CHECK(U32Coder::Write(kQFThresholdDist, i - 1, writer)); + } + EncodeContextMap(ctx_map, block_ctx_map.num_ctxs, writer, kLayerAC, aux_out); + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_context_map.h b/media/libjxl/src/lib/jxl/enc_context_map.h new file mode 100644 index 000000000..57e79a173 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_context_map.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_CONTEXT_MAP_H_ +#define LIB_JXL_ENC_CONTEXT_MAP_H_ + +#include +#include + +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Max limit is 255 because encoding assumes numbers < 255 +// More clusters can help compression, but makes encode/decode somewhat slower +static const size_t kClustersLimit = 128; + +// Encodes the given context map to the bit stream. The number of different +// histogram ids is given by num_histograms. +void EncodeContextMap(const std::vector& context_map, + size_t num_histograms, BitWriter* writer, size_t layer, + AuxOut* aux_out); + +void EncodeBlockCtxMap(const BlockCtxMap& block_ctx_map, BitWriter* writer, + AuxOut* aux_out); +} // namespace jxl + +#endif // LIB_JXL_ENC_CONTEXT_MAP_H_ diff --git a/media/libjxl/src/lib/jxl/enc_detect_dots.cc b/media/libjxl/src/lib/jxl/enc_detect_dots.cc new file mode 100644 index 000000000..f7021d6cc --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_detect_dots.cc @@ -0,0 +1,627 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_detect_dots.h" + +#include + +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_detect_dots.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/optimize.h" + +// Set JXL_DEBUG_DOT_DETECT to 1 to enable debugging. +#ifndef JXL_DEBUG_DOT_DETECT +#define JXL_DEBUG_DOT_DETECT 0 +#endif + +#if JXL_DEBUG_DOT_DETECT +#include "lib/jxl/aux_out.h" +#endif + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Sub; + +ImageF SumOfSquareDifferences(const Image3F& forig, const Image3F& smooth, + ThreadPool* pool) { + const HWY_FULL(float) d; + const auto color_coef0 = Set(d, 0.0f); + const auto color_coef1 = Set(d, 10.0f); + const auto color_coef2 = Set(d, 0.0f); + + ImageF sum_of_squares(forig.xsize(), forig.ysize()); + JXL_CHECK(RunOnPool( + pool, 0, forig.ysize(), ThreadPool::NoInit, + [&](const uint32_t task, size_t thread) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT orig_row0 = forig.Plane(0).ConstRow(y); + const float* JXL_RESTRICT orig_row1 = forig.Plane(1).ConstRow(y); + const float* JXL_RESTRICT orig_row2 = forig.Plane(2).ConstRow(y); + const float* JXL_RESTRICT smooth_row0 = smooth.Plane(0).ConstRow(y); + const float* JXL_RESTRICT smooth_row1 = smooth.Plane(1).ConstRow(y); + const float* JXL_RESTRICT smooth_row2 = smooth.Plane(2).ConstRow(y); + float* JXL_RESTRICT sos_row = sum_of_squares.Row(y); + + for (size_t x = 0; x < forig.xsize(); x += Lanes(d)) { + auto v0 = Sub(Load(d, orig_row0 + x), Load(d, smooth_row0 + x)); + auto v1 = Sub(Load(d, orig_row1 + x), Load(d, smooth_row1 + x)); + auto v2 = Sub(Load(d, orig_row2 + x), Load(d, smooth_row2 + x)); + v0 = Mul(Mul(v0, v0), color_coef0); + v1 = Mul(Mul(v1, v1), color_coef1); + v2 = Mul(Mul(v2, v2), color_coef2); + const auto sos = + Add(v0, Add(v1, v2)); // weighted sum of square diffs + Store(sos, d, sos_row + x); + } + }, + "ComputeEnergyImage")); + return sum_of_squares; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(SumOfSquareDifferences); // Local function + +const int kEllipseWindowSize = 5; + +namespace { +struct GaussianEllipse { + double x; // position in x + double y; // position in y + double sigma_x; // scale in x + double sigma_y; // scale in y + double angle; // ellipse rotation in radians + std::array intensity; // intensity in each channel + + // The following variables do not need to be encoded + double l2_loss; // error after the Gaussian was fit + double l1_loss; + double ridge_loss; // the l2_loss plus regularization term + double custom_loss; // experimental custom loss + std::array bgColor; // best background color + size_t neg_pixels; // number of negative pixels when subtracting dot + std::array neg_value; // debt due to channel truncation +}; +double DotGaussianModel(double dx, double dy, double ct, double st, + double sigma_x, double sigma_y, double intensity) { + double rx = ct * dx + st * dy; + double ry = -st * dx + ct * dy; + double md = (rx * rx / sigma_x) + (ry * ry / sigma_y); + double value = intensity * exp(-0.5 * md); + return value; +} + +constexpr bool kOptimizeBackground = true; + +// Gaussian that smooths noise but preserves dots +const WeightsSeparable5& WeightsSeparable5Gaussian0_65() { + constexpr float w0 = 0.558311f; + constexpr float w1 = 0.210395f; + constexpr float w2 = 0.010449f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +// (Iterated) Gaussian that removes dots. +const WeightsSeparable5& WeightsSeparable5Gaussian3() { + constexpr float w0 = 0.222338f; + constexpr float w1 = 0.210431f; + constexpr float w2 = 0.1784f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +ImageF ComputeEnergyImage(const Image3F& orig, Image3F* smooth, + ThreadPool* pool) { + PROFILER_FUNC; + + // Prepare guidance images for dot selection. + Image3F forig(orig.xsize(), orig.ysize()); + *smooth = Image3F(orig.xsize(), orig.ysize()); + Rect rect(orig); + + const auto& weights1 = WeightsSeparable5Gaussian0_65(); + const auto& weights3 = WeightsSeparable5Gaussian3(); + + for (size_t c = 0; c < 3; ++c) { + // Use forig as temporary storage to reduce memory and keep it warmer. + Separable5(orig.Plane(c), rect, weights3, pool, &forig.Plane(c)); + Separable5(forig.Plane(c), rect, weights3, pool, &smooth->Plane(c)); + Separable5(orig.Plane(c), rect, weights1, pool, &forig.Plane(c)); + } + +#if JXL_DEBUG_DOT_DETECT + AuxOut aux; + aux.debug_prefix = "/tmp/sebastian/"; + aux.DumpImage("filtered", forig); + aux.DumpImage("sm", *smooth); +#endif + + return HWY_DYNAMIC_DISPATCH(SumOfSquareDifferences)(forig, *smooth, pool); +} + +struct Pixel { + int x; + int y; +}; + +Pixel operator+(const Pixel& a, const Pixel& b) { + return Pixel{a.x + b.x, a.y + b.y}; +} + +// Maximum area in pixels of a ellipse +const size_t kMaxCCSize = 1000; + +// Extracts a connected component from a Binary image where seed is part +// of the component +bool ExtractComponent(ImageF* img, std::vector* pixels, + const Pixel& seed, double threshold) { + PROFILER_FUNC; + static const std::vector neighbors{{1, -1}, {1, 0}, {1, 1}, {0, -1}, + {0, 1}, {-1, -1}, {-1, 1}, {1, 0}}; + std::vector q{seed}; + while (!q.empty()) { + Pixel current = q.back(); + q.pop_back(); + pixels->push_back(current); + if (pixels->size() > kMaxCCSize) return false; + for (const Pixel& delta : neighbors) { + Pixel child = current + delta; + if (child.x >= 0 && static_cast(child.x) < img->xsize() && + child.y >= 0 && static_cast(child.y) < img->ysize()) { + float* value = &img->Row(child.y)[child.x]; + if (*value > threshold) { + *value = 0.0; + q.push_back(child); + } + } + } + } + return true; +} + +inline bool PointInRect(const Rect& r, const Pixel& p) { + return (static_cast(p.x) >= r.x0() && + static_cast(p.x) < (r.x0() + r.xsize()) && + static_cast(p.y) >= r.y0() && + static_cast(p.y) < (r.y0() + r.ysize())); +} + +struct ConnectedComponent { + ConnectedComponent(const Rect& bounds, const std::vector&& pixels) + : bounds(bounds), pixels(pixels) {} + Rect bounds; + std::vector pixels; + float maxEnergy; + float meanEnergy; + float varEnergy; + float meanBg; + float varBg; + float score; + Pixel mode; + + void CompStats(const ImageF& energy, int extra) { + PROFILER_FUNC; + maxEnergy = 0.0; + meanEnergy = 0.0; + varEnergy = 0.0; + meanBg = 0.0; + varBg = 0.0; + int nIn = 0; + int nOut = 0; + mode.x = 0; + mode.y = 0; + for (int sy = -extra; sy < (static_cast(bounds.ysize()) + extra); + sy++) { + int y = sy + static_cast(bounds.y0()); + if (y < 0 || static_cast(y) >= energy.ysize()) continue; + const float* JXL_RESTRICT erow = energy.ConstRow(y); + for (int sx = -extra; sx < (static_cast(bounds.xsize()) + extra); + sx++) { + int x = sx + static_cast(bounds.x0()); + if (x < 0 || static_cast(x) >= energy.xsize()) continue; + if (erow[x] > maxEnergy) { + maxEnergy = erow[x]; + mode.x = x; + mode.y = y; + } + if (PointInRect(bounds, Pixel{x, y})) { + meanEnergy += erow[x]; + varEnergy += erow[x] * erow[x]; + nIn++; + } else { + meanBg += erow[x]; + varBg += erow[x] * erow[x]; + nOut++; + } + } + } + meanEnergy = meanEnergy / nIn; + meanBg = meanBg / nOut; + varEnergy = (varEnergy / nIn) - meanEnergy * meanEnergy; + varBg = (varBg / nOut) - meanBg * meanBg; + score = (meanEnergy - meanBg) / std::sqrt(varBg); + } +}; + +Rect BoundingRectangle(const std::vector& pixels) { + PROFILER_FUNC; + JXL_ASSERT(!pixels.empty()); + int low_x, high_x, low_y, high_y; + low_x = high_x = pixels[0].x; + low_y = high_y = pixels[0].y; + for (const Pixel& p : pixels) { + low_x = std::min(low_x, p.x); + high_x = std::max(high_x, p.x); + low_y = std::min(low_y, p.y); + high_y = std::max(high_y, p.y); + } + return Rect(low_x, low_y, high_x - low_x + 1, high_y - low_y + 1); +} + +std::vector FindCC(const ImageF& energy, double t_low, + double t_high, uint32_t maxWindow, + double minScore) { + PROFILER_FUNC; + const int kExtraRect = 4; + ImageF img = CopyImage(energy); + std::vector ans; + for (size_t y = 0; y < img.ysize(); y++) { + float* JXL_RESTRICT row = img.Row(y); + for (size_t x = 0; x < img.xsize(); x++) { + if (row[x] > t_high) { + std::vector pixels; + row[x] = 0.0; + bool success = ExtractComponent( + &img, &pixels, Pixel{static_cast(x), static_cast(y)}, + t_low); + if (!success) continue; +#if JXL_DEBUG_DOT_DETECT + for (size_t i = 0; i < pixels.size(); i++) { + fprintf(stderr, "(%d,%d) ", pixels[i].x, pixels[i].y); + } + fprintf(stderr, "\n"); +#endif // JXL_DEBUG_DOT_DETECT + Rect bounds = BoundingRectangle(pixels); + if (bounds.xsize() < maxWindow && bounds.ysize() < maxWindow) { + ConnectedComponent cc{bounds, std::move(pixels)}; + cc.CompStats(energy, kExtraRect); + if (cc.score < minScore) continue; + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "cc mode: (%d,%d), max: %f, bgMean: %f bgVar: " + "%f bound:(%" PRIuS ",%" PRIuS ",%" PRIuS ",%" PRIuS ")\n", + cc.mode.x, cc.mode.y, cc.maxEnergy, cc.meanEnergy, + cc.varEnergy, cc.bounds.x0(), cc.bounds.y0(), + cc.bounds.xsize(), cc.bounds.ysize()); + ans.push_back(cc); + } + } + } + } + return ans; +} + +// TODO (sggonzalez): Adapt this function for the different color spaces or +// remove it if the color space with the best performance does not need it +void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc, + const Image3F& img, const Image3F& background) { + PROFILER_FUNC; + const int rectBounds = 2; + const double kIntensityR = 0.0; // 0.015; + const double kSigmaR = 0.0; // 0.01; + const double kZeroEpsilon = 0.1; // Tolerance to consider a value negative + double ct = cos(ellipse->angle), st = sin(ellipse->angle); + const std::array channelGains{{1.0, 1.0, 1.0}}; + int N = 0; + ellipse->l1_loss = 0.0; + ellipse->l2_loss = 0.0; + ellipse->neg_pixels = 0; + ellipse->neg_value.fill(0.0); + double distMeanModeSq = (cc.mode.x - ellipse->x) * (cc.mode.x - ellipse->x) + + (cc.mode.y - ellipse->y) * (cc.mode.y - ellipse->y); + ellipse->custom_loss = 0.0; + for (int c = 0; c < 3; c++) { + for (int sy = -rectBounds; + sy < (static_cast(cc.bounds.ysize()) + rectBounds); sy++) { + int y = sy + cc.bounds.y0(); + if (y < 0 || static_cast(y) >= img.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y); + // bgrow is only used if kOptimizeBackground is false. + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y); + for (int sx = -rectBounds; + sx < (static_cast(cc.bounds.xsize()) + rectBounds); sx++) { + int x = sx + cc.bounds.x0(); + if (x < 0 || static_cast(x) >= img.xsize()) continue; + double target = row[x]; + double dotDelta = DotGaussianModel( + x - ellipse->x, y - ellipse->y, ct, st, ellipse->sigma_x, + ellipse->sigma_y, ellipse->intensity[c]); + if (dotDelta > target + kZeroEpsilon) { + ellipse->neg_pixels++; + ellipse->neg_value[c] += dotDelta - target; + } + double bkg = kOptimizeBackground ? ellipse->bgColor[c] : bgrow[x]; + double pred = bkg + dotDelta; + double diff = target - pred; + double l2 = channelGains[c] * diff * diff; + double l1 = channelGains[c] * std::fabs(diff); + ellipse->l2_loss += l2; + ellipse->l1_loss += l1; + double w = DotGaussianModel(x - cc.mode.x, y - cc.mode.y, 1.0, 0.0, + 1.0 + ellipse->sigma_x, + 1.0 + ellipse->sigma_y, 1.0); + ellipse->custom_loss += w * l2; + N++; + } + } + } + ellipse->l2_loss /= N; + ellipse->custom_loss /= N; + ellipse->custom_loss += 20.0 * distMeanModeSq + ellipse->neg_value[1]; + ellipse->l1_loss /= N; + double ridgeTerm = kSigmaR * ellipse->sigma_x + kSigmaR * ellipse->sigma_y; + for (int c = 0; c < 3; c++) { + ridgeTerm += kIntensityR * ellipse->intensity[c] * ellipse->intensity[c]; + } + ellipse->ridge_loss = ellipse->l2_loss + ridgeTerm; +} + +GaussianEllipse FitGaussianFast(const ConnectedComponent& cc, + const ImageF& energy, const Image3F& img, + const Image3F& background) { + PROFILER_FUNC; + constexpr bool leastSqIntensity = true; + constexpr double kEpsilon = 1e-6; + GaussianEllipse ans; + constexpr int kRectBounds = (kEllipseWindowSize >> 1); + + // Compute the 1st and 2nd moments of the CC + double sum = 0.0; + int N = 0; + std::array m1{{0.0, 0.0, 0.0}}; + std::array m2{{0.0, 0.0, 0.0}}; + std::array color{{0.0, 0.0, 0.0}}; + std::array bgColor{{0.0, 0.0, 0.0}}; + + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "%" PRIuS " %" PRIuS " %" PRIuS " %" PRIuS "\n", cc.bounds.x0(), + cc.bounds.y0(), cc.bounds.xsize(), cc.bounds.ysize()); + for (int c = 0; c < 3; c++) { + color[c] = img.ConstPlaneRow(c, cc.mode.y)[cc.mode.x] - + background.ConstPlaneRow(c, cc.mode.y)[cc.mode.x]; + } + double sign = (color[1] > 0) ? 1 : -1; + for (int sy = -kRectBounds; sy <= kRectBounds; sy++) { + int y = sy + cc.mode.y; + if (y < 0 || static_cast(y) >= energy.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(1, y); + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(1, y); + for (int sx = -kRectBounds; sx <= kRectBounds; sx++) { + int x = sx + cc.mode.x; + if (x < 0 || static_cast(x) >= energy.xsize()) continue; + double w = std::max(kEpsilon, sign * (row[x] - bgrow[x])); + sum += w; + + m1[0] += w * x; + m1[1] += w * y; + m2[0] += w * x * x; + m2[1] += w * x * y; + m2[2] += w * y * y; + for (int c = 0; c < 3; c++) { + bgColor[c] += background.ConstPlaneRow(c, y)[x]; + } + N++; + } + } + JXL_CHECK(N > 0); + + for (int i = 0; i < 3; i++) { + m1[i] /= sum; + m2[i] /= sum; + bgColor[i] /= N; + } + + // Some magic constants + constexpr double kSigmaMult = 1.0; + constexpr std::array kScaleMult{{1.1, 1.1, 1.1}}; + + // Now set the parameters of the Gaussian + ans.x = m1[0]; + ans.y = m1[1]; + for (int j = 0; j < 3; j++) { + ans.intensity[j] = kScaleMult[j] * color[j]; + } + + ImageD Sigma(2, 2), D(1, 2), U(2, 2); + Sigma.Row(0)[0] = m2[0] - m1[0] * m1[0]; + Sigma.Row(1)[1] = m2[2] - m1[1] * m1[1]; + Sigma.Row(0)[1] = Sigma.Row(1)[0] = m2[1] - m1[0] * m1[1]; + ConvertToDiagonal(Sigma, &D, &U); + const double* JXL_RESTRICT d = D.ConstRow(0); + const double* JXL_RESTRICT u = U.ConstRow(1); + int p1 = 0, p2 = 1; + if (d[0] < d[1]) std::swap(p1, p2); + ans.sigma_x = kSigmaMult * d[p1]; + ans.sigma_y = kSigmaMult * d[p2]; + ans.angle = std::atan2(u[p1], u[p2]); + ans.l2_loss = 0.0; + ans.bgColor = bgColor; + if (leastSqIntensity) { + GaussianEllipse* ellipse = &ans; + double ct = cos(ans.angle), st = sin(ans.angle); + // Estimate intensity with least squares (fixed background) + for (int c = 0; c < 3; c++) { + double gg = 0.0; + double gd = 0.0; + int yc = static_cast(cc.mode.y); + int xc = static_cast(cc.mode.x); + for (int y = yc - kRectBounds; y <= yc + kRectBounds; y++) { + if (y < 0 || static_cast(y) >= img.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y); + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y); + for (int x = xc - kRectBounds; x <= xc + kRectBounds; x++) { + if (x < 0 || static_cast(x) >= img.xsize()) continue; + double target = row[x] - bgrow[x]; + double gaussian = + DotGaussianModel(x - ellipse->x, y - ellipse->y, ct, st, + ellipse->sigma_x, ellipse->sigma_y, 1.0); + gg += gaussian * gaussian; + gd += gaussian * target; + } + } + ans.intensity[c] = gd / (gg + 1e-6); // Regularized least squares + } + } + ComputeDotLosses(&ans, cc, img, background); + return ans; +} + +GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy, + const Image3F& img, const Image3F& background) { + auto ellipse = FitGaussianFast(cc, energy, img, background); + if (ellipse.sigma_x < ellipse.sigma_y) { + std::swap(ellipse.sigma_x, ellipse.sigma_y); + ellipse.angle += kPi / 2.0; + } + ellipse.angle -= kPi * std::floor(ellipse.angle / kPi); + if (fabs(ellipse.angle - kPi) < 1e-6 || fabs(ellipse.angle) < 1e-6) { + ellipse.angle = 0.0; + } + JXL_CHECK(ellipse.angle >= 0 && ellipse.angle <= kPi && + ellipse.sigma_x >= ellipse.sigma_y); + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "Ellipse mu=(%lf,%lf) sigma=(%lf,%lf) angle=%lf " + "intensity=(%lf,%lf,%lf) bg=(%lf,%lf,%lf) l2_loss=%lf " + "custom_loss=%lf, neg_pix=%" PRIuS ", neg_v=(%lf,%lf,%lf)\n", + ellipse.x, ellipse.y, ellipse.sigma_x, ellipse.sigma_y, + ellipse.angle, ellipse.intensity[0], ellipse.intensity[1], + ellipse.intensity[2], ellipse.bgColor[0], ellipse.bgColor[1], + ellipse.bgColor[2], ellipse.l2_loss, ellipse.custom_loss, + ellipse.neg_pixels, ellipse.neg_value[0], ellipse.neg_value[1], + ellipse.neg_value[2]); + return ellipse; +} + +} // namespace + +std::vector DetectGaussianEllipses( + const Image3F& opsin, const GaussianDetectParams& params, + const EllipseQuantParams& qParams, ThreadPool* pool) { + PROFILER_FUNC; + std::vector dots; + Image3F smooth(opsin.xsize(), opsin.ysize()); + ImageF energy = ComputeEnergyImage(opsin, &smooth, pool); +#if JXL_DEBUG_DOT_DETECT + AuxOut aux; + aux.debug_prefix = "/tmp/sebastian/"; + aux.DumpXybImage("smooth", smooth); + aux.DumpPlaneNormalized("energy", energy); +#endif // JXL_DEBUG_DOT_DETECT + std::vector components = FindCC( + energy, params.t_low, params.t_high, params.maxWinSize, params.minScore); + size_t numCC = + std::min(params.maxCC, (components.size() * params.percCC) / 100); + if (components.size() > numCC) { + std::sort( + components.begin(), components.end(), + [](const ConnectedComponent& a, const ConnectedComponent& b) -> bool { + return a.score > b.score; + }); + components.erase(components.begin() + numCC, components.end()); + } + for (const auto& cc : components) { + GaussianEllipse ellipse = FitGaussian(cc, energy, opsin, smooth); + if (ellipse.x < 0.0 || + std::ceil(ellipse.x) >= static_cast(opsin.xsize()) || + ellipse.y < 0.0 || + std::ceil(ellipse.y) >= static_cast(opsin.ysize())) { + continue; + } + if (ellipse.neg_pixels > params.maxNegPixels) continue; + double intensity = 0.21 * ellipse.intensity[0] + + 0.72 * ellipse.intensity[1] + + 0.07 * ellipse.intensity[2]; + double intensitySq = intensity * intensity; + // for (int c = 0; c < 3; c++) { + // intensitySq += ellipse.intensity[c] * ellipse.intensity[c]; + //} + double sqDistMeanMode = (ellipse.x - cc.mode.x) * (ellipse.x - cc.mode.x) + + (ellipse.y - cc.mode.y) * (ellipse.y - cc.mode.y); + if (ellipse.l2_loss < params.maxL2Loss && + ellipse.custom_loss < params.maxCustomLoss && + intensitySq > (params.minIntensity * params.minIntensity) && + sqDistMeanMode < params.maxDistMeanMode * params.maxDistMeanMode) { + size_t x0 = cc.bounds.x0(); + size_t y0 = cc.bounds.y0(); + dots.emplace_back(); + dots.back().second.emplace_back(x0, y0); + QuantizedPatch& patch = dots.back().first; + patch.xsize = cc.bounds.xsize(); + patch.ysize = cc.bounds.ysize(); + for (size_t y = 0; y < patch.ysize; y++) { + for (size_t x = 0; x < patch.xsize; x++) { + for (size_t c = 0; c < 3; c++) { + patch.fpixels[c][y * patch.xsize + x] = + opsin.ConstPlaneRow(c, y0 + y)[x0 + x] - + smooth.ConstPlaneRow(c, y0 + y)[x0 + x]; + } + } + } + } + } +#if JXL_DEBUG_DOT_DETECT + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, "Candidates: %" PRIuS ", Dots: %" PRIuS "\n", + components.size(), dots.size()); + ApplyGaussianEllipses(&smooth, dots, 1.0); + aux.DumpXybImage("draw", smooth); + ApplyGaussianEllipses(&smooth, dots, -1.0); + + auto qdots = QuantizeGaussianEllipses(dots, qParams); + auto deq = DequantizeGaussianEllipses(qdots, qParams); + ApplyGaussianEllipses(&smooth, deq, 1.0); + aux.DumpXybImage("qdraw", smooth); + ApplyGaussianEllipses(&smooth, deq, -1.0); +#endif // JXL_DEBUG_DOT_DETECT + return dots; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_detect_dots.h b/media/libjxl/src/lib/jxl/enc_detect_dots.h new file mode 100644 index 000000000..c3071d9a2 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_detect_dots.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// We attempt to remove dots, or speckle from images using Gaussian blur. +#ifndef LIB_JXL_ENC_DETECT_DOTS_H_ +#define LIB_JXL_ENC_DETECT_DOTS_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct GaussianDetectParams { + double t_high = 0; // at least one pixel must have larger energy than t_high + double t_low = 0; // all pixels must have a larger energy than tLow + uint32_t maxWinSize = 0; // discard dots larger than this containing window + double maxL2Loss = 0; + double maxCustomLoss = 0; + double minIntensity = 0; // If the intensity is too low, discard it + double maxDistMeanMode = 0; // The mean and the mode must be close + size_t maxNegPixels = 0; // Maximum number of negative pixel + size_t minScore = 0; + size_t maxCC = 50; // Maximum number of CC to keep + size_t percCC = 15; // Percentage in [0,100] of CC to keep +}; + +// Ellipse Quantization Params +struct EllipseQuantParams { + size_t xsize; // Image size in x + size_t ysize; // Image size in y + size_t qPosition; // Position quantization delta + // Quantization for the Gaussian sigma parameters + double minSigma; + double maxSigma; + size_t qSigma; // number of quantization levels + // Quantization for the rotation angle (between -pi and pi) + size_t qAngle; + // Quantization for the intensity + std::array minIntensity; + std::array maxIntensity; + std::array qIntensity; // number of quantization levels + // Extra parameters for the encoding + bool subtractQuantized; // Should we subtract quantized or detected dots? + float ytox; + float ytob; + + void QuantPositionSize(size_t* xsize, size_t* ysize) const; +}; + +// Detects dots in XYB image. +std::vector DetectGaussianEllipses( + const Image3F& opsin, const GaussianDetectParams& params, + const EllipseQuantParams& qParams, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_DETECT_DOTS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_dot_dictionary.cc b/media/libjxl/src/lib/jxl/enc_dot_dictionary.cc new file mode 100644 index 000000000..1b5413bdf --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_dot_dictionary.cc @@ -0,0 +1,72 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_dot_dictionary.h" + +#include +#include + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_detect_dots.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Private implementation of Dictionary Encode/Decode +namespace { + +/* Quantization constants for Ellipse dots */ +const size_t kEllipsePosQ = 2; // Quantization level for the position +const double kEllipseMinSigma = 0.1; // Minimum sigma value +const double kEllipseMaxSigma = 3.1; // Maximum Sigma value +const size_t kEllipseSigmaQ = 16; // Number of quantization levels for sigma +const size_t kEllipseAngleQ = 8; // Quantization level for the angle +// TODO: fix these values. +const std::array kEllipseMinIntensity{{-0.05, 0.0, -0.5}}; +const std::array kEllipseMaxIntensity{{0.05, 1.0, 0.4}}; +const std::array kEllipseIntensityQ{{10, 36, 10}}; +} // namespace + +std::vector FindDotDictionary(const CompressParams& cparams, + const Image3F& opsin, + const ColorCorrelationMap& cmap, + ThreadPool* pool) { + if (ApplyOverride(cparams.dots, + cparams.butteraugli_distance >= kMinButteraugliForDots)) { + GaussianDetectParams ellipse_params; + ellipse_params.t_high = 0.04; + ellipse_params.t_low = 0.02; + ellipse_params.maxWinSize = 5; + ellipse_params.maxL2Loss = 0.005; + ellipse_params.maxCustomLoss = 300; + ellipse_params.minIntensity = 0.12; + ellipse_params.maxDistMeanMode = 1.0; + ellipse_params.maxNegPixels = 0; + ellipse_params.minScore = 12.0; + ellipse_params.maxCC = 100; + ellipse_params.percCC = 100; + EllipseQuantParams qParams{ + opsin.xsize(), opsin.ysize(), kEllipsePosQ, + kEllipseMinSigma, kEllipseMaxSigma, kEllipseSigmaQ, + kEllipseAngleQ, kEllipseMinIntensity, kEllipseMaxIntensity, + kEllipseIntensityQ, kEllipsePosQ <= 5, cmap.YtoXRatio(0), + cmap.YtoBRatio(0)}; + + return DetectGaussianEllipses(opsin, ellipse_params, qParams, pool); + } + return {}; +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_dot_dictionary.h b/media/libjxl/src/lib/jxl/enc_dot_dictionary.h new file mode 100644 index 000000000..af76bfc3b --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_dot_dictionary.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_DOT_DICTIONARY_H_ +#define LIB_JXL_ENC_DOT_DICTIONARY_H_ + +// Dots are stored in a dictionary to avoid storing similar dots multiple +// times. + +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/image.h" + +namespace jxl { + +std::vector FindDotDictionary(const CompressParams& cparams, + const Image3F& opsin, + const ColorCorrelationMap& cmap, + ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_DOT_DICTIONARY_H_ diff --git a/media/libjxl/src/lib/jxl/enc_entropy_coder.cc b/media/libjxl/src/lib/jxl/enc_entropy_coder.cc new file mode 100644 index 000000000..c634445e8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_entropy_coder.cc @@ -0,0 +1,274 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_entropy_coder.h" + +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_entropy_coder.cc" +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::AndNot; +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::GetLane; + +// Returns number of non-zero coefficients (but skip LLF). +// We cannot rely on block[] being all-zero bits, so first truncate to integer. +// Also writes the per-8x8 block nzeros starting at nzeros_pos. +int32_t NumNonZeroExceptLLF(const size_t cx, const size_t cy, + const AcStrategy acs, const size_t covered_blocks, + const size_t log2_covered_blocks, + const int32_t* JXL_RESTRICT block, + const size_t nzeros_stride, + int32_t* JXL_RESTRICT nzeros_pos) { + const HWY_CAPPED(int32_t, kBlockDim) di; + + const auto zero = Zero(di); + // Add FF..FF for every zero coefficient, negate to get #zeros. + auto neg_sum_zero = zero; + + { + // Mask sufficient for one row of coefficients. + HWY_ALIGN const int32_t + llf_mask_lanes[AcStrategy::kMaxCoeffBlocks * (1 + kBlockDim)] = { + -1, -1, -1, -1}; + // First cx=1,2,4 elements are FF..FF, others 0. + const int32_t* llf_mask_pos = + llf_mask_lanes + AcStrategy::kMaxCoeffBlocks - cx; + + // Rows with LLF: mask out the LLF + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx * kBlockDim; x += Lanes(di)) { + const auto llf_mask = LoadU(di, llf_mask_pos + x); + + // LLF counts as zero so we don't include it in nzeros. + const auto coef = + AndNot(llf_mask, Load(di, &block[y * cx * kBlockDim + x])); + + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + } + + // Remaining rows: no mask + for (size_t y = cy; y < cy * kBlockDim; y++) { + for (size_t x = 0; x < cx * kBlockDim; x += Lanes(di)) { + const auto coef = Load(di, &block[y * cx * kBlockDim + x]); + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + + // We want area - sum_zero, add because neg_sum_zero is already negated. + const int32_t nzeros = + int32_t(cx * cy * kDCTBlockSize) + GetLane(SumOfLanes(di, neg_sum_zero)); + + const int32_t shifted_nzeros = static_cast( + (nzeros + covered_blocks - 1) >> log2_covered_blocks); + // Need non-canonicalized dimensions! + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + nzeros_pos[x + y * nzeros_stride] = shifted_nzeros; + } + } + + return nzeros; +} + +// Specialization for 8x8, where only top-left is LLF/DC. +// About 1% overall speedup vs. NumNonZeroExceptLLF. +int32_t NumNonZero8x8ExceptDC(const int32_t* JXL_RESTRICT block, + int32_t* JXL_RESTRICT nzeros_pos) { + const HWY_CAPPED(int32_t, kBlockDim) di; + + const auto zero = Zero(di); + // Add FF..FF for every zero coefficient, negate to get #zeros. + auto neg_sum_zero = zero; + + { + // First row has DC, so mask + const size_t y = 0; + HWY_ALIGN const int32_t dc_mask_lanes[kBlockDim] = {-1}; + + for (size_t x = 0; x < kBlockDim; x += Lanes(di)) { + const auto dc_mask = Load(di, dc_mask_lanes + x); + + // DC counts as zero so we don't include it in nzeros. + const auto coef = AndNot(dc_mask, Load(di, &block[y * kBlockDim + x])); + + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + + // Remaining rows: no mask + for (size_t y = 1; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x += Lanes(di)) { + const auto coef = Load(di, &block[y * kBlockDim + x]); + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + + // We want 64 - sum_zero, add because neg_sum_zero is already negated. + const int32_t nzeros = + int32_t(kDCTBlockSize) + GetLane(SumOfLanes(di, neg_sum_zero)); + + *nzeros_pos = nzeros; + + return nzeros; +} + +// The number of nonzeros of each block is predicted from the top and the left +// blocks, with opportune scaling to take into account the number of blocks of +// each strategy. The predicted number of nonzeros divided by two is used as a +// context; if this number is above 63, a specific context is used. If the +// number of nonzeros of a strategy is above 63, it is written directly using a +// fixed number of bits (that depends on the size of the strategy). +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map) { + const size_t xsize_blocks = rect.xsize(); + const size_t ysize_blocks = rect.ysize(); + + // TODO(user): update the estimate: usually less coefficients are used. + output->reserve(output->size() + + 3 * xsize_blocks * ysize_blocks * kDCTBlockSize); + + size_t offset[3] = {}; + const size_t nzeros_stride = tmp_num_nzeroes->PixelsPerRow(); + for (size_t by = 0; by < ysize_blocks; ++by) { + size_t sby[3] = {by >> cs.VShift(0), by >> cs.VShift(1), + by >> cs.VShift(2)}; + int32_t* JXL_RESTRICT row_nzeros[3] = { + tmp_num_nzeroes->PlaneRow(0, sby[0]), + tmp_num_nzeroes->PlaneRow(1, sby[1]), + tmp_num_nzeroes->PlaneRow(2, sby[2]), + }; + const int32_t* JXL_RESTRICT row_nzeros_top[3] = { + sby[0] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(0, sby[0] - 1), + sby[1] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(1, sby[1] - 1), + sby[2] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(2, sby[2] - 1), + }; + const uint8_t* JXL_RESTRICT row_qdc = + qdc.ConstRow(rect.y0() + by) + rect.x0(); + const int32_t* JXL_RESTRICT row_qf = rect.ConstRow(qf, by); + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < xsize_blocks; ++bx) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + size_t sbx[3] = {bx >> cs.HShift(0), bx >> cs.HShift(1), + bx >> cs.HShift(2)}; + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + const size_t covered_blocks = cx * cy; // = #LLF coefficients + const size_t log2_covered_blocks = + Num0BitsBelowLS1Bit_Nonzero(covered_blocks); + const size_t size = covered_blocks * kDCTBlockSize; + + CoefficientLayout(&cy, &cx); // swap cx/cy to canonical order + + for (int c : {1, 0, 2}) { + if (sbx[c] << cs.HShift(c) != bx) continue; + if (sby[c] << cs.VShift(c) != by) continue; + const int32_t* JXL_RESTRICT block = ac_rows[c] + offset[c]; + + int32_t nzeros = + (covered_blocks == 1) + ? NumNonZero8x8ExceptDC(block, row_nzeros[c] + sbx[c]) + : NumNonZeroExceptLLF(cx, cy, acs, covered_blocks, + log2_covered_blocks, block, nzeros_stride, + row_nzeros[c] + sbx[c]); + + int ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &orders[CoeffOrderOffset(ord, c)]; + + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top[c], row_nzeros[c], sbx[c], 32); + size_t block_ctx = + block_ctx_map.Context(row_qdc[bx], row_qf[sbx[c]], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx); + + output->emplace_back(nzero_ctx, nzeros); + const size_t histo_offset = + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + // Skip LLF. + size_t prev = (nzeros > static_cast(size / 16) ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + int32_t coeff = block[order[k]]; + size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + uint32_t u_coeff = PackSigned(coeff); + output->emplace_back(ctx, u_coeff); + prev = coeff != 0; + nzeros -= prev; + } + JXL_DASSERT(nzeros == 0); + offset[c] += size; + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(TokenizeCoefficients); +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map) { + return HWY_DYNAMIC_DISPATCH(TokenizeCoefficients)( + orders, rect, ac_rows, ac_strategy, cs, tmp_num_nzeroes, output, qdc, qf, + block_ctx_map); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_entropy_coder.h b/media/libjxl/src/lib/jxl/enc_entropy_coder.h new file mode 100644 index 000000000..7dfc71c72 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_entropy_coder.h @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ENTROPY_CODER_H_ +#define LIB_JXL_ENC_ENTROPY_CODER_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/ac_context.h" // BlockCtxMap +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/frame_header.h" // YCbCrChromaSubsampling +#include "lib/jxl/image.h" + +// Entropy coding and context modeling of DC and AC coefficients, as well as AC +// strategy and quantization field. + +namespace jxl { + +// Generate DCT NxN quantized AC values tokens. +// Only the subset "rect" [in units of blocks] within all images. +// See also DecodeACVarBlock. +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ENTROPY_CODER_H_ diff --git a/media/libjxl/src/lib/jxl/enc_external_image.cc b/media/libjxl/src/lib/jxl/enc_external_image.cc new file mode 100644 index 000000000..346182b3b --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_external_image.cc @@ -0,0 +1,428 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_external_image.h" + +#include + +#include +#include +#include +#include +#include + +#include "jxl/types.h" +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" + +namespace jxl { +namespace { + +// Based on highway scalar implementation, for testing +float LoadFloat16(uint16_t bits16) { + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = (1.0f / 16384) * (mantissa * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + float result; + memcpy(&result, &bits32, 4); + return result; +} + +float LoadLEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadLE16(p); + return LoadFloat16(bits16); +} + +float LoadBEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadBE16(p); + return LoadFloat16(bits16); +} + +// Loads a float in big endian +float LoadBEFloat(const uint8_t* p) { + float value; + const uint32_t u = LoadBE32(p); + memcpy(&value, &u, 4); + return value; +} + +// Loads a float in little endian +float LoadLEFloat(const uint8_t* p) { + float value; + const uint32_t u = LoadLE32(p); + memcpy(&value, &u, 4); + return value; +} + +typedef uint32_t(LoadFuncType)(const uint8_t* p); +template +void JXL_INLINE LoadFloatRow(float* JXL_RESTRICT row_out, const uint8_t* in, + float mul, size_t xsize, size_t bytes_per_pixel) { + size_t i = 0; + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = mul * LoadFunc(in + i); + i += bytes_per_pixel; + } +} + +uint32_t JXL_INLINE Load8(const uint8_t* p) { return *p; } + +Status PixelFormatToExternal(const JxlPixelFormat& pixel_format, + size_t* bitdepth, bool* float_in) { + if (pixel_format.data_type == JXL_TYPE_FLOAT) { + *bitdepth = 32; + *float_in = true; + } else if (pixel_format.data_type == JXL_TYPE_FLOAT16) { + *bitdepth = 16; + *float_in = true; + } else if (pixel_format.data_type == JXL_TYPE_UINT8) { + *bitdepth = 8; + *float_in = false; + } else if (pixel_format.data_type == JXL_TYPE_UINT16) { + *bitdepth = 16; + *float_in = false; + } else { + return JXL_FAILURE("unsupported pixel format data type"); + } + return true; +} +} // namespace + +Status ConvertFromExternal(Span bytes, size_t xsize, + size_t ysize, size_t bits_per_sample, + JxlEndianness endianness, ThreadPool* pool, + ImageF* channel, bool float_in, size_t align) { + // TODO(firsching): Avoid code duplication with the function below. + JXL_CHECK(float_in ? bits_per_sample == 16 || bits_per_sample == 32 + : bits_per_sample > 0 && bits_per_sample <= 16); + const size_t bytes_per_pixel = DivCeil(bits_per_sample, jxl::kBitsPerByte); + const size_t last_row_size = xsize * bytes_per_pixel; + const size_t row_size = + (align > 1 ? jxl::DivCeil(last_row_size, align) * align : last_row_size); + const size_t bytes_to_read = row_size * (ysize - 1) + last_row_size; + if (xsize == 0 || ysize == 0) return JXL_FAILURE("Empty image"); + if (bytes.size() < bytes_to_read) { + return JXL_FAILURE("Buffer size is too small"); + } + JXL_ASSERT(channel->xsize() == xsize); + JXL_ASSERT(channel->ysize() == ysize); + // Too large buffer is likely an application bug, so also fail for that. + // Do allow padding to stride in last row though. + if (bytes.size() > row_size * ysize) { + return JXL_FAILURE("Buffer size is too large"); + } + + const bool little_endian = + endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + const uint8_t* const in = bytes.data(); + if (float_in) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t i = row_size * task; + float* JXL_RESTRICT row_out = channel->Row(y); + if (bits_per_sample == 16) { + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat16(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat16(in + i); + i += bytes_per_pixel; + } + } + } else { + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat(in + i); + i += bytes_per_pixel; + } + } + } + }, + "ConvertExtraChannelFloat")); + } else { + float mul = 1. / ((1ull << bits_per_sample) - 1); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t i = row_size * task; + float* JXL_RESTRICT row_out = channel->Row(y); + if (bits_per_sample <= 8) { + LoadFloatRow(row_out, in + i, mul, xsize, bytes_per_pixel); + } else { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } + }, + "ConvertExtraChannelUint")); + } + + return true; +} +Status ConvertFromExternal(Span bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + size_t channels, bool alpha_is_premultiplied, + size_t bits_per_sample, JxlEndianness endianness, + ThreadPool* pool, ImageBundle* ib, bool float_in, + size_t align) { + JXL_CHECK(float_in ? bits_per_sample == 16 || bits_per_sample == 32 + : bits_per_sample > 0 && bits_per_sample <= 16); + + const size_t color_channels = c_current.Channels(); + bool has_alpha = channels == 2 || channels == 4; + if (channels < color_channels) { + return JXL_FAILURE("Expected %" PRIuS + " color channels, received only %" PRIuS " channels", + color_channels, channels); + } + + const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte); + const size_t bytes_per_pixel = channels * bytes_per_channel; + if (bits_per_sample > 16 && bits_per_sample < 32) { + return JXL_FAILURE("not supported, try bits_per_sample=32"); + } + + const size_t last_row_size = xsize * bytes_per_pixel; + const size_t row_size = + (align > 1 ? jxl::DivCeil(last_row_size, align) * align : last_row_size); + const size_t bytes_to_read = row_size * (ysize - 1) + last_row_size; + if (xsize == 0 || ysize == 0) return JXL_FAILURE("Empty image"); + if (bytes.size() < bytes_to_read) { + return JXL_FAILURE( + "Buffer size is too small: expected at least %" PRIuS + " bytes (= %" PRIuS " * %" PRIuS " * %" PRIuS "), got %" PRIuS " bytes", + bytes_to_read, xsize, ysize, bytes_per_pixel, bytes.size()); + } + // Too large buffer is likely an application bug, so also fail for that. + // Do allow padding to stride in last row though. + if (bytes.size() > row_size * ysize) { + return JXL_FAILURE( + "Buffer size is too large: expected at most %" PRIuS " bytes (= %" PRIuS + " * %" PRIuS " * %" PRIuS "), got %" PRIuS " bytes", + row_size * ysize, xsize, ysize, bytes_per_pixel, bytes.size()); + } + const bool little_endian = + endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + const uint8_t* const in = bytes.data(); + + Image3F color(xsize, ysize); + + if (float_in) { + for (size_t c = 0; c < color_channels; ++c) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t i = + row_size * task + (c * bits_per_sample / jxl::kBitsPerByte); + float* JXL_RESTRICT row_out = color.PlaneRow(c, y); + if (bits_per_sample == 16) { + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat16(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat16(in + i); + i += bytes_per_pixel; + } + } + } else { + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat(in + i); + i += bytes_per_pixel; + } + } + } + }, + "ConvertRGBFloat")); + } + } else { + // Multiplier to convert from the integer range to floating point 0-1 range. + float mul = 1. / ((1ull << bits_per_sample) - 1); + for (size_t c = 0; c < color_channels; ++c) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t i = row_size * task + c * bytes_per_channel; + float* JXL_RESTRICT row_out = color.PlaneRow(c, y); + if (bits_per_sample <= 8) { + LoadFloatRow(row_out, in + i, mul, xsize, bytes_per_pixel); + } else { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } + }, + "ConvertRGBUint")); + } + } + + if (color_channels == 1) { + CopyImageTo(color.Plane(0), &color.Plane(1)); + CopyImageTo(color.Plane(0), &color.Plane(2)); + } + + ib->SetFromImage(std::move(color), c_current); + + // Passing an interleaved image with an alpha channel to an image that doesn't + // have alpha channel just discards the passed alpha channel. + if (has_alpha && ib->HasAlpha()) { + ImageF alpha(xsize, ysize); + + if (float_in) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t i = row_size * task + + ((channels - 1) * bits_per_sample / jxl::kBitsPerByte); + float* JXL_RESTRICT row_out = alpha.Row(y); + if (bits_per_sample == 16) { + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat16(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat16(in + i); + i += bytes_per_pixel; + } + } + } else { + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat(in + i); + i += bytes_per_pixel; + } + } + } + }, + "ConvertAlphaFloat")); + } else { + float mul = 1. / ((1ull << bits_per_sample) - 1); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t i = row_size * task + (channels - 1) * bytes_per_channel; + float* JXL_RESTRICT row_out = alpha.Row(y); + if (bits_per_sample <= 8) { + LoadFloatRow(row_out, in + i, mul, xsize, bytes_per_pixel); + } else { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } + }, + "ConvertAlphaUint")); + } + + ib->SetAlpha(std::move(alpha), alpha_is_premultiplied); + } else if (!has_alpha && ib->HasAlpha()) { + // if alpha is not passed, but it is expected, then assume + // it is all-opaque + ImageF alpha(xsize, ysize); + FillImage(1.0f, &alpha); + ib->SetAlpha(std::move(alpha), alpha_is_premultiplied); + } + + return true; +} + +Status BufferToImageF(const JxlPixelFormat& pixel_format, size_t xsize, + size_t ysize, const void* buffer, size_t size, + ThreadPool* pool, ImageF* channel) { + size_t bitdepth; + bool float_in; + + JXL_RETURN_IF_ERROR( + PixelFormatToExternal(pixel_format, &bitdepth, &float_in)); + + JXL_RETURN_IF_ERROR(ConvertFromExternal( + jxl::Span(static_cast(buffer), size), + xsize, ysize, bitdepth, pixel_format.endianness, pool, channel, float_in, + pixel_format.align)); + + return true; +} + +Status BufferToImageBundle(const JxlPixelFormat& pixel_format, uint32_t xsize, + uint32_t ysize, const void* buffer, size_t size, + jxl::ThreadPool* pool, + const jxl::ColorEncoding& c_current, + jxl::ImageBundle* ib) { + size_t bitdepth; + bool float_in; + JXL_RETURN_IF_ERROR( + PixelFormatToExternal(pixel_format, &bitdepth, &float_in)); + + JXL_RETURN_IF_ERROR(ConvertFromExternal( + jxl::Span(static_cast(buffer), size), + xsize, ysize, c_current, pixel_format.num_channels, + /*alpha_is_premultiplied=*/false, bitdepth, pixel_format.endianness, pool, + ib, float_in, pixel_format.align)); + ib->VerifyMetadata(); + + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_external_image.h b/media/libjxl/src/lib/jxl/enc_external_image.h new file mode 100644 index 000000000..73b7175a9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_external_image.h @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_EXTERNAL_IMAGE_H_ +#define LIB_JXL_ENC_EXTERNAL_IMAGE_H_ + +// Interleaved image for color transforms and Codec. + +#include +#include + +#include "jxl/types.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { +Status ConvertFromExternal(Span bytes, size_t xsize, + size_t ysize, size_t bits_per_sample, + JxlEndianness endianness, ThreadPool* pool, + ImageF* channel, bool float_in, size_t align); + +// Convert an interleaved pixel buffer to the internal ImageBundle +// representation. This is the opposite of ConvertToExternal(). +Status ConvertFromExternal(Span bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + size_t channels, bool alpha_is_premultiplied, + size_t bits_per_sample, JxlEndianness endianness, + ThreadPool* pool, ImageBundle* ib, bool float_in, + size_t align); +Status BufferToImageF(const JxlPixelFormat& pixel_format, size_t xsize, + size_t ysize, const void* buffer, size_t size, + ThreadPool* pool, ImageF* channel); +Status BufferToImageBundle(const JxlPixelFormat& pixel_format, uint32_t xsize, + uint32_t ysize, const void* buffer, size_t size, + jxl::ThreadPool* pool, + const jxl::ColorEncoding& c_current, + jxl::ImageBundle* ib); + +} // namespace jxl + +#endif // LIB_JXL_ENC_EXTERNAL_IMAGE_H_ diff --git a/media/libjxl/src/lib/jxl/enc_external_image_gbench.cc b/media/libjxl/src/lib/jxl/enc_external_image_gbench.cc new file mode 100644 index 000000000..a123d4b35 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_external_image_gbench.cc @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Encoder case, deinterleaves a buffer. +void BM_EncExternalImage_ConvertImageRGBA(benchmark::State& state) { + const size_t kNumIter = 5; + size_t xsize = state.range(); + size_t ysize = state.range(); + + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + + std::vector interleaved(xsize * ysize * 4); + + for (auto _ : state) { + for (size_t i = 0; i < kNumIter; ++i) { + JXL_CHECK(ConvertFromExternal( + Span(interleaved.data(), interleaved.size()), xsize, + ysize, + /*c_current=*/ColorEncoding::SRGB(), + /*channels=*/4, + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/8, JXL_NATIVE_ENDIAN, + /*pool=*/nullptr, &ib, /*float_in=*/false, /*align=*/0)); + } + } + + // Pixels per second. + state.SetItemsProcessed(kNumIter * state.iterations() * xsize * ysize); + state.SetBytesProcessed(kNumIter * state.iterations() * interleaved.size()); +} + +BENCHMARK(BM_EncExternalImage_ConvertImageRGBA) + ->RangeMultiplier(2) + ->Range(256, 2048); + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_external_image_test.cc b/media/libjxl/src/lib/jxl/enc_external_image_test.cc new file mode 100644 index 000000000..2c5fa5aca --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_external_image_test.cc @@ -0,0 +1,69 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_external_image.h" + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +#if !defined(JXL_CRASH_ON_ERROR) +TEST(ExternalImageTest, InvalidSize) { + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + + const uint8_t buf[10 * 100 * 8] = {}; + EXPECT_FALSE(ConvertFromExternal( + Span(buf, 10), /*xsize=*/10, /*ysize=*/100, + /*c_current=*/ColorEncoding::SRGB(), /*channels=*/4, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + nullptr, &ib, /*float_in=*/false, /*align=*/0)); + EXPECT_FALSE(ConvertFromExternal( + Span(buf, sizeof(buf) - 1), /*xsize=*/10, /*ysize=*/100, + /*c_current=*/ColorEncoding::SRGB(), /*channels=*/4, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + nullptr, &ib, /*float_in=*/false, /*align=*/0)); + EXPECT_TRUE( + ConvertFromExternal(Span(buf, sizeof(buf)), /*xsize=*/10, + /*ysize=*/100, /*c_current=*/ColorEncoding::SRGB(), + /*channels=*/4, /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/16, JXL_BIG_ENDIAN, nullptr, &ib, + /*float_in=*/false, /*align=*/0)); +} +#endif + +TEST(ExternalImageTest, AlphaMissing) { + ImageMetadata im; + im.SetAlphaBits(0); // No alpha + ImageBundle ib(&im); + + const size_t xsize = 10; + const size_t ysize = 20; + const uint8_t buf[xsize * ysize * 4] = {}; + + // has_alpha is true but the ImageBundle has no alpha. Alpha channel should + // be ignored. + EXPECT_TRUE( + ConvertFromExternal(Span(buf, sizeof(buf)), xsize, ysize, + /*c_current=*/ColorEncoding::SRGB(), + /*channels=*/4, /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/8, JXL_BIG_ENDIAN, nullptr, &ib, + /*float_in=*/false, /*align=*/0)); + EXPECT_FALSE(ib.HasAlpha()); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_file.cc b/media/libjxl/src/lib/jxl/enc_file.cc new file mode 100644 index 000000000..0f29bd92d --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_file.cc @@ -0,0 +1,238 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_file.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +namespace { + +// DC + 'Very Low Frequency' +PassDefinition progressive_passes_dc_vlf[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}}; + +PassDefinition progressive_passes_dc_lf[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}}; + +PassDefinition progressive_passes_dc_lf_salient_ac[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/true, + /*suitable_for_downsampling_of_at_least=*/0}}; + +PassDefinition progressive_passes_dc_lf_salient_ac_other_ac[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/true, + /*suitable_for_downsampling_of_at_least=*/0}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/0}}; + +PassDefinition progressive_passes_dc_quant_ac_full_ac[] = { + {/*num_coefficients=*/8, /*shift=*/1, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/0}, +}; + +Status PrepareCodecMetadataFromIO(const CompressParams& cparams, + const CodecInOut* io, + CodecMetadata* metadata) { + *metadata = io->metadata; + size_t ups = 1; + if (cparams.already_downsampled) ups = cparams.resampling; + + JXL_RETURN_IF_ERROR(metadata->size.Set(io->xsize() * ups, io->ysize() * ups)); + + // Keep ICC profile in lossless modes because a reconstructed profile may be + // slightly different (quantization). + // Also keep ICC in JPEG reconstruction mode as we need byte-exact profiles. + if (!cparams.IsLossless() && !io->Main().IsJPEG()) { + metadata->m.color_encoding.DecideIfWantICC(); + } + + metadata->m.xyb_encoded = + cparams.color_transform == ColorTransform::kXYB ? true : false; + + // TODO(firsching): move this EncodeFile to test_utils / re-implement this + // using API functions + return true; +} + +} // namespace + +Status EncodePreview(const CompressParams& cparams, const ImageBundle& ib, + const CodecMetadata* metadata, const JxlCmsInterface& cms, + ThreadPool* pool, BitWriter* JXL_RESTRICT writer) { + BitWriter preview_writer; + // TODO(janwas): also support generating preview by downsampling + if (ib.HasColor()) { + AuxOut aux_out; + PassesEncoderState passes_enc_state; + // TODO(lode): check if we want all extra channels and matching xyb_encoded + // for the preview, such that using the main ImageMetadata object for + // encoding this frame is warrented. + FrameInfo frame_info; + frame_info.is_preview = true; + JXL_RETURN_IF_ERROR(EncodeFrame(cparams, frame_info, metadata, ib, + &passes_enc_state, cms, pool, + &preview_writer, &aux_out)); + preview_writer.ZeroPadToByte(); + } + + if (preview_writer.BitsWritten() != 0) { + writer->ZeroPadToByte(); + writer->AppendByteAligned(preview_writer); + } + + return true; +} + +Status WriteHeaders(CodecMetadata* metadata, BitWriter* writer, + AuxOut* aux_out) { + // Marker/signature + BitWriter::Allotment allotment(writer, 16); + writer->Write(8, 0xFF); + writer->Write(8, kCodestreamMarker); + ReclaimAndCharge(writer, &allotment, kLayerHeader, aux_out); + + JXL_RETURN_IF_ERROR( + WriteSizeHeader(metadata->size, writer, kLayerHeader, aux_out)); + + JXL_RETURN_IF_ERROR( + WriteImageMetadata(metadata->m, writer, kLayerHeader, aux_out)); + + metadata->transform_data.nonserialized_xyb_encoded = metadata->m.xyb_encoded; + JXL_RETURN_IF_ERROR( + Bundle::Write(metadata->transform_data, writer, kLayerHeader, aux_out)); + + return true; +} + +Status EncodeFile(const CompressParams& params, const CodecInOut* io, + PassesEncoderState* passes_enc_state, PaddedBytes* compressed, + const JxlCmsInterface& cms, AuxOut* aux_out, + ThreadPool* pool) { + io->CheckMetadata(); + BitWriter writer; + + CompressParams cparams = params; + if (io->Main().color_transform != ColorTransform::kNone) { + // Set the color transform to YCbCr or XYB if the original image is such. + cparams.color_transform = io->Main().color_transform; + } + + JXL_RETURN_IF_ERROR(ParamsPostInit(&cparams)); + + std::unique_ptr metadata = jxl::make_unique(); + JXL_RETURN_IF_ERROR(PrepareCodecMetadataFromIO(cparams, io, metadata.get())); + JXL_RETURN_IF_ERROR(WriteHeaders(metadata.get(), &writer, aux_out)); + + // Only send ICC (at least several hundred bytes) if fields aren't enough. + if (metadata->m.color_encoding.WantICC()) { + JXL_RETURN_IF_ERROR(WriteICC(metadata->m.color_encoding.ICC(), &writer, + kLayerHeader, aux_out)); + } + + if (metadata->m.have_preview) { + JXL_RETURN_IF_ERROR(EncodePreview(cparams, io->preview_frame, + metadata.get(), cms, pool, &writer)); + } + + // Each frame should start on byte boundaries. + BitWriter::Allotment allotment(&writer, 8); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, kLayerHeader, aux_out); + + if (cparams.progressive_mode || cparams.qprogressive_mode) { + if (cparams.saliency_map != nullptr) { + passes_enc_state->progressive_splitter.SetSaliencyMap( + cparams.saliency_map); + } + passes_enc_state->progressive_splitter.SetSaliencyThreshold( + cparams.saliency_threshold); + if (cparams.qprogressive_mode) { + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_quant_ac_full_ac}); + } else { + switch (cparams.saliency_num_progressive_steps) { + case 1: + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_vlf}); + break; + case 2: + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf}); + break; + case 3: + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf_salient_ac}); + break; + case 4: + if (cparams.saliency_threshold == 0.0f) { + // No need for a 4th pass if saliency-threshold regards everything + // as salient. + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf_salient_ac}); + } else { + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf_salient_ac_other_ac}); + } + break; + default: + return JXL_FAILURE("Invalid saliency_num_progressive_steps."); + } + } + } + + for (size_t i = 0; i < io->frames.size(); i++) { + FrameInfo info; + info.is_last = i == io->frames.size() - 1; + if (io->frames[i].use_for_next_frame) { + info.save_as_reference = 1; + } + JXL_RETURN_IF_ERROR(EncodeFrame(cparams, info, metadata.get(), + io->frames[i], passes_enc_state, cms, pool, + &writer, aux_out)); + } + + // Clean up passes_enc_state in case it gets reused. + for (size_t i = 0; i < 4; i++) { + passes_enc_state->shared.dc_frames[i] = Image3F(); + passes_enc_state->shared.reference_frames[i].storage = ImageBundle(); + } + + *compressed = std::move(writer).TakeBytes(); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_file.h b/media/libjxl/src/lib/jxl/enc_file.h new file mode 100644 index 000000000..37b3a27f3 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_file.h @@ -0,0 +1,55 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_FILE_H_ +#define LIB_JXL_ENC_FILE_H_ + +// Facade for JXL encoding. + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" + +namespace jxl { + +// Write preview from `io`. +Status EncodePreview(const CompressParams& cparams, const ImageBundle& ib, + const CodecMetadata* metadata, const JxlCmsInterface& cms, + ThreadPool* pool, BitWriter* JXL_RESTRICT writer); + +// Write headers from the CodecMetadata. Also may modify nonserialized_... +// fields of the metadata. +Status WriteHeaders(CodecMetadata* metadata, BitWriter* writer, + AuxOut* aux_out); + +// Compresses pixels from `io` (given in any ColorEncoding). +// `io->metadata.m.original` must be set. +Status EncodeFile(const CompressParams& params, const CodecInOut* io, + PassesEncoderState* passes_enc_state, PaddedBytes* compressed, + const JxlCmsInterface& cms, AuxOut* aux_out = nullptr, + ThreadPool* pool = nullptr); + +// Backwards-compatible interface. Don't use in new code. +// TODO(deymo): Remove this function once we migrate users to C encoder API. +struct FrameEncCache {}; +JXL_INLINE Status EncodeFile(const CompressParams& params, const CodecInOut* io, + FrameEncCache* /* unused */, + PaddedBytes* compressed, + const JxlCmsInterface& cms, + AuxOut* aux_out = nullptr, + ThreadPool* pool = nullptr) { + PassesEncoderState passes_enc_state; + return EncodeFile(params, io, &passes_enc_state, compressed, cms, aux_out, + pool); +} + +} // namespace jxl + +#endif // LIB_JXL_ENC_FILE_H_ diff --git a/media/libjxl/src/lib/jxl/enc_frame.cc b/media/libjxl/src/lib/jxl/enc_frame.cc new file mode 100644 index 000000000..f57175ba0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_frame.cc @@ -0,0 +1,1547 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_frame.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_chroma_from_luma.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/enc_context_map.h" +#include "lib/jxl/enc_entropy_coder.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_photon_noise.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/enc_toc.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/gaborish.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/toc.h" + +namespace jxl { +namespace { + +void ClusterGroups(PassesEncoderState* enc_state) { + if (enc_state->shared.frame_header.passes.num_passes > 1) { + // TODO(veluca): implement this for progressive modes. + return; + } + // This only considers pass 0 for now. + std::vector context_map; + EntropyEncodingData codes; + auto& ac = enc_state->passes[0].ac_tokens; + size_t limit = std::ceil(std::sqrt(ac.size())); + if (limit == 1) return; + size_t num_contexts = enc_state->shared.block_ctx_map.NumACContexts(); + std::vector costs(ac.size()); + HistogramParams params; + params.uint_method = HistogramParams::HybridUintMethod::kNone; + params.lz77_method = HistogramParams::LZ77Method::kNone; + params.ans_histogram_strategy = + HistogramParams::ANSHistogramStrategy::kApproximate; + size_t max = 0; + auto token_cost = [&](std::vector>& tokens, size_t num_ctx, + bool estimate = true) { + // TODO(veluca): not estimating is very expensive. + BitWriter writer; + size_t c = BuildAndEncodeHistograms( + params, num_ctx, tokens, &codes, &context_map, + estimate ? nullptr : &writer, 0, /*aux_out=*/0); + if (estimate) return c; + for (size_t i = 0; i < tokens.size(); i++) { + WriteTokens(tokens[i], codes, context_map, &writer, 0, nullptr); + } + return writer.BitsWritten(); + }; + for (size_t i = 0; i < ac.size(); i++) { + std::vector> tokens{ac[i]}; + costs[i] = + token_cost(tokens, enc_state->shared.block_ctx_map.NumACContexts()); + if (costs[i] > costs[max]) { + max = i; + } + } + auto dist = [&](int i, int j) { + std::vector> tokens{ac[i], ac[j]}; + return token_cost(tokens, num_contexts) - costs[i] - costs[j]; + }; + std::vector out{max}; + std::vector dists(ac.size()); + size_t farthest = 0; + for (size_t i = 0; i < ac.size(); i++) { + if (i == max) continue; + dists[i] = dist(max, i); + if (dists[i] > dists[farthest]) { + farthest = i; + } + } + + while (dists[farthest] > 0 && out.size() < limit) { + out.push_back(farthest); + dists[farthest] = 0; + enc_state->histogram_idx[farthest] = out.size() - 1; + for (size_t i = 0; i < ac.size(); i++) { + float d = dist(out.back(), i); + if (d < dists[i]) { + dists[i] = d; + enc_state->histogram_idx[i] = out.size() - 1; + } + if (dists[i] > dists[farthest]) { + farthest = i; + } + } + } + + std::vector remap(out.size()); + std::iota(remap.begin(), remap.end(), 0); + for (size_t i = 0; i < enc_state->histogram_idx.size(); i++) { + enc_state->histogram_idx[i] = remap[enc_state->histogram_idx[i]]; + } + auto remap_cost = [&](std::vector remap) { + std::vector re_remap(remap.size(), remap.size()); + size_t r = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (re_remap[remap[i]] == remap.size()) { + re_remap[remap[i]] = r++; + } + remap[i] = re_remap[remap[i]]; + } + auto tokens = ac; + size_t max_hist = 0; + for (size_t i = 0; i < tokens.size(); i++) { + for (size_t j = 0; j < tokens[i].size(); j++) { + size_t hist = remap[enc_state->histogram_idx[i]]; + tokens[i][j].context += hist * num_contexts; + max_hist = std::max(hist + 1, max_hist); + } + } + return token_cost(tokens, max_hist * num_contexts, /*estimate=*/false); + }; + + for (size_t src = 0; src < out.size(); src++) { + float cost = remap_cost(remap); + size_t best = src; + for (size_t j = src + 1; j < out.size(); j++) { + if (remap[src] == remap[j]) continue; + auto remap_c = remap; + std::replace(remap_c.begin(), remap_c.end(), remap[src], remap[j]); + float c = remap_cost(remap_c); + if (c < cost) { + best = j; + cost = c; + } + } + if (src != best) { + std::replace(remap.begin(), remap.end(), remap[src], remap[best]); + } + } + std::vector re_remap(remap.size(), remap.size()); + size_t r = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (re_remap[remap[i]] == remap.size()) { + re_remap[remap[i]] = r++; + } + remap[i] = re_remap[remap[i]]; + } + + enc_state->shared.num_histograms = + *std::max_element(remap.begin(), remap.end()) + 1; + for (size_t i = 0; i < enc_state->histogram_idx.size(); i++) { + enc_state->histogram_idx[i] = remap[enc_state->histogram_idx[i]]; + } + for (size_t i = 0; i < ac.size(); i++) { + for (size_t j = 0; j < ac[i].size(); j++) { + ac[i][j].context += enc_state->histogram_idx[i] * num_contexts; + } + } +} + +uint64_t FrameFlagsFromParams(const CompressParams& cparams) { + uint64_t flags = 0; + + const float dist = cparams.butteraugli_distance; + + // We don't add noise at low butteraugli distances because the original + // noise is stored within the compressed image and adding noise makes things + // worse. + if (ApplyOverride(cparams.noise, dist >= kMinButteraugliForNoise) || + cparams.photon_noise_iso > 0 || + cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) { + flags |= FrameHeader::kNoise; + } + + if (cparams.progressive_dc > 0 && cparams.modular_mode == false) { + flags |= FrameHeader::kUseDcFrame; + } + + return flags; +} + +Status LoopFilterFromParams(const CompressParams& cparams, + FrameHeader* JXL_RESTRICT frame_header) { + LoopFilter* loop_filter = &frame_header->loop_filter; + + // Gaborish defaults to enabled in Hare or slower. + loop_filter->gab = ApplyOverride( + cparams.gaborish, cparams.speed_tier <= SpeedTier::kHare && + frame_header->encoding == FrameEncoding::kVarDCT && + cparams.decoding_speed_tier < 4); + + if (cparams.epf != -1) { + loop_filter->epf_iters = cparams.epf; + } else { + if (frame_header->encoding == FrameEncoding::kModular) { + loop_filter->epf_iters = 0; + } else { + constexpr float kThresholds[3] = {0.7, 1.5, 4.0}; + loop_filter->epf_iters = 0; + if (cparams.decoding_speed_tier < 3) { + for (size_t i = cparams.decoding_speed_tier == 2 ? 1 : 0; i < 3; i++) { + if (cparams.butteraugli_distance >= kThresholds[i]) { + loop_filter->epf_iters++; + } + } + } + } + } + // Strength of EPF in modular mode. + if (frame_header->encoding == FrameEncoding::kModular && + !cparams.IsLossless()) { + // TODO(veluca): this formula is nonsense. + loop_filter->epf_sigma_for_modular = cparams.butteraugli_distance; + } + if (frame_header->encoding == FrameEncoding::kModular && + cparams.lossy_palette) { + loop_filter->epf_sigma_for_modular = 1.0f; + } + + return true; +} + +Status MakeFrameHeader(const CompressParams& cparams, + const ProgressiveSplitter& progressive_splitter, + const FrameInfo& frame_info, const ImageBundle& ib, + FrameHeader* JXL_RESTRICT frame_header) { + frame_header->nonserialized_is_preview = frame_info.is_preview; + frame_header->is_last = frame_info.is_last; + frame_header->save_before_color_transform = + frame_info.save_before_color_transform; + frame_header->frame_type = frame_info.frame_type; + frame_header->name = ib.name; + + progressive_splitter.InitPasses(&frame_header->passes); + + if (cparams.modular_mode) { + frame_header->encoding = FrameEncoding::kModular; + frame_header->group_size_shift = cparams.modular_group_size_shift; + } + + frame_header->chroma_subsampling = ib.chroma_subsampling; + if (ib.IsJPEG()) { + // we are transcoding a JPEG, so we don't get to choose + frame_header->encoding = FrameEncoding::kVarDCT; + frame_header->color_transform = ib.color_transform; + } else { + frame_header->color_transform = cparams.color_transform; + if (!cparams.modular_mode && + (frame_header->chroma_subsampling.MaxHShift() != 0 || + frame_header->chroma_subsampling.MaxVShift() != 0)) { + return JXL_FAILURE( + "Chroma subsampling is not supported in VarDCT mode when not " + "recompressing JPEGs"); + } + } + + frame_header->flags = FrameFlagsFromParams(cparams); + // Non-photon noise is not supported in the Modular encoder for now. + if (frame_header->encoding != FrameEncoding::kVarDCT && + cparams.photon_noise_iso == 0 && cparams.manual_noise.empty()) { + frame_header->UpdateFlag(false, FrameHeader::Flags::kNoise); + } + + JXL_RETURN_IF_ERROR(LoopFilterFromParams(cparams, frame_header)); + + frame_header->dc_level = frame_info.dc_level; + if (frame_header->dc_level > 2) { + // With 3 or more progressive_dc frames, the implementation does not yet + // work, see enc_cache.cc. + return JXL_FAILURE("progressive_dc > 2 is not yet supported"); + } + if (cparams.progressive_dc > 0 && + (cparams.ec_resampling != 1 || cparams.resampling != 1)) { + return JXL_FAILURE("Resampling not supported with DC frames"); + } + if (cparams.resampling != 1 && cparams.resampling != 2 && + cparams.resampling != 4 && cparams.resampling != 8) { + return JXL_FAILURE("Invalid resampling factor"); + } + if (cparams.ec_resampling != 1 && cparams.ec_resampling != 2 && + cparams.ec_resampling != 4 && cparams.ec_resampling != 8) { + return JXL_FAILURE("Invalid ec_resampling factor"); + } + // Resized frames. + if (frame_info.frame_type != FrameType::kDCFrame) { + frame_header->frame_origin = ib.origin; + size_t ups = 1; + if (cparams.already_downsampled) ups = cparams.resampling; + + // TODO(lode): this is not correct in case of odd original image sizes in + // combination with cparams.already_downsampled. Likely these values should + // be set to respectively frame_header->default_xsize() and + // frame_header->default_ysize() instead, the original (non downsampled) + // intended decoded image dimensions. But it may be more subtle than that + // if combined with crop. This issue causes custom_size_or_origin to be + // incorrectly set to true in case of already_downsampled with odd output + // image size when no cropping is used. + frame_header->frame_size.xsize = ib.xsize() * ups; + frame_header->frame_size.ysize = ib.ysize() * ups; + if (ib.origin.x0 != 0 || ib.origin.y0 != 0 || + frame_header->frame_size.xsize != frame_header->default_xsize() || + frame_header->frame_size.ysize != frame_header->default_ysize()) { + frame_header->custom_size_or_origin = true; + } + } + // Upsampling. + frame_header->upsampling = cparams.resampling; + const std::vector& extra_channels = + frame_header->nonserialized_metadata->m.extra_channel_info; + frame_header->extra_channel_upsampling.clear(); + frame_header->extra_channel_upsampling.resize(extra_channels.size(), + cparams.ec_resampling); + frame_header->save_as_reference = frame_info.save_as_reference; + + // Set blending-related information. + if (ib.blend || frame_header->custom_size_or_origin) { + // Set blend_channel to the first alpha channel. These values are only + // encoded in case a blend mode involving alpha is used and there are more + // than one extra channels. + size_t index = 0; + if (frame_info.alpha_channel == -1) { + if (extra_channels.size() > 1) { + for (size_t i = 0; i < extra_channels.size(); i++) { + if (extra_channels[i].type == ExtraChannel::kAlpha) { + index = i; + break; + } + } + } + } else { + index = static_cast(frame_info.alpha_channel); + JXL_ASSERT(index == 0 || index < extra_channels.size()); + } + frame_header->blending_info.alpha_channel = index; + frame_header->blending_info.mode = + ib.blend ? ib.blendmode : BlendMode::kReplace; + frame_header->blending_info.source = frame_info.source; + frame_header->blending_info.clamp = frame_info.clamp; + const auto& extra_channel_info = frame_info.extra_channel_blending_info; + for (size_t i = 0; i < extra_channels.size(); i++) { + if (i < extra_channel_info.size()) { + frame_header->extra_channel_blending_info[i] = extra_channel_info[i]; + } else { + frame_header->extra_channel_blending_info[i].alpha_channel = index; + BlendMode default_blend = ib.blendmode; + if (extra_channels[i].type != ExtraChannel::kBlack && i != index) { + // K needs to be blended, spot colors and other stuff gets added + default_blend = BlendMode::kAdd; + } + frame_header->extra_channel_blending_info[i].mode = + ib.blend ? default_blend : BlendMode::kReplace; + frame_header->extra_channel_blending_info[i].source = 1; + } + } + } + + frame_header->animation_frame.duration = ib.duration; + frame_header->animation_frame.timecode = ib.timecode; + + return true; +} + +// Invisible (alpha = 0) pixels tend to be a mess in optimized PNGs. +// Since they have no visual impact whatsoever, we can replace them with +// something that compresses better and reduces artifacts near the edges. This +// does some kind of smooth stuff that seems to work. +// Replace invisible pixels with a weighted average of the pixel to the left, +// the pixel to the topright, and non-invisible neighbours. +// Produces downward-blurry smears, with in the upwards direction only a 1px +// edge duplication but not more. It would probably be better to smear in all +// directions. That requires an alpha-weighed convolution with a large enough +// kernel though, which might be overkill... +void SimplifyInvisible(Image3F* image, const ImageF& alpha, bool lossless) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + float* JXL_RESTRICT row = image->PlaneRow(c, y); + const float* JXL_RESTRICT prow = + (y > 0 ? image->PlaneRow(c, y - 1) : nullptr); + const float* JXL_RESTRICT nrow = + (y + 1 < image->ysize() ? image->PlaneRow(c, y + 1) : nullptr); + const float* JXL_RESTRICT a = alpha.Row(y); + const float* JXL_RESTRICT pa = (y > 0 ? alpha.Row(y - 1) : nullptr); + const float* JXL_RESTRICT na = + (y + 1 < image->ysize() ? alpha.Row(y + 1) : nullptr); + for (size_t x = 0; x < image->xsize(); ++x) { + if (a[x] == 0) { + if (lossless) { + row[x] = 0; + continue; + } + float d = 0.f; + row[x] = 0; + if (x > 0) { + row[x] += row[x - 1]; + d++; + if (a[x - 1] > 0.f) { + row[x] += row[x - 1]; + d++; + } + } + if (x + 1 < image->xsize()) { + if (y > 0) { + row[x] += prow[x + 1]; + d++; + } + if (a[x + 1] > 0.f) { + row[x] += 2.f * row[x + 1]; + d += 2.f; + } + if (y > 0 && pa[x + 1] > 0.f) { + row[x] += 2.f * prow[x + 1]; + d += 2.f; + } + if (y + 1 < image->ysize() && na[x + 1] > 0.f) { + row[x] += 2.f * nrow[x + 1]; + d += 2.f; + } + } + if (y > 0 && pa[x] > 0.f) { + row[x] += 2.f * prow[x]; + d += 2.f; + } + if (y + 1 < image->ysize() && na[x] > 0.f) { + row[x] += 2.f * nrow[x]; + d += 2.f; + } + if (d > 1.f) row[x] /= d; + } + } + } + } +} + +} // namespace + +class LossyFrameEncoder { + public: + LossyFrameEncoder(const CompressParams& cparams, + const FrameHeader& frame_header, + PassesEncoderState* JXL_RESTRICT enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) + : enc_state_(enc_state), cms_(cms), pool_(pool), aux_out_(aux_out) { + JXL_CHECK(InitializePassesSharedState(frame_header, &enc_state_->shared, + /*encoder=*/true)); + enc_state_->cparams = cparams; + enc_state_->passes.clear(); + } + + Status ComputeEncodingData(const ImageBundle* linear, + Image3F* JXL_RESTRICT opsin, + const JxlCmsInterface& cms, ThreadPool* pool, + ModularFrameEncoder* modular_frame_encoder, + FrameHeader* frame_header) { + PROFILER_ZONE("ComputeEncodingData uninstrumented"); + JXL_ASSERT((opsin->xsize() % kBlockDim) == 0 && + (opsin->ysize() % kBlockDim) == 0); + PassesSharedState& shared = enc_state_->shared; + + if (!enc_state_->cparams.max_error_mode) { + float x_qm_scale_steps[2] = {1.25f, 9.0f}; + shared.frame_header.x_qm_scale = 2; + for (float x_qm_scale_step : x_qm_scale_steps) { + if (enc_state_->cparams.butteraugli_distance > x_qm_scale_step) { + shared.frame_header.x_qm_scale++; + } + } + if (enc_state_->cparams.butteraugli_distance < 0.299f) { + // Favor chromacity preservation for making images appear more + // faithful to original even with extreme (5-10x) zooming. + shared.frame_header.x_qm_scale++; + } + } + + JXL_RETURN_IF_ERROR(enc_state_->heuristics->LossyFrameHeuristics( + enc_state_, modular_frame_encoder, linear, opsin, cms_, pool_, + aux_out_)); + + JXL_RETURN_IF_ERROR(InitializePassesEncoder( + *opsin, cms, pool_, enc_state_, modular_frame_encoder, aux_out_)); + + enc_state_->passes.resize(enc_state_->progressive_splitter.GetNumPasses()); + for (PassesEncoderState::PassData& pass : enc_state_->passes) { + pass.ac_tokens.resize(shared.frame_dim.num_groups); + } + + ComputeAllCoeffOrders(shared.frame_dim); + shared.num_histograms = 1; + + const auto tokenize_group_init = [&](const size_t num_threads) { + group_caches_.resize(num_threads); + return true; + }; + const auto tokenize_group = [&](const uint32_t group_index, + const size_t thread) { + // Tokenize coefficients. + const Rect rect = shared.BlockGroupRect(group_index); + for (size_t idx_pass = 0; idx_pass < enc_state_->passes.size(); + idx_pass++) { + JXL_ASSERT(enc_state_->coeffs[idx_pass]->Type() == ACType::k32); + const int32_t* JXL_RESTRICT ac_rows[3] = { + enc_state_->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32, + }; + // Ensure group cache is initialized. + group_caches_[thread].InitOnce(); + TokenizeCoefficients( + &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect, + ac_rows, shared.ac_strategy, frame_header->chroma_subsampling, + &group_caches_[thread].num_nzeroes, + &enc_state_->passes[idx_pass].ac_tokens[group_index], + enc_state_->shared.quant_dc, enc_state_->shared.raw_quant_field, + enc_state_->shared.block_ctx_map); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool_, 0, shared.frame_dim.num_groups, + tokenize_group_init, tokenize_group, + "TokenizeGroup")); + + *frame_header = shared.frame_header; + return true; + } + + Status ComputeJPEGTranscodingData(const jpeg::JPEGData& jpeg_data, + ModularFrameEncoder* modular_frame_encoder, + FrameHeader* frame_header) { + PROFILER_ZONE("ComputeJPEGTranscodingData uninstrumented"); + PassesSharedState& shared = enc_state_->shared; + + frame_header->x_qm_scale = 2; + frame_header->b_qm_scale = 2; + + FrameDimensions frame_dim = frame_header->ToFrameDimensions(); + + const size_t xsize = frame_dim.xsize_padded; + const size_t ysize = frame_dim.ysize_padded; + const size_t xsize_blocks = frame_dim.xsize_blocks; + const size_t ysize_blocks = frame_dim.ysize_blocks; + + // no-op chroma from luma + shared.cmap = ColorCorrelationMap(xsize, ysize, false); + shared.ac_strategy.FillDCT8(); + FillImage(uint8_t(0), &shared.epf_sharpness); + + enc_state_->coeffs.clear(); + enc_state_->coeffs.emplace_back(make_unique>( + kGroupDim * kGroupDim, frame_dim.num_groups)); + + // convert JPEG quantization table to a Quantizer object + float dcquantization[3]; + std::vector qe(DequantMatrices::kNum, + QuantEncoding::Library(0)); + + auto jpeg_c_map = JpegOrder(frame_header->color_transform, + jpeg_data.components.size() == 1); + + std::vector qt(192); + for (size_t c = 0; c < 3; c++) { + size_t jpeg_c = jpeg_c_map[c]; + const int32_t* quant = + jpeg_data.quant[jpeg_data.components[jpeg_c].quant_idx].values.data(); + + dcquantization[c] = 255 * 8.0f / quant[0]; + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + // JPEG XL transposes the DCT, JPEG doesn't. + qt[c * 64 + 8 * x + y] = quant[8 * y + x]; + } + } + } + DequantMatricesSetCustomDC(&shared.matrices, dcquantization); + float dcquantization_r[3] = {1.0f / dcquantization[0], + 1.0f / dcquantization[1], + 1.0f / dcquantization[2]}; + + qe[AcStrategy::Type::DCT] = QuantEncoding::RAW(qt); + DequantMatricesSetCustom(&shared.matrices, qe, modular_frame_encoder); + + // Ensure that InvGlobalScale() is 1. + shared.quantizer = Quantizer(&shared.matrices, 1, kGlobalScaleDenom); + // Recompute MulDC() and InvMulDC(). + shared.quantizer.RecomputeFromGlobalScale(); + + // Per-block dequant scaling should be 1. + FillImage(static_cast(shared.quantizer.InvGlobalScale()), + &shared.raw_quant_field); + + std::vector scaled_qtable(192); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 64; i++) { + scaled_qtable[64 * c + i] = + (1 << kCFLFixedPointPrecision) * qt[64 + i] / qt[64 * c + i]; + } + } + + auto jpeg_row = [&](size_t c, size_t y) { + return jpeg_data.components[jpeg_c_map[c]].coeffs.data() + + jpeg_data.components[jpeg_c_map[c]].width_in_blocks * + kDCTBlockSize * y; + }; + + Image3F dc = Image3F(xsize_blocks, ysize_blocks); + bool DCzero = + (shared.frame_header.color_transform == ColorTransform::kYCbCr); + // Compute chroma-from-luma for AC (doesn't seem to be useful for DC) + if (frame_header->chroma_subsampling.Is444() && + enc_state_->cparams.force_cfl_jpeg_recompression && + jpeg_data.components.size() == 3) { + for (size_t c : {0, 2}) { + ImageSB* map = (c == 0 ? &shared.cmap.ytox_map : &shared.cmap.ytob_map); + const float kScale = kDefaultColorFactor; + const int kOffset = 127; + const float kBase = + c == 0 ? shared.cmap.YtoXRatio(0) : shared.cmap.YtoBRatio(0); + const float kZeroThresh = + kScale * kZeroBiasDefault[c] * + 0.9999f; // just epsilon less for better rounding + + auto process_row = [&](const uint32_t task, const size_t thread) { + size_t ty = task; + int8_t* JXL_RESTRICT row_out = map->Row(ty); + for (size_t tx = 0; tx < map->xsize(); ++tx) { + const size_t y0 = ty * kColorTileDimInBlocks; + const size_t x0 = tx * kColorTileDimInBlocks; + const size_t y1 = std::min(frame_dim.ysize_blocks, + (ty + 1) * kColorTileDimInBlocks); + const size_t x1 = std::min(frame_dim.xsize_blocks, + (tx + 1) * kColorTileDimInBlocks); + int32_t d_num_zeros[257] = {0}; + // TODO(veluca): this needs SIMD + fixed point adaptation, and/or + // conversion to the new CfL algorithm. + for (size_t y = y0; y < y1; ++y) { + const int16_t* JXL_RESTRICT row_m = jpeg_row(1, y); + const int16_t* JXL_RESTRICT row_s = jpeg_row(c, y); + for (size_t x = x0; x < x1; ++x) { + for (size_t coeffpos = 1; coeffpos < kDCTBlockSize; + coeffpos++) { + const float scaled_m = + row_m[x * kDCTBlockSize + coeffpos] * + scaled_qtable[64 * c + coeffpos] * + (1.0f / (1 << kCFLFixedPointPrecision)); + const float scaled_s = + kScale * row_s[x * kDCTBlockSize + coeffpos] + + (kOffset - kBase * kScale) * scaled_m; + if (std::abs(scaled_m) > 1e-8f) { + float from, to; + if (scaled_m > 0) { + from = (scaled_s - kZeroThresh) / scaled_m; + to = (scaled_s + kZeroThresh) / scaled_m; + } else { + from = (scaled_s + kZeroThresh) / scaled_m; + to = (scaled_s - kZeroThresh) / scaled_m; + } + if (from < 0.0f) { + from = 0.0f; + } + if (to > 255.0f) { + to = 255.0f; + } + // Instead of clamping the both values + // we just check that range is sane. + if (from <= to) { + d_num_zeros[static_cast(std::ceil(from))]++; + d_num_zeros[static_cast(std::floor(to + 1))]--; + } + } + } + } + } + int best = 0; + int32_t best_sum = 0; + FindIndexOfSumMaximum(d_num_zeros, 256, &best, &best_sum); + int32_t offset_sum = 0; + for (int i = 0; i < 256; ++i) { + if (i <= kOffset) { + offset_sum += d_num_zeros[i]; + } + } + row_out[tx] = 0; + if (best_sum > offset_sum + 1) { + row_out[tx] = best - kOffset; + } + } + }; + + JXL_RETURN_IF_ERROR(RunOnPool(pool_, 0, map->ysize(), + ThreadPool::NoInit, process_row, + "FindCorrelation")); + } + } + if (!frame_header->chroma_subsampling.Is444()) { + ZeroFillImage(&dc); + enc_state_->coeffs[0]->ZeroFill(); + } + // JPEG DC is from -1024 to 1023. + std::vector dc_counts[3] = {}; + dc_counts[0].resize(2048); + dc_counts[1].resize(2048); + dc_counts[2].resize(2048); + size_t total_dc[3] = {}; + for (size_t c : {1, 0, 2}) { + if (jpeg_data.components.size() == 1 && c != 1) { + enc_state_->coeffs[0]->ZeroFillPlane(c); + ZeroFillImage(&dc.Plane(c)); + // Ensure no division by 0. + dc_counts[c][1024] = 1; + total_dc[c] = 1; + continue; + } + size_t hshift = frame_header->chroma_subsampling.HShift(c); + size_t vshift = frame_header->chroma_subsampling.VShift(c); + ImageSB& map = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map); + for (size_t group_index = 0; group_index < frame_dim.num_groups; + group_index++) { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + size_t offset = 0; + int32_t* JXL_RESTRICT ac = + enc_state_->coeffs[0]->PlaneRow(c, group_index, 0).ptr32; + for (size_t by = gy * kGroupDimInBlocks; + by < ysize_blocks && by < (gy + 1) * kGroupDimInBlocks; ++by) { + if ((by >> vshift) << vshift != by) continue; + const int16_t* JXL_RESTRICT inputjpeg = jpeg_row(c, by >> vshift); + const int16_t* JXL_RESTRICT inputjpegY = jpeg_row(1, by); + float* JXL_RESTRICT fdc = dc.PlaneRow(c, by >> vshift); + const int8_t* JXL_RESTRICT cm = + map.ConstRow(by / kColorTileDimInBlocks); + for (size_t bx = gx * kGroupDimInBlocks; + bx < xsize_blocks && bx < (gx + 1) * kGroupDimInBlocks; ++bx) { + if ((bx >> hshift) << hshift != bx) continue; + size_t base = (bx >> hshift) * kDCTBlockSize; + int idc; + if (DCzero) { + idc = inputjpeg[base]; + } else { + idc = inputjpeg[base] + 1024 / qt[c * 64]; + } + dc_counts[c][std::min(static_cast(idc + 1024), + uint32_t(2047))]++; + total_dc[c]++; + fdc[bx >> hshift] = idc * dcquantization_r[c]; + if (c == 1 || !enc_state_->cparams.force_cfl_jpeg_recompression || + !frame_header->chroma_subsampling.Is444()) { + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + ac[offset + y * 8 + x] = inputjpeg[base + x * 8 + y]; + } + } + } else { + const int32_t scale = + shared.cmap.RatioJPEG(cm[bx / kColorTileDimInBlocks]); + + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + int Y = inputjpegY[kDCTBlockSize * bx + x * 8 + y]; + int QChroma = inputjpeg[kDCTBlockSize * bx + x * 8 + y]; + // Fixed-point multiply of CfL scale with quant table ratio + // first, and Y value second. + int coeff_scale = (scale * scaled_qtable[64 * c + y * 8 + x] + + (1 << (kCFLFixedPointPrecision - 1))) >> + kCFLFixedPointPrecision; + int cfl_factor = (Y * coeff_scale + + (1 << (kCFLFixedPointPrecision - 1))) >> + kCFLFixedPointPrecision; + int QCR = QChroma - cfl_factor; + ac[offset + y * 8 + x] = QCR; + } + } + } + offset += 64; + } + } + } + } + + auto& dct = enc_state_->shared.block_ctx_map.dc_thresholds; + auto& num_dc_ctxs = enc_state_->shared.block_ctx_map.num_dc_ctxs; + num_dc_ctxs = 1; + for (size_t i = 0; i < 3; i++) { + dct[i].clear(); + int num_thresholds = (CeilLog2Nonzero(total_dc[i]) - 12) / 2; + // up to 3 buckets per channel: + // dark/medium/bright, yellow/unsat/blue, green/unsat/red + num_thresholds = std::min(std::max(num_thresholds, 0), 2); + size_t cumsum = 0; + size_t cut = total_dc[i] / (num_thresholds + 1); + for (int j = 0; j < 2048; j++) { + cumsum += dc_counts[i][j]; + if (cumsum > cut) { + dct[i].push_back(j - 1025); + cut = total_dc[i] * (dct[i].size() + 1) / (num_thresholds + 1); + } + } + num_dc_ctxs *= dct[i].size() + 1; + } + + auto& ctx_map = enc_state_->shared.block_ctx_map.ctx_map; + ctx_map.clear(); + ctx_map.resize(3 * kNumOrders * num_dc_ctxs, 0); + + int lbuckets = (dct[1].size() + 1); + for (size_t i = 0; i < num_dc_ctxs; i++) { + // up to 9 contexts for luma + ctx_map[i] = i / lbuckets; + // up to 3 contexts for chroma + ctx_map[kNumOrders * num_dc_ctxs + i] = + ctx_map[2 * kNumOrders * num_dc_ctxs + i] = + num_dc_ctxs / lbuckets + (i % lbuckets); + } + enc_state_->shared.block_ctx_map.num_ctxs = + *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; + + enc_state_->histogram_idx.resize(shared.frame_dim.num_groups); + + // disable DC frame for now + shared.frame_header.UpdateFlag(false, FrameHeader::kUseDcFrame); + auto compute_dc_coeffs = [&](const uint32_t group_index, + size_t /* thread */) { + modular_frame_encoder->AddVarDCTDC(dc, group_index, /*nl_dc=*/false, + enc_state_, /*jpeg_transcode=*/true); + modular_frame_encoder->AddACMetadata(group_index, /*jpeg_transcode=*/true, + enc_state_); + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool_, 0, shared.frame_dim.num_dc_groups, + ThreadPool::NoInit, compute_dc_coeffs, + "Compute DC coeffs")); + + // Must happen before WriteFrameHeader! + shared.frame_header.UpdateFlag(true, FrameHeader::kSkipAdaptiveDCSmoothing); + + enc_state_->passes.resize(enc_state_->progressive_splitter.GetNumPasses()); + for (PassesEncoderState::PassData& pass : enc_state_->passes) { + pass.ac_tokens.resize(shared.frame_dim.num_groups); + } + + JXL_CHECK(enc_state_->passes.size() == + 1); // skipping coeff splitting so need to have only one pass + + ComputeAllCoeffOrders(frame_dim); + shared.num_histograms = 1; + + const auto tokenize_group_init = [&](const size_t num_threads) { + group_caches_.resize(num_threads); + return true; + }; + const auto tokenize_group = [&](const uint32_t group_index, + const size_t thread) { + // Tokenize coefficients. + const Rect rect = shared.BlockGroupRect(group_index); + for (size_t idx_pass = 0; idx_pass < enc_state_->passes.size(); + idx_pass++) { + JXL_ASSERT(enc_state_->coeffs[idx_pass]->Type() == ACType::k32); + const int32_t* JXL_RESTRICT ac_rows[3] = { + enc_state_->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32, + }; + // Ensure group cache is initialized. + group_caches_[thread].InitOnce(); + TokenizeCoefficients( + &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect, + ac_rows, shared.ac_strategy, frame_header->chroma_subsampling, + &group_caches_[thread].num_nzeroes, + &enc_state_->passes[idx_pass].ac_tokens[group_index], + enc_state_->shared.quant_dc, enc_state_->shared.raw_quant_field, + enc_state_->shared.block_ctx_map); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool_, 0, shared.frame_dim.num_groups, + tokenize_group_init, tokenize_group, + "TokenizeGroup")); + *frame_header = shared.frame_header; + doing_jpeg_recompression = true; + return true; + } + + Status EncodeGlobalDCInfo(const FrameHeader& frame_header, + BitWriter* writer) const { + // Encode quantizer DC and global scale. + JXL_RETURN_IF_ERROR( + enc_state_->shared.quantizer.Encode(writer, kLayerQuant, aux_out_)); + EncodeBlockCtxMap(enc_state_->shared.block_ctx_map, writer, aux_out_); + ColorCorrelationMapEncodeDC(&enc_state_->shared.cmap, writer, kLayerDC, + aux_out_); + return true; + } + + Status EncodeGlobalACInfo(BitWriter* writer, + ModularFrameEncoder* modular_frame_encoder) { + JXL_RETURN_IF_ERROR(DequantMatricesEncode(&enc_state_->shared.matrices, + writer, kLayerQuant, aux_out_, + modular_frame_encoder)); + if (enc_state_->cparams.speed_tier <= SpeedTier::kTortoise) { + if (!doing_jpeg_recompression) ClusterGroups(enc_state_); + } + size_t num_histo_bits = + CeilLog2Nonzero(enc_state_->shared.frame_dim.num_groups); + if (num_histo_bits != 0) { + BitWriter::Allotment allotment(writer, num_histo_bits); + writer->Write(num_histo_bits, enc_state_->shared.num_histograms - 1); + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out_); + } + + for (size_t i = 0; i < enc_state_->progressive_splitter.GetNumPasses(); + i++) { + // Encode coefficient orders. + size_t order_bits = 0; + JXL_RETURN_IF_ERROR(U32Coder::CanEncode( + kOrderEnc, enc_state_->used_orders[i], &order_bits)); + BitWriter::Allotment allotment(writer, order_bits); + JXL_CHECK(U32Coder::Write(kOrderEnc, enc_state_->used_orders[i], writer)); + ReclaimAndCharge(writer, &allotment, kLayerOrder, aux_out_); + EncodeCoeffOrders( + enc_state_->used_orders[i], + &enc_state_->shared + .coeff_orders[i * enc_state_->shared.coeff_order_size], + writer, kLayerOrder, aux_out_); + + // Encode histograms. + HistogramParams hist_params( + enc_state_->cparams.speed_tier, + enc_state_->shared.block_ctx_map.NumACContexts()); + if (enc_state_->cparams.speed_tier > SpeedTier::kTortoise) { + hist_params.lz77_method = HistogramParams::LZ77Method::kNone; + } + if (enc_state_->cparams.decoding_speed_tier >= 1) { + hist_params.max_histograms = 6; + } + BuildAndEncodeHistograms( + hist_params, + enc_state_->shared.num_histograms * + enc_state_->shared.block_ctx_map.NumACContexts(), + enc_state_->passes[i].ac_tokens, &enc_state_->passes[i].codes, + &enc_state_->passes[i].context_map, writer, kLayerAC, aux_out_); + } + + return true; + } + + Status EncodeACGroup(size_t pass, size_t group_index, BitWriter* group_code, + AuxOut* local_aux_out) { + return EncodeGroupTokenizedCoefficients( + group_index, pass, enc_state_->histogram_idx[group_index], *enc_state_, + group_code, local_aux_out); + } + + PassesEncoderState* State() { return enc_state_; } + + private: + void ComputeAllCoeffOrders(const FrameDimensions& frame_dim) { + PROFILER_FUNC; + // No coefficient reordering in Falcon or faster. + auto used_orders_info = ComputeUsedOrders( + enc_state_->cparams.speed_tier, enc_state_->shared.ac_strategy, + Rect(enc_state_->shared.raw_quant_field)); + enc_state_->used_orders.clear(); + enc_state_->used_orders.resize( + enc_state_->progressive_splitter.GetNumPasses(), + used_orders_info.second); + for (size_t i = 0; i < enc_state_->progressive_splitter.GetNumPasses(); + i++) { + ComputeCoeffOrder( + enc_state_->cparams.speed_tier, *enc_state_->coeffs[i], + enc_state_->shared.ac_strategy, frame_dim, enc_state_->used_orders[i], + used_orders_info.first, + &enc_state_->shared + .coeff_orders[i * enc_state_->shared.coeff_order_size]); + } + } + + template + static inline void FindIndexOfSumMaximum(const V* array, const size_t len, + R* idx, V* sum) { + JXL_ASSERT(len > 0); + V maxval = 0; + V val = 0; + R maxidx = 0; + for (size_t i = 0; i < len; ++i) { + val += array[i]; + if (val > maxval) { + maxval = val; + maxidx = i; + } + } + *idx = maxidx; + *sum = maxval; + } + + PassesEncoderState* JXL_RESTRICT enc_state_; + JxlCmsInterface cms_; + ThreadPool* pool_; + AuxOut* aux_out_; + std::vector group_caches_; + bool doing_jpeg_recompression = false; +}; + +Status ParamsPostInit(CompressParams* p) { + if (!p->manual_noise.empty() && + p->manual_noise.size() != NoiseParams::kNumNoisePoints) { + return JXL_FAILURE("Invalid number of noise lut entries"); + } + if (!p->manual_xyb_factors.empty() && p->manual_xyb_factors.size() != 3) { + return JXL_FAILURE("Invalid number of XYB quantization factors"); + } + if (!p->modular_mode && p->butteraugli_distance == 0.0) { + p->butteraugli_distance = kMinButteraugliDistance; + } + if (p->original_butteraugli_distance == -1.0) { + p->original_butteraugli_distance = p->butteraugli_distance; + } + if (p->resampling <= 0) { + p->resampling = 1; + // For very low bit rates, using 2x2 resampling gives better results on + // most photographic images, with an adjusted butteraugli score chosen to + // give roughly the same amount of bits per pixel. + if (!p->already_downsampled && p->butteraugli_distance >= 20) { + p->resampling = 2; + p->butteraugli_distance = 6 + ((p->butteraugli_distance - 20) * 0.25); + } + } + if (p->ec_resampling <= 0) { + p->ec_resampling = p->resampling; + } + return true; +} + +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + const ImageBundle& ib, PassesEncoderState* passes_enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + BitWriter* writer, AuxOut* aux_out) { + CompressParams cparams = cparams_orig; + if (cparams_orig.target_bitrate > 0.0f && + frame_info.frame_type == FrameType::kRegularFrame) { + cparams.target_bitrate = 0.0f; + const float target_bitrate = cparams_orig.target_bitrate; + float bitrate = 0.0f; + float prev_bitrate = 0.0f; + float rescale = 1.0f; + size_t prev_bits = 0; + float error = 0.0f; + float best_error = 100.0f; + float best_rescale = 1.0f; + for (size_t i = 0; i < 10; ++i) { + std::unique_ptr state = + jxl::make_unique(); + BitWriter bw; + JXL_CHECK(EncodeFrame(cparams, frame_info, metadata, ib, state.get(), cms, + pool, &bw, nullptr)); + bitrate = bw.BitsWritten() * 1.0 / (ib.xsize() * ib.ysize()); + error = target_bitrate / bitrate - 1.0f; + if (std::abs(error) < std::abs(best_error)) { + best_error = error; + best_rescale = cparams.quant_ac_rescale; + } + if (bw.BitsWritten() == prev_bits || std::abs(error) < 0.0005f) { + break; + } + float lambda = 1.0f; + if (i > 0) { + lambda = (((bitrate / prev_bitrate) - 1.0f) / (rescale - 1.0f)); + } + rescale = (1.0f + ((target_bitrate / bitrate) - 1.0f) / lambda); + if (rescale < 0.0f) { + break; + } + cparams.quant_ac_rescale *= rescale; + prev_bitrate = bitrate; + prev_bits = bw.BitsWritten(); + } + if (aux_out) { + aux_out->max_quant_rescale = best_rescale; + aux_out->min_quant_rescale = best_rescale; + aux_out->min_bitrate_error = best_error; + aux_out->max_bitrate_error = best_error; + } + cparams.quant_ac_rescale = best_rescale; + } + ib.VerifyMetadata(); + + passes_enc_state->special_frames.clear(); + + JXL_RETURN_IF_ERROR(ParamsPostInit(&cparams)); + + if (cparams.progressive_dc < 0) { + if (cparams.progressive_dc != -1) { + return JXL_FAILURE("Invalid progressive DC setting value (%d)", + cparams.progressive_dc); + } + cparams.progressive_dc = 0; + // Enable progressive_dc for lower qualities, except for fast speeds where + // the modular encoder uses fixed tree. + if (cparams.speed_tier <= SpeedTier::kCheetah && + cparams.butteraugli_distance >= + kMinButteraugliDistanceForProgressiveDc) { + cparams.progressive_dc = 1; + } + } + if (cparams.ec_resampling < cparams.resampling) { + cparams.ec_resampling = cparams.resampling; + } + if (cparams.resampling > 1 || frame_info.is_preview) { + cparams.progressive_dc = 0; + } + + if (frame_info.dc_level + cparams.progressive_dc > 4) { + return JXL_FAILURE("Too many levels of progressive DC"); + } + + if (cparams.butteraugli_distance != 0 && + cparams.butteraugli_distance < kMinButteraugliDistance) { + return JXL_FAILURE("Butteraugli distance is too low (%f)", + cparams.butteraugli_distance); + } + + if (ib.IsJPEG()) { + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.modular_mode = false; + } + + if (ib.xsize() == 0 || ib.ysize() == 0) return JXL_FAILURE("Empty image"); + + // Assert that this metadata is correctly set up for the compression params, + // this should have been done by enc_file.cc + JXL_ASSERT(metadata->m.xyb_encoded == + (cparams.color_transform == ColorTransform::kXYB)); + std::unique_ptr frame_header = + jxl::make_unique(metadata); + JXL_RETURN_IF_ERROR(MakeFrameHeader(cparams, + passes_enc_state->progressive_splitter, + frame_info, ib, frame_header.get())); + // Check that if the codestream header says xyb_encoded, the color_transform + // matches the requirement. This is checked from the cparams here, even though + // optimally we'd be able to check this against what has actually been written + // in the main codestream header, but since ib is a const object and the data + // written to the main codestream header is (in modified form) in ib, the + // encoder cannot indicate this fact in the ib's metadata. + if (cparams_orig.color_transform == ColorTransform::kXYB) { + if (frame_header->color_transform != ColorTransform::kXYB) { + return JXL_FAILURE( + "The color transform of frames must be xyb if the codestream is xyb " + "encoded"); + } + } else { + if (frame_header->color_transform == ColorTransform::kXYB) { + return JXL_FAILURE( + "The color transform of frames cannot be xyb if the codestream is " + "not xyb encoded"); + } + } + + FrameDimensions frame_dim = frame_header->ToFrameDimensions(); + + const size_t num_groups = frame_dim.num_groups; + + Image3F opsin; + const ColorEncoding& c_linear = ColorEncoding::LinearSRGB(ib.IsGray()); + std::unique_ptr metadata_linear = + jxl::make_unique(); + metadata_linear->xyb_encoded = + (cparams.color_transform == ColorTransform::kXYB); + metadata_linear->color_encoding = c_linear; + ImageBundle linear_storage(metadata_linear.get()); + + std::vector aux_outs; + // LossyFrameEncoder stores a reference to a std::function + // so we need to keep the std::function being referenced + // alive while lossy_frame_encoder is used. We could make resize_aux_outs a + // lambda type by making LossyFrameEncoder a template instead, but this is + // simpler. + const std::function resize_aux_outs = + [&aux_outs, aux_out](const size_t num_threads) -> Status { + if (aux_out != nullptr) { + size_t old_size = aux_outs.size(); + for (size_t i = num_threads; i < old_size; i++) { + aux_out->Assimilate(aux_outs[i]); + } + aux_outs.resize(num_threads); + // Each thread needs these INPUTS. Don't copy the entire AuxOut + // because it may contain stats which would be Assimilated multiple + // times below. + for (size_t i = old_size; i < aux_outs.size(); i++) { + aux_outs[i].dump_image = aux_out->dump_image; + aux_outs[i].debug_prefix = aux_out->debug_prefix; + } + } + return true; + }; + + LossyFrameEncoder lossy_frame_encoder(cparams, *frame_header, + passes_enc_state, cms, pool, aux_out); + std::unique_ptr modular_frame_encoder = + jxl::make_unique(*frame_header, cparams); + + const std::vector* extra_channels = &ib.extra_channels(); + std::vector extra_channels_storage; + // Clear patches + passes_enc_state->shared.image_features.patches = PatchDictionary(); + passes_enc_state->shared.image_features.patches.SetPassesSharedState( + &passes_enc_state->shared); + + if (ib.IsJPEG()) { + JXL_RETURN_IF_ERROR(lossy_frame_encoder.ComputeJPEGTranscodingData( + *ib.jpeg_data, modular_frame_encoder.get(), frame_header.get())); + } else if (!lossy_frame_encoder.State()->heuristics->HandlesColorConversion( + cparams, ib) || + frame_header->encoding != FrameEncoding::kVarDCT) { + // Allocating a large enough image avoids a copy when padding. + opsin = + Image3F(RoundUpToBlockDim(ib.xsize()), RoundUpToBlockDim(ib.ysize())); + opsin.ShrinkTo(ib.xsize(), ib.ysize()); + + const bool want_linear = frame_header->encoding == FrameEncoding::kVarDCT && + cparams.speed_tier <= SpeedTier::kKitten; + const ImageBundle* JXL_RESTRICT ib_or_linear = &ib; + + if (frame_header->color_transform == ColorTransform::kXYB && + frame_info.ib_needs_color_transform) { + // linear_storage would only be used by the Butteraugli loop (passing + // linear sRGB avoids a color conversion there). Otherwise, don't + // fill it to reduce memory usage. + ib_or_linear = + ToXYB(ib, pool, &opsin, cms, want_linear ? &linear_storage : nullptr); + } else { // RGB or YCbCr: don't do anything (forward YCbCr is not + // implemented, this is only used when the input is already in + // YCbCr) + // If encoding a special DC or reference frame, don't do anything: + // input is already in XYB. + CopyImageTo(ib.color(), &opsin); + } + bool lossless = cparams.IsLossless(); + if (ib.HasAlpha() && !ib.AlphaIsPremultiplied() && + frame_header->frame_type == FrameType::kRegularFrame && + !ApplyOverride(cparams.keep_invisible, lossless) && + cparams.ec_resampling == cparams.resampling) { + // simplify invisible pixels + SimplifyInvisible(&opsin, ib.alpha(), lossless); + if (want_linear) { + SimplifyInvisible(const_cast(&ib_or_linear->color()), + ib.alpha(), lossless); + } + } + if (aux_out != nullptr) { + JXL_RETURN_IF_ERROR( + aux_out->InspectImage3F("enc_frame:OpsinDynamicsImage", opsin)); + } + if (frame_header->encoding == FrameEncoding::kVarDCT) { + PadImageToBlockMultipleInPlace(&opsin); + JXL_RETURN_IF_ERROR(lossy_frame_encoder.ComputeEncodingData( + ib_or_linear, &opsin, cms, pool, modular_frame_encoder.get(), + frame_header.get())); + } else if (frame_header->upsampling != 1 && !cparams.already_downsampled) { + // In VarDCT mode, LossyFrameHeuristics takes care of running downsampling + // after noise, if necessary. + DownsampleImage(&opsin, frame_header->upsampling); + } + } else { + JXL_RETURN_IF_ERROR(lossy_frame_encoder.ComputeEncodingData( + &ib, &opsin, cms, pool, modular_frame_encoder.get(), + frame_header.get())); + } + if (cparams.ec_resampling != 1 && !cparams.already_downsampled) { + extra_channels = &extra_channels_storage; + for (size_t i = 0; i < ib.extra_channels().size(); i++) { + extra_channels_storage.emplace_back(CopyImage(ib.extra_channels()[i])); + DownsampleImage(&extra_channels_storage.back(), cparams.ec_resampling); + } + } + // needs to happen *AFTER* VarDCT-ComputeEncodingData. + JXL_RETURN_IF_ERROR(modular_frame_encoder->ComputeEncodingData( + *frame_header, *ib.metadata(), &opsin, *extra_channels, + lossy_frame_encoder.State(), cms, pool, aux_out, + /* do_color=*/frame_header->encoding == FrameEncoding::kModular)); + + writer->AppendByteAligned(lossy_frame_encoder.State()->special_frames); + frame_header->UpdateFlag( + lossy_frame_encoder.State()->shared.image_features.patches.HasAny(), + FrameHeader::kPatches); + frame_header->UpdateFlag( + lossy_frame_encoder.State()->shared.image_features.splines.HasAny(), + FrameHeader::kSplines); + JXL_RETURN_IF_ERROR(WriteFrameHeader(*frame_header, writer, aux_out)); + + const size_t num_passes = + passes_enc_state->progressive_splitter.GetNumPasses(); + + // DC global info + DC groups + AC global info + AC groups * + // num_passes. + const bool has_ac_global = true; + std::vector group_codes(NumTocEntries(frame_dim.num_groups, + frame_dim.num_dc_groups, + num_passes, has_ac_global)); + const size_t global_ac_index = frame_dim.num_dc_groups + 1; + const bool is_small_image = frame_dim.num_groups == 1 && num_passes == 1; + const auto get_output = [&](const size_t index) { + return &group_codes[is_small_image ? 0 : index]; + }; + auto ac_group_code = [&](size_t pass, size_t group) { + return get_output(AcGroupIndex(pass, group, frame_dim.num_groups, + frame_dim.num_dc_groups, has_ac_global)); + }; + + if (frame_header->flags & FrameHeader::kPatches) { + PatchDictionaryEncoder::Encode( + lossy_frame_encoder.State()->shared.image_features.patches, + get_output(0), kLayerDictionary, aux_out); + } + + if (frame_header->flags & FrameHeader::kSplines) { + EncodeSplines(lossy_frame_encoder.State()->shared.image_features.splines, + get_output(0), kLayerSplines, HistogramParams(), aux_out); + } + + if (cparams.photon_noise_iso > 0) { + lossy_frame_encoder.State()->shared.image_features.noise_params = + SimulatePhotonNoise(ib.xsize(), ib.ysize(), cparams.photon_noise_iso); + } + if (cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) { + for (size_t i = 0; i < NoiseParams::kNumNoisePoints; i++) { + lossy_frame_encoder.State()->shared.image_features.noise_params.lut[i] = + cparams.manual_noise[i]; + } + } + if (frame_header->flags & FrameHeader::kNoise) { + EncodeNoise(lossy_frame_encoder.State()->shared.image_features.noise_params, + get_output(0), kLayerNoise, aux_out); + } + + JXL_RETURN_IF_ERROR( + DequantMatricesEncodeDC(&lossy_frame_encoder.State()->shared.matrices, + get_output(0), kLayerQuant, aux_out)); + if (frame_header->encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + lossy_frame_encoder.EncodeGlobalDCInfo(*frame_header, get_output(0))); + } + JXL_RETURN_IF_ERROR( + modular_frame_encoder->EncodeGlobalInfo(get_output(0), aux_out)); + JXL_RETURN_IF_ERROR(modular_frame_encoder->EncodeStream( + get_output(0), aux_out, kLayerModularGlobal, ModularStreamId::Global())); + + const auto process_dc_group = [&](const uint32_t group_index, + const size_t thread) { + AuxOut* my_aux_out = aux_out ? &aux_outs[thread] : nullptr; + BitWriter* output = get_output(group_index + 1); + if (frame_header->encoding == FrameEncoding::kVarDCT && + !(frame_header->flags & FrameHeader::kUseDcFrame)) { + BitWriter::Allotment allotment(output, 2); + output->Write(2, modular_frame_encoder->extra_dc_precision[group_index]); + ReclaimAndCharge(output, &allotment, kLayerDC, my_aux_out); + JXL_CHECK(modular_frame_encoder->EncodeStream( + output, my_aux_out, kLayerDC, + ModularStreamId::VarDCTDC(group_index))); + } + JXL_CHECK(modular_frame_encoder->EncodeStream( + output, my_aux_out, kLayerModularDcGroup, + ModularStreamId::ModularDC(group_index))); + if (frame_header->encoding == FrameEncoding::kVarDCT) { + const Rect& rect = + lossy_frame_encoder.State()->shared.DCGroupRect(group_index); + size_t nb_bits = CeilLog2Nonzero(rect.xsize() * rect.ysize()); + if (nb_bits != 0) { + BitWriter::Allotment allotment(output, nb_bits); + output->Write(nb_bits, + modular_frame_encoder->ac_metadata_size[group_index] - 1); + ReclaimAndCharge(output, &allotment, kLayerControlFields, my_aux_out); + } + JXL_CHECK(modular_frame_encoder->EncodeStream( + output, my_aux_out, kLayerControlFields, + ModularStreamId::ACMetadata(group_index))); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_dc_groups, + resize_aux_outs, process_dc_group, + "EncodeDCGroup")); + + if (frame_header->encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(lossy_frame_encoder.EncodeGlobalACInfo( + get_output(global_ac_index), modular_frame_encoder.get())); + } + + std::atomic num_errors{0}; + const auto process_group = [&](const uint32_t group_index, + const size_t thread) { + AuxOut* my_aux_out = aux_out ? &aux_outs[thread] : nullptr; + + for (size_t i = 0; i < num_passes; i++) { + if (frame_header->encoding == FrameEncoding::kVarDCT) { + if (!lossy_frame_encoder.EncodeACGroup( + i, group_index, ac_group_code(i, group_index), my_aux_out)) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + } + // Write all modular encoded data (color?, alpha, depth, extra channels) + if (!modular_frame_encoder->EncodeStream( + ac_group_code(i, group_index), my_aux_out, kLayerModularAcGroup, + ModularStreamId::ModularAC(group_index, i))) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, resize_aux_outs, + process_group, "EncodeGroupCoefficients")); + + // Resizing aux_outs to 0 also Assimilates the array. + static_cast(resize_aux_outs(0)); + JXL_RETURN_IF_ERROR(num_errors.load(std::memory_order_relaxed) == 0); + + for (BitWriter& bw : group_codes) { + BitWriter::Allotment allotment(&bw, 8); + bw.ZeroPadToByte(); // end of group. + ReclaimAndCharge(&bw, &allotment, kLayerAC, aux_out); + } + + std::vector* permutation_ptr = nullptr; + std::vector permutation; + if (cparams.centerfirst && !(num_passes == 1 && num_groups == 1)) { + permutation_ptr = &permutation; + // Don't permute global DC/AC or DC. + permutation.resize(global_ac_index + 1); + std::iota(permutation.begin(), permutation.end(), 0); + std::vector ac_group_order(num_groups); + std::iota(ac_group_order.begin(), ac_group_order.end(), 0); + size_t group_dim = frame_dim.group_dim; + + // The center of the image is either given by parameters or chosen + // to be the middle of the image by default if center_x, center_y resp. + // are not provided. + + int64_t imag_cx; + if (cparams.center_x != static_cast(-1)) { + JXL_RETURN_IF_ERROR(cparams.center_x < ib.xsize()); + imag_cx = cparams.center_x; + } else { + imag_cx = ib.xsize() / 2; + } + + int64_t imag_cy; + if (cparams.center_y != static_cast(-1)) { + JXL_RETURN_IF_ERROR(cparams.center_y < ib.ysize()); + imag_cy = cparams.center_y; + } else { + imag_cy = ib.ysize() / 2; + } + + // The center of the group containing the center of the image. + int64_t cx = (imag_cx / group_dim) * group_dim + group_dim / 2; + int64_t cy = (imag_cy / group_dim) * group_dim + group_dim / 2; + // This identifies in what area of the central group the center of the image + // lies in. + double direction = -std::atan2(imag_cy - cy, imag_cx - cx); + // This identifies the side of the central group the center of the image + // lies closest to. This can take values 0, 1, 2, 3 corresponding to left, + // bottom, right, top. + int64_t side = std::fmod((direction + 5 * kPi / 4), 2 * kPi) * 2 / kPi; + auto get_distance_from_center = [&](size_t gid) { + Rect r = passes_enc_state->shared.GroupRect(gid); + int64_t gcx = r.x0() + group_dim / 2; + int64_t gcy = r.y0() + group_dim / 2; + int64_t dx = gcx - cx; + int64_t dy = gcy - cy; + // The angle is determined by taking atan2 and adding an appropriate + // starting point depending on the side we want to start on. + double angle = std::remainder( + std::atan2(dy, dx) + kPi / 4 + side * (kPi / 2), 2 * kPi); + // Concentric squares in clockwise order. + return std::make_pair(std::max(std::abs(dx), std::abs(dy)), angle); + }; + std::sort(ac_group_order.begin(), ac_group_order.end(), + [&](coeff_order_t a, coeff_order_t b) { + return get_distance_from_center(a) < + get_distance_from_center(b); + }); + std::vector inv_ac_group_order(ac_group_order.size(), 0); + for (size_t i = 0; i < ac_group_order.size(); i++) { + inv_ac_group_order[ac_group_order[i]] = i; + } + for (size_t i = 0; i < num_passes; i++) { + size_t pass_start = permutation.size(); + for (coeff_order_t v : inv_ac_group_order) { + permutation.push_back(pass_start + v); + } + } + std::vector new_group_codes(group_codes.size()); + for (size_t i = 0; i < permutation.size(); i++) { + new_group_codes[permutation[i]] = std::move(group_codes[i]); + } + group_codes = std::move(new_group_codes); + } + + JXL_RETURN_IF_ERROR( + WriteGroupOffsets(group_codes, permutation_ptr, writer, aux_out)); + writer->AppendByteAligned(group_codes); + + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_frame.h b/media/libjxl/src/lib/jxl/enc_frame.h new file mode 100644 index 000000000..c046014f8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_frame.h @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_FRAME_H_ +#define LIB_JXL_ENC_FRAME_H_ + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Information needed for encoding a frame that is not contained elsewhere and +// does not belong to `cparams`. +// TODO(lode): if possible, it might be better to replace FrameInfo and several +// fields from ImageBundle (such as frame name and duration) by direct usage of +// jxl::FrameHeader itself. +struct FrameInfo { + // TODO(veluca): consider adding more parameters, such as custom patches. + bool save_before_color_transform = false; + // Whether or not the input image bundle is already in the codestream + // colorspace (as deduced by cparams). + // TODO(veluca): this is a hack - ImageBundle doesn't have a simple way to say + // "this is already in XYB". + bool ib_needs_color_transform = true; + FrameType frame_type = FrameType::kRegularFrame; + size_t dc_level = 0; + // Only used for kRegularFrame. + bool is_last = true; + bool is_preview = false; + // Information for storing this frame for future use (only for non-DC frames). + size_t save_as_reference = 0; + // The source frame for blending of a next frame, matching the + // save_as_reference value of a previous frame. Animated frames can use + // save_as_reference values 1, 2 and 3, while composite still frames can use + // save_as_reference values 0, 1, 2 and 3. The current C++ encoder + // implementation is assuming and using 1 for all frames of animations, so + // using that as the default value here. + // Corresponds to BlendingInfo::source from the FrameHeader. + size_t source = 1; + // Corresponds to BlendingInfo::clamp from the FrameHeader. + size_t clamp = 1; + // Corresponds to BlendingInfo::alpha_channel from the FrameHeader, or set to + // -1 to automatically choose it as the index of the first extra channel of + // type alpha. + int alpha_channel = -1; + + // If non-empty, uses this blending info for the extra channels, otherwise + // automatically chooses it. The encoder API will fill this vector with the + // extra channel info and allows more options. The non-API cjxl leaves it + // empty and relies on the default behavior. + std::vector extra_channel_blending_info; +}; + +// Checks and adjusts CompressParams when they are all initialized. +Status ParamsPostInit(CompressParams* p); + +// Encodes a single frame (including its header) into a byte stream. Groups may +// be processed in parallel by `pool`. metadata is the ImageMetadata encoded in +// the codestream, and must be used for the FrameHeaders, do not use +// ib.metadata. +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + const ImageBundle& ib, PassesEncoderState* passes_enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + BitWriter* writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_FRAME_H_ diff --git a/media/libjxl/src/lib/jxl/enc_gamma_correct.h b/media/libjxl/src/lib/jxl/enc_gamma_correct.h new file mode 100644 index 000000000..0db7012bb --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_gamma_correct.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_GAMMA_CORRECT_H_ +#define LIB_JXL_ENC_GAMMA_CORRECT_H_ + +// Deprecated: sRGB transfer function. Use color_management.h instead. + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/transfer_functions-inl.h" + +namespace jxl { + +// Values are in [0, 1]. +static JXL_INLINE double Srgb8ToLinearDirect(double srgb) { + if (srgb <= 0.0) return 0.0; + if (srgb <= 0.04045) return srgb / 12.92; + if (srgb >= 1.0) return 1.0; + return std::pow((srgb + 0.055) / 1.055, 2.4); +} + +// Values are in [0, 1]. +static JXL_INLINE double LinearToSrgb8Direct(double linear) { + if (linear <= 0.0) return 0.0; + if (linear >= 1.0) return 1.0; + if (linear <= 0.0031308) return linear * 12.92; + return std::pow(linear, 1.0 / 2.4) * 1.055 - 0.055; +} + +} // namespace jxl + +#endif // LIB_JXL_ENC_GAMMA_CORRECT_H_ diff --git a/media/libjxl/src/lib/jxl/enc_group.cc b/media/libjxl/src/lib/jxl/enc_group.cc new file mode 100644 index 000000000..bf853064e --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_group.cc @@ -0,0 +1,361 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_group.h" + +#include + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_group.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::MaskFromVec; +using hwy::HWY_NAMESPACE::Round; + +// NOTE: caller takes care of extracting quant from rect of RawQuantField. +void QuantizeBlockAC(const Quantizer& quantizer, const bool error_diffusion, + size_t c, int32_t quant, float qm_multiplier, + size_t quant_kind, size_t xsize, size_t ysize, + const float* JXL_RESTRICT block_in, + int32_t* JXL_RESTRICT block_out) { + PROFILER_FUNC; + const float* JXL_RESTRICT qm = quantizer.InvDequantMatrix(quant_kind, c); + const float qac = quantizer.Scale() * quant; + // Not SIMD-fied for now. + float thres[4] = {0.58f, 0.635f, 0.66f, 0.7f}; + if (c == 0) { + for (int i = 1; i < 4; ++i) { + thres[i] += 0.08f; + } + } + if (c == 2) { + for (int i = 1; i < 4; ++i) { + thres[i] = 0.75f; + } + } + if (xsize > 1 || ysize > 1) { + for (int i = 0; i < 4; ++i) { + thres[i] -= Clamp1(0.003f * xsize * ysize, 0.f, (c > 0 ? 0.08f : 0.12f)); + } + } + + if (!error_diffusion) { + HWY_CAPPED(float, kBlockDim) df; + HWY_CAPPED(int32_t, kBlockDim) di; + HWY_CAPPED(uint32_t, kBlockDim) du; + const auto quant = Set(df, qac * qm_multiplier); + + for (size_t y = 0; y < ysize * kBlockDim; y++) { + size_t yfix = static_cast(y >= ysize * kBlockDim / 2) * 2; + const size_t off = y * kBlockDim * xsize; + for (size_t x = 0; x < xsize * kBlockDim; x += Lanes(df)) { + auto thr = Zero(df); + if (xsize == 1) { + HWY_ALIGN uint32_t kMask[kBlockDim] = {0, 0, 0, 0, + ~0u, ~0u, ~0u, ~0u}; + const auto mask = MaskFromVec(BitCast(df, Load(du, kMask + x))); + thr = + IfThenElse(mask, Set(df, thres[yfix + 1]), Set(df, thres[yfix])); + } else { + // Same for all lanes in the vector. + thr = Set( + df, + thres[yfix + static_cast(x >= xsize * kBlockDim / 2)]); + } + + const auto q = Mul(Load(df, qm + off + x), quant); + const auto in = Load(df, block_in + off + x); + const auto val = Mul(q, in); + const auto nzero_mask = Ge(Abs(val), thr); + const auto v = ConvertTo(di, IfThenElseZero(nzero_mask, Round(val))); + Store(v, di, block_out + off + x); + } + } + return; + } + +retry: + int hfNonZeros[4] = {}; + float hfError[4] = {}; + float hfMaxError[4] = {}; + size_t hfMaxErrorIx[4] = {}; + for (size_t y = 0; y < ysize * kBlockDim; y++) { + for (size_t x = 0; x < xsize * kBlockDim; x++) { + const size_t pos = y * kBlockDim * xsize + x; + if (x < xsize && y < ysize) { + // Ensure block is initialized + block_out[pos] = 0; + continue; + } + const size_t hfix = (static_cast(y >= ysize * kBlockDim / 2) * 2 + + static_cast(x >= xsize * kBlockDim / 2)); + const float val = block_in[pos] * (qm[pos] * qac * qm_multiplier); + float v = (std::abs(val) < thres[hfix]) ? 0 : rintf(val); + const float error = std::abs(val) - std::abs(v); + hfError[hfix] += error * error; + if (hfMaxError[hfix] < error) { + hfMaxError[hfix] = error; + hfMaxErrorIx[hfix] = pos; + } + if (v != 0.0f) { + hfNonZeros[hfix] += std::abs(v); + } + block_out[pos] = static_cast(rintf(v)); + } + } + if (c != 1) return; + constexpr size_t kPartialBlockKinds = + (1 << AcStrategy::Type::IDENTITY) | (1 << AcStrategy::Type::DCT2X2) | + (1 << AcStrategy::Type::DCT4X4) | (1 << AcStrategy::Type::DCT4X8) | + (1 << AcStrategy::Type::DCT8X4) | (1 << AcStrategy::Type::AFV0) | + (1 << AcStrategy::Type::AFV1) | (1 << AcStrategy::Type::AFV2) | + (1 << AcStrategy::Type::AFV3); + if ((1 << quant_kind) & kPartialBlockKinds) return; + float hfErrorLimit = 0.029f * (xsize * ysize) * kDCTBlockSize * 0.25f; + bool goretry = false; + for (int i = 1; i < 4; ++i) { + if (hfError[i] >= hfErrorLimit && + hfNonZeros[i] <= (xsize + ysize) * 0.25f) { + if (thres[i] >= 0.4f) { + thres[i] -= 0.01f; + goretry = true; + } + } + } + if (goretry) goto retry; + for (int i = 1; i < 4; ++i) { + if (hfError[i] >= hfErrorLimit && hfNonZeros[i] == 0) { + const size_t pos = hfMaxErrorIx[i]; + if (hfMaxError[i] >= 0.4f) { + block_out[pos] = block_in[pos] > 0.0f ? 1.0f : -1.0f; + } + } + } +} + +// NOTE: caller takes care of extracting quant from rect of RawQuantField. +void QuantizeRoundtripYBlockAC(const Quantizer& quantizer, + const bool error_diffusion, int32_t quant, + size_t quant_kind, size_t xsize, size_t ysize, + const float* JXL_RESTRICT biases, + float* JXL_RESTRICT inout, + int32_t* JXL_RESTRICT quantized) { + QuantizeBlockAC(quantizer, error_diffusion, 1, quant, 1.0f, quant_kind, xsize, + ysize, inout, quantized); + + PROFILER_ZONE("enc quant adjust bias"); + const float* JXL_RESTRICT dequant_matrix = + quantizer.DequantMatrix(quant_kind, 1); + + HWY_CAPPED(float, kDCTBlockSize) df; + HWY_CAPPED(int32_t, kDCTBlockSize) di; + const auto inv_qac = Set(df, quantizer.inv_quant_ac(quant)); + for (size_t k = 0; k < kDCTBlockSize * xsize * ysize; k += Lanes(df)) { + const auto quant = Load(di, quantized + k); + const auto adj_quant = AdjustQuantBias(di, 1, quant, biases); + const auto dequantm = Load(df, dequant_matrix + k); + Store(Mul(Mul(adj_quant, dequantm), inv_qac), df, inout + k); + } +} + +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, Image3F* dc) { + PROFILER_FUNC; + const Rect block_group_rect = enc_state->shared.BlockGroupRect(group_idx); + const Rect group_rect = enc_state->shared.GroupRect(group_idx); + const Rect cmap_rect( + block_group_rect.x0() / kColorTileDimInBlocks, + block_group_rect.y0() / kColorTileDimInBlocks, + DivCeil(block_group_rect.xsize(), kColorTileDimInBlocks), + DivCeil(block_group_rect.ysize(), kColorTileDimInBlocks)); + + const size_t xsize_blocks = block_group_rect.xsize(); + const size_t ysize_blocks = block_group_rect.ysize(); + + const size_t dc_stride = static_cast(dc->PixelsPerRow()); + const size_t opsin_stride = static_cast(opsin.PixelsPerRow()); + + const ImageI& full_quant_field = enc_state->shared.raw_quant_field; + const CompressParams& cparams = enc_state->cparams; + + // TODO(veluca): consider strategies to reduce this memory. + auto mem = hwy::AllocateAligned(3 * AcStrategy::kMaxCoeffArea); + auto fmem = hwy::AllocateAligned(5 * AcStrategy::kMaxCoeffArea); + float* JXL_RESTRICT scratch_space = + fmem.get() + 3 * AcStrategy::kMaxCoeffArea; + { + // Only use error diffusion in Squirrel mode or slower. + const bool error_diffusion = cparams.speed_tier <= SpeedTier::kSquirrel; + constexpr HWY_CAPPED(float, kDCTBlockSize) d; + + int32_t* JXL_RESTRICT coeffs[kMaxNumPasses][3] = {}; + size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); + JXL_DASSERT(num_passes > 0); + for (size_t i = 0; i < num_passes; i++) { + // TODO(veluca): 16-bit quantized coeffs are not implemented yet. + JXL_ASSERT(enc_state->coeffs[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + coeffs[i][c] = enc_state->coeffs[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + + HWY_ALIGN float* coeffs_in = fmem.get(); + HWY_ALIGN int32_t* quantized = mem.get(); + + size_t offset = 0; + + for (size_t by = 0; by < ysize_blocks; ++by) { + const int32_t* JXL_RESTRICT row_quant_ac = + block_group_rect.ConstRow(full_quant_field, by); + size_t ty = by / kColorTileDimInBlocks; + const int8_t* JXL_RESTRICT row_cmap[3] = { + cmap_rect.ConstRow(enc_state->shared.cmap.ytox_map, ty), + nullptr, + cmap_rect.ConstRow(enc_state->shared.cmap.ytob_map, ty), + }; + const float* JXL_RESTRICT opsin_rows[3] = { + group_rect.ConstPlaneRow(opsin, 0, by * kBlockDim), + group_rect.ConstPlaneRow(opsin, 1, by * kBlockDim), + group_rect.ConstPlaneRow(opsin, 2, by * kBlockDim), + }; + float* JXL_RESTRICT dc_rows[3] = { + block_group_rect.PlaneRow(dc, 0, by), + block_group_rect.PlaneRow(dc, 1, by), + block_group_rect.PlaneRow(dc, 2, by), + }; + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(block_group_rect, by); + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + const auto x_factor = + Set(d, enc_state->shared.cmap.YtoXRatio(row_cmap[0][tx])); + const auto b_factor = + Set(d, enc_state->shared.cmap.YtoBRatio(row_cmap[2][tx])); + for (size_t bx = tx * kColorTileDimInBlocks; + bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks; ++bx) { + const AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + + size_t xblocks = acs.covered_blocks_x(); + size_t yblocks = acs.covered_blocks_y(); + + CoefficientLayout(&yblocks, &xblocks); + + size_t size = kDCTBlockSize * xblocks * yblocks; + + // DCT Y channel, roundtrip-quantize it and set DC. + const int32_t quant_ac = row_quant_ac[bx]; + TransformFromPixels(acs.Strategy(), opsin_rows[1] + bx * kBlockDim, + opsin_stride, coeffs_in + size, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), coeffs_in + size, + dc_rows[1] + bx, dc_stride); + QuantizeRoundtripYBlockAC( + enc_state->shared.quantizer, error_diffusion, quant_ac, + acs.RawStrategy(), xblocks, yblocks, kDefaultQuantBias, + coeffs_in + size, quantized + size); + + // DCT X and B channels + for (size_t c : {0, 2}) { + TransformFromPixels(acs.Strategy(), opsin_rows[c] + bx * kBlockDim, + opsin_stride, coeffs_in + c * size, + scratch_space); + } + + // Unapply color correlation + for (size_t k = 0; k < size; k += Lanes(d)) { + const auto in_x = Load(d, coeffs_in + k); + const auto in_y = Load(d, coeffs_in + size + k); + const auto in_b = Load(d, coeffs_in + 2 * size + k); + const auto out_x = NegMulAdd(x_factor, in_y, in_x); + const auto out_b = NegMulAdd(b_factor, in_y, in_b); + Store(out_x, d, coeffs_in + k); + Store(out_b, d, coeffs_in + 2 * size + k); + } + + // Quantize X and B channels and set DC. + for (size_t c : {0, 2}) { + QuantizeBlockAC(enc_state->shared.quantizer, error_diffusion, c, + quant_ac, + c == 0 ? enc_state->x_qm_multiplier + : enc_state->b_qm_multiplier, + acs.RawStrategy(), xblocks, yblocks, + coeffs_in + c * size, quantized + c * size); + DCFromLowestFrequencies(acs.Strategy(), coeffs_in + c * size, + dc_rows[c] + bx, dc_stride); + } + enc_state->progressive_splitter.SplitACCoefficients( + quantized, size, acs, bx, by, offset, coeffs); + offset += size; + } + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ComputeCoefficients); +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, Image3F* dc) { + return HWY_DYNAMIC_DISPATCH(ComputeCoefficients)(group_idx, enc_state, opsin, + dc); +} + +Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx, + size_t histogram_idx, + const PassesEncoderState& enc_state, + BitWriter* writer, AuxOut* aux_out) { + // Select which histogram to use among those of the current pass. + const size_t num_histograms = enc_state.shared.num_histograms; + // num_histograms is 0 only for lossless. + JXL_ASSERT(num_histograms == 0 || histogram_idx < num_histograms); + size_t histo_selector_bits = CeilLog2Nonzero(num_histograms); + + if (histo_selector_bits != 0) { + BitWriter::Allotment allotment(writer, histo_selector_bits); + writer->Write(histo_selector_bits, histogram_idx); + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out); + } + WriteTokens(enc_state.passes[pass_idx].ac_tokens[group_idx], + enc_state.passes[pass_idx].codes, + enc_state.passes[pass_idx].context_map, writer, kLayerACTokens, + aux_out); + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_group.h b/media/libjxl/src/lib/jxl/enc_group.h new file mode 100644 index 000000000..62468ddf9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_group.h @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_GROUP_H_ +#define LIB_JXL_ENC_GROUP_H_ + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" + +namespace jxl { + +// Fills DC +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, Image3F* dc); + +Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx, + size_t histogram_idx, + const PassesEncoderState& enc_state, + BitWriter* writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_GROUP_H_ diff --git a/media/libjxl/src/lib/jxl/enc_heuristics.cc b/media/libjxl/src/lib/jxl/enc_heuristics.cc new file mode 100644 index 000000000..1ab4ea56c --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_heuristics.cc @@ -0,0 +1,935 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_heuristics.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/enc_ac_strategy.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ar_control_field.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_chroma_from_luma.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_photon_noise.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/gaborish.h" + +namespace jxl { +namespace { +void FindBestBlockEntropyModel(PassesEncoderState& enc_state) { + if (enc_state.cparams.decoding_speed_tier >= 1) { + static constexpr uint8_t kSimpleCtxMap[] = { + // Cluster all blocks together + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + }; + static_assert( + 3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap, + "Update simple context map"); + + auto bcm = enc_state.shared.block_ctx_map; + bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap)); + bcm.num_ctxs = 2; + bcm.num_dc_ctxs = 1; + return; + } + if (enc_state.cparams.speed_tier >= SpeedTier::kFalcon) { + return; + } + const ImageI& rqf = enc_state.shared.raw_quant_field; + // No need to change context modeling for small images. + size_t tot = rqf.xsize() * rqf.ysize(); + size_t size_for_ctx_model = + (1 << 10) * enc_state.cparams.butteraugli_distance; + if (tot < size_for_ctx_model) return; + + struct OccCounters { + // count the occurrences of each qf value and each strategy type. + OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) { + for (size_t y = 0; y < rqf.ysize(); y++) { + const int32_t* qf_row = rqf.Row(y); + AcStrategyRow acs_row = ac_strategy.ConstRow(y); + for (size_t x = 0; x < rqf.xsize(); x++) { + int ord = kStrategyOrder[acs_row[x].RawStrategy()]; + int qf = qf_row[x] - 1; + qf_counts[qf]++; + qf_ord_counts[ord][qf]++; + ord_counts[ord]++; + } + } + } + + size_t qf_counts[256] = {}; + size_t qf_ord_counts[kNumOrders][256] = {}; + size_t ord_counts[kNumOrders] = {}; + }; + // The OccCounters struct is too big to allocate on the stack. + std::unique_ptr counters( + new OccCounters(rqf, enc_state.shared.ac_strategy)); + + // Splitting the context model according to the quantization field seems to + // mostly benefit only large images. + size_t size_for_qf_split = (1 << 13) * enc_state.cparams.butteraugli_distance; + size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2; + std::vector& qft = enc_state.shared.block_ctx_map.qf_thresholds; + qft.clear(); + // Divide the quant field in up to num_qf_segments segments. + size_t cumsum = 0; + size_t next = 1; + size_t last_cut = 256; + size_t cut = tot * next / num_qf_segments; + for (uint32_t j = 0; j < 256; j++) { + cumsum += counters->qf_counts[j]; + if (cumsum > cut) { + if (j != 0) { + qft.push_back(j); + } + last_cut = j; + while (cumsum > cut) { + next++; + cut = tot * next / num_qf_segments; + } + } else if (next > qft.size() + 1) { + if (j - 1 == last_cut && j != 0) { + qft.push_back(j); + } + } + } + + // Count the occurrences of each segment. + std::vector counts(kNumOrders * (qft.size() + 1)); + size_t qft_pos = 0; + for (size_t j = 0; j < 256; j++) { + if (qft_pos < qft.size() && j == qft[qft_pos]) { + qft_pos++; + } + for (size_t i = 0; i < kNumOrders; i++) { + counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j]; + } + } + + // Repeatedly merge the lowest-count pair. + std::vector remap((qft.size() + 1) * kNumOrders); + std::iota(remap.begin(), remap.end(), 0); + std::vector clusters(remap); + size_t nb_clusters = Clamp1((int)(tot / size_for_ctx_model / 2), 2, 9); + size_t nb_clusters_chroma = Clamp1((int)(tot / size_for_ctx_model / 3), 1, 5); + // This is O(n^2 log n), but n is small. + while (clusters.size() > nb_clusters) { + std::sort(clusters.begin(), clusters.end(), + [&](int a, int b) { return counts[a] > counts[b]; }); + counts[clusters[clusters.size() - 2]] += counts[clusters.back()]; + counts[clusters.back()] = 0; + remap[clusters.back()] = clusters[clusters.size() - 2]; + clusters.pop_back(); + } + for (size_t i = 0; i < remap.size(); i++) { + while (remap[remap[i]] != remap[i]) { + remap[i] = remap[remap[i]]; + } + } + // Relabel starting from 0. + std::vector remap_remap(remap.size(), remap.size()); + size_t num = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (remap_remap[remap[i]] == remap.size()) { + remap_remap[remap[i]] = num++; + } + remap[i] = remap_remap[remap[i]]; + } + // Write the block context map. + auto& ctx_map = enc_state.shared.block_ctx_map.ctx_map; + ctx_map = remap; + ctx_map.resize(remap.size() * 3); + // for chroma, only use up to nb_clusters_chroma separate block contexts + // (those for the biggest clusters) + for (size_t i = remap.size(); i < remap.size() * 3; i++) { + ctx_map[i] = num + Clamp1((int)remap[i % remap.size()], 0, + (int)nb_clusters_chroma - 1); + } + enc_state.shared.block_ctx_map.num_ctxs = + *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; +} + +} // namespace + +void FindBestDequantMatrices(const CompressParams& cparams, + const Image3F& opsin, + ModularFrameEncoder* modular_frame_encoder, + DequantMatrices* dequant_matrices) { + // TODO(veluca): quant matrices for no-gaborish. + // TODO(veluca): heuristics for in-bitstream quant tables. + *dequant_matrices = DequantMatrices(); + if (cparams.max_error_mode) { + // Set numerators of all quantization matrices to constant values. + float weights[3][1] = {{1.0f / cparams.max_error[0]}, + {1.0f / cparams.max_error[1]}, + {1.0f / cparams.max_error[2]}}; + DctQuantWeightParams dct_params(weights); + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::DCT(dct_params)); + DequantMatricesSetCustom(dequant_matrices, encodings, + modular_frame_encoder); + float dc_weights[3] = {1.0f / cparams.max_error[0], + 1.0f / cparams.max_error[1], + 1.0f / cparams.max_error[2]}; + DequantMatricesSetCustomDC(dequant_matrices, dc_weights); + } +} + +bool DefaultEncoderHeuristics::HandlesColorConversion( + const CompressParams& cparams, const ImageBundle& ib) { + return cparams.noise != Override::kOn && cparams.patches != Override::kOn && + cparams.speed_tier >= SpeedTier::kWombat && cparams.resampling == 1 && + cparams.color_transform == ColorTransform::kXYB && + !cparams.modular_mode && !ib.HasAlpha(); +} + +namespace { + +void StoreMin2(const float v, float& min1, float& min2) { + if (v < min2) { + if (v < min1) { + min2 = min1; + min1 = v; + } else { + min2 = v; + } + } +} + +void CreateMask(const ImageF& image, ImageF& mask) { + for (size_t y = 0; y < image.ysize(); y++) { + auto* row_n = y > 0 ? image.Row(y - 1) : image.Row(y); + auto* row_in = image.Row(y); + auto* row_s = y + 1 < image.ysize() ? image.Row(y + 1) : image.Row(y); + auto* row_out = mask.Row(y); + for (size_t x = 0; x < image.xsize(); x++) { + // Center, west, east, north, south values and their absolute difference + float c = row_in[x]; + float w = x > 0 ? row_in[x - 1] : row_in[x]; + float e = x + 1 < image.xsize() ? row_in[x + 1] : row_in[x]; + float n = row_n[x]; + float s = row_s[x]; + float dw = std::abs(c - w); + float de = std::abs(c - e); + float dn = std::abs(c - n); + float ds = std::abs(c - s); + float min = std::numeric_limits::max(); + float min2 = std::numeric_limits::max(); + StoreMin2(dw, min, min2); + StoreMin2(de, min, min2); + StoreMin2(dn, min, min2); + StoreMin2(ds, min, min2); + row_out[x] = min2; + } + } +} + +// Downsamples the image by a factor of 2 with a kernel that's sharper than +// the standard 2x2 box kernel used by DownsampleImage. +// The kernel is optimized against the result of the 2x2 upsampling kernel used +// by the decoder. Ringing is slightly reduced by clamping the values of the +// resulting pixels within certain bounds of a small region in the original +// image. +void DownsampleImage2_Sharper(const ImageF& input, ImageF* output) { + const int64_t kernelx = 12; + const int64_t kernely = 12; + + static const float kernel[144] = { + -0.000314256996835, -0.000314256996835, -0.000897597057705, + -0.000562751488849, -0.000176807273646, 0.001864627368902, + 0.001864627368902, -0.000176807273646, -0.000562751488849, + -0.000897597057705, -0.000314256996835, -0.000314256996835, + -0.000314256996835, -0.001527942804748, -0.000121760530512, + 0.000191123989093, 0.010193185932466, 0.058637519197110, + 0.058637519197110, 0.010193185932466, 0.000191123989093, + -0.000121760530512, -0.001527942804748, -0.000314256996835, + -0.000897597057705, -0.000121760530512, 0.000946363683751, + 0.007113577630288, 0.000437956841058, -0.000372823835211, + -0.000372823835211, 0.000437956841058, 0.007113577630288, + 0.000946363683751, -0.000121760530512, -0.000897597057705, + -0.000562751488849, 0.000191123989093, 0.007113577630288, + 0.044592622228814, 0.000222278879007, -0.162864473015945, + -0.162864473015945, 0.000222278879007, 0.044592622228814, + 0.007113577630288, 0.000191123989093, -0.000562751488849, + -0.000176807273646, 0.010193185932466, 0.000437956841058, + 0.000222278879007, -0.000913092543974, -0.017071696107902, + -0.017071696107902, -0.000913092543974, 0.000222278879007, + 0.000437956841058, 0.010193185932466, -0.000176807273646, + 0.001864627368902, 0.058637519197110, -0.000372823835211, + -0.162864473015945, -0.017071696107902, 0.414660099370354, + 0.414660099370354, -0.017071696107902, -0.162864473015945, + -0.000372823835211, 0.058637519197110, 0.001864627368902, + 0.001864627368902, 0.058637519197110, -0.000372823835211, + -0.162864473015945, -0.017071696107902, 0.414660099370354, + 0.414660099370354, -0.017071696107902, -0.162864473015945, + -0.000372823835211, 0.058637519197110, 0.001864627368902, + -0.000176807273646, 0.010193185932466, 0.000437956841058, + 0.000222278879007, -0.000913092543974, -0.017071696107902, + -0.017071696107902, -0.000913092543974, 0.000222278879007, + 0.000437956841058, 0.010193185932466, -0.000176807273646, + -0.000562751488849, 0.000191123989093, 0.007113577630288, + 0.044592622228814, 0.000222278879007, -0.162864473015945, + -0.162864473015945, 0.000222278879007, 0.044592622228814, + 0.007113577630288, 0.000191123989093, -0.000562751488849, + -0.000897597057705, -0.000121760530512, 0.000946363683751, + 0.007113577630288, 0.000437956841058, -0.000372823835211, + -0.000372823835211, 0.000437956841058, 0.007113577630288, + 0.000946363683751, -0.000121760530512, -0.000897597057705, + -0.000314256996835, -0.001527942804748, -0.000121760530512, + 0.000191123989093, 0.010193185932466, 0.058637519197110, + 0.058637519197110, 0.010193185932466, 0.000191123989093, + -0.000121760530512, -0.001527942804748, -0.000314256996835, + -0.000314256996835, -0.000314256996835, -0.000897597057705, + -0.000562751488849, -0.000176807273646, 0.001864627368902, + 0.001864627368902, -0.000176807273646, -0.000562751488849, + -0.000897597057705, -0.000314256996835, -0.000314256996835}; + + int64_t xsize = input.xsize(); + int64_t ysize = input.ysize(); + + ImageF box_downsample = CopyImage(input); + DownsampleImage(&box_downsample, 2); + + ImageF mask(box_downsample.xsize(), box_downsample.ysize()); + CreateMask(box_downsample, mask); + + for (size_t y = 0; y < output->ysize(); y++) { + float* row_out = output->Row(y); + const float* row_in[kernely]; + const float* row_mask = mask.Row(y); + // get the rows in the support + for (size_t ky = 0; ky < kernely; ky++) { + int64_t iy = y * 2 + ky - (kernely - 1) / 2; + if (iy < 0) iy = 0; + if (iy >= ysize) iy = ysize - 1; + row_in[ky] = input.Row(iy); + } + + for (size_t x = 0; x < output->xsize(); x++) { + // get min and max values of the original image in the support + float min = std::numeric_limits::max(); + float max = std::numeric_limits::min(); + // kernelx - R and kernely - R are the radius of a rectangular region in + // which the values of a pixel are bounded to reduce ringing. + static constexpr int64_t R = 5; + for (int64_t ky = R; ky + R < kernely; ky++) { + for (int64_t kx = R; kx + R < kernelx; kx++) { + int64_t ix = x * 2 + kx - (kernelx - 1) / 2; + if (ix < 0) ix = 0; + if (ix >= xsize) ix = xsize - 1; + min = std::min(min, row_in[ky][ix]); + max = std::max(max, row_in[ky][ix]); + } + } + + float sum = 0; + for (int64_t ky = 0; ky < kernely; ky++) { + for (int64_t kx = 0; kx < kernelx; kx++) { + int64_t ix = x * 2 + kx - (kernelx - 1) / 2; + if (ix < 0) ix = 0; + if (ix >= xsize) ix = xsize - 1; + sum += row_in[ky][ix] * kernel[ky * kernelx + kx]; + } + } + + row_out[x] = sum; + + // Clamp the pixel within the value of a small area to prevent ringning. + // The mask determines how much to clamp, clamp more to reduce more + // ringing in smooth areas, clamp less in noisy areas to get more + // sharpness. Higher mask_multiplier gives less clamping, so less + // ringing reduction. + const constexpr float mask_multiplier = 1; + float a = row_mask[x] * mask_multiplier; + float clip_min = min - a; + float clip_max = max + a; + if (row_out[x] < clip_min) { + row_out[x] = clip_min; + } else if (row_out[x] > clip_max) { + row_out[x] = clip_max; + } + } + } +} + +void DownsampleImage2_Sharper(Image3F* opsin) { + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), 2) + kBlockDim, + DivCeil(opsin->ysize(), 2) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + + for (size_t c = 0; c < 3; c++) { + DownsampleImage2_Sharper(opsin->Plane(c), &downsampled.Plane(c)); + } + *opsin = std::move(downsampled); +} + +// The default upsampling kernels used by Upsampler in the decoder. +static const constexpr int64_t kSize = 5; + +static const float kernel00[25] = { + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, + -0.03452303f, 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, + -0.04022174f, 0.28896755f, 0.56661550f, 0.03777607f, -0.01986694f, + -0.02921014f, 0.00278718f, 0.03777607f, -0.03144731f, -0.01185068f, + -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f, +}; +static const float kernel01[25] = { + -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f, + -0.02921014f, 0.00278718f, 0.03777607f, -0.03144731f, -0.01185068f, + -0.04022174f, 0.28896755f, 0.56661550f, 0.03777607f, -0.01986694f, + -0.03452303f, 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, +}; +static const float kernel10[25] = { + -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f, + -0.01610267f, 0.00278718f, 0.28896755f, 0.14111091f, -0.03452303f, + -0.01986694f, 0.03777607f, 0.56661550f, 0.28896755f, -0.04022174f, + -0.01185068f, -0.03144731f, 0.03777607f, 0.00278718f, -0.02921014f, + -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f, +}; +static const float kernel11[25] = { + -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f, + -0.01185068f, -0.03144731f, 0.03777607f, 0.00278718f, -0.02921014f, + -0.01986694f, 0.03777607f, 0.56661550f, 0.28896755f, -0.04022174f, + -0.01610267f, 0.00278718f, 0.28896755f, 0.14111091f, -0.03452303f, + -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f, +}; + +// Does exactly the same as the Upsampler in dec_upsampler for 2x2 pixels, with +// default CustomTransformData. +// TODO(lode): use Upsampler instead. However, it requires pre-initialization +// and padding on the left side of the image which requires refactoring the +// other code using this. +static void UpsampleImage(const ImageF& input, ImageF* output) { + int64_t xsize = input.xsize(); + int64_t ysize = input.ysize(); + int64_t xsize2 = output->xsize(); + int64_t ysize2 = output->ysize(); + for (int64_t y = 0; y < ysize2; y++) { + for (int64_t x = 0; x < xsize2; x++) { + auto kernel = kernel00; + if ((x & 1) && (y & 1)) { + kernel = kernel11; + } else if (x & 1) { + kernel = kernel10; + } else if (y & 1) { + kernel = kernel01; + } + float sum = 0; + int64_t x2 = x / 2; + int64_t y2 = y / 2; + + // get min and max values of the original image in the support + float min = std::numeric_limits::max(); + float max = std::numeric_limits::min(); + + for (int64_t ky = 0; ky < kSize; ky++) { + for (int64_t kx = 0; kx < kSize; kx++) { + int64_t xi = x2 - kSize / 2 + kx; + int64_t yi = y2 - kSize / 2 + ky; + if (xi < 0) xi = 0; + if (xi >= xsize) xi = input.xsize() - 1; + if (yi < 0) yi = 0; + if (yi >= ysize) yi = input.ysize() - 1; + min = std::min(min, input.Row(yi)[xi]); + max = std::max(max, input.Row(yi)[xi]); + } + } + + for (int64_t ky = 0; ky < kSize; ky++) { + for (int64_t kx = 0; kx < kSize; kx++) { + int64_t xi = x2 - kSize / 2 + kx; + int64_t yi = y2 - kSize / 2 + ky; + if (xi < 0) xi = 0; + if (xi >= xsize) xi = input.xsize() - 1; + if (yi < 0) yi = 0; + if (yi >= ysize) yi = input.ysize() - 1; + sum += input.Row(yi)[xi] * kernel[ky * kSize + kx]; + } + } + output->Row(y)[x] = sum; + if (output->Row(y)[x] < min) output->Row(y)[x] = min; + if (output->Row(y)[x] > max) output->Row(y)[x] = max; + } + } +} + +// Returns the derivative of Upsampler, with respect to input pixel x2, y2, to +// output pixel x, y (ignoring the clamping). +float UpsamplerDeriv(int64_t x2, int64_t y2, int64_t x, int64_t y) { + auto kernel = kernel00; + if ((x & 1) && (y & 1)) { + kernel = kernel11; + } else if (x & 1) { + kernel = kernel10; + } else if (y & 1) { + kernel = kernel01; + } + + int64_t ix = x / 2; + int64_t iy = y / 2; + int64_t kx = x2 - ix + kSize / 2; + int64_t ky = y2 - iy + kSize / 2; + + // This should not happen. + if (kx < 0 || kx >= kSize || ky < 0 || ky >= kSize) return 0; + + return kernel[ky * kSize + kx]; +} + +// Apply the derivative of the Upsampler to the input, reversing the effect of +// its coefficients. The output image is 2x2 times smaller than the input. +void AntiUpsample(const ImageF& input, ImageF* d) { + int64_t xsize = input.xsize(); + int64_t ysize = input.ysize(); + int64_t xsize2 = d->xsize(); + int64_t ysize2 = d->ysize(); + int64_t k0 = kSize - 1; + int64_t k1 = kSize; + for (int64_t y2 = 0; y2 < ysize2; ++y2) { + auto* row = d->Row(y2); + for (int64_t x2 = 0; x2 < xsize2; ++x2) { + int64_t x0 = x2 * 2 - k0; + if (x0 < 0) x0 = 0; + int64_t x1 = x2 * 2 + k1 + 1; + if (x1 > xsize) x1 = xsize; + int64_t y0 = y2 * 2 - k0; + if (y0 < 0) y0 = 0; + int64_t y1 = y2 * 2 + k1 + 1; + if (y1 > ysize) y1 = ysize; + + float sum = 0; + for (int64_t y = y0; y < y1; ++y) { + const auto* row_in = input.Row(y); + for (int64_t x = x0; x < x1; ++x) { + double deriv = UpsamplerDeriv(x2, y2, x, y); + sum += deriv * row_in[x]; + } + } + row[x2] = sum; + } + } +} + +// Element-wise multiplies two images. +template +void ElwiseMul(const Plane& image1, const Plane& image2, Plane* out) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + JXL_CHECK(xsize == out->xsize()); + JXL_CHECK(ysize == out->ysize()); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] * row2[x]; + } + } +} + +// Element-wise divides two images. +template +void ElwiseDiv(const Plane& image1, const Plane& image2, Plane* out) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + JXL_CHECK(xsize == out->xsize()); + JXL_CHECK(ysize == out->ysize()); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] / row2[x]; + } + } +} + +void ReduceRinging(const ImageF& initial, const ImageF& mask, ImageF& down) { + int64_t xsize2 = down.xsize(); + int64_t ysize2 = down.ysize(); + + for (size_t y = 0; y < down.ysize(); y++) { + const float* row_mask = mask.Row(y); + float* row_out = down.Row(y); + for (size_t x = 0; x < down.xsize(); x++) { + float v = down.Row(y)[x]; + float min = initial.Row(y)[x]; + float max = initial.Row(y)[x]; + for (int64_t yi = -1; yi < 2; yi++) { + for (int64_t xi = -1; xi < 2; xi++) { + int64_t x2 = (int64_t)x + xi; + int64_t y2 = (int64_t)y + yi; + if (x2 < 0 || y2 < 0 || x2 >= (int64_t)xsize2 || + y2 >= (int64_t)ysize2) + continue; + min = std::min(min, initial.Row(y2)[x2]); + max = std::max(max, initial.Row(y2)[x2]); + } + } + + row_out[x] = v; + + // Clamp the pixel within the value of a small area to prevent ringning. + // The mask determines how much to clamp, clamp more to reduce more + // ringing in smooth areas, clamp less in noisy areas to get more + // sharpness. Higher mask_multiplier gives less clamping, so less + // ringing reduction. + const constexpr float mask_multiplier = 2; + float a = row_mask[x] * mask_multiplier; + float clip_min = min - a; + float clip_max = max + a; + if (row_out[x] < clip_min) row_out[x] = clip_min; + if (row_out[x] > clip_max) row_out[x] = clip_max; + } + } +} + +// TODO(lode): move this to a separate file enc_downsample.cc +void DownsampleImage2_Iterative(const ImageF& orig, ImageF* output) { + int64_t xsize = orig.xsize(); + int64_t ysize = orig.ysize(); + int64_t xsize2 = DivCeil(orig.xsize(), 2); + int64_t ysize2 = DivCeil(orig.ysize(), 2); + + ImageF box_downsample = CopyImage(orig); + DownsampleImage(&box_downsample, 2); + ImageF mask(box_downsample.xsize(), box_downsample.ysize()); + CreateMask(box_downsample, mask); + + output->ShrinkTo(xsize2, ysize2); + + // Initial result image using the sharper downsampling. + // Allocate extra space to avoid a reallocation when padding. + ImageF initial(DivCeil(orig.xsize(), 2) + kBlockDim, + DivCeil(orig.ysize(), 2) + kBlockDim); + initial.ShrinkTo(initial.xsize() - kBlockDim, initial.ysize() - kBlockDim); + DownsampleImage2_Sharper(orig, &initial); + + ImageF down = CopyImage(initial); + ImageF up(xsize, ysize); + ImageF corr(xsize, ysize); + ImageF corr2(xsize2, ysize2); + + // In the weights map, relatively higher values will allow less ringing but + // also less sharpness. With all constant values, it optimizes equally + // everywhere. Even in this case, the weights2 computed from + // this is still used and differs at the borders of the image. + // TODO(lode): Make use of the weights field for anti-ringing and clamping, + // the values are all set to 1 for now, but it is intended to be used for + // reducing ringing based on the mask, and taking clamping into account. + ImageF weights(xsize, ysize); + for (size_t y = 0; y < weights.ysize(); y++) { + auto* row = weights.Row(y); + for (size_t x = 0; x < weights.xsize(); x++) { + row[x] = 1; + } + } + ImageF weights2(xsize2, ysize2); + AntiUpsample(weights, &weights2); + + const size_t num_it = 3; + for (size_t it = 0; it < num_it; ++it) { + UpsampleImage(down, &up); + corr = LinComb(1, orig, -1, up); + ElwiseMul(corr, weights, &corr); + AntiUpsample(corr, &corr2); + ElwiseDiv(corr2, weights2, &corr2); + + down = LinComb(1, down, 1, corr2); + } + + ReduceRinging(initial, mask, down); + + // can't just use CopyImage, because the output image was prepared with + // padding. + for (size_t y = 0; y < down.ysize(); y++) { + for (size_t x = 0; x < down.xsize(); x++) { + float v = down.Row(y)[x]; + output->Row(y)[x] = v; + } + } +} + +void DownsampleImage2_Iterative(Image3F* opsin) { + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), 2) + kBlockDim, + DivCeil(opsin->ysize(), 2) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + + Image3F rgb(opsin->xsize(), opsin->ysize()); + OpsinParams opsin_params; // TODO: use the ones that are actually used + opsin_params.Init(kDefaultIntensityTarget); + OpsinToLinear(*opsin, Rect(rgb), nullptr, &rgb, opsin_params); + + ImageF mask(opsin->xsize(), opsin->ysize()); + ButteraugliParams butter_params; + ButteraugliComparator butter(rgb, butter_params); + butter.Mask(&mask); + ImageF mask_fuzzy(opsin->xsize(), opsin->ysize()); + + for (size_t c = 0; c < 3; c++) { + DownsampleImage2_Iterative(opsin->Plane(c), &downsampled.Plane(c)); + } + *opsin = std::move(downsampled); +} +} // namespace + +Status DefaultEncoderHeuristics::LossyFrameHeuristics( + PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* original_pixels, Image3F* opsin, + const JxlCmsInterface& cms, ThreadPool* pool, AuxOut* aux_out) { + PROFILER_ZONE("JxlLossyFrameHeuristics uninstrumented"); + + CompressParams& cparams = enc_state->cparams; + PassesSharedState& shared = enc_state->shared; + + // Compute parameters for noise synthesis. + if (shared.frame_header.flags & FrameHeader::kNoise) { + PROFILER_ZONE("enc GetNoiseParam"); + if (cparams.photon_noise_iso == 0) { + // Don't start at zero amplitude since adding noise is expensive -- it + // significantly slows down decoding, and this is unlikely to + // completely go away even with advanced optimizations. After the + // kNoiseModelingRampUpDistanceRange we have reached the full level, + // i.e. noise is no longer represented by the compressed image, so we + // can add full noise by the noise modeling itself. + static const float kNoiseModelingRampUpDistanceRange = 0.6; + static const float kNoiseLevelAtStartOfRampUp = 0.25; + static const float kNoiseRampupStart = 1.0; + // TODO(user) test and properly select quality_coef with smooth + // filter + float quality_coef = 1.0f; + const float rampup = (cparams.butteraugli_distance - kNoiseRampupStart) / + kNoiseModelingRampUpDistanceRange; + if (rampup < 1.0f) { + quality_coef = kNoiseLevelAtStartOfRampUp + + (1.0f - kNoiseLevelAtStartOfRampUp) * rampup; + } + if (rampup < 0.0f) { + quality_coef = kNoiseRampupStart; + } + if (!GetNoiseParameter(*opsin, &shared.image_features.noise_params, + quality_coef)) { + shared.frame_header.flags &= ~FrameHeader::kNoise; + } + } + } + if (enc_state->shared.frame_header.upsampling != 1 && + !cparams.already_downsampled) { + // In VarDCT mode, LossyFrameHeuristics takes care of running downsampling + // after noise, if necessary. + if (cparams.resampling == 2) { + // TODO(lode): use the regular DownsampleImage, or adapt to the custom + // coefficients, if there is are custom upscaling coefficients in + // CustomTransformData + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + // TODO(lode): DownsampleImage2_Iterative is currently too slow to + // be used for squirrel, make it faster, and / or enable it only for + // kitten. + DownsampleImage2_Iterative(opsin); + } else { + DownsampleImage2_Sharper(opsin); + } + } else { + DownsampleImage(opsin, cparams.resampling); + } + PadImageToBlockMultipleInPlace(opsin); + } + + if (cparams.butteraugli_distance < 0) { + return JXL_FAILURE("Expected non-negative distance"); + } + + // Find and subtract splines. + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + // If we do already have them, they were passed upstream to EncodeFile. + if (!shared.image_features.splines.HasAny()) { + shared.image_features.splines = FindSplines(*opsin); + } + JXL_RETURN_IF_ERROR(shared.image_features.splines.InitializeDrawCache( + opsin->xsize(), opsin->ysize(), shared.cmap)); + shared.image_features.splines.SubtractFrom(opsin); + } + + // Find and subtract patches/dots. + if (ApplyOverride(cparams.patches, + cparams.speed_tier <= SpeedTier::kSquirrel)) { + FindBestPatchDictionary(*opsin, enc_state, cms, pool, aux_out); + PatchDictionaryEncoder::SubtractFrom(shared.image_features.patches, opsin); + } + + static const float kAcQuant = 0.79f; + const float quant_dc = InitialQuantDC(cparams.butteraugli_distance); + Quantizer& quantizer = enc_state->shared.quantizer; + // We don't know the quant field yet, but for computing the global scale + // assuming that it will be the same as for Falcon mode is good enough. + quantizer.ComputeGlobalScaleAndQuant( + quant_dc, kAcQuant / cparams.butteraugli_distance, 0); + + // TODO(veluca): we can now run all the code from here to FindBestQuantizer + // (excluded) one rect at a time. Do that. + + // Dependency graph: + // + // input: either XYB or input image + // + // input image -> XYB [optional] + // XYB -> initial quant field + // XYB -> Gaborished XYB + // Gaborished XYB -> CfL1 + // initial quant field, Gaborished XYB, CfL1 -> ACS + // initial quant field, ACS, Gaborished XYB -> EPF control field + // initial quant field -> adjusted initial quant field + // adjusted initial quant field, ACS -> raw quant field + // raw quant field, ACS, Gaborished XYB -> CfL2 + // + // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field. + + ArControlFieldHeuristics ar_heuristics; + AcStrategyHeuristics acs_heuristics; + CfLHeuristics cfl_heuristics; + + if (!opsin->xsize()) { + JXL_ASSERT(HandlesColorConversion(cparams, *original_pixels)); + *opsin = Image3F(RoundUpToBlockDim(original_pixels->xsize()), + RoundUpToBlockDim(original_pixels->ysize())); + opsin->ShrinkTo(original_pixels->xsize(), original_pixels->ysize()); + ToXYB(*original_pixels, pool, opsin, cms, /*linear=*/nullptr); + PadImageToBlockMultipleInPlace(opsin); + } + + // Compute an initial estimate of the quantization field. + // Call InitialQuantField only in Hare mode or slower. Otherwise, rely + // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon + // mode. + if (cparams.speed_tier > SpeedTier::kHare || cparams.uniform_quant > 0) { + enc_state->initial_quant_field = + ImageF(shared.frame_dim.xsize_blocks, shared.frame_dim.ysize_blocks); + float q = cparams.uniform_quant > 0 + ? cparams.uniform_quant + : kAcQuant / cparams.butteraugli_distance; + FillImage(q, &enc_state->initial_quant_field); + } else { + // Call this here, as it relies on pre-gaborish values. + float butteraugli_distance_for_iqf = cparams.butteraugli_distance; + if (!shared.frame_header.loop_filter.gab) { + butteraugli_distance_for_iqf *= 0.73f; + } + enc_state->initial_quant_field = InitialQuantField( + butteraugli_distance_for_iqf, *opsin, shared.frame_dim, pool, 1.0f, + &enc_state->initial_quant_masking); + quantizer.SetQuantField(quant_dc, enc_state->initial_quant_field, nullptr); + } + + // TODO(veluca): do something about animations. + + // Apply inverse-gaborish. + if (shared.frame_header.loop_filter.gab) { + GaborishInverse(opsin, 0.9908511000000001f, pool); + } + + FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder, + &enc_state->shared.matrices); + + cfl_heuristics.Init(*opsin); + acs_heuristics.Init(*opsin, enc_state); + + auto process_tile = [&](const uint32_t tid, const size_t thread) { + size_t n_enc_tiles = + DivCeil(enc_state->shared.frame_dim.xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, + enc_state->shared.frame_dim.ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, + enc_state->shared.frame_dim.xsize_blocks); + Rect r(bx0, by0, bx1 - bx0, by1 - by0); + + // For speeds up to Wombat, we only compute the color correlation map + // once we know the transform type and the quantization map. + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + cfl_heuristics.ComputeTile(r, *opsin, enc_state->shared.matrices, + /*ac_strategy=*/nullptr, + /*quantizer=*/nullptr, /*fast=*/false, thread, + &enc_state->shared.cmap); + } + + // Choose block sizes. + acs_heuristics.ProcessRect(r); + + // Choose amount of post-processing smoothing. + // TODO(veluca): should this go *after* AdjustQuantField? + ar_heuristics.RunRect(r, *opsin, enc_state, thread); + + // Always set the initial quant field, so we can compute the CfL map with + // more accuracy. The initial quant field might change in slower modes, but + // adjusting the quant field with butteraugli when all the other encoding + // parameters are fixed is likely a more reliable choice anyway. + AdjustQuantField(enc_state->shared.ac_strategy, r, + &enc_state->initial_quant_field); + quantizer.SetQuantFieldRect(enc_state->initial_quant_field, r, + &enc_state->shared.raw_quant_field); + + // Compute a non-default CfL map if we are at Hare speed, or slower. + if (cparams.speed_tier <= SpeedTier::kHare) { + cfl_heuristics.ComputeTile( + r, *opsin, enc_state->shared.matrices, &enc_state->shared.ac_strategy, + &enc_state->shared.quantizer, + /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, + &enc_state->shared.cmap); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, + DivCeil(enc_state->shared.frame_dim.xsize_blocks, kEncTileDimInBlocks) * + DivCeil(enc_state->shared.frame_dim.ysize_blocks, + kEncTileDimInBlocks), + [&](const size_t num_threads) { + ar_heuristics.PrepareForThreads(num_threads); + cfl_heuristics.PrepareForThreads(num_threads); + return true; + }, + process_tile, "Enc Heuristics")); + + acs_heuristics.Finalize(aux_out); + if (cparams.speed_tier <= SpeedTier::kHare) { + cfl_heuristics.ComputeDC(/*fast=*/cparams.speed_tier >= SpeedTier::kWombat, + &enc_state->shared.cmap); + } + + // Refine quantization levels. + FindBestQuantizer(original_pixels, *opsin, enc_state, cms, pool, aux_out); + + // Choose a context model that depends on the amount of quantization for AC. + if (cparams.speed_tier < SpeedTier::kFalcon) { + FindBestBlockEntropyModel(*enc_state); + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_heuristics.h b/media/libjxl/src/lib/jxl/enc_heuristics.h new file mode 100644 index 000000000..16509f00d --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_heuristics.h @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_HEURISTICS_H_ +#define LIB_JXL_ENC_HEURISTICS_H_ + +// Hook for custom encoder heuristics (VarDCT only for now). + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/enc_ma.h" + +namespace jxl { + +struct PassesEncoderState; +class ImageBundle; +class ModularFrameEncoder; + +class EncoderHeuristics { + public: + virtual ~EncoderHeuristics() = default; + // Initializes encoder structures in `enc_state` using the original image data + // in `original_pixels`, and the XYB image data in `opsin`. Also modifies the + // `opsin` image by applying Gaborish, and doing other modifications if + // necessary. `pool` is used for running the computations on multiple threads. + // `aux_out` collects statistics and can be used to print debug images. + virtual Status LossyFrameHeuristics( + PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* original_pixels, Image3F* opsin, + const JxlCmsInterface& cms, ThreadPool* pool, AuxOut* aux_out) = 0; + + // Custom fixed tree for lossless mode. Must set `tree` to a valid tree if + // the function returns true. + virtual bool CustomFixedTreeLossless(const FrameDimensions& frame_dim, + Tree* tree) { + return false; + } + + // If this method returns `true`, the `opsin` parameter to + // LossyFrameHeuristics will not be initialized, and should be initialized + // during the call. Moreover, `original_pixels` may not be in a linear + // colorspace (but will be the same as the `ib` value passed to this + // function). + virtual bool HandlesColorConversion(const CompressParams& cparams, + const ImageBundle& ib) { + return false; + } +}; + +class DefaultEncoderHeuristics : public EncoderHeuristics { + public: + Status LossyFrameHeuristics(PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* original_pixels, + Image3F* opsin, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out) override; + bool HandlesColorConversion(const CompressParams& cparams, + const ImageBundle& ib) override; +}; + +// Exposed here since it may be used by other EncoderHeuristics implementations +// outside this project. +void FindBestDequantMatrices(const CompressParams& cparams, + const Image3F& opsin, + ModularFrameEncoder* modular_frame_encoder, + DequantMatrices* dequant_matrices); + +} // namespace jxl + +#endif // LIB_JXL_ENC_HEURISTICS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_huffman.cc b/media/libjxl/src/lib/jxl/enc_huffman.cc new file mode 100644 index 000000000..04b566998 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_huffman.cc @@ -0,0 +1,214 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_huffman.h" + +#include +#include + +#include "lib/jxl/huffman_tree.h" + +namespace jxl { + +namespace { + +constexpr int kCodeLengthCodes = 18; + +void StoreHuffmanTreeOfHuffmanTreeToBitMask(const int num_codes, + const uint8_t* code_length_bitdepth, + BitWriter* writer) { + static const uint8_t kStorageOrder[kCodeLengthCodes] = { + 1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + // The bit lengths of the Huffman code over the code length alphabet + // are compressed with the following static Huffman code: + // Symbol Code + // ------ ---- + // 0 00 + // 1 1110 + // 2 110 + // 3 01 + // 4 10 + // 5 1111 + static const uint8_t kHuffmanBitLengthHuffmanCodeSymbols[6] = {0, 7, 3, + 2, 1, 15}; + static const uint8_t kHuffmanBitLengthHuffmanCodeBitLengths[6] = {2, 4, 3, + 2, 2, 4}; + + // Throw away trailing zeros: + size_t codes_to_store = kCodeLengthCodes; + if (num_codes > 1) { + for (; codes_to_store > 0; --codes_to_store) { + if (code_length_bitdepth[kStorageOrder[codes_to_store - 1]] != 0) { + break; + } + } + } + size_t skip_some = 0; // skips none. + if (code_length_bitdepth[kStorageOrder[0]] == 0 && + code_length_bitdepth[kStorageOrder[1]] == 0) { + skip_some = 2; // skips two. + if (code_length_bitdepth[kStorageOrder[2]] == 0) { + skip_some = 3; // skips three. + } + } + writer->Write(2, skip_some); + for (size_t i = skip_some; i < codes_to_store; ++i) { + size_t l = code_length_bitdepth[kStorageOrder[i]]; + writer->Write(kHuffmanBitLengthHuffmanCodeBitLengths[l], + kHuffmanBitLengthHuffmanCodeSymbols[l]); + } +} + +void StoreHuffmanTreeToBitMask(const size_t huffman_tree_size, + const uint8_t* huffman_tree, + const uint8_t* huffman_tree_extra_bits, + const uint8_t* code_length_bitdepth, + const uint16_t* code_length_bitdepth_symbols, + BitWriter* writer) { + for (size_t i = 0; i < huffman_tree_size; ++i) { + size_t ix = huffman_tree[i]; + writer->Write(code_length_bitdepth[ix], code_length_bitdepth_symbols[ix]); + // Extra bits + switch (ix) { + case 16: + writer->Write(2, huffman_tree_extra_bits[i]); + break; + case 17: + writer->Write(3, huffman_tree_extra_bits[i]); + break; + } + } +} + +void StoreSimpleHuffmanTree(const uint8_t* depths, size_t symbols[4], + size_t num_symbols, size_t max_bits, + BitWriter* writer) { + // value of 1 indicates a simple Huffman code + writer->Write(2, 1); + writer->Write(2, num_symbols - 1); // NSYM - 1 + + // Sort + for (size_t i = 0; i < num_symbols; i++) { + for (size_t j = i + 1; j < num_symbols; j++) { + if (depths[symbols[j]] < depths[symbols[i]]) { + std::swap(symbols[j], symbols[i]); + } + } + } + + if (num_symbols == 2) { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + } else if (num_symbols == 3) { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + writer->Write(max_bits, symbols[2]); + } else { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + writer->Write(max_bits, symbols[2]); + writer->Write(max_bits, symbols[3]); + // tree-select + writer->Write(1, depths[symbols[0]] == 1 ? 1 : 0); + } +} + +// num = alphabet size +// depths = symbol depths +void StoreHuffmanTree(const uint8_t* depths, size_t num, BitWriter* writer) { + // Write the Huffman tree into the compact representation. + std::unique_ptr arena(new uint8_t[2 * num]); + uint8_t* huffman_tree = arena.get(); + uint8_t* huffman_tree_extra_bits = arena.get() + num; + size_t huffman_tree_size = 0; + WriteHuffmanTree(depths, num, &huffman_tree_size, huffman_tree, + huffman_tree_extra_bits); + + // Calculate the statistics of the Huffman tree in the compact representation. + uint32_t huffman_tree_histogram[kCodeLengthCodes] = {0}; + for (size_t i = 0; i < huffman_tree_size; ++i) { + ++huffman_tree_histogram[huffman_tree[i]]; + } + + int num_codes = 0; + int code = 0; + for (int i = 0; i < kCodeLengthCodes; ++i) { + if (huffman_tree_histogram[i]) { + if (num_codes == 0) { + code = i; + num_codes = 1; + } else if (num_codes == 1) { + num_codes = 2; + break; + } + } + } + + // Calculate another Huffman tree to use for compressing both the + // earlier Huffman tree with. + uint8_t code_length_bitdepth[kCodeLengthCodes] = {0}; + uint16_t code_length_bitdepth_symbols[kCodeLengthCodes] = {0}; + CreateHuffmanTree(&huffman_tree_histogram[0], kCodeLengthCodes, 5, + &code_length_bitdepth[0]); + ConvertBitDepthsToSymbols(code_length_bitdepth, kCodeLengthCodes, + &code_length_bitdepth_symbols[0]); + + // Now, we have all the data, let's start storing it + StoreHuffmanTreeOfHuffmanTreeToBitMask(num_codes, code_length_bitdepth, + writer); + + if (num_codes == 1) { + code_length_bitdepth[code] = 0; + } + + // Store the real huffman tree now. + StoreHuffmanTreeToBitMask(huffman_tree_size, huffman_tree, + huffman_tree_extra_bits, &code_length_bitdepth[0], + code_length_bitdepth_symbols, writer); +} + +} // namespace + +void BuildAndStoreHuffmanTree(const uint32_t* histogram, const size_t length, + uint8_t* depth, uint16_t* bits, + BitWriter* writer) { + size_t count = 0; + size_t s4[4] = {0}; + for (size_t i = 0; i < length; i++) { + if (histogram[i]) { + if (count < 4) { + s4[count] = i; + } else if (count > 4) { + break; + } + count++; + } + } + + size_t max_bits_counter = length - 1; + size_t max_bits = 0; + while (max_bits_counter) { + max_bits_counter >>= 1; + ++max_bits; + } + + if (count <= 1) { + // Output symbol bits and depths are initialized with 0, nothing to do. + writer->Write(4, 1); + writer->Write(max_bits, s4[0]); + return; + } + + CreateHuffmanTree(histogram, length, 15, depth); + ConvertBitDepthsToSymbols(depth, length, bits); + + if (count <= 4) { + StoreSimpleHuffmanTree(depth, s4, count, max_bits, writer); + } else { + StoreHuffmanTree(depth, length, writer); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_huffman.h b/media/libjxl/src/lib/jxl/enc_huffman.h new file mode 100644 index 000000000..d7a66584e --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_huffman.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_HUFFMAN_H_ +#define LIB_JXL_ENC_HUFFMAN_H_ + +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Builds a Huffman tree for the given histogram, and encodes it into writer +// in a format that can be read by HuffmanDecodingData::ReadFromBitstream. +// An allotment for `writer` must already have been created by the caller. +void BuildAndStoreHuffmanTree(const uint32_t* histogram, size_t length, + uint8_t* depth, uint16_t* bits, + BitWriter* writer); + +} // namespace jxl + +#endif // LIB_JXL_ENC_HUFFMAN_H_ diff --git a/media/libjxl/src/lib/jxl/enc_icc_codec.cc b/media/libjxl/src/lib/jxl/enc_icc_codec.cc new file mode 100644 index 000000000..32e9b6b47 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_icc_codec.cc @@ -0,0 +1,407 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_icc_codec.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/icc_codec_common.h" + +namespace jxl { +namespace { + +// Unshuffles or de-interleaves bytes, for example with width 2, turns +// "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into +// first all the high order bytes, then all the low order bytes. +// Transposes a matrix of width columns and ceil(size / width) rows. There are +// size elements, size may be < width * height, if so the +// last elements of the bottom row are missing, the missing spots are +// transposed along with the filled spots, and the result has the missing +// elements at the bottom of the rightmost column. The input is the input matrix +// in scanline order, the output is the result matrix in scanline order, with +// missing elements skipped over (this may occur at multiple positions). +void Unshuffle(uint8_t* data, size_t size, size_t width) { + size_t height = (size + width - 1) / width; // amount of rows of input + PaddedBytes result(size); + // i = input index, j output index + size_t s = 0, j = 0; + for (size_t i = 0; i < size; i++) { + result[j] = data[i]; + j += height; + if (j >= size) j = ++s; + } + + for (size_t i = 0; i < size; i++) { + data[i] = result[i]; + } +} + +// This is performed by the encoder, the encoder must be able to encode any +// random byte stream (not just byte streams that are a valid ICC profile), so +// an error returned by this function is an implementation error. +Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num, + const uint8_t* data, size_t size, size_t* pos, + PaddedBytes* result) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size)); + // Required by the specification, see decoder. stride * 4 must be < *pos. + if (!*pos || ((*pos - 1u) >> 2u) < stride) { + return JXL_FAILURE("Invalid stride"); + } + if (*pos < stride * 4) return JXL_FAILURE("Too large stride"); + size_t start = result->size(); + for (size_t i = 0; i < num; i++) { + uint8_t predicted = + LinearPredictICCValue(data, *pos, i, stride, width, order); + result->push_back(data[*pos + i] - predicted); + } + *pos += num; + if (width > 1) Unshuffle(result->data() + start, num, width); + return true; +} +} // namespace + +// Outputs a transformed form of the given icc profile. The result itself is +// not particularly smaller than the input data in bytes, but it will be in a +// form that is easier to compress (more zeroes, ...) and will compress better +// with brotli. +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { + PaddedBytes commands; + PaddedBytes data; + + EncodeVarInt(size, result); + + // Header + PaddedBytes header = ICCInitialHeaderPrediction(); + EncodeUint32(0, size, &header); + for (size_t i = 0; i < kICCHeaderSize && i < size; i++) { + ICCPredictHeader(icc, size, header.data(), i); + data.push_back(icc[i] - header[i]); + } + if (size <= kICCHeaderSize) { + EncodeVarInt(0, result); // 0 commands + for (size_t i = 0; i < data.size(); i++) { + result->push_back(data[i]); + } + return true; + } + + std::vector tags; + std::vector tagstarts; + std::vector tagsizes; + std::map tagmap; + + // Tag list + size_t pos = kICCHeaderSize; + if (pos + 4 <= size) { + uint64_t numtags = DecodeUint32(icc, size, pos); + pos += 4; + EncodeVarInt(numtags + 1, &commands); + uint64_t prevtagstart = kICCHeaderSize + numtags * 12; + uint32_t prevtagsize = 0; + for (size_t i = 0; i < numtags; i++) { + if (pos + 12 > size) break; + + Tag tag = DecodeKeyword(icc, size, pos + 0); + uint32_t tagstart = DecodeUint32(icc, size, pos + 4); + uint32_t tagsize = DecodeUint32(icc, size, pos + 8); + pos += 12; + + tags.push_back(tag); + tagstarts.push_back(tagstart); + tagsizes.push_back(tagsize); + tagmap[tagstart] = tags.size() - 1; + + uint8_t tagcode = kCommandTagUnknown; + for (size_t j = 0; j < kNumTagStrings; j++) { + if (tag == *kTagStrings[j]) { + tagcode = j + kCommandTagStringFirst; + break; + } + } + + if (tag == kRtrcTag && pos + 24 < size) { + bool ok = true; + ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag; + ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag; + if (ok) { + for (size_t kk = 0; kk < 8; kk++) { + if (icc[pos - 8 + kk] != icc[pos + 4 + kk]) ok = false; + if (icc[pos - 8 + kk] != icc[pos + 16 + kk]) ok = false; + } + } + if (ok) { + tagcode = kCommandTagTRC; + pos += 24; + i += 2; + } + } + + if (tag == kRxyzTag && pos + 24 < size) { + bool ok = true; + ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag; + ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag; + uint32_t offsetr = tagstart; + uint32_t offsetg = DecodeUint32(icc, size, pos + 4); + uint32_t offsetb = DecodeUint32(icc, size, pos + 16); + uint32_t sizer = tagsize; + uint32_t sizeg = DecodeUint32(icc, size, pos + 8); + uint32_t sizeb = DecodeUint32(icc, size, pos + 20); + ok &= sizer == 20; + ok &= sizeg == 20; + ok &= sizeb == 20; + ok &= (offsetg == offsetr + 20); + ok &= (offsetb == offsetr + 40); + if (ok) { + tagcode = kCommandTagXYZ; + pos += 24; + i += 2; + } + } + + uint8_t command = tagcode; + uint64_t predicted_tagstart = prevtagstart + prevtagsize; + if (predicted_tagstart != tagstart) command |= kFlagBitOffset; + size_t predicted_tagsize = prevtagsize; + if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || + tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || + tag == kLumiTag) { + predicted_tagsize = 20; + } + if (predicted_tagsize != tagsize) command |= kFlagBitSize; + commands.push_back(command); + if (tagcode == 1) { + AppendKeyword(tag, &data); + } + if (command & kFlagBitOffset) EncodeVarInt(tagstart, &commands); + if (command & kFlagBitSize) EncodeVarInt(tagsize, &commands); + + prevtagstart = tagstart; + prevtagsize = tagsize; + } + } + // Indicate end of tag list or varint indicating there's none + commands.push_back(0); + + // Main content + // The main content in a valid ICC profile contains tagged elements, with the + // tag types (4 letter names) given by the tag list above, and the tag list + // pointing to the start and indicating the size of each tagged element. It is + // allowed for tagged elements to overlap, e.g. the curve for R, G and B could + // all point to the same one. + Tag tag; + size_t tagstart = 0, tagsize = 0, clutstart = 0; + + size_t last0 = pos; + // This loop appends commands to the output, processing some sub-section of a + // current tagged element each time. We need to keep track of the tagtype of + // the current element, and update it when we encounter the boundary of a + // next one. + // It is not required that the input data is a valid ICC profile, if the + // encoder does not recognize the data it will still be able to output bytes + // but will not predict as well. + while (pos <= size) { + size_t last1 = pos; + PaddedBytes commands_add; + PaddedBytes data_add; + + // This means the loop brought the position beyond the tag end. + if (pos > tagstart + tagsize) { + tag = {{0, 0, 0, 0}}; // nonsensical value + } + + if (commands_add.empty() && data_add.empty() && tagmap.count(pos) && + pos + 4 <= size) { + size_t index = tagmap[pos]; + tag = DecodeKeyword(icc, size, pos); + tagstart = tagstarts[index]; + tagsize = tagsizes[index]; + + if (tag == kMlucTag && pos + tagsize <= size && tagsize > 8 && + icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && + icc[pos + 7] == 0) { + size_t num = tagsize - 8; + commands_add.push_back(kCommandTypeStartFirst + 3); + pos += 8; + commands_add.push_back(kCommandShuffle2); + EncodeVarInt(num, &commands_add); + size_t start = data_add.size(); + for (size_t i = 0; i < num; i++) { + data_add.push_back(icc[pos]); + pos++; + } + Unshuffle(data_add.data() + start, num, 2); + } + + if (tag == kCurvTag && pos + tagsize <= size && tagsize > 8 && + icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && + icc[pos + 7] == 0) { + size_t num = tagsize - 8; + if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) { + commands_add.push_back(kCommandTypeStartFirst + 5); + pos += 8; + commands_add.push_back(kCommandPredict); + int order = 1, width = 2, stride = width; + commands_add.push_back((order << 2) | (width - 1)); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + } + + if (tag == kMab_Tag || tag == kMba_Tag) { + Tag subTag = DecodeKeyword(icc, size, pos); + if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) && + DecodeUint32(icc, size, pos + 4) == 0) { + uint32_t num = DecodeUint32(icc, size, pos + 8) * 2; + if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) { + pos += 12; + last1 = pos; + commands_add.push_back(kCommandPredict); + int order = 1, width = 2, stride = width; + commands_add.push_back((order << 2) | (width - 1)); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + + if (pos == tagstart + 24 && pos + 4 < size) { + // Note that this value can be remembered for next iterations of the + // loop, so the "pos == clutstart" if below can trigger during a later + // iteration. + clutstart = tagstart + DecodeUint32(icc, size, pos); + } + + if (pos == clutstart && clutstart + 16 < size) { + size_t numi = icc[tagstart + 8]; + size_t numo = icc[tagstart + 9]; + size_t width = icc[clutstart + 16]; + size_t stride = width * numo; + size_t num = width * numo; + for (size_t i = 0; i < numi && clutstart + i < size; i++) { + num *= icc[clutstart + i]; + } + if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) && + pos + num <= size && pos > stride * 4) { + commands_add.push_back(kCommandPredict); + int order = 1; + uint8_t flags = + (order << 2) | (width - 1) | (stride == width ? 0 : 16); + commands_add.push_back(flags); + if (flags & 16) EncodeVarInt(stride, &commands_add); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + } + + if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag && + pos == tagstart + 8 && pos + tagsize - 8 <= size && pos > 16 && + tagsize > 8) { + size_t width = 4, order = 0, stride = width; + size_t num = tagsize - 8; + uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16); + commands_add.push_back(kCommandPredict); + commands_add.push_back(flags); + if (flags & 16) EncodeVarInt(stride, &commands_add); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + + if (commands_add.empty() && data_add.empty() && pos + 20 <= size) { + Tag subTag = DecodeKeyword(icc, size, pos); + if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) { + commands_add.push_back(kCommandXYZ); + pos += 8; + for (size_t j = 0; j < 12; j++) data_add.push_back(icc[pos++]); + } + } + + if (commands_add.empty() && data_add.empty() && pos + 8 <= size) { + if (DecodeUint32(icc, size, pos + 4) == 0) { + Tag subTag = DecodeKeyword(icc, size, pos); + for (size_t i = 0; i < kNumTypeStrings; i++) { + if (subTag == *kTypeStrings[i]) { + commands_add.push_back(kCommandTypeStartFirst + i); + pos += 8; + break; + } + } + } + } + + if (!(commands_add.empty() && data_add.empty()) || pos == size) { + if (last0 < last1) { + commands.push_back(kCommandInsert); + EncodeVarInt(last1 - last0, &commands); + while (last0 < last1) { + data.push_back(icc[last0++]); + } + } + for (size_t i = 0; i < commands_add.size(); i++) { + commands.push_back(commands_add[i]); + } + for (size_t i = 0; i < data_add.size(); i++) { + data.push_back(data_add[i]); + } + last0 = pos; + } + if (commands_add.empty() && data_add.empty()) { + pos++; + } + } + + EncodeVarInt(commands.size(), result); + for (size_t i = 0; i < commands.size(); i++) { + result->push_back(commands[i]); + } + for (size_t i = 0; i < data.size(); i++) { + result->push_back(data[i]); + } + + return true; +} + +Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out) { + if (icc.empty()) return JXL_FAILURE("ICC must be non-empty"); + PaddedBytes enc; + JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc)); + std::vector> tokens(1); + BitWriter::Allotment allotment(writer, 128); + JXL_RETURN_IF_ERROR(U64Coder::Write(enc.size(), writer)); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + + for (size_t i = 0; i < enc.size(); i++) { + tokens[0].emplace_back( + ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0), + enc[i]); + } + HistogramParams params; + params.lz77_method = enc.size() < 4096 ? HistogramParams::LZ77Method::kOptimal + : HistogramParams::LZ77Method::kLZ77; + EntropyEncodingData code; + std::vector context_map; + params.force_huffman = true; + BuildAndEncodeHistograms(params, kNumICCContexts, tokens, &code, &context_map, + writer, layer, aux_out); + WriteTokens(tokens[0], code, context_map, writer, layer, aux_out); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_icc_codec.h b/media/libjxl/src/lib/jxl/enc_icc_codec.h new file mode 100644 index 000000000..2480e3ae9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_icc_codec.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ICC_CODEC_H_ +#define LIB_JXL_ENC_ICC_CODEC_H_ + +// Compressed representation of ICC profiles. + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Should still be called if `icc.empty()` - if so, writes only 1 bit. +Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +// Exposed only for testing +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ICC_CODEC_H_ diff --git a/media/libjxl/src/lib/jxl/enc_image_bundle.cc b/media/libjxl/src/lib/jxl/enc_image_bundle.cc new file mode 100644 index 000000000..fe6d282a8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_image_bundle.cc @@ -0,0 +1,154 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_image_bundle.h" + +#include +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" + +namespace jxl { + +namespace { + +// Copies ib:rect, converts, and copies into out. +Status CopyToT(const ImageMetadata* metadata, const ImageBundle* ib, + const Rect& rect, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, Image3F* out) { + PROFILER_FUNC; + ColorSpaceTransform c_transform(cms); + // Changing IsGray is probably a bug. + JXL_CHECK(ib->IsGray() == c_desired.IsGray()); + bool is_gray = ib->IsGray(); + if (out->xsize() < rect.xsize() || out->ysize() < rect.ysize()) { + *out = Image3F(rect.xsize(), rect.ysize()); + } else { + out->ShrinkTo(rect.xsize(), rect.ysize()); + } + std::atomic ok{true}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, rect.ysize(), + [&](const size_t num_threads) { + return c_transform.Init(ib->c_current(), c_desired, + metadata->IntensityTarget(), rect.xsize(), + num_threads); + }, + [&](const uint32_t y, const size_t thread) { + float* mutable_src_buf = c_transform.BufSrc(thread); + const float* src_buf = mutable_src_buf; + // Interleave input. + if (is_gray) { + src_buf = rect.ConstPlaneRow(ib->color(), 0, y); + } else if (ib->c_current().IsCMYK()) { + if (!ib->HasBlack()) { + ok.store(false); + return; + } + const float* JXL_RESTRICT row_in0 = + rect.ConstPlaneRow(ib->color(), 0, y); + const float* JXL_RESTRICT row_in1 = + rect.ConstPlaneRow(ib->color(), 1, y); + const float* JXL_RESTRICT row_in2 = + rect.ConstPlaneRow(ib->color(), 2, y); + const float* JXL_RESTRICT row_in3 = rect.ConstRow(ib->black(), y); + for (size_t x = 0; x < rect.xsize(); x++) { + // CMYK convention in JXL: 0 = max ink, 1 = white + mutable_src_buf[4 * x + 0] = row_in0[x]; + mutable_src_buf[4 * x + 1] = row_in1[x]; + mutable_src_buf[4 * x + 2] = row_in2[x]; + mutable_src_buf[4 * x + 3] = row_in3[x]; + } + } else { + const float* JXL_RESTRICT row_in0 = + rect.ConstPlaneRow(ib->color(), 0, y); + const float* JXL_RESTRICT row_in1 = + rect.ConstPlaneRow(ib->color(), 1, y); + const float* JXL_RESTRICT row_in2 = + rect.ConstPlaneRow(ib->color(), 2, y); + for (size_t x = 0; x < rect.xsize(); x++) { + mutable_src_buf[3 * x + 0] = row_in0[x]; + mutable_src_buf[3 * x + 1] = row_in1[x]; + mutable_src_buf[3 * x + 2] = row_in2[x]; + } + } + float* JXL_RESTRICT dst_buf = c_transform.BufDst(thread); + if (!c_transform.Run(thread, src_buf, dst_buf)) { + ok.store(false); + return; + } + float* JXL_RESTRICT row_out0 = out->PlaneRow(0, y); + float* JXL_RESTRICT row_out1 = out->PlaneRow(1, y); + float* JXL_RESTRICT row_out2 = out->PlaneRow(2, y); + // De-interleave output and convert type. + if (is_gray) { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = dst_buf[x]; + row_out1[x] = dst_buf[x]; + row_out2[x] = dst_buf[x]; + } + } else { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = dst_buf[3 * x + 0]; + row_out1[x] = dst_buf[3 * x + 1]; + row_out2[x] = dst_buf[3 * x + 2]; + } + } + }, + "Colorspace transform")); + return ok.load(); +} + +} // namespace + +Status ImageBundle::TransformTo(const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(CopyTo(Rect(color_), c_desired, cms, &color_, pool)); + c_current_ = c_desired; + return true; +} +Status ImageBundle::CopyTo(const Rect& rect, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, Image3F* out, + ThreadPool* pool) const { + return CopyToT(metadata_, this, rect, c_desired, cms, pool, out); +} +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, + ImageBundle* store, const ImageBundle** out) { + if (in.c_current().SameColorEncoding(c_desired) && !in.HasBlack()) { + *out = ∈ + return true; + } + // TODO(janwas): avoid copying via createExternal+copyBackToIO + // instead of copy+createExternal+copyBackToIO + store->SetFromImage(CopyImage(in.color()), in.c_current()); + + // Must at least copy the alpha channel for use by external_image. + if (in.HasExtraChannels()) { + std::vector extra_channels; + for (const ImageF& extra_channel : in.extra_channels()) { + extra_channels.emplace_back(CopyImage(extra_channel)); + } + store->SetExtraChannels(std::move(extra_channels)); + } + + if (!store->TransformTo(c_desired, cms, pool)) { + return false; + } + *out = store; + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_image_bundle.h b/media/libjxl/src/lib/jxl/enc_image_bundle.h new file mode 100644 index 000000000..85f8e14e1 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_image_bundle.h @@ -0,0 +1,25 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_IMAGE_BUNDLE_H_ +#define LIB_JXL_ENC_IMAGE_BUNDLE_H_ + +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Does color transformation from in.c_current() to c_desired if the color +// encodings are different, or nothing if they are already the same. +// If color transformation is done, stores the transformed values into store and +// sets the out pointer to store, else leaves store untouched and sets the out +// pointer to &in. +// Returns false if color transform fails. +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, + ImageBundle* store, const ImageBundle** out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_IMAGE_BUNDLE_H_ diff --git a/media/libjxl/src/lib/jxl/enc_jxl_skcms.h b/media/libjxl/src/lib/jxl/enc_jxl_skcms.h new file mode 100644 index 000000000..4be54205b --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_jxl_skcms.h @@ -0,0 +1,54 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_JXL_SKCMS_H_ +#define LIB_JXL_ENC_JXL_SKCMS_H_ + +// skcms wrapper to rename the skcms symbols to avoid conflicting names with +// other projects using skcms as well. When using JPEGXL_BUNDLE_SKCMS the +// bundled functions will be renamed from skcms_ to jxl_skcms_ + +#ifdef SKCMS_API +#error "Must include jxl_skcms.h and not skcms.h directly" +#endif // SKCMS_API + +#if JPEGXL_BUNDLE_SKCMS + +#define skcms_252_random_bytes jxl_skcms_252_random_bytes +#define skcms_AdaptToXYZD50 jxl_skcms_AdaptToXYZD50 +#define skcms_ApproximateCurve jxl_skcms_ApproximateCurve +#define skcms_ApproximatelyEqualProfiles jxl_skcms_ApproximatelyEqualProfiles +#define skcms_AreApproximateInverses jxl_skcms_AreApproximateInverses +#define skcms_GetCHAD jxl_skcms_GetCHAD +#define skcms_GetTagByIndex jxl_skcms_GetTagByIndex +#define skcms_GetTagBySignature jxl_skcms_GetTagBySignature +#define skcms_GetWTPT jxl_skcms_GetWTPT +#define skcms_Identity_TransferFunction jxl_skcms_Identity_TransferFunction +#define skcms_MakeUsableAsDestination jxl_skcms_MakeUsableAsDestination +#define skcms_MakeUsableAsDestinationWithSingleCurve \ + jxl_skcms_MakeUsableAsDestinationWithSingleCurve +#define skcms_Matrix3x3_concat jxl_skcms_Matrix3x3_concat +#define skcms_Matrix3x3_invert jxl_skcms_Matrix3x3_invert +#define skcms_MaxRoundtripError jxl_skcms_MaxRoundtripError +#define skcms_Parse jxl_skcms_Parse +#define skcms_PrimariesToXYZD50 jxl_skcms_PrimariesToXYZD50 +#define skcms_sRGB_Inverse_TransferFunction \ + jxl_skcms_sRGB_Inverse_TransferFunction +#define skcms_sRGB_profile jxl_skcms_sRGB_profile +#define skcms_sRGB_TransferFunction jxl_skcms_sRGB_TransferFunction +#define skcms_TransferFunction_eval jxl_skcms_TransferFunction_eval +#define skcms_TransferFunction_invert jxl_skcms_TransferFunction_invert +#define skcms_TransferFunction_makeHLGish jxl_skcms_TransferFunction_makeHLGish +#define skcms_TransferFunction_makePQish jxl_skcms_TransferFunction_makePQish +#define skcms_Transform jxl_skcms_Transform +#define skcms_TransformWithPalette jxl_skcms_TransformWithPalette +#define skcms_TRCs_AreApproximateInverse jxl_skcms_TRCs_AreApproximateInverse +#define skcms_XYZD50_profile jxl_skcms_XYZD50_profile + +#endif // JPEGXL_BUNDLE_SKCMS + +#include "skcms.h" + +#endif // LIB_JXL_ENC_JXL_SKCMS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_modular.cc b/media/libjxl/src/lib/jxl/enc_modular.cc new file mode 100644 index 000000000..9e34fe875 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_modular.cc @@ -0,0 +1,1744 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_modular.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cluster.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/gaborish.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_debug_tree.h" +#include "lib/jxl/modular/encoding/enc_encoding.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/enc_transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Squeeze default quantization factors +// these quantization factors are for -Q 50 (other qualities simply scale the +// factors; things are rounded down and obviously cannot get below 1) +static const float squeeze_quality_factor = + 0.35; // for easy tweaking of the quality range (decrease this number for + // higher quality) +static const float squeeze_luma_factor = + 1.1; // for easy tweaking of the balance between luma (or anything + // non-chroma) and chroma (decrease this number for higher quality + // luma) +static const float squeeze_quality_factor_xyb = 2.4f; +static const float squeeze_xyb_qtable[3][16] = { + {163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, 0.64, 0.32, 0.16, + 0.08, 0.04, 0.02, 0.01, 0.005}, // Y + {1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, + 0.5}, // X + {2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, + 0.5}, // B-Y +}; + +static const float squeeze_luma_qtable[16] = { + 163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, + 0.64, 0.32, 0.16, 0.08, 0.04, 0.02, 0.01, 0.005}; +// for 8-bit input, the range of YCoCg chroma is -255..255 so basically this +// does 4:2:0 subsampling (two most fine grained layers get quantized away) +static const float squeeze_chroma_qtable[16] = { + 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, 0.5}; + +// `cutoffs` must be sorted. +Tree MakeFixedTree(int property, const std::vector& cutoffs, + Predictor pred, size_t num_pixels) { + size_t log_px = CeilLog2Nonzero(num_pixels); + size_t min_gap = 0; + // Reduce fixed tree height when encoding small images. + if (log_px < 14) { + min_gap = 8 * (14 - log_px); + } + Tree tree; + struct NodeInfo { + size_t begin, end, pos; + }; + std::queue q; + // Leaf IDs will be set by roundtrip decoding the tree. + tree.push_back(PropertyDecisionNode::Leaf(pred)); + q.push(NodeInfo{0, cutoffs.size(), 0}); + while (!q.empty()) { + NodeInfo info = q.front(); + q.pop(); + if (info.begin + min_gap >= info.end) continue; + uint32_t split = (info.begin + info.end) / 2; + tree[info.pos] = + PropertyDecisionNode::Split(property, cutoffs[split], tree.size()); + q.push(NodeInfo{split + 1, info.end, tree.size()}); + tree.push_back(PropertyDecisionNode::Leaf(pred)); + q.push(NodeInfo{info.begin, split, tree.size()}); + tree.push_back(PropertyDecisionNode::Leaf(pred)); + } + return tree; +} + +Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels) { + if (tree_kind == ModularOptions::TreeKind::kJpegTranscodeACMeta || + tree_kind == ModularOptions::TreeKind::kTrivialTreeNoPredictor) { + // All the data is 0, so no need for a fancy tree. + return {PropertyDecisionNode::Leaf(Predictor::Zero)}; + } + if (tree_kind == ModularOptions::TreeKind::kFalconACMeta) { + // All the data is 0 except the quant field. TODO(veluca): make that 0 too. + return {PropertyDecisionNode::Leaf(Predictor::Left)}; + } + if (tree_kind == ModularOptions::TreeKind::kACMeta) { + // Small image. + if (total_pixels < 1024) { + return {PropertyDecisionNode::Leaf(Predictor::Left)}; + } + Tree tree; + // 0: c > 1 + tree.push_back(PropertyDecisionNode::Split(0, 1, 1)); + // 1: c > 2 + tree.push_back(PropertyDecisionNode::Split(0, 2, 3)); + // 2: c > 0 + tree.push_back(PropertyDecisionNode::Split(0, 0, 5)); + // 3: EPF control field (all 0 or 4), top > 0 + tree.push_back(PropertyDecisionNode::Split(6, 0, 21)); + // 4: ACS+QF, y > 0 + tree.push_back(PropertyDecisionNode::Split(2, 0, 7)); + // 5: CfL x + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); + // 6: CfL b + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); + // 7: QF: split according to the left quant value. + tree.push_back(PropertyDecisionNode::Split(7, 5, 9)); + // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large + // rectangular 6-11, 8x8 12+), according to previous ACS value. + tree.push_back(PropertyDecisionNode::Split(7, 5, 15)); + // QF + tree.push_back(PropertyDecisionNode::Split(7, 11, 11)); + tree.push_back(PropertyDecisionNode::Split(7, 3, 13)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + // ACS + tree.push_back(PropertyDecisionNode::Split(7, 11, 17)); + tree.push_back(PropertyDecisionNode::Split(7, 3, 19)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + // EPF, left > 0 + tree.push_back(PropertyDecisionNode::Split(7, 0, 23)); + tree.push_back(PropertyDecisionNode::Split(7, 0, 25)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + return tree; + } + if (tree_kind == ModularOptions::TreeKind::kWPFixedDC) { + std::vector cutoffs = { + -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, + -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, + 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; + return MakeFixedTree(kWPProp, cutoffs, Predictor::Weighted, total_pixels); + } + if (tree_kind == ModularOptions::TreeKind::kGradientFixedDC) { + std::vector cutoffs = { + -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, + -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, + 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; + return MakeFixedTree(kGradientProp, cutoffs, Predictor::Gradient, + total_pixels); + } + JXL_ABORT("Unreachable"); + return {}; +} + +// Merges the trees in `trees` using nodes that decide on stream_id, as defined +// by `tree_splits`. +void MergeTrees(const std::vector& trees, + const std::vector& tree_splits, size_t begin, + size_t end, Tree* tree) { + JXL_ASSERT(trees.size() + 1 == tree_splits.size()); + JXL_ASSERT(end > begin); + JXL_ASSERT(end <= trees.size()); + if (end == begin + 1) { + // Insert the tree, adding the opportune offset to all child nodes. + // This will make the leaf IDs wrong, but subsequent roundtripping will fix + // them. + size_t sz = tree->size(); + tree->insert(tree->end(), trees[begin].begin(), trees[begin].end()); + for (size_t i = sz; i < tree->size(); i++) { + (*tree)[i].lchild += sz; + (*tree)[i].rchild += sz; + } + return; + } + size_t mid = (begin + end) / 2; + size_t splitval = tree_splits[mid] - 1; + size_t cur = tree->size(); + tree->emplace_back(1 /*stream_id*/, splitval, 0, 0, Predictor::Zero, 0, 1); + (*tree)[cur].lchild = tree->size(); + MergeTrees(trees, tree_splits, mid, end, tree); + (*tree)[cur].rchild = tree->size(); + MergeTrees(trees, tree_splits, begin, mid, tree); +} + +void QuantizeChannel(Channel& ch, const int q) { + if (q == 1) return; + for (size_t y = 0; y < ch.plane.ysize(); y++) { + pixel_type* row = ch.plane.Row(y); + for (size_t x = 0; x < ch.plane.xsize(); x++) { + if (row[x] < 0) { + row[x] = -((-row[x] + q / 2) / q) * q; + } else { + row[x] = ((row[x] + q / 2) / q) * q; + } + } + } +} + +// convert binary32 float that corresponds to custom [bits]-bit float (with +// [exp_bits] exponent bits) to a [bits]-bit integer representation that should +// fit in pixel_type +Status float_to_int(const float* const row_in, pixel_type* const row_out, + size_t xsize, unsigned int bits, unsigned int exp_bits, + bool fp, double dfactor) { + JXL_ASSERT(sizeof(pixel_type) * 8 >= bits); + if (!fp) { + if (bits > 22) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * dfactor + (row_in[x] < 0 ? -0.5 : 0.5); + } + } else { + float factor = dfactor; + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * factor + (row_in[x] < 0 ? -0.5f : 0.5f); + } + } + return true; + } + if (bits == 32 && fp) { + JXL_ASSERT(exp_bits == 8); + memcpy((void*)row_out, (const void*)row_in, 4 * xsize); + return true; + } + + int exp_bias = (1 << (exp_bits - 1)) - 1; + int max_exp = (1 << exp_bits) - 1; + uint32_t sign = (1u << (bits - 1)); + int mant_bits = bits - exp_bits - 1; + int mant_shift = 23 - mant_bits; + for (size_t x = 0; x < xsize; ++x) { + uint32_t f; + memcpy(&f, &row_in[x], 4); + int signbit = (f >> 31); + f &= 0x7fffffff; + if (f == 0) { + row_out[x] = (signbit ? sign : 0); + continue; + } + int exp = (f >> 23) - 127; + if (exp == 128) return JXL_FAILURE("Inf/NaN not allowed"); + int mantissa = (f & 0x007fffff); + // broke up the binary32 into its parts, now reassemble into + // arbitrary float + exp += exp_bias; + if (exp < 0) { // will become a subnormal number + // add implicit leading 1 to mantissa + mantissa |= 0x00800000; + if (exp < -mant_bits) { + return JXL_FAILURE( + "Invalid float number: %g cannot be represented with %i " + "exp_bits and %i mant_bits (exp %i)", + row_in[x], exp_bits, mant_bits, exp); + } + mantissa >>= 1 - exp; + exp = 0; + } + // exp should be representable in exp_bits, otherwise input was + // invalid + if (exp > max_exp) return JXL_FAILURE("Invalid float exponent"); + if (mantissa & ((1 << mant_shift) - 1)) { + return JXL_FAILURE("%g is losing precision (mant: %x)", row_in[x], + mantissa); + } + mantissa >>= mant_shift; + f = (signbit ? sign : 0); + f |= (exp << mant_bits); + f |= mantissa; + row_out[x] = (pixel_type)f; + } + return true; +} +} // namespace + +ModularFrameEncoder::ModularFrameEncoder(const FrameHeader& frame_header, + const CompressParams& cparams_orig) + : frame_dim_(frame_header.ToFrameDimensions()), cparams_(cparams_orig) { + size_t num_streams = + ModularStreamId::Num(frame_dim_, frame_header.passes.num_passes); + if (cparams_.IsLossless()) { + switch (cparams_.decoding_speed_tier) { + case 0: + break; + case 1: + cparams_.options.wp_tree_mode = ModularOptions::TreeMode::kWPOnly; + break; + case 2: { + cparams_.options.wp_tree_mode = ModularOptions::TreeMode::kGradientOnly; + cparams_.options.predictor = Predictor::Gradient; + break; + } + case 3: { // LZ77, no Gradient. + cparams_.options.nb_repeats = 0; + cparams_.options.predictor = Predictor::Gradient; + break; + } + default: { // LZ77, no predictor. + cparams_.options.nb_repeats = 0; + cparams_.options.predictor = Predictor::Zero; + break; + } + } + } + if (cparams_.decoding_speed_tier >= 1 && cparams_.responsive && + cparams_.IsLossless()) { + cparams_.options.tree_kind = + ModularOptions::TreeKind::kTrivialTreeNoPredictor; + cparams_.options.nb_repeats = 0; + } + stream_images_.resize(num_streams); + + // use a sensible default if nothing explicit is specified: + // Squeeze for lossy, no squeeze for lossless + if (cparams_.responsive < 0) { + if (cparams_.IsLossless()) { + cparams_.responsive = 0; + } else { + cparams_.responsive = 1; + } + } + + if (cparams_.speed_tier > SpeedTier::kWombat) { + cparams_.options.splitting_heuristics_node_threshold = 192; + } else { + cparams_.options.splitting_heuristics_node_threshold = 96; + } + { + // Set properties. + std::vector prop_order; + if (cparams_.responsive) { + // Properties in order of their likelihood of being useful for Squeeze + // residuals. + prop_order = {0, 1, 4, 5, 6, 7, 8, 15, 9, 10, 11, 12, 13, 14, 2, 3}; + } else { + // Same, but for the non-Squeeze case. + prop_order = {0, 1, 15, 9, 10, 11, 12, 13, 14, 2, 3, 4, 5, 6, 7, 8}; + } + switch (cparams_.speed_tier) { + case SpeedTier::kSquirrel: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 8); + cparams_.options.max_property_values = 32; + break; + case SpeedTier::kKitten: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 10); + cparams_.options.max_property_values = 64; + break; + case SpeedTier::kTortoise: + cparams_.options.splitting_heuristics_properties = prop_order; + cparams_.options.max_property_values = 256; + break; + default: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 6); + cparams_.options.max_property_values = 16; + break; + } + if (cparams_.speed_tier > SpeedTier::kTortoise) { + // Gradient in previous channels. + for (int i = 0; i < cparams_.options.max_properties; i++) { + cparams_.options.splitting_heuristics_properties.push_back( + kNumNonrefProperties + i * 4 + 3); + } + } else { + // All the extra properties in Tortoise mode. + for (int i = 0; i < cparams_.options.max_properties * 4; i++) { + cparams_.options.splitting_heuristics_properties.push_back( + kNumNonrefProperties + i); + } + } + } + + if (cparams_.options.predictor == static_cast(-1)) { + // no explicit predictor(s) given, set a good default + if ((cparams_.speed_tier <= SpeedTier::kTortoise || + cparams_.modular_mode == false) && + cparams_.IsLossless() && cparams_.responsive == false) { + // TODO(veluca): allow all predictors that don't break residual + // multipliers in lossy mode. + cparams_.options.predictor = Predictor::Variable; + } else if (cparams_.responsive || cparams_.lossy_palette) { + // zero predictor for Squeeze residues and lossy palette + cparams_.options.predictor = Predictor::Zero; + } else if (!cparams_.IsLossless()) { + // If not responsive and lossy. TODO(veluca): use near_lossless instead? + cparams_.options.predictor = Predictor::Gradient; + } else if (cparams_.speed_tier < SpeedTier::kFalcon) { + // try median and weighted predictor for anything else + cparams_.options.predictor = Predictor::Best; + } else if (cparams_.speed_tier == SpeedTier::kFalcon) { + // just weighted predictor in falcon mode + cparams_.options.predictor = Predictor::Weighted; + } else if (cparams_.speed_tier > SpeedTier::kFalcon) { + // just gradient predictor in thunder mode + cparams_.options.predictor = Predictor::Gradient; + } + } else { + delta_pred_ = cparams_.options.predictor; + if (cparams_.lossy_palette) cparams_.options.predictor = Predictor::Zero; + } + if (!cparams_.IsLossless()) { + if (cparams_.options.predictor == Predictor::Weighted || + cparams_.options.predictor == Predictor::Variable || + cparams_.options.predictor == Predictor::Best) + cparams_.options.predictor = Predictor::Zero; + } + tree_splits_.push_back(0); + if (cparams_.modular_mode == false) { + cparams_.options.fast_decode_multiplier = 1.0f; + tree_splits_.push_back(ModularStreamId::VarDCTDC(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::ModularDC(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::ACMetadata(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::QuantTable(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::ModularAC(0, 0).ID(frame_dim_)); + ac_metadata_size.resize(frame_dim_.num_dc_groups); + extra_dc_precision.resize(frame_dim_.num_dc_groups); + } + tree_splits_.push_back(num_streams); + cparams_.options.max_chan_size = frame_dim_.group_dim; + cparams_.options.group_dim = frame_dim_.group_dim; + + // TODO(veluca): figure out how to use different predictor sets per channel. + stream_options_.resize(num_streams, cparams_.options); +} + +bool do_transform(Image& image, const Transform& tr, + const weighted::Header& wp_header, + jxl::ThreadPool* pool = nullptr, bool force_jxlart = false) { + Transform t = tr; + bool did_it = true; + if (force_jxlart) { + if (!t.MetaApply(image)) return false; + } else { + did_it = TransformForward(t, image, wp_header, pool); + } + if (did_it) image.transform.push_back(t); + return did_it; +} + +Status ModularFrameEncoder::ComputeEncodingData( + const FrameHeader& frame_header, const ImageMetadata& metadata, + Image3F* JXL_RESTRICT color, const std::vector& extra_channels, + PassesEncoderState* JXL_RESTRICT enc_state, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out, bool do_color) { + JXL_DEBUG_V(6, "Computing modular encoding data for frame %s", + frame_header.DebugString().c_str()); + + if (do_color && frame_header.loop_filter.gab) { + GaborishInverse(color, 0.9908511000000001f, pool); + } + + if (do_color && metadata.bit_depth.bits_per_sample <= 16 && + cparams_.speed_tier < SpeedTier::kCheetah && + cparams_.decoding_speed_tier < 2) { + FindBestPatchDictionary(*color, enc_state, cms, nullptr, aux_out, + cparams_.color_transform == ColorTransform::kXYB); + PatchDictionaryEncoder::SubtractFrom( + enc_state->shared.image_features.patches, color); + } + + // Convert ImageBundle to modular Image object + const size_t xsize = frame_dim_.xsize; + const size_t ysize = frame_dim_.ysize; + + int nb_chans = 3; + if (metadata.color_encoding.IsGray() && + cparams_.color_transform == ColorTransform::kNone) { + nb_chans = 1; + } + if (!do_color) nb_chans = 0; + + nb_chans += extra_channels.size(); + + bool fp = metadata.bit_depth.floating_point_sample && + cparams_.color_transform != ColorTransform::kXYB; + + // bits_per_sample is just metadata for XYB images. + if (metadata.bit_depth.bits_per_sample >= 32 && do_color && + cparams_.color_transform != ColorTransform::kXYB) { + if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { + return JXL_FAILURE("uint32_t not supported in enc_modular"); + } else if (metadata.bit_depth.bits_per_sample > 32) { + return JXL_FAILURE("bits_per_sample > 32 not supported"); + } + } + + // in the non-float case, there is an implicit 0 sign bit + int max_bitdepth = + do_color ? metadata.bit_depth.bits_per_sample + (fp ? 0 : 1) : 0; + Image& gi = stream_images_[0]; + gi = Image(xsize, ysize, metadata.bit_depth.bits_per_sample, nb_chans); + int c = 0; + if (cparams_.color_transform == ColorTransform::kXYB && + cparams_.modular_mode == true) { + float enc_factors[3] = {32768.0f, 2048.0f, 2048.0f}; + if (cparams_.butteraugli_distance > 0 && !cparams_.responsive) { + // quantize XYB here and then treat it as a lossless image + enc_factors[0] *= 1.f / (1.f + 23.f * cparams_.butteraugli_distance); + enc_factors[1] *= 1.f / (1.f + 14.f * cparams_.butteraugli_distance); + enc_factors[2] *= 1.f / (1.f + 14.f * cparams_.butteraugli_distance); + cparams_.butteraugli_distance = 0; + } + if (cparams_.manual_xyb_factors.size() == 3) { + DequantMatricesSetCustomDC(&enc_state->shared.matrices, + cparams_.manual_xyb_factors.data()); + // TODO(jon): update max_bitdepth in this case + } else { + DequantMatricesSetCustomDC(&enc_state->shared.matrices, enc_factors); + max_bitdepth = 12; + } + } + pixel_type maxval = gi.bitdepth < 32 ? (1u << gi.bitdepth) - 1 : 0; + if (do_color) { + for (; c < 3; c++) { + if (metadata.color_encoding.IsGray() && + cparams_.color_transform == ColorTransform::kNone && + c != (cparams_.color_transform == ColorTransform::kXYB ? 1 : 0)) + continue; + int c_out = c; + // XYB is encoded as YX(B-Y) + if (cparams_.color_transform == ColorTransform::kXYB && c < 2) + c_out = 1 - c_out; + double factor = maxval; + if (cparams_.color_transform == ColorTransform::kXYB) + factor = enc_state->shared.matrices.InvDCQuant(c); + if (c == 2 && cparams_.color_transform == ColorTransform::kXYB) { + JXL_ASSERT(!fp); + for (size_t y = 0; y < ysize; ++y) { + const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y); + pixel_type* const JXL_RESTRICT row_Y = gi.channel[0].Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * factor + 0.5f; + row_out[x] -= row_Y[x]; + // zero the lsb of B + row_out[x] = row_out[x] / 2 * 2; + } + } + } else { + int bits = metadata.bit_depth.bits_per_sample; + int exp_bits = metadata.bit_depth.exponent_bits_per_sample; + gi.channel[c_out].hshift = + enc_state->shared.frame_header.chroma_subsampling.HShift(c); + gi.channel[c_out].vshift = + enc_state->shared.frame_header.chroma_subsampling.VShift(c); + size_t xsize_shifted = DivCeil(xsize, 1 << gi.channel[c_out].hshift); + size_t ysize_shifted = DivCeil(ysize, 1 << gi.channel[c_out].vshift); + gi.channel[c_out].shrink(xsize_shifted, ysize_shifted); + std::atomic has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const int task, const int thread) { + const size_t y = task; + const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y); + if (!float_to_int(row_in, row_out, xsize_shifted, bits, exp_bits, + fp, factor)) { + has_error = true; + }; + }, + "float2int")); + if (has_error) { + return JXL_FAILURE("Error in float to integer conversion"); + } + } + } + if (metadata.color_encoding.IsGray() && + cparams_.color_transform == ColorTransform::kNone) + c = 1; + } + + for (size_t ec = 0; ec < extra_channels.size(); ec++, c++) { + const ExtraChannelInfo& eci = metadata.extra_channel_info[ec]; + size_t ecups = frame_header.extra_channel_upsampling[ec]; + gi.channel[c].shrink(DivCeil(frame_dim_.xsize_upsampled, ecups), + DivCeil(frame_dim_.ysize_upsampled, ecups)); + gi.channel[c].hshift = gi.channel[c].vshift = + CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling); + + int bits = eci.bit_depth.bits_per_sample; + int exp_bits = eci.bit_depth.exponent_bits_per_sample; + bool fp = eci.bit_depth.floating_point_sample; + double factor = (fp ? 1 : ((1u << eci.bit_depth.bits_per_sample) - 1)); + if (bits + (fp ? 0 : 1) > max_bitdepth) max_bitdepth = bits + (fp ? 0 : 1); + std::atomic has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, gi.channel[c].plane.ysize(), ThreadPool::NoInit, + [&](const int task, const int thread) { + const size_t y = task; + const float* const JXL_RESTRICT row_in = extra_channels[ec].Row(y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c].Row(y); + if (!float_to_int(row_in, row_out, gi.channel[c].plane.xsize(), bits, + exp_bits, fp, factor)) { + has_error = true; + }; + }, + "float2int")); + if (has_error) return JXL_FAILURE("Error in float to integer conversion"); + } + JXL_ASSERT(c == nb_chans); + + int level_max_bitdepth = (cparams_.level == 5 ? 16 : 32); + if (max_bitdepth > level_max_bitdepth) + return JXL_FAILURE( + "Bitdepth too high for level %i (need %i bits, have only %i in this " + "level)", + cparams_.level, max_bitdepth, level_max_bitdepth); + + // Set options and apply transformations + + if (cparams_.butteraugli_distance > 0) { + if (cparams_.palette_colors != 0) { + JXL_DEBUG_V(3, "Lossy encode, not doing palette transforms"); + } + if (cparams_.color_transform == ColorTransform::kXYB) { + cparams_.channel_colors_pre_transform_percent = 0; + } + cparams_.channel_colors_percent = 0; + cparams_.palette_colors = 0; + cparams_.lossy_palette = false; + } + + // if few colors, do all-channel palette before trying channel palette + // Logic is as follows: + // - if you can make a palette with few colors (arbitrary threshold: 200), + // then you can also make channel palettes, but they will just be extra + // signaling cost for almost no benefit + // - if the palette needs more colors, then channel palette might help to + // reduce palette signaling cost + if (cparams_.palette_colors != 0 && + cparams_.speed_tier < SpeedTier::kFalcon) { + // all-channel palette (e.g. RGBA) + if (gi.channel.size() > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.channel.size() - gi.nb_meta_channels; + maybe_palette.nb_colors = + std::min(std::min(200, (int)(xsize * ysize / 8)), + std::abs(cparams_.palette_colors) / 16); + maybe_palette.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette.lossy_palette = false; + do_transform(gi, maybe_palette, weighted::Header(), pool); + } + } + + // Global channel palette + if (cparams_.channel_colors_pre_transform_percent > 0 && + !cparams_.lossy_palette && + (cparams_.speed_tier <= SpeedTier::kThunder || + (do_color && metadata.bit_depth.bits_per_sample > 8))) { + // single channel palette (like FLIF's ChannelCompact) + size_t nb_channels = gi.channel.size() - gi.nb_meta_channels; + int orig_bitdepth = max_bitdepth; + max_bitdepth = 0; + for (size_t i = 0; i < nb_channels; i++) { + int32_t min, max; + compute_minmax(gi.channel[gi.nb_meta_channels + i], &min, &max); + int64_t colors = max - min + 1; + JXL_DEBUG_V(10, "Channel %" PRIuS ": range=%i..%i", i, min, max); + Transform maybe_palette_1(TransformId::kPalette); + maybe_palette_1.begin_c = i + gi.nb_meta_channels; + maybe_palette_1.num_c = 1; + // simple heuristic: if less than X percent of the values in the range + // actually occur, it is probably worth it to do a compaction + // (but only if the channel palette is less than 6% the size of the + // image itself) + maybe_palette_1.nb_colors = std::min( + (int)(xsize * ysize / 16), + (int)(cparams_.channel_colors_pre_transform_percent / 100. * colors)); + if (do_transform(gi, maybe_palette_1, weighted::Header(), pool)) { + // effective bit depth is lower, adjust quantization accordingly + compute_minmax(gi.channel[gi.nb_meta_channels + i], &min, &max); + if (max < maxval) maxval = max; + int ch_bitdepth = + (max > 0 ? CeilLog2Nonzero(static_cast(max)) : 0); + if (ch_bitdepth > max_bitdepth) max_bitdepth = ch_bitdepth; + } else + max_bitdepth = orig_bitdepth; + } + } + + // Global palette + if ((cparams_.palette_colors != 0 || cparams_.lossy_palette) && + cparams_.speed_tier < SpeedTier::kFalcon) { + // all-channel palette (e.g. RGBA) + if (gi.channel.size() - gi.nb_meta_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.channel.size() - gi.nb_meta_channels; + maybe_palette.nb_colors = + std::min((int)(xsize * ysize / 8), std::abs(cparams_.palette_colors)); + maybe_palette.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette.lossy_palette = + (cparams_.lossy_palette && maybe_palette.num_c == 3); + if (maybe_palette.lossy_palette) { + maybe_palette.predictor = delta_pred_; + } + // TODO(veluca): use a custom weighted header if using the weighted + // predictor. + do_transform(gi, maybe_palette, weighted::Header(), pool, + cparams_.options.zero_tokens); + } + // all-minus-one-channel palette (RGB with separate alpha, or CMY with + // separate K) + if (gi.channel.size() - gi.nb_meta_channels > 3) { + Transform maybe_palette_3(TransformId::kPalette); + maybe_palette_3.begin_c = gi.nb_meta_channels; + maybe_palette_3.num_c = gi.channel.size() - gi.nb_meta_channels - 1; + maybe_palette_3.nb_colors = + std::min((int)(xsize * ysize / 8), std::abs(cparams_.palette_colors)); + maybe_palette_3.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette_3.lossy_palette = cparams_.lossy_palette; + if (maybe_palette_3.lossy_palette) { + maybe_palette_3.predictor = delta_pred_; + } + do_transform(gi, maybe_palette_3, weighted::Header(), pool, + cparams_.options.zero_tokens); + } + } + + // don't do an RCT if we're short on bits + if (cparams_.color_transform == ColorTransform::kNone && do_color && + gi.channel.size() - gi.nb_meta_channels >= 3 && + max_bitdepth + 1 < level_max_bitdepth) { + if (cparams_.colorspace < 0 && + (!cparams_.IsLossless() || cparams_.speed_tier > SpeedTier::kHare)) { + Transform ycocg{TransformId::kRCT}; + ycocg.rct_type = 6; + ycocg.begin_c = gi.nb_meta_channels; + do_transform(gi, ycocg, weighted::Header(), pool); + max_bitdepth++; + } else if (cparams_.colorspace > 0) { + Transform sg(TransformId::kRCT); + sg.begin_c = gi.nb_meta_channels; + sg.rct_type = cparams_.colorspace; + do_transform(gi, sg, weighted::Header(), pool); + max_bitdepth++; + } + } + + // don't do squeeze if we don't have some spare bits + if (cparams_.responsive && !gi.channel.empty() && + max_bitdepth + 2 < level_max_bitdepth) { + Transform t(TransformId::kSqueeze); + t.squeezes = cparams_.squeezes; + do_transform(gi, t, weighted::Header(), pool); + max_bitdepth += 2; + } + + if (max_bitdepth + 1 > level_max_bitdepth) { + // force no group RCTs if we don't have a spare bit + cparams_.colorspace = 0; + } + JXL_ASSERT(max_bitdepth <= level_max_bitdepth); + + std::vector quants; + + if (cparams_.butteraugli_distance > 0) { + quants.resize(gi.channel.size(), 1); + float quality = 0.25f * cparams_.butteraugli_distance; + JXL_DEBUG_V(2, + "Adding quantization constants corresponding to distance %.3f ", + quality); + if (!cparams_.responsive) { + JXL_DEBUG_V(1, + "Warning: lossy compression without Squeeze " + "transform is just color quantization."); + quality *= 0.1f; + } + if (cparams_.color_transform != ColorTransform::kXYB) { + quality *= maxval / 255.f; + } + if (cparams_.options.nb_repeats == 0) { + return JXL_FAILURE("nb_repeats = 0 not supported with modular lossy!"); + } + for (uint32_t i = gi.nb_meta_channels; i < gi.channel.size(); i++) { + Channel& ch = gi.channel[i]; + int shift = ch.hshift + ch.vshift; // number of pixel halvings + if (shift > 16) shift = 16; + if (shift > 0) shift--; + int q; + // assuming default Squeeze here + int component = + (do_color ? 0 : 3) + ((i - gi.nb_meta_channels) % nb_chans); + // last 4 channels are final chroma residuals + if (nb_chans > 2 && i >= gi.channel.size() - 4 && cparams_.responsive) { + component = 1; + } + if (cparams_.color_transform == ColorTransform::kXYB && component < 3) { + q = quality * squeeze_quality_factor_xyb * + squeeze_xyb_qtable[component][shift]; + } else { + if (cparams_.colorspace != 0 && component > 0 && component < 3) { + q = quality * squeeze_quality_factor * squeeze_chroma_qtable[shift]; + } else { + q = quality * squeeze_quality_factor * squeeze_luma_factor * + squeeze_luma_qtable[shift]; + } + } + if (q < 1) q = 1; + QuantizeChannel(gi.channel[i], q); + quants[i] = q; + } + } + + // Fill other groups. + struct GroupParams { + Rect rect; + int minShift; + int maxShift; + ModularStreamId id; + }; + std::vector stream_params; + + stream_options_[0] = cparams_.options; + + // DC + for (size_t group_id = 0; group_id < frame_dim_.num_dc_groups; group_id++) { + const size_t gx = group_id % frame_dim_.xsize_dc_groups; + const size_t gy = group_id / frame_dim_.xsize_dc_groups; + const Rect rect(gx * frame_dim_.dc_group_dim, gy * frame_dim_.dc_group_dim, + frame_dim_.dc_group_dim, frame_dim_.dc_group_dim); + // minShift==3 because (frame_dim.dc_group_dim >> 3) == frame_dim.group_dim + // maxShift==1000 is infinity + stream_params.push_back( + GroupParams{rect, 3, 1000, ModularStreamId::ModularDC(group_id)}); + } + // AC global -> nothing. + // AC + for (size_t group_id = 0; group_id < frame_dim_.num_groups; group_id++) { + const size_t gx = group_id % frame_dim_.xsize_groups; + const size_t gy = group_id / frame_dim_.xsize_groups; + const Rect mrect(gx * frame_dim_.group_dim, gy * frame_dim_.group_dim, + frame_dim_.group_dim, frame_dim_.group_dim); + for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses(); + i++) { + int maxShift, minShift; + frame_header.passes.GetDownsamplingBracket(i, minShift, maxShift); + stream_params.push_back(GroupParams{ + mrect, minShift, maxShift, ModularStreamId::ModularAC(group_id, i)}); + } + } + // if there's only one group, everything ends up in GlobalModular + // in that case, also try RCTs/WP params for the one group + if (stream_params.size() == 2) { + stream_params.push_back(GroupParams{Rect(0, 0, xsize, ysize), 0, 1000, + ModularStreamId::Global()}); + } + gi_channel_.resize(stream_images_.size()); + + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, stream_params.size(), ThreadPool::NoInit, + [&](const uint32_t i, size_t /* thread */) { + stream_options_[stream_params[i].id.ID(frame_dim_)] = cparams_.options; + JXL_CHECK(PrepareStreamParams( + stream_params[i].rect, cparams_, stream_params[i].minShift, + stream_params[i].maxShift, stream_params[i].id, do_color)); + }, + "ChooseParams")); + { + // Clear out channels that have been copied to groups. + Image& full_image = stream_images_[0]; + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim_.group_dim || fc.h > frame_dim_.group_dim) break; + } + for (; c < full_image.channel.size(); c++) { + full_image.channel[c].plane = ImageI(); + } + } + + if (!quants.empty()) { + for (uint32_t stream_id = 0; stream_id < stream_images_.size(); + stream_id++) { + // skip non-modular stream_ids + if (stream_id > 0 && gi_channel_[stream_id].empty()) continue; + const Image& image = stream_images_[stream_id]; + const ModularOptions& options = stream_options_[stream_id]; + for (uint32_t i = image.nb_meta_channels; i < image.channel.size(); i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + continue; + } + if (stream_id > 0 && gi_channel_[stream_id].empty()) continue; + size_t ch_id = stream_id == 0 + ? i + : gi_channel_[stream_id][i - image.nb_meta_channels]; + uint32_t q = quants[ch_id]; + // Inform the tree splitting heuristics that each channel in each group + // used this quantization factor. This will produce a tree with the + // given multipliers. + if (multiplier_info_.empty() || + multiplier_info_.back().range[1][0] != stream_id || + multiplier_info_.back().multiplier != q) { + StaticPropRange range; + range[0] = {{i, i + 1}}; + range[1] = {{stream_id, stream_id + 1}}; + multiplier_info_.push_back({range, (uint32_t)q}); + } else { + // Previous channel in the same group had the same quantization + // factor. Don't provide two different ranges, as that creates + // unnecessary nodes. + multiplier_info_.back().range[0][1] = i + 1; + } + } + } + // Merge group+channel settings that have the same channels and quantization + // factors, to avoid unnecessary nodes. + std::sort(multiplier_info_.begin(), multiplier_info_.end(), + [](ModularMultiplierInfo a, ModularMultiplierInfo b) { + return std::make_tuple(a.range, a.multiplier) < + std::make_tuple(b.range, b.multiplier); + }); + size_t new_num = 1; + for (size_t i = 1; i < multiplier_info_.size(); i++) { + ModularMultiplierInfo& prev = multiplier_info_[new_num - 1]; + ModularMultiplierInfo& cur = multiplier_info_[i]; + if (prev.range[0] == cur.range[0] && prev.multiplier == cur.multiplier && + prev.range[1][1] == cur.range[1][0]) { + prev.range[1][1] = cur.range[1][1]; + } else { + multiplier_info_[new_num++] = multiplier_info_[i]; + } + } + multiplier_info_.resize(new_num); + } + + JXL_RETURN_IF_ERROR(ValidateChannelDimensions(gi, stream_options_[0])); + + return PrepareEncoding(frame_header, pool, enc_state->heuristics.get(), + aux_out); +} + +Status ModularFrameEncoder::PrepareEncoding(const FrameHeader& frame_header, + ThreadPool* pool, + EncoderHeuristics* heuristics, + AuxOut* aux_out) { + if (!tree_.empty()) return true; + + // Compute tree. + size_t num_streams = stream_images_.size(); + stream_headers_.resize(num_streams); + tokens_.resize(num_streams); + + if (heuristics->CustomFixedTreeLossless(frame_dim_, &tree_)) { + // Using a fixed tree. + } else if (cparams_.speed_tier < SpeedTier::kFalcon || + !cparams_.modular_mode) { + // Avoid creating a tree with leaves that don't correspond to any pixels. + std::vector useful_splits; + useful_splits.reserve(tree_splits_.size()); + for (size_t chunk = 0; chunk < tree_splits_.size() - 1; chunk++) { + bool has_pixels = false; + size_t start = tree_splits_[chunk]; + size_t stop = tree_splits_[chunk + 1]; + for (size_t i = start; i < stop; i++) { + if (!stream_images_[i].empty()) has_pixels = true; + } + if (has_pixels) { + useful_splits.push_back(tree_splits_[chunk]); + } + } + // Don't do anything if modular mode does not have any pixels in this image + if (useful_splits.empty()) return true; + useful_splits.push_back(tree_splits_.back()); + + std::atomic_flag invalid_force_wp = ATOMIC_FLAG_INIT; + + std::vector trees(useful_splits.size() - 1); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, useful_splits.size() - 1, ThreadPool::NoInit, + [&](const uint32_t chunk, size_t /* thread */) { + // TODO(veluca): parallelize more. + size_t total_pixels = 0; + uint32_t start = useful_splits[chunk]; + uint32_t stop = useful_splits[chunk + 1]; + while (start < stop && stream_images_[start].empty()) ++start; + while (start < stop && stream_images_[stop - 1].empty()) --stop; + uint32_t max_c = 0; + if (stream_options_[start].tree_kind != + ModularOptions::TreeKind::kLearn) { + for (size_t i = start; i < stop; i++) { + for (const Channel& ch : stream_images_[i].channel) { + total_pixels += ch.w * ch.h; + } + } + trees[chunk] = + PredefinedTree(stream_options_[start].tree_kind, total_pixels); + return; + } + TreeSamples tree_samples; + if (!tree_samples.SetPredictor(stream_options_[start].predictor, + stream_options_[start].wp_tree_mode)) { + invalid_force_wp.test_and_set(std::memory_order_acq_rel); + return; + } + if (!tree_samples.SetProperties( + stream_options_[start].splitting_heuristics_properties, + stream_options_[start].wp_tree_mode)) { + invalid_force_wp.test_and_set(std::memory_order_acq_rel); + return; + } + std::vector pixel_samples; + std::vector diff_samples; + std::vector group_pixel_count; + std::vector channel_pixel_count; + for (size_t i = start; i < stop; i++) { + max_c = std::max(stream_images_[i].channel.size(), max_c); + CollectPixelSamples(stream_images_[i], stream_options_[i], i, + group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples); + } + StaticPropRange range; + range[0] = {{0, max_c}}; + range[1] = {{start, stop}}; + auto local_multiplier_info = multiplier_info_; + + tree_samples.PreQuantizeProperties( + range, local_multiplier_info, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples, + stream_options_[start].max_property_values); + for (size_t i = start; i < stop; i++) { + JXL_CHECK(ModularGenericCompress( + stream_images_[i], stream_options_[i], /*writer=*/nullptr, + /*aux_out=*/nullptr, 0, i, &tree_samples, &total_pixels)); + } + + // TODO(veluca): parallelize more. + trees[chunk] = + LearnTree(std::move(tree_samples), total_pixels, + stream_options_[start], local_multiplier_info, range); + }, + "LearnTrees")); + if (invalid_force_wp.test_and_set(std::memory_order_acq_rel)) { + return JXL_FAILURE("PrepareEncoding: force_no_wp with {Weighted}"); + } + tree_.clear(); + MergeTrees(trees, useful_splits, 0, useful_splits.size() - 1, &tree_); + } else { + // Fixed tree. + size_t total_pixels = 0; + for (const Image& img : stream_images_) { + for (const Channel& ch : img.channel) { + total_pixels += ch.w * ch.h; + } + } + if (cparams_.speed_tier <= SpeedTier::kFalcon) { + tree_ = + PredefinedTree(ModularOptions::TreeKind::kWPFixedDC, total_pixels); + } else if (cparams_.speed_tier <= SpeedTier::kThunder) { + tree_ = PredefinedTree(ModularOptions::TreeKind::kGradientFixedDC, + total_pixels); + } else { + tree_ = {PropertyDecisionNode::Leaf(Predictor::Gradient)}; + } + } + tree_tokens_.resize(1); + tree_tokens_[0].clear(); + Tree decoded_tree; + TokenizeTree(tree_, &tree_tokens_[0], &decoded_tree); + JXL_ASSERT(tree_.size() == decoded_tree.size()); + tree_ = std::move(decoded_tree); + + if (WantDebugOutput(aux_out)) { + if (frame_header.dc_level > 0) { + PrintTree(tree_, aux_out->debug_prefix + "/dc_frame_level" + + std::to_string(frame_header.dc_level) + "_tree"); + } else { + PrintTree(tree_, aux_out->debug_prefix + "/global_tree"); + } + } + + image_widths_.resize(num_streams); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, num_streams, ThreadPool::NoInit, + [&](const uint32_t stream_id, size_t /* thread */) { + AuxOut my_aux_out; + if (aux_out) { + my_aux_out.dump_image = aux_out->dump_image; + my_aux_out.debug_prefix = aux_out->debug_prefix; + } + tokens_[stream_id].clear(); + JXL_CHECK(ModularGenericCompress( + stream_images_[stream_id], stream_options_[stream_id], + /*writer=*/nullptr, &my_aux_out, 0, stream_id, + /*tree_samples=*/nullptr, + /*total_pixels=*/nullptr, + /*tree=*/&tree_, /*header=*/&stream_headers_[stream_id], + /*tokens=*/&tokens_[stream_id], + /*widths=*/&image_widths_[stream_id])); + }, + "ComputeTokens")); + return true; +} + +Status ModularFrameEncoder::EncodeGlobalInfo(BitWriter* writer, + AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 1); + // If we are using brotli, or not using modular mode. + if (tree_tokens_.empty() || tree_tokens_[0].empty()) { + writer->Write(1, 0); + ReclaimAndCharge(writer, &allotment, kLayerModularTree, aux_out); + return true; + } + writer->Write(1, 1); + ReclaimAndCharge(writer, &allotment, kLayerModularTree, aux_out); + + // Write tree + HistogramParams params; + if (cparams_.speed_tier > SpeedTier::kKitten) { + params.clustering = HistogramParams::ClusteringType::kFast; + params.ans_histogram_strategy = + cparams_.speed_tier > SpeedTier::kThunder + ? HistogramParams::ANSHistogramStrategy::kFast + : HistogramParams::ANSHistogramStrategy::kApproximate; + params.lz77_method = + cparams_.decoding_speed_tier >= 3 && cparams_.modular_mode + ? (cparams_.speed_tier >= SpeedTier::kFalcon + ? HistogramParams::LZ77Method::kRLE + : HistogramParams::LZ77Method::kLZ77) + : HistogramParams::LZ77Method::kNone; + // Near-lossless DC, as well as modular mode, require choosing hybrid uint + // more carefully. + if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) || + (cparams_.modular_mode && cparams_.speed_tier < SpeedTier::kCheetah)) { + params.uint_method = HistogramParams::HybridUintMethod::kFast; + } else { + params.uint_method = HistogramParams::HybridUintMethod::kNone; + } + } else if (cparams_.speed_tier <= SpeedTier::kTortoise) { + params.lz77_method = HistogramParams::LZ77Method::kOptimal; + } else { + params.lz77_method = HistogramParams::LZ77Method::kLZ77; + } + if (cparams_.decoding_speed_tier >= 1) { + params.max_histograms = 12; + } + if (cparams_.decoding_speed_tier >= 1 && cparams_.responsive) { + params.lz77_method = cparams_.speed_tier >= SpeedTier::kCheetah + ? HistogramParams::LZ77Method::kRLE + : cparams_.speed_tier >= SpeedTier::kKitten + ? HistogramParams::LZ77Method::kLZ77 + : HistogramParams::LZ77Method::kOptimal; + } + if (cparams_.decoding_speed_tier >= 2 && cparams_.responsive) { + params.uint_method = HistogramParams::HybridUintMethod::k000; + params.force_huffman = true; + } + BuildAndEncodeHistograms(params, kNumTreeContexts, tree_tokens_, &code_, + &context_map_, writer, kLayerModularTree, aux_out); + WriteTokens(tree_tokens_[0], code_, context_map_, writer, kLayerModularTree, + aux_out); + params.image_widths = image_widths_; + // Write histograms. + BuildAndEncodeHistograms(params, (tree_.size() + 1) / 2, tokens_, &code_, + &context_map_, writer, kLayerModularGlobal, aux_out); + return true; +} + +Status ModularFrameEncoder::EncodeStream(BitWriter* writer, AuxOut* aux_out, + size_t layer, + const ModularStreamId& stream) { + size_t stream_id = stream.ID(frame_dim_); + if (stream_images_[stream_id].channel.empty()) { + return true; // Image with no channels, header never gets decoded. + } + JXL_RETURN_IF_ERROR( + Bundle::Write(stream_headers_[stream_id], writer, layer, aux_out)); + WriteTokens(tokens_[stream_id], code_, context_map_, writer, layer, aux_out); + return true; +} + +namespace { +float EstimateWPCost(const Image& img, size_t i) { + size_t extra_bits = 0; + float histo_cost = 0; + HybridUintConfig config; + int32_t cutoffs[] = {-500, -392, -255, -191, -127, -95, -63, -47, -31, + -23, -15, -11, -7, -4, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, 63, + 95, 127, 191, 255, 392, 500}; + constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1; + Histogram histo[nc] = {}; + weighted::Header wp_header; + PredictorMode(i, &wp_header); + for (const Channel& ch : img.channel) { + const intptr_t onerow = ch.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, ch.w, ch.h); + Properties properties(1); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type* JXL_RESTRICT r = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < ch.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + pixel_type guess = wp_state.Predict( + x, y, ch.w, top, left, topright, topleft, toptop, &properties, + offset); + size_t ctx = 0; + for (int c : cutoffs) { + ctx += c >= properties[0]; + } + pixel_type res = r[x] - guess; + uint32_t token, nbits, bits; + config.Encode(PackSigned(res), &token, &nbits, &bits); + histo[ctx].Add(token); + extra_bits += nbits; + wp_state.UpdateErrors(r[x], x, y, ch.w); + } + } + for (size_t h = 0; h < nc; h++) { + histo_cost += histo[h].ShannonEntropy(); + histo[h].Clear(); + } + } + return histo_cost + extra_bits; +} + +float EstimateCost(const Image& img) { + // TODO(veluca): consider SIMDfication of this code. + size_t extra_bits = 0; + float histo_cost = 0; + HybridUintConfig config; + uint32_t cutoffs[] = {0, 1, 3, 5, 7, 11, 15, 23, 31, + 47, 63, 95, 127, 191, 255, 392, 500}; + constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1; + Histogram histo[nc] = {}; + for (const Channel& ch : img.channel) { + const intptr_t onerow = ch.plane.PixelsPerRow(); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type* JXL_RESTRICT r = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + size_t maxdiff = std::max(std::max(left, top), topleft) - + std::min(std::min(left, top), topleft); + size_t ctx = 0; + for (uint32_t c : cutoffs) { + ctx += c > maxdiff; + } + pixel_type res = r[x] - ClampedGradient(top, left, topleft); + uint32_t token, nbits, bits; + config.Encode(PackSigned(res), &token, &nbits, &bits); + histo[ctx].Add(token); + extra_bits += nbits; + } + } + for (size_t h = 0; h < nc; h++) { + histo_cost += histo[h].ShannonEntropy(); + histo[h].Clear(); + } + } + return histo_cost + extra_bits; +} + +} // namespace + +Status ModularFrameEncoder::PrepareStreamParams(const Rect& rect, + const CompressParams& cparams_, + int minShift, int maxShift, + const ModularStreamId& stream, + bool do_color) { + size_t stream_id = stream.ID(frame_dim_); + Image& full_image = stream_images_[0]; + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + Image& gi = stream_images_[stream_id]; + if (stream_id > 0) { + gi = Image(xsize, ysize, full_image.bitdepth, 0); + // start at the first bigger-than-frame_dim.group_dim non-metachannel + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim_.group_dim || fc.h > frame_dim_.group_dim) break; + } + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + gi_channel_[stream_id].push_back(c); + Channel gc(r.xsize(), r.ysize()); + gc.hshift = fc.hshift; + gc.vshift = fc.vshift; + for (size_t y = 0; y < r.ysize(); ++y) { + memcpy(gc.Row(y), r.ConstRow(fc.plane, y), + r.xsize() * sizeof(pixel_type)); + } + gi.channel.emplace_back(std::move(gc)); + } + + if (gi.channel.empty()) return true; + // Do some per-group transforms + + // Local palette + // TODO(veluca): make this work with quantize-after-prediction in lossy + // mode. + if (cparams_.butteraugli_distance == 0.f && cparams_.palette_colors != 0 && + cparams_.speed_tier < SpeedTier::kCheetah) { + // all-channel palette (e.g. RGBA) + if (gi.channel.size() - gi.nb_meta_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.channel.size() - gi.nb_meta_channels; + maybe_palette.nb_colors = std::abs(cparams_.palette_colors); + maybe_palette.ordered_palette = cparams_.palette_colors >= 0; + do_transform(gi, maybe_palette, weighted::Header()); + } + // all-minus-one-channel palette (RGB with separate alpha, or CMY with + // separate K) + if (gi.channel.size() - gi.nb_meta_channels > 3) { + Transform maybe_palette_3(TransformId::kPalette); + maybe_palette_3.begin_c = gi.nb_meta_channels; + maybe_palette_3.num_c = gi.channel.size() - gi.nb_meta_channels - 1; + maybe_palette_3.nb_colors = std::abs(cparams_.palette_colors); + maybe_palette_3.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette_3.lossy_palette = cparams_.lossy_palette; + if (maybe_palette_3.lossy_palette) { + maybe_palette_3.predictor = Predictor::Weighted; + } + do_transform(gi, maybe_palette_3, weighted::Header()); + } + } + + // Local channel palette + if (cparams_.channel_colors_percent > 0 && + cparams_.butteraugli_distance == 0.f && !cparams_.lossy_palette && + cparams_.speed_tier < SpeedTier::kCheetah && + !(cparams_.responsive && cparams_.decoding_speed_tier >= 1)) { + // single channel palette (like FLIF's ChannelCompact) + size_t nb_channels = gi.channel.size() - gi.nb_meta_channels; + for (size_t i = 0; i < nb_channels; i++) { + int32_t min, max; + compute_minmax(gi.channel[gi.nb_meta_channels + i], &min, &max); + int colors = max - min + 1; + JXL_DEBUG_V(10, "Channel %" PRIuS ": range=%i..%i", i, min, max); + Transform maybe_palette_1(TransformId::kPalette); + maybe_palette_1.begin_c = i + gi.nb_meta_channels; + maybe_palette_1.num_c = 1; + // simple heuristic: if less than X percent of the values in the range + // actually occur, it is probably worth it to do a compaction + // (but only if the channel palette is less than 80% the size of the + // image itself) + maybe_palette_1.nb_colors = + std::min((int)(xsize * ysize * 0.8), + (int)(cparams_.channel_colors_percent / 100. * colors)); + do_transform(gi, maybe_palette_1, weighted::Header()); + } + } + } + + // lossless and no specific color transform specified: try Nothing, YCoCg, + // and 17 RCTs + if (cparams_.color_transform == ColorTransform::kNone && + cparams_.IsLossless() && cparams_.colorspace < 0 && + gi.channel.size() - gi.nb_meta_channels >= 3 && + cparams_.responsive == false && do_color && + cparams_.speed_tier <= SpeedTier::kHare) { + Transform sg(TransformId::kRCT); + sg.begin_c = gi.nb_meta_channels; + size_t nb_rcts_to_try = 0; + switch (cparams_.speed_tier) { + case SpeedTier::kLightning: + case SpeedTier::kThunder: + case SpeedTier::kFalcon: + case SpeedTier::kCheetah: + nb_rcts_to_try = 0; // Just do global YCoCg + break; + case SpeedTier::kHare: + nb_rcts_to_try = 4; + break; + case SpeedTier::kWombat: + nb_rcts_to_try = 5; + break; + case SpeedTier::kSquirrel: + nb_rcts_to_try = 7; + break; + case SpeedTier::kKitten: + nb_rcts_to_try = 9; + break; + case SpeedTier::kTortoise: + nb_rcts_to_try = 19; + break; + } + float best_cost = std::numeric_limits::max(); + size_t best_rct = 0; + // These should be 19 actually different transforms; the remaining ones + // are equivalent to one of these (note that the first two are do-nothing + // and YCoCg) modulo channel reordering (which only matters in the case of + // MA-with-prev-channels-properties) and/or sign (e.g. RmG vs GmR) + for (int i : {0 * 7 + 0, 0 * 7 + 6, 0 * 7 + 5, 1 * 7 + 3, 3 * 7 + 5, + 5 * 7 + 5, 1 * 7 + 5, 2 * 7 + 5, 1 * 7 + 1, 0 * 7 + 4, + 1 * 7 + 2, 2 * 7 + 1, 2 * 7 + 2, 2 * 7 + 3, 4 * 7 + 4, + 4 * 7 + 5, 0 * 7 + 2, 0 * 7 + 1, 0 * 7 + 3}) { + if (nb_rcts_to_try == 0) break; + sg.rct_type = i; + nb_rcts_to_try--; + if (do_transform(gi, sg, weighted::Header())) { + float cost = EstimateCost(gi); + if (cost < best_cost) { + best_rct = i; + best_cost = cost; + } + Transform t = gi.transform.back(); + JXL_RETURN_IF_ERROR(t.Inverse(gi, weighted::Header(), nullptr)); + gi.transform.pop_back(); + } + } + // Apply the best RCT to the image for future encoding. + sg.rct_type = best_rct; + do_transform(gi, sg, weighted::Header()); + } else { + // No need to try anything, just use the default options. + } + size_t nb_wp_modes = 1; + if (cparams_.speed_tier <= SpeedTier::kTortoise) { + nb_wp_modes = 5; + } else if (cparams_.speed_tier <= SpeedTier::kKitten) { + nb_wp_modes = 2; + } + if (nb_wp_modes > 1 && + (stream_options_[stream_id].predictor == Predictor::Weighted || + stream_options_[stream_id].predictor == Predictor::Best || + stream_options_[stream_id].predictor == Predictor::Variable)) { + float best_cost = std::numeric_limits::max(); + stream_options_[stream_id].wp_mode = 0; + for (size_t i = 0; i < nb_wp_modes; i++) { + float cost = EstimateWPCost(gi, i); + if (cost < best_cost) { + best_cost = cost; + stream_options_[stream_id].wp_mode = i; + } + } + } + return true; +} + +constexpr float q_deadzone = 0.62f; +int QuantizeWP(const int32_t* qrow, size_t onerow, size_t c, size_t x, size_t y, + size_t w, weighted::State* wp_state, float value, + float inv_factor) { + float svalue = value * inv_factor; + PredictionResult pred = + PredictNoTreeWP(w, qrow + x, onerow, x, y, Predictor::Weighted, wp_state); + svalue -= pred.guess; + if (svalue > -q_deadzone && svalue < q_deadzone) svalue = 0; + int residual = roundf(svalue); + if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2; + return residual + pred.guess; +} + +int QuantizeGradient(const int32_t* qrow, size_t onerow, size_t c, size_t x, + size_t y, size_t w, float value, float inv_factor) { + float svalue = value * inv_factor; + PredictionResult pred = + PredictNoTreeNoWP(w, qrow + x, onerow, x, y, Predictor::Gradient); + svalue -= pred.guess; + if (svalue > -q_deadzone && svalue < q_deadzone) svalue = 0; + int residual = roundf(svalue); + if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2; + return residual + pred.guess; +} + +void ModularFrameEncoder::AddVarDCTDC(const Image3F& dc, size_t group_index, + bool nl_dc, PassesEncoderState* enc_state, + bool jpeg_transcode) { + const Rect r = enc_state->shared.DCGroupRect(group_index); + extra_dc_precision[group_index] = nl_dc ? 1 : 0; + float mul = 1 << extra_dc_precision[group_index]; + + size_t stream_id = ModularStreamId::VarDCTDC(group_index).ID(frame_dim_); + stream_options_[stream_id].max_chan_size = 0xFFFFFF; + stream_options_[stream_id].predictor = Predictor::Weighted; + stream_options_[stream_id].wp_tree_mode = ModularOptions::TreeMode::kWPOnly; + if (cparams_.speed_tier >= SpeedTier::kSquirrel) { + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kWPFixedDC; + } + if (cparams_.speed_tier < SpeedTier::kSquirrel && !nl_dc) { + stream_options_[stream_id].predictor = + (cparams_.speed_tier < SpeedTier::kKitten ? Predictor::Variable + : Predictor::Best); + stream_options_[stream_id].wp_tree_mode = + ModularOptions::TreeMode::kDefault; + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kLearn; + } + if (cparams_.decoding_speed_tier >= 1) { + stream_options_[stream_id].tree_kind = + ModularOptions::TreeKind::kGradientFixedDC; + } + + stream_images_[stream_id] = Image(r.xsize(), r.ysize(), 8, 3); + if (nl_dc && stream_options_[stream_id].tree_kind == + ModularOptions::TreeKind::kGradientFixedDC) { + JXL_ASSERT(enc_state->shared.frame_header.chroma_subsampling.Is444()); + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + size_t stride = stream_images_[stream_id] + .channel[c < 2 ? c ^ 1 : c] + .plane.PixelsPerRow(); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeGradient(quant_row, stride, c, x, y, + r.xsize(), row[x], inv_factor); + } + } else { + int32_t* quant_row_y = + stream_images_[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeGradient( + quant_row, stride, c, x, y, r.xsize(), + row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor); + } + } + } + } + } else if (nl_dc) { + JXL_ASSERT(enc_state->shared.frame_header.chroma_subsampling.Is444()); + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + weighted::Header header; + weighted::State wp_state(header, r.xsize(), r.ysize()); + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + size_t stride = stream_images_[stream_id] + .channel[c < 2 ? c ^ 1 : c] + .plane.PixelsPerRow(); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeWP(quant_row, stride, c, x, y, r.xsize(), + &wp_state, row[x], inv_factor); + wp_state.UpdateErrors(quant_row[x], x, y, r.xsize()); + } + } else { + int32_t* quant_row_y = + stream_images_[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeWP( + quant_row, stride, c, x, y, r.xsize(), &wp_state, + row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor); + wp_state.UpdateErrors(quant_row[x], x, y, r.xsize()); + } + } + } + } + } else if (enc_state->shared.frame_header.chroma_subsampling.Is444()) { + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = roundf(row[x] * inv_factor); + } + } else { + int32_t* quant_row_y = + stream_images_[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = + roundf((row[x] - quant_row_y[x] * (y_factor * cfl_factor)) * + inv_factor); + } + } + } + } + } else { + for (size_t c : {1, 0, 2}) { + Rect rect( + r.x0() >> enc_state->shared.frame_header.chroma_subsampling.HShift(c), + r.y0() >> enc_state->shared.frame_header.chroma_subsampling.VShift(c), + r.xsize() >> + enc_state->shared.frame_header.chroma_subsampling.HShift(c), + r.ysize() >> + enc_state->shared.frame_header.chroma_subsampling.VShift(c)); + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + size_t ys = rect.ysize(); + size_t xs = rect.xsize(); + Channel& ch = stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c]; + ch.w = xs; + ch.h = ys; + ch.shrink(); + for (size_t y = 0; y < ys; y++) { + int32_t* quant_row = ch.plane.Row(y); + const float* row = rect.ConstPlaneRow(dc, c, y); + for (size_t x = 0; x < xs; x++) { + quant_row[x] = roundf(row[x] * inv_factor); + } + } + } + } + + DequantDC(r, &enc_state->shared.dc_storage, &enc_state->shared.quant_dc, + stream_images_[stream_id], enc_state->shared.quantizer.MulDC(), + 1.0 / mul, enc_state->shared.cmap.DCFactors(), + enc_state->shared.frame_header.chroma_subsampling, + enc_state->shared.block_ctx_map); +} + +void ModularFrameEncoder::AddACMetadata(size_t group_index, bool jpeg_transcode, + PassesEncoderState* enc_state) { + const Rect r = enc_state->shared.DCGroupRect(group_index); + size_t stream_id = ModularStreamId::ACMetadata(group_index).ID(frame_dim_); + stream_options_[stream_id].max_chan_size = 0xFFFFFF; + stream_options_[stream_id].wp_tree_mode = ModularOptions::TreeMode::kNoWP; + if (jpeg_transcode) { + stream_options_[stream_id].tree_kind = + ModularOptions::TreeKind::kJpegTranscodeACMeta; + } else if (cparams_.speed_tier >= SpeedTier::kFalcon) { + stream_options_[stream_id].tree_kind = + ModularOptions::TreeKind::kFalconACMeta; + } else if (cparams_.speed_tier > SpeedTier::kKitten) { + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kACMeta; + } + // If we are using a non-constant CfL field, and are in a slow enough mode, + // re-enable tree computation for it. + if (cparams_.speed_tier < SpeedTier::kSquirrel && + cparams_.force_cfl_jpeg_recompression) { + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kLearn; + } + // YToX, YToB, ACS + QF, EPF + Image& image = stream_images_[stream_id]; + image = Image(r.xsize(), r.ysize(), 8, 4); + static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); + Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); + image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[2] = Channel(r.xsize() * r.ysize(), 2, 0, 0); + ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytox_map, + Rect(image.channel[0].plane), &image.channel[0].plane); + ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytob_map, + Rect(image.channel[1].plane), &image.channel[1].plane); + size_t num = 0; + for (size_t y = 0; y < r.ysize(); y++) { + AcStrategyRow row_acs = enc_state->shared.ac_strategy.ConstRow(r, y); + const int32_t* row_qf = r.ConstRow(enc_state->shared.raw_quant_field, y); + const uint8_t* row_epf = r.ConstRow(enc_state->shared.epf_sharpness, y); + int32_t* out_acs = image.channel[2].plane.Row(0); + int32_t* out_qf = image.channel[2].plane.Row(1); + int32_t* row_out_epf = image.channel[3].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + row_out_epf[x] = row_epf[x]; + if (!row_acs[x].IsFirstBlock()) continue; + out_acs[num] = row_acs[x].RawStrategy(); + out_qf[num] = row_qf[x] - 1; + num++; + } + } + image.channel[2].w = num; + ac_metadata_size[group_index] = num; +} + +void ModularFrameEncoder::EncodeQuantTable( + size_t size_x, size_t size_y, BitWriter* writer, + const QuantEncoding& encoding, size_t idx, + ModularFrameEncoder* modular_frame_encoder) { + JXL_ASSERT(encoding.qraw.qtable != nullptr); + JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size()); + JXL_CHECK(F16Coder::Write(encoding.qraw.qtable_den, writer)); + if (modular_frame_encoder) { + JXL_CHECK(modular_frame_encoder->EncodeStream( + writer, nullptr, 0, ModularStreamId::QuantTable(idx))); + return; + } + Image image(size_x, size_y, 8, 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < size_y; y++) { + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < size_x; x++) { + row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x]; + } + } + } + ModularOptions cfopts; + JXL_CHECK(ModularGenericCompress(image, cfopts, writer)); +} + +void ModularFrameEncoder::AddQuantTable(size_t size_x, size_t size_y, + const QuantEncoding& encoding, + size_t idx) { + size_t stream_id = ModularStreamId::QuantTable(idx).ID(frame_dim_); + JXL_ASSERT(encoding.qraw.qtable != nullptr); + JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size()); + Image& image = stream_images_[stream_id]; + image = Image(size_x, size_y, 8, 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < size_y; y++) { + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < size_x; x++) { + row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x]; + } + } + } +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_modular.h b/media/libjxl/src/lib/jxl/enc_modular.h new file mode 100644 index 000000000..02477ee65 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_modular.h @@ -0,0 +1,92 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_MODULAR_H_ +#define LIB_JXL_ENC_MODULAR_H_ + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +class ModularFrameEncoder { + public: + ModularFrameEncoder(const FrameHeader& frame_header, + const CompressParams& cparams_orig); + Status ComputeEncodingData(const FrameHeader& frame_header, + const ImageMetadata& metadata, + Image3F* JXL_RESTRICT color, + const std::vector& extra_channels, + PassesEncoderState* JXL_RESTRICT enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, bool do_color); + // Encodes global info (tree + histograms) in the `writer`. + Status EncodeGlobalInfo(BitWriter* writer, AuxOut* aux_out); + // Encodes a specific modular image (identified by `stream`) in the `writer`, + // assigning bits to the provided `layer`. + Status EncodeStream(BitWriter* writer, AuxOut* aux_out, size_t layer, + const ModularStreamId& stream); + // Creates a modular image for a given DC group of VarDCT mode. `dc` is the + // input DC image, not quantized; the group is specified by `group_index`, and + // `nl_dc` decides whether to apply a near-lossless processing to the DC or + // not. + void AddVarDCTDC(const Image3F& dc, size_t group_index, bool nl_dc, + PassesEncoderState* enc_state, bool jpeg_transcode); + // Creates a modular image for the AC metadata of the given group + // (`group_index`). + void AddACMetadata(size_t group_index, bool jpeg_transcode, + PassesEncoderState* enc_state); + // Encodes a RAW quantization table in `writer`. If `modular_frame_encoder` is + // null, the quantization table in `encoding` is used, with dimensions `size_x + // x size_y`. Otherwise, the table with ID `idx` is encoded from the given + // `modular_frame_encoder`. + static void EncodeQuantTable(size_t size_x, size_t size_y, BitWriter* writer, + const QuantEncoding& encoding, size_t idx, + ModularFrameEncoder* modular_frame_encoder); + // Stores a quantization table for future usage with `EncodeQuantTable`. + void AddQuantTable(size_t size_x, size_t size_y, + const QuantEncoding& encoding, size_t idx); + + std::vector ac_metadata_size; + std::vector extra_dc_precision; + + private: + Status PrepareEncoding(const FrameHeader& frame_header, ThreadPool* pool, + EncoderHeuristics* heuristics, + AuxOut* aux_out = nullptr); + Status PrepareStreamParams(const Rect& rect, const CompressParams& cparams, + int minShift, int maxShift, + const ModularStreamId& stream, bool do_color); + std::vector stream_images_; + std::vector stream_options_; + + Tree tree_; + std::vector> tree_tokens_; + std::vector stream_headers_; + std::vector> tokens_; + EntropyEncodingData code_; + std::vector context_map_; + FrameDimensions frame_dim_; + CompressParams cparams_; + std::vector tree_splits_; + std::vector multiplier_info_; + std::vector> gi_channel_; + std::vector image_widths_; + Predictor delta_pred_ = Predictor::Average4; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_MODULAR_H_ diff --git a/media/libjxl/src/lib/jxl/enc_noise.cc b/media/libjxl/src/lib/jxl/enc_noise.cc new file mode 100644 index 000000000..36287613b --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_noise.cc @@ -0,0 +1,373 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_noise.h" + +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/optimize.h" + +namespace jxl { +namespace { + +using OptimizeArray = optimize::Array; + +float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x, + const int y, const int block_size) { + const int small_bl_size_x = 3; + const int small_bl_size_y = 4; + const int kNumSAD = + (block_size - small_bl_size_x) * (block_size - small_bl_size_y); + // block_size x block_size reference pixels + int counter = 0; + const int offset = 2; + + std::vector sad(kNumSAD, 0); + for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) { + for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) { + float sad_sum = 0; + // size of the center patch, we compare all the patches inside window with + // the center one + for (int cy = 0; cy < small_bl_size_y; ++cy) { + for (int cx = 0; cx < small_bl_size_x; ++cx) { + float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] + + opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]); + float center = + 0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] + + opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]); + sad_sum += std::abs(center - wnd); + } + } + sad[counter++] = sad_sum; + } + } + const int kSamples = (kNumSAD) / 2; + // As with ROAD (rank order absolute distance), we keep the smallest half of + // the values in SAD (we use here the more robust patch SAD instead of + // absolute single-pixel differences). + std::sort(sad.begin(), sad.end()); + const float total_sad_sum = + std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f); + return total_sad_sum / kSamples; +} + +class NoiseHistogram { + public: + static constexpr int kBins = 256; + + NoiseHistogram() { std::fill(bins, bins + kBins, 0); } + + void Increment(const float x) { bins[Index(x)] += 1; } + int Get(const float x) const { return bins[Index(x)]; } + int Bin(const size_t bin) const { return bins[bin]; } + + int Mode() const { + size_t max_idx = 0; + for (size_t i = 0; i < kBins; i++) { + if (bins[i] > bins[max_idx]) max_idx = i; + } + return max_idx; + } + + double Quantile(double q01) const { + const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1}); + const int64_t target = static_cast(q01 * total); + // Until sum >= target: + int64_t sum = 0; + size_t i = 0; + for (; i < kBins; ++i) { + sum += bins[i]; + // Exact match: assume middle of bin i + if (sum == target) { + return i + 0.5; + } + if (sum > target) break; + } + + // Next non-empty bin (in case histogram is sparsely filled) + size_t next = i + 1; + while (next < kBins && bins[next] == 0) { + ++next; + } + + // Linear interpolation according to how far into next we went + const double excess = target - sum; + const double weight_next = bins[Index(next)] / excess; + return ClampX(next * weight_next + i * (1.0 - weight_next)); + } + + // Inter-quartile range + double IQR() const { return Quantile(0.75) - Quantile(0.25); } + + private: + template + T ClampX(const T x) const { + return std::min(std::max(T(0), x), T(kBins - 1)); + } + size_t Index(const float x) const { return ClampX(static_cast(x)); } + + uint32_t bins[kBins]; +}; + +std::vector GetSADScoresForPatches(const Image3F& opsin, + const size_t block_s, + const size_t num_bin, + NoiseHistogram* sad_histogram) { + std::vector sad_scores( + (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f); + + int block_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s); + sad_scores[block_index++] = sad_sc; + sad_histogram->Increment(sad_sc * num_bin); + } + } + return sad_scores; +} + +float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) { + // Here we assume that the most patches with similar SAD value is a "flat" + // patches. However, some images might contain regular texture part and + // generate second strong peak at the histogram + // TODO(user) handle bimodal and heavy-tailed case + const int mode = histogram.Mode(); + return static_cast(mode) / NoiseHistogram::kBins; +} + +// loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2 +// where asym = 1 if F(x) < nl, kAsym if F(x) > nl. +struct LossFunction { + explicit LossFunction(std::vector nl0) : nl(std::move(nl0)) {} + + double Compute(const OptimizeArray& w, OptimizeArray* df, + bool skip_regularization = false) const { + constexpr double kReg = 0.005; + constexpr double kAsym = 1.1; + double loss_function = 0; + for (size_t i = 0; i < w.size(); i++) { + (*df)[i] = 0; + } + for (auto ind : nl) { + std::pair pos = IndexAndFrac(ind.intensity); + JXL_DASSERT(pos.first >= 0 && static_cast(pos.first) < + NoiseParams::kNumNoisePoints - 1); + double low = w[pos.first]; + double hi = w[pos.first + 1]; + double val = low * (1.0f - pos.second) + hi * pos.second; + double dist = val - ind.noise_level; + if (dist > 0) { + loss_function += kAsym * dist * dist; + (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist; + (*df)[pos.first + 1] -= kAsym * pos.second * dist; + } else { + loss_function += dist * dist; + (*df)[pos.first] -= (1.0f - pos.second) * dist; + (*df)[pos.first + 1] -= pos.second * dist; + } + } + if (skip_regularization) return loss_function; + for (size_t i = 0; i + 1 < w.size(); i++) { + double diff = w[i] - w[i + 1]; + loss_function += kReg * nl.size() * diff * diff; + (*df)[i] -= kReg * diff * nl.size(); + (*df)[i + 1] += kReg * diff * nl.size(); + } + return loss_function; + } + + std::vector nl; +}; + +void OptimizeNoiseParameters(const std::vector& noise_level, + NoiseParams* noise_params) { + constexpr double kMaxError = 1e-3; + static const double kPrecision = 1e-8; + static const int kMaxIter = 40; + + float avg = 0; + for (const NoiseLevel& nl : noise_level) { + avg += nl.noise_level; + } + avg /= noise_level.size(); + + LossFunction loss_function(noise_level); + OptimizeArray parameter_vector; + for (size_t i = 0; i < parameter_vector.size(); i++) { + parameter_vector[i] = avg; + } + + parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod( + loss_function, parameter_vector, kPrecision, kMaxIter); + + OptimizeArray df = parameter_vector; + float loss = loss_function.Compute(parameter_vector, &df, + /*skip_regularization=*/true) / + noise_level.size(); + + // Approximation went too badly: escape with no noise at all. + if (loss > kMaxError) { + noise_params->Clear(); + return; + } + + for (size_t i = 0; i < parameter_vector.size(); i++) { + noise_params->lut[i] = std::max(parameter_vector[i], 0.0); + } +} + +std::vector GetNoiseLevel( + const Image3F& opsin, const std::vector& texture_strength, + const float threshold, const size_t block_s) { + std::vector noise_level_per_intensity; + + const int filt_size = 1; + static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = { + {-0.25f, -1.0f, -0.25f}, + {-1.0f, 5.0f, -1.0f}, + {-0.25f, -1.0f, -0.25f}, + }; + + // The noise model is built based on channel 0.5 * (X+Y) as we notice that it + // is similar to the model 0.5 * (Y-X) + size_t patch_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + if (texture_strength[patch_index] <= threshold) { + // Calculate mean value + float mean_int = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] + + opsin.PlaneRow(0, y + y_bl)[x + x_bl]); + } + } + mean_int /= block_s * block_s; + + // Calculate Noise level + float noise_level = 0; + size_t count = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + float filtered_value = 0; + for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) { + if ((static_cast(y_bl) + y_f) >= 0 && + (y_bl + y_f) < block_s) { + for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { + if ((static_cast(x_bl) + x_f) >= 0 && + (x_bl + x_f) < block_s) { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] + + opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } else { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] + + opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } + } + } else { + for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { + if ((static_cast(x_bl) + x_f) >= 0 && + (x_bl + x_f) < block_s) { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] + + opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } else { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] + + opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } + } + } + } + noise_level += std::abs(filtered_value); + ++count; + } + } + noise_level /= count; + NoiseLevel nl; + nl.intensity = mean_int; + nl.noise_level = noise_level; + noise_level_per_intensity.push_back(nl); + } + ++patch_index; + } + } + return noise_level_per_intensity; +} + +void EncodeFloatParam(float val, float precision, BitWriter* writer) { + JXL_ASSERT(val >= 0); + const int absval_quant = static_cast(val * precision + 0.5f); + JXL_ASSERT(absval_quant < (1 << 10)); + writer->Write(10, absval_quant); +} + +} // namespace + +Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, + float quality_coef) { + // The size of a patch in decoder might be different from encoder's patch + // size. + // For encoder: the patch size should be big enough to estimate + // noise level, but, at the same time, it should be not too big + // to be able to estimate intensity value of the patch + const size_t block_s = 8; + const size_t kNumBin = 256; + NoiseHistogram sad_histogram; + std::vector sad_scores = + GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram); + float sad_threshold = GetSADThreshold(sad_histogram, kNumBin); + // If threshold is too large, the image has a strong pattern. This pattern + // fools our model and it will add too much noise. Therefore, we do not add + // noise for such images + if (sad_threshold > 0.15f || sad_threshold <= 0.0f) { + noise_params->Clear(); + return false; + } + std::vector nl = + GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s); + + OptimizeNoiseParameters(nl, noise_params); + for (float& i : noise_params->lut) { + i *= quality_coef * 1.4; + } + return noise_params->HasAny(); +} + +void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + JXL_ASSERT(noise_params.HasAny()); + + BitWriter::Allotment allotment(writer, NoiseParams::kNumNoisePoints * 16); + for (float i : noise_params.lut) { + EncodeFloatParam(i, kNoisePrecision, writer); + } + ReclaimAndCharge(writer, &allotment, layer, aux_out); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_noise.h b/media/libjxl/src/lib/jxl/enc_noise.h new file mode 100644 index 000000000..15fb07a8c --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_noise.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_NOISE_H_ +#define LIB_JXL_ENC_NOISE_H_ + +// Noise parameter estimation. + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +// Get parameters of the noise for NoiseParams model +// Returns whether a valid noise model (with HasAny()) is set. +Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, + float quality_coef); + +// Does not write anything if `noise_params` are empty. Otherwise, caller must +// set FrameHeader.flags.kNoise. +void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_NOISE_H_ diff --git a/media/libjxl/src/lib/jxl/enc_params.h b/media/libjxl/src/lib/jxl/enc_params.h new file mode 100644 index 000000000..2e16fae86 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_params.h @@ -0,0 +1,287 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_PARAMS_H_ +#define LIB_JXL_ENC_PARAMS_H_ + +// Parameters and flags that govern JXL compression. + +#include +#include + +#include + +#include "lib/jxl/base/override.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +enum class SpeedTier { + // Turns on FindBestQuantizationHQ loop. Equivalent to "guetzli" mode. + kTortoise = 1, + // Turns on FindBestQuantization butteraugli loop. + kKitten = 2, + // Turns on dots, patches, and spline detection by default, as well as full + // context clustering. Default. + kSquirrel = 3, + // Turns on error diffusion and full AC strategy heuristics. Equivalent to + // "fast" mode. + kWombat = 4, + // Turns on gaborish by default, non-default cmap, initial quant field. + kHare = 5, + // Turns on simple heuristics for AC strategy, quant field, and clustering; + // also enables coefficient reordering. + kCheetah = 6, + // Turns off most encoder features. Does context clustering. + // Modular: uses fixed tree with Weighted predictor. + kFalcon = 7, + // Currently fastest possible setting for VarDCT. + // Modular: uses fixed tree with Gradient predictor. + kThunder = 8, + // VarDCT: same as kThunder. + // Modular: no tree, Gradient predictor, fast histograms + kLightning = 9 +}; + +inline bool ParseSpeedTier(const std::string& s, SpeedTier* out) { + if (s == "lightning") { + *out = SpeedTier::kLightning; + return true; + } else if (s == "thunder") { + *out = SpeedTier::kThunder; + return true; + } else if (s == "falcon") { + *out = SpeedTier::kFalcon; + return true; + } else if (s == "cheetah") { + *out = SpeedTier::kCheetah; + return true; + } else if (s == "hare") { + *out = SpeedTier::kHare; + return true; + } else if (s == "fast" || s == "wombat") { + *out = SpeedTier::kWombat; + return true; + } else if (s == "squirrel") { + *out = SpeedTier::kSquirrel; + return true; + } else if (s == "kitten") { + *out = SpeedTier::kKitten; + return true; + } else if (s == "guetzli" || s == "tortoise") { + *out = SpeedTier::kTortoise; + return true; + } + size_t st = 10 - static_cast(strtoull(s.c_str(), nullptr, 0)); + if (st <= static_cast(SpeedTier::kLightning) && + st >= static_cast(SpeedTier::kTortoise)) { + *out = SpeedTier(st); + return true; + } + return false; +} + +inline const char* SpeedTierName(SpeedTier speed_tier) { + switch (speed_tier) { + case SpeedTier::kLightning: + return "lightning"; + case SpeedTier::kThunder: + return "thunder"; + case SpeedTier::kFalcon: + return "falcon"; + case SpeedTier::kCheetah: + return "cheetah"; + case SpeedTier::kHare: + return "hare"; + case SpeedTier::kWombat: + return "wombat"; + case SpeedTier::kSquirrel: + return "squirrel"; + case SpeedTier::kKitten: + return "kitten"; + case SpeedTier::kTortoise: + return "tortoise"; + } + return "INVALID"; +} + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct CompressParams { + float butteraugli_distance = 1.0f; + size_t target_size = 0; + float target_bitrate = 0.0f; + + // 0.0 means search for the adaptive quantization map that matches the + // butteraugli distance, positive values mean quantize everywhere with that + // value. + float uniform_quant = 0.0f; + float quant_border_bias = 0.0f; + + // Try to achieve a maximum pixel-by-pixel error on each channel. + bool max_error_mode = false; + float max_error[3] = {0.0, 0.0, 0.0}; + + SpeedTier speed_tier = SpeedTier::kSquirrel; + int brotli_effort = -1; + + // 0 = default. + // 1 = slightly worse quality. + // 4 = fastest speed, lowest quality + // TODO(veluca): hook this up to the C API. + size_t decoding_speed_tier = 0; + + int max_butteraugli_iters = 4; + + int max_butteraugli_iters_guetzli_mode = 100; + + ColorTransform color_transform = ColorTransform::kXYB; + YCbCrChromaSubsampling chroma_subsampling; + + // If true, the "modular mode options" members below are used. + bool modular_mode = false; + + // Change group size in modular mode (0=128, 1=256, 2=512, 3=1024). + size_t modular_group_size_shift = 1; + + Override preview = Override::kDefault; + Override noise = Override::kDefault; + Override dots = Override::kDefault; + Override patches = Override::kDefault; + Override gaborish = Override::kDefault; + int epf = -1; + + // Progressive mode. + bool progressive_mode = false; + + // Quantized-progressive mode. + bool qprogressive_mode = false; + + // Put center groups first in the bitstream. + bool centerfirst = false; + + // Pixel coordinates of the center. First group will contain that center. + size_t center_x = static_cast(-1); + size_t center_y = static_cast(-1); + + int progressive_dc = -1; + + // If on: preserve color of invisible pixels (if off: don't care) + // Default: on for lossless, off for lossy + Override keep_invisible = Override::kDefault; + + // Progressive-mode saliency. + // + // How many progressive saliency-encoding steps to perform. + // - 1: Encode only DC and lowest-frequency AC. Does not need a saliency-map. + // - 2: Encode only DC+LF, dropping all HF AC data. + // Does not need a saliency-map. + // - 3: Encode DC+LF+{salient HF}, dropping all non-salient HF data. + // - 4: Encode DC+LF+{salient HF}+{other HF}. + // - 5: Encode DC+LF+{quantized HF}+{low HF bits}. + size_t saliency_num_progressive_steps = 3; + // Every saliency-heatmap cell with saliency >= threshold will be considered + // as 'salient'. The default value of 0.0 will consider every AC-block + // as salient, hence not require a saliency-map, and not actually generate + // a 4th progressive step. + float saliency_threshold = 0.0f; + // Saliency-map (owned by caller). + ImageF* saliency_map = nullptr; + + // Input and output file name. Will be used to provide pluggable saliency + // extractor with paths. + const char* file_in = nullptr; + const char* file_out = nullptr; + + // Currently unused as of 2020-01. + bool clear_metadata = false; + + // Prints extra information during/after encoding. + bool verbose = false; + bool log_search_state = false; + + ButteraugliParams ba_params; + + // Force usage of CfL when doing JPEG recompression. This can have unexpected + // effects on the decoded pixels, while still being JPEG-compliant and + // allowing reconstruction of the original JPEG. + bool force_cfl_jpeg_recompression = true; + + // Set the noise to what it would approximately be if shooting at the nominal + // exposure for a given ISO setting on a 35mm camera. + float photon_noise_iso = 0; + + // modular mode options below + ModularOptions options; + int responsive = -1; + // empty for default squeeze + std::vector squeezes; + int colorspace = -1; + // Use Global channel palette if #colors < this percentage of range + float channel_colors_pre_transform_percent = 95.f; + // Use Local channel palette if #colors < this percentage of range + float channel_colors_percent = 80.f; + int palette_colors = 1 << 10; // up to 10-bit palette is probably worthwhile + bool lossy_palette = false; + + // Returns whether these params are lossless as defined by SetLossless(); + bool IsLossless() const { + // YCbCr is also considered lossless here since it's intended for + // source material that is already YCbCr (we don't do the fwd transform) + return modular_mode && butteraugli_distance == 0.0f && + color_transform != jxl::ColorTransform::kXYB; + } + + // Sets the parameters required to make the codec lossless. + void SetLossless() { + modular_mode = true; + butteraugli_distance = 0.0f; + color_transform = jxl::ColorTransform::kNone; + } + + // Down/upsample the image before encoding / after decoding by this factor. + // The resampling value can also be set to <= 0 to automatically choose based + // on distance, however EncodeFrame doesn't support this, so it is + // required to call PostInit() to set a valid positive resampling + // value and altered butteraugli score if this is used. + int resampling = -1; + int ec_resampling = -1; + // Skip the downsampling before encoding if this is true. + bool already_downsampled = false; + // Butteraugli target distance on the original full size image, this can be + // different from butteraugli_distance if resampling was used. + float original_butteraugli_distance = -1.0f; + + float quant_ac_rescale = 1.0; + + // Codestream level to conform to. + // -1: don't care + int level = -1; + + std::vector manual_noise; + std::vector manual_xyb_factors; +}; + +static constexpr float kMinButteraugliForDynamicAR = 0.5f; +static constexpr float kMinButteraugliForDots = 3.0f; +static constexpr float kMinButteraugliToSubtractOriginalPatches = 3.0f; +static constexpr float kMinButteraugliDistanceForProgressiveDc = 4.5f; + +// Always off +static constexpr float kMinButteraugliForNoise = 99.0f; + +// Minimum butteraugli distance the encoder accepts. +static constexpr float kMinButteraugliDistance = 0.01f; + +// Tile size for encoder-side processing. Must be equal to color tile dim in the +// current implementation. +static constexpr size_t kEncTileDim = 64; +static constexpr size_t kEncTileDimInBlocks = kEncTileDim / kBlockDim; + +} // namespace jxl + +#endif // LIB_JXL_ENC_PARAMS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_patch_dictionary.cc b/media/libjxl/src/lib/jxl/enc_patch_dictionary.cc new file mode 100644 index 000000000..ff57ff049 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_patch_dictionary.cc @@ -0,0 +1,812 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_patch_dictionary.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_dot_dictionary.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/patch_dictionary_internal.h" + +namespace jxl { + +// static +void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + JXL_ASSERT(pdic.HasAny()); + std::vector> tokens(1); + size_t num_ec = pdic.shared_->metadata->m.num_extra_channels; + + auto add_num = [&](int context, size_t num) { + tokens[0].emplace_back(context, num); + }; + size_t num_ref_patch = 0; + for (size_t i = 0; i < pdic.positions_.size();) { + size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx; + while (i < pdic.positions_.size() && + pdic.positions_[i].ref_pos_idx == ref_pos_idx) { + i++; + } + num_ref_patch++; + } + add_num(kNumRefPatchContext, num_ref_patch); + size_t blend_pos = 0; + for (size_t i = 0; i < pdic.positions_.size();) { + size_t i_start = i; + size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx; + const auto& ref_pos = pdic.ref_positions_[ref_pos_idx]; + while (i < pdic.positions_.size() && + pdic.positions_[i].ref_pos_idx == ref_pos_idx) { + i++; + } + size_t num = i - i_start; + JXL_ASSERT(num > 0); + add_num(kReferenceFrameContext, ref_pos.ref); + add_num(kPatchReferencePositionContext, ref_pos.x0); + add_num(kPatchReferencePositionContext, ref_pos.y0); + add_num(kPatchSizeContext, ref_pos.xsize - 1); + add_num(kPatchSizeContext, ref_pos.ysize - 1); + add_num(kPatchCountContext, num - 1); + for (size_t j = i_start; j < i; j++) { + const PatchPosition& pos = pdic.positions_[j]; + if (j == i_start) { + add_num(kPatchPositionContext, pos.x); + add_num(kPatchPositionContext, pos.y); + } else { + add_num(kPatchOffsetContext, + PackSigned(pos.x - pdic.positions_[j - 1].x)); + add_num(kPatchOffsetContext, + PackSigned(pos.y - pdic.positions_[j - 1].y)); + } + for (size_t j = 0; j < num_ec + 1; ++j, ++blend_pos) { + const PatchBlending& info = pdic.blendings_[blend_pos]; + add_num(kPatchBlendModeContext, static_cast(info.mode)); + if (UsesAlpha(info.mode) && + pdic.shared_->metadata->m.extra_channel_info.size() > 1) { + add_num(kPatchAlphaChannelContext, info.alpha_channel); + } + if (UsesClamp(info.mode)) { + add_num(kPatchClampContext, info.clamp); + } + } + } + } + + EntropyEncodingData codes; + std::vector context_map; + BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts, + tokens, &codes, &context_map, writer, layer, + aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); +} + +// static +void PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic, + Image3F* opsin) { + size_t num_ec = pdic.shared_->metadata->m.num_extra_channels; + // TODO(veluca): this can likely be optimized knowing it runs on full images. + for (size_t y = 0; y < opsin->ysize(); y++) { + float* JXL_RESTRICT rows[3] = { + opsin->PlaneRow(0, y), + opsin->PlaneRow(1, y), + opsin->PlaneRow(2, y), + }; + for (size_t pos_idx : pdic.GetPatchesForRow(y)) { + const size_t blending_idx = pos_idx * (num_ec + 1); + const PatchPosition& pos = pdic.positions_[pos_idx]; + const PatchReferencePosition& ref_pos = + pdic.ref_positions_[pos.ref_pos_idx]; + const PatchBlendMode mode = pdic.blendings_[blending_idx].mode; + size_t by = pos.y; + size_t bx = pos.x; + size_t xsize = ref_pos.xsize; + JXL_DASSERT(y >= by); + JXL_DASSERT(y < by + ref_pos.ysize); + size_t iy = y - by; + size_t ref = ref_pos.ref; + const float* JXL_RESTRICT ref_rows[3] = { + pdic.shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + 0, ref_pos.y0 + iy) + + ref_pos.x0, + pdic.shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + 1, ref_pos.y0 + iy) + + ref_pos.x0, + pdic.shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + 2, ref_pos.y0 + iy) + + ref_pos.x0, + }; + for (size_t ix = 0; ix < xsize; ix++) { + for (size_t c = 0; c < 3; c++) { + if (mode == PatchBlendMode::kAdd) { + rows[c][bx + ix] -= ref_rows[c][ix]; + } else if (mode == PatchBlendMode::kReplace) { + rows[c][bx + ix] = 0; + } else if (mode == PatchBlendMode::kNone) { + // Nothing to do. + } else { + JXL_ABORT("Blending mode %u not yet implemented", (uint32_t)mode); + } + } + } + } + } +} + +namespace { + +struct PatchColorspaceInfo { + float kChannelDequant[3]; + float kChannelWeights[3]; + + explicit PatchColorspaceInfo(bool is_xyb) { + if (is_xyb) { + kChannelDequant[0] = 0.01615; + kChannelDequant[1] = 0.08875; + kChannelDequant[2] = 0.1922; + kChannelWeights[0] = 30.0; + kChannelWeights[1] = 3.0; + kChannelWeights[2] = 1.0; + } else { + kChannelDequant[0] = 20.0f / 255; + kChannelDequant[1] = 22.0f / 255; + kChannelDequant[2] = 20.0f / 255; + kChannelWeights[0] = 0.017 * 255; + kChannelWeights[1] = 0.02 * 255; + kChannelWeights[2] = 0.017 * 255; + } + } + + float ScaleForQuantization(float val, size_t c) { + return val / kChannelDequant[c]; + } + + int Quantize(float val, size_t c) { + return truncf(ScaleForQuantization(val, c)); + } + + bool is_similar_v(const float v1[3], const float v2[3], float threshold) { + float distance = 0; + for (size_t c = 0; c < 3; c++) { + distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c]; + } + return distance <= threshold; + } +}; + +std::vector FindTextLikePatches( + const Image3F& opsin, const PassesEncoderState* JXL_RESTRICT state, + ThreadPool* pool, AuxOut* aux_out, bool is_xyb) { + if (state->cparams.patches == Override::kOff) return {}; + + PatchColorspaceInfo pci(is_xyb); + float kSimilarThreshold = 0.8f; + + auto is_similar_impl = [&pci](std::pair p1, + std::pair p2, + const float* JXL_RESTRICT rows[3], + size_t stride, float threshold) { + float v1[3], v2[3]; + for (size_t c = 0; c < 3; c++) { + v1[c] = rows[c][p1.second * stride + p1.first]; + v2[c] = rows[c][p2.second * stride + p2.first]; + } + return pci.is_similar_v(v1, v2, threshold); + }; + + std::atomic has_screenshot_areas{false}; + const size_t opsin_stride = opsin.PixelsPerRow(); + const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0), + opsin.ConstPlaneRow(1, 0), + opsin.ConstPlaneRow(2, 0)}; + + auto is_same = [&opsin_rows, opsin_stride](std::pair p1, + std::pair p2) { + for (size_t c = 0; c < 3; c++) { + float v1 = opsin_rows[c][p1.second * opsin_stride + p1.first]; + float v2 = opsin_rows[c][p2.second * opsin_stride + p2.first]; + if (std::fabs(v1 - v2) > 1e-4) { + return false; + } + } + return true; + }; + + auto is_similar = [&](std::pair p1, + std::pair p2) { + return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold); + }; + + constexpr int64_t kPatchSide = 4; + constexpr int64_t kExtraSide = 4; + + // Look for kPatchSide size squares, naturally aligned, that all have the same + // pixel values. + ImageB is_screenshot_like(DivCeil(opsin.xsize(), kPatchSide), + DivCeil(opsin.ysize(), kPatchSide)); + ZeroFillImage(&is_screenshot_like); + uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0); + const size_t screenshot_stride = is_screenshot_like.PixelsPerRow(); + const auto process_row = [&](const uint32_t y, size_t /* thread */) { + for (uint64_t x = 0; x < opsin.xsize() / kPatchSide; x++) { + bool all_same = true; + for (size_t iy = 0; iy < static_cast(kPatchSide); iy++) { + for (size_t ix = 0; ix < static_cast(kPatchSide); ix++) { + size_t cx = x * kPatchSide + ix; + size_t cy = y * kPatchSide + iy; + if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) { + all_same = false; + break; + } + } + } + if (!all_same) continue; + size_t num = 0; + size_t num_same = 0; + for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) { + for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) { + int64_t cx = x * kPatchSide + ix; + int64_t cy = y * kPatchSide + iy; + if (cx < 0 || static_cast(cx) >= opsin.xsize() || // + cy < 0 || static_cast(cy) >= opsin.ysize()) { + continue; + } + num++; + if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++; + } + } + // Too few equal pixels nearby. + if (num_same * 8 < num * 7) continue; + screenshot_row[y * screenshot_stride + x] = 1; + has_screenshot_areas = true; + } + }; + JXL_CHECK(RunOnPool(pool, 0, opsin.ysize() / kPatchSide, ThreadPool::NoInit, + process_row, "IsScreenshotLike")); + + // TODO(veluca): also parallelize the rest of this function. + if (WantDebugOutput(aux_out)) { + aux_out->DumpPlaneNormalized("screenshot_like", is_screenshot_like); + } + + constexpr int kSearchRadius = 1; + + if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) { + return {}; + } + + // Search for "similar enough" pixels near the screenshot-like areas. + ImageB is_background(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&is_background); + Image3F background(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&background); + constexpr size_t kDistanceLimit = 50; + float* JXL_RESTRICT background_rows[3] = { + background.PlaneRow(0, 0), + background.PlaneRow(1, 0), + background.PlaneRow(2, 0), + }; + const size_t background_stride = background.PixelsPerRow(); + uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0); + const size_t is_background_stride = is_background.PixelsPerRow(); + std::vector< + std::pair, std::pair>> + queue; + size_t queue_front = 0; + for (size_t y = 0; y < opsin.ysize(); y++) { + for (size_t x = 0; x < opsin.xsize(); x++) { + if (!screenshot_row[screenshot_stride * (y / kPatchSide) + + (x / kPatchSide)]) + continue; + queue.push_back({{x, y}, {x, y}}); + } + } + while (queue.size() != queue_front) { + std::pair cur = queue[queue_front].first; + std::pair src = queue[queue_front].second; + queue_front++; + if (is_background_row[cur.second * is_background_stride + cur.first]) + continue; + is_background_row[cur.second * is_background_stride + cur.first] = 1; + for (size_t c = 0; c < 3; c++) { + background_rows[c][cur.second * background_stride + cur.first] = + opsin_rows[c][src.second * opsin_stride + src.first]; + } + for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) { + for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) { + if (dx == 0 && dy == 0) continue; + int next_first = cur.first + dx; + int next_second = cur.second + dy; + if (next_first < 0 || next_second < 0 || + static_cast(next_first) >= opsin.xsize() || + static_cast(next_second) >= opsin.ysize()) { + continue; + } + if (static_cast( + std::abs(next_first - static_cast(src.first)) + + std::abs(next_second - static_cast(src.second))) > + kDistanceLimit) { + continue; + } + std::pair next{next_first, next_second}; + if (is_similar(src, next)) { + if (!screenshot_row[next.second / kPatchSide * screenshot_stride + + next.first / kPatchSide] || + is_same(src, next)) { + if (!is_background_row[next.second * is_background_stride + + next.first]) + queue.emplace_back(next, src); + } + } + } + } + } + queue.clear(); + + ImageF ccs; + Rng rng(0); + bool paint_ccs = false; + if (WantDebugOutput(aux_out)) { + aux_out->DumpPlaneNormalized("is_background", is_background); + if (is_xyb) { + aux_out->DumpXybImage("background", background); + } else { + aux_out->DumpImage("background", background); + } + ccs = ImageF(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&ccs); + paint_ccs = true; + } + + constexpr float kVerySimilarThreshold = 0.03f; + constexpr float kHasSimilarThreshold = 0.03f; + + const float* JXL_RESTRICT const_background_rows[3] = { + background_rows[0], background_rows[1], background_rows[2]}; + auto is_similar_b = [&](std::pair p1, std::pair p2) { + return is_similar_impl(p1, p2, const_background_rows, background_stride, + kVerySimilarThreshold); + }; + + constexpr int kMinPeak = 2; + constexpr int kHasSimilarRadius = 2; + + std::vector info; + + // Find small CC outside the "similar enough" areas, compute bounding boxes, + // and run heuristics to exclude some patches. + ImageB visited(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&visited); + uint8_t* JXL_RESTRICT visited_row = visited.Row(0); + const size_t visited_stride = visited.PixelsPerRow(); + std::vector> cc; + std::vector> stack; + for (size_t y = 0; y < opsin.ysize(); y++) { + for (size_t x = 0; x < opsin.xsize(); x++) { + if (is_background_row[y * is_background_stride + x]) continue; + cc.clear(); + stack.clear(); + stack.emplace_back(x, y); + size_t min_x = x; + size_t max_x = x; + size_t min_y = y; + size_t max_y = y; + std::pair reference; + bool found_border = false; + bool all_similar = true; + while (!stack.empty()) { + std::pair cur = stack.back(); + stack.pop_back(); + if (visited_row[cur.second * visited_stride + cur.first]) continue; + visited_row[cur.second * visited_stride + cur.first] = 1; + if (cur.first < min_x) min_x = cur.first; + if (cur.first > max_x) max_x = cur.first; + if (cur.second < min_y) min_y = cur.second; + if (cur.second > max_y) max_y = cur.second; + if (paint_ccs) { + cc.push_back(cur); + } + for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) { + for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) { + if (dx == 0 && dy == 0) continue; + int next_first = static_cast(cur.first) + dx; + int next_second = static_cast(cur.second) + dy; + if (next_first < 0 || next_second < 0 || + static_cast(next_first) >= opsin.xsize() || + static_cast(next_second) >= opsin.ysize()) { + continue; + } + std::pair next{next_first, next_second}; + if (!is_background_row[next.second * is_background_stride + + next.first]) { + stack.push_back(next); + } else { + if (!found_border) { + reference = next; + found_border = true; + } else { + if (!is_similar_b(next, reference)) all_similar = false; + } + } + } + } + } + if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize || + max_y - min_y >= kMaxPatchSize) { + continue; + } + size_t bpos = background_stride * reference.second + reference.first; + float ref[3] = {background_rows[0][bpos], background_rows[1][bpos], + background_rows[2][bpos]}; + bool has_similar = false; + for (size_t iy = std::max( + static_cast(min_y) - kHasSimilarRadius, 0); + iy < std::min(max_y + kHasSimilarRadius + 1, opsin.ysize()); iy++) { + for (size_t ix = std::max( + static_cast(min_x) - kHasSimilarRadius, 0); + ix < std::min(max_x + kHasSimilarRadius + 1, opsin.xsize()); + ix++) { + size_t opos = opsin_stride * iy + ix; + float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos], + opsin_rows[2][opos]}; + if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) { + has_similar = true; + } + } + } + if (!has_similar) continue; + info.emplace_back(); + info.back().second.emplace_back(min_x, min_y); + QuantizedPatch& patch = info.back().first; + patch.xsize = max_x - min_x + 1; + patch.ysize = max_y - min_y + 1; + int max_value = 0; + for (size_t c : {1, 0, 2}) { + for (size_t iy = min_y; iy <= max_y; iy++) { + for (size_t ix = min_x; ix <= max_x; ix++) { + size_t offset = (iy - min_y) * patch.xsize + ix - min_x; + patch.fpixels[c][offset] = + opsin_rows[c][iy * opsin_stride + ix] - ref[c]; + int val = pci.Quantize(patch.fpixels[c][offset], c); + patch.pixels[c][offset] = val; + if (std::abs(val) > max_value) max_value = std::abs(val); + } + } + } + if (max_value < kMinPeak) { + info.pop_back(); + continue; + } + if (paint_ccs) { + float cc_color = rng.UniformF(0.5, 1.0); + for (std::pair p : cc) { + ccs.Row(p.second)[p.first] = cc_color; + } + } + } + } + + if (paint_ccs) { + JXL_ASSERT(WantDebugOutput(aux_out)); + aux_out->DumpPlaneNormalized("ccs", ccs); + } + if (info.empty()) { + return {}; + } + + // Remove duplicates. + constexpr size_t kMinPatchOccurences = 2; + std::sort(info.begin(), info.end()); + size_t unique = 0; + for (size_t i = 1; i < info.size(); i++) { + if (info[i].first == info[unique].first) { + info[unique].second.insert(info[unique].second.end(), + info[i].second.begin(), info[i].second.end()); + } else { + if (info[unique].second.size() >= kMinPatchOccurences) { + unique++; + } + info[unique] = info[i]; + } + } + if (info[unique].second.size() >= kMinPatchOccurences) { + unique++; + } + info.resize(unique); + + size_t max_patch_size = 0; + + for (size_t i = 0; i < info.size(); i++) { + size_t pixels = info[i].first.xsize * info[i].first.ysize; + if (pixels > max_patch_size) max_patch_size = pixels; + } + + // don't use patches if all patches are smaller than this + constexpr size_t kMinMaxPatchSize = 20; + if (max_patch_size < kMinMaxPatchSize) return {}; + + return info; +} + +} // namespace + +void FindBestPatchDictionary(const Image3F& opsin, + PassesEncoderState* JXL_RESTRICT state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, bool is_xyb) { + std::vector info = + FindTextLikePatches(opsin, state, pool, aux_out, is_xyb); + + // TODO(veluca): this doesn't work if both dots and patches are enabled. + // For now, since dots and patches are not likely to occur in the same kind of + // images, disable dots if some patches were found. + if (info.empty() && + ApplyOverride( + state->cparams.dots, + state->cparams.speed_tier <= SpeedTier::kSquirrel && + state->cparams.butteraugli_distance >= kMinButteraugliForDots)) { + info = FindDotDictionary(state->cparams, opsin, state->shared.cmap, pool); + } + + if (info.empty()) return; + + std::sort( + info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) { + return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize; + }); + + size_t max_x_size = 0; + size_t max_y_size = 0; + size_t total_pixels = 0; + + for (size_t i = 0; i < info.size(); i++) { + size_t pixels = info[i].first.xsize * info[i].first.ysize; + if (max_x_size < info[i].first.xsize) max_x_size = info[i].first.xsize; + if (max_y_size < info[i].first.ysize) max_y_size = info[i].first.ysize; + total_pixels += pixels; + } + + // Bin-packing & conversion of patches. + constexpr float kBinPackingSlackness = 1.05f; + size_t ref_xsize = std::max(max_x_size, std::sqrt(total_pixels)); + size_t ref_ysize = std::max(max_y_size, std::sqrt(total_pixels)); + std::vector> ref_positions(info.size()); + // TODO(veluca): allow partial overlaps of patches that have the same pixels. + size_t max_y = 0; + do { + max_y = 0; + // Increase packed image size. + ref_xsize = ref_xsize * kBinPackingSlackness + 1; + ref_ysize = ref_ysize * kBinPackingSlackness + 1; + + ImageB occupied(ref_xsize, ref_ysize); + ZeroFillImage(&occupied); + uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0); + size_t occupied_stride = occupied.PixelsPerRow(); + + bool success = true; + // For every patch... + for (size_t patch = 0; patch < info.size(); patch++) { + size_t x0 = 0; + size_t y0 = 0; + size_t xsize = info[patch].first.xsize; + size_t ysize = info[patch].first.ysize; + bool found = false; + // For every possible start position ... + for (; y0 + ysize <= ref_ysize; y0++) { + x0 = 0; + for (; x0 + xsize <= ref_xsize; x0++) { + bool has_occupied_pixel = false; + size_t x = x0; + // Check if it is possible to place the patch in this position in the + // reference frame. + for (size_t y = y0; y < y0 + ysize; y++) { + x = x0; + for (; x < x0 + xsize; x++) { + if (occupied_rows[y * occupied_stride + x]) { + has_occupied_pixel = true; + break; + } + } + } // end of positioning check + if (!has_occupied_pixel) { + found = true; + break; + } + x0 = x; // Jump to next pixel after the occupied one. + } + if (found) break; + } // end of start position checking + + // We didn't find a possible position: repeat from the beginning with a + // larger reference frame size. + if (!found) { + success = false; + break; + } + + // We found a position: mark the corresponding positions in the reference + // image as used. + ref_positions[patch] = {x0, y0}; + for (size_t y = y0; y < y0 + ysize; y++) { + for (size_t x = x0; x < x0 + xsize; x++) { + occupied_rows[y * occupied_stride + x] = true; + } + } + max_y = std::max(max_y, y0 + ysize); + } + + if (success) break; + } while (true); + + JXL_ASSERT(ref_ysize >= max_y); + + ref_ysize = max_y; + + Image3F reference_frame(ref_xsize, ref_ysize); + // TODO(veluca): figure out a better way to fill the image. + ZeroFillImage(&reference_frame); + std::vector positions; + std::vector pref_positions; + std::vector blendings; + float* JXL_RESTRICT ref_rows[3] = { + reference_frame.PlaneRow(0, 0), + reference_frame.PlaneRow(1, 0), + reference_frame.PlaneRow(2, 0), + }; + size_t ref_stride = reference_frame.PixelsPerRow(); + size_t num_ec = state->shared.metadata->m.num_extra_channels; + + for (size_t i = 0; i < info.size(); i++) { + PatchReferencePosition ref_pos; + ref_pos.xsize = info[i].first.xsize; + ref_pos.ysize = info[i].first.ysize; + ref_pos.x0 = ref_positions[i].first; + ref_pos.y0 = ref_positions[i].second; + ref_pos.ref = 0; + for (size_t y = 0; y < ref_pos.ysize; y++) { + for (size_t x = 0; x < ref_pos.xsize; x++) { + for (size_t c = 0; c < 3; c++) { + ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] = + info[i].first.fpixels[c][y * ref_pos.xsize + x]; + } + } + } + for (const auto& pos : info[i].second) { + positions.emplace_back( + PatchPosition{pos.first, pos.second, pref_positions.size()}); + // Add blending for color channels, ignore other channels. + blendings.push_back({PatchBlendMode::kAdd, 0, false}); + for (size_t j = 0; j < num_ec; ++j) { + blendings.push_back({PatchBlendMode::kNone, 0, false}); + } + } + pref_positions.emplace_back(std::move(ref_pos)); + } + + CompressParams cparams = state->cparams; + // Recursive application of patches could create very weird issues. + cparams.patches = Override::kOff; + + RoundtripPatchFrame(&reference_frame, state, 0, cparams, cms, pool, aux_out, + /*subtract=*/true); + + // TODO(veluca): this assumes that applying patches is commutative, which is + // not true for all blending modes. This code only produces kAdd patches, so + // this works out. + PatchDictionaryEncoder::SetPositions( + &state->shared.image_features.patches, std::move(positions), + std::move(pref_positions), std::move(blendings)); +} + +void RoundtripPatchFrame(Image3F* reference_frame, + PassesEncoderState* JXL_RESTRICT state, int idx, + CompressParams& cparams, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out, bool subtract) { + FrameInfo patch_frame_info; + cparams.resampling = 1; + cparams.ec_resampling = 1; + cparams.dots = Override::kOff; + cparams.noise = Override::kOff; + cparams.modular_mode = true; + cparams.responsive = 0; + cparams.progressive_dc = 0; + cparams.progressive_mode = false; + cparams.qprogressive_mode = false; + // Use gradient predictor and not Predictor::Best. + cparams.options.predictor = Predictor::Gradient; + patch_frame_info.save_as_reference = idx; // always saved. + patch_frame_info.frame_type = FrameType::kReferenceOnly; + patch_frame_info.save_before_color_transform = true; + ImageBundle ib(&state->shared.metadata->m); + // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is + // no simple way to express that yet. + patch_frame_info.ib_needs_color_transform = false; + ib.SetFromImage(std::move(*reference_frame), + state->shared.metadata->m.color_encoding); + if (!ib.metadata()->extra_channel_info.empty()) { + // Add dummy extra channels to the patch image: patch encoding does not yet + // support extra channels, but the codec expects that the amount of extra + // channels in frames matches that in the metadata of the codestream. + std::vector extra_channels; + extra_channels.reserve(ib.metadata()->extra_channel_info.size()); + for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) { + extra_channels.emplace_back(ib.xsize(), ib.ysize()); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + // TODO(lode): patches must copy and use the real extra channels instead. + ZeroFillImage(&extra_channels.back()); + } + ib.SetExtraChannels(std::move(extra_channels)); + } + PassesEncoderState roundtrip_state; + auto special_frame = std::unique_ptr(new BitWriter()); + AuxOut patch_aux_out; + JXL_CHECK(EncodeFrame(cparams, patch_frame_info, state->shared.metadata, ib, + &roundtrip_state, cms, pool, special_frame.get(), + aux_out ? &patch_aux_out : nullptr)); + if (aux_out) { + for (const auto& l : patch_aux_out.layers) { + aux_out->layers[kLayerDictionary].Assimilate(l); + } + } + const Span encoded = special_frame->GetSpan(); + state->special_frames.emplace_back(std::move(special_frame)); + if (subtract) { + ImageBundle decoded(&state->shared.metadata->m); + PassesDecoderState dec_state; + JXL_CHECK(dec_state.output_encoding_info.SetFromMetadata( + *state->shared.metadata)); + const uint8_t* frame_start = encoded.data(); + size_t encoded_size = encoded.size(); + JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size, &decoded, + *state->shared.metadata)); + frame_start += decoded.decoded_bytes(); + encoded_size -= decoded.decoded_bytes(); + size_t ref_xsize = + dec_state.shared_storage.reference_frames[idx].storage.color()->xsize(); + // if the frame itself uses patches, we need to decode another frame + if (!ref_xsize) { + JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size, + &decoded, *state->shared.metadata)); + } + JXL_CHECK(encoded_size == 0); + state->shared.reference_frames[idx] = + std::move(dec_state.shared_storage.reference_frames[idx]); + } else { + state->shared.reference_frames[idx].storage = std::move(ib); + } + state->shared.reference_frames[idx].frame = + &state->shared.reference_frames[idx].storage; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_patch_dictionary.h b/media/libjxl/src/lib/jxl/enc_patch_dictionary.h new file mode 100644 index 000000000..090827f68 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_patch_dictionary.h @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_PATCH_DICTIONARY_H_ +#define LIB_JXL_ENC_PATCH_DICTIONARY_H_ + +// Chooses reference patches, and avoids encoding them once per occurrence. + +#include +#include +#include + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +constexpr size_t kMaxPatchSize = 32; + +struct QuantizedPatch { + size_t xsize; + size_t ysize; + QuantizedPatch() { + for (size_t i = 0; i < 3; i++) { + pixels[i].resize(kMaxPatchSize * kMaxPatchSize); + fpixels[i].resize(kMaxPatchSize * kMaxPatchSize); + } + } + std::vector pixels[3] = {}; + // Not compared. Used only to retrieve original pixels to construct the + // reference image. + std::vector fpixels[3] = {}; + bool operator==(const QuantizedPatch& other) const { + if (xsize != other.xsize) return false; + if (ysize != other.ysize) return false; + for (size_t c = 0; c < 3; c++) { + if (memcmp(pixels[c].data(), other.pixels[c].data(), + sizeof(int8_t) * xsize * ysize) != 0) + return false; + } + return true; + } + + bool operator<(const QuantizedPatch& other) const { + if (xsize != other.xsize) return xsize < other.xsize; + if (ysize != other.ysize) return ysize < other.ysize; + for (size_t c = 0; c < 3; c++) { + int cmp = memcmp(pixels[c].data(), other.pixels[c].data(), + sizeof(int8_t) * xsize * ysize); + if (cmp > 0) return false; + if (cmp < 0) return true; + } + return false; + } +}; + +// Pair (patch, vector of occurrences). +using PatchInfo = + std::pair>>; + +// Friend class of PatchDictionary. +class PatchDictionaryEncoder { + public: + // Only call if HasAny(). + static void Encode(const PatchDictionary& pdic, BitWriter* writer, + size_t layer, AuxOut* aux_out); + + static void SetPositions(PatchDictionary* pdic, + std::vector positions, + std::vector ref_positions, + std::vector blendings) { + pdic->positions_ = std::move(positions); + pdic->ref_positions_ = std::move(ref_positions); + pdic->blendings_ = std::move(blendings); + pdic->ComputePatchTree(); + } + + static void SubtractFrom(const PatchDictionary& pdic, Image3F* opsin); +}; + +void FindBestPatchDictionary(const Image3F& opsin, + PassesEncoderState* JXL_RESTRICT state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, bool is_xyb = true); + +void RoundtripPatchFrame(Image3F* reference_frame, + PassesEncoderState* JXL_RESTRICT state, int idx, + CompressParams& cparams, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out, bool subtract); + +} // namespace jxl + +#endif // LIB_JXL_ENC_PATCH_DICTIONARY_H_ diff --git a/media/libjxl/src/lib/jxl/enc_photon_noise.cc b/media/libjxl/src/lib/jxl/enc_photon_noise.cc new file mode 100644 index 000000000..3786ef5cf --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_photon_noise.cc @@ -0,0 +1,89 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_photon_noise.h" + +namespace jxl { + +namespace { + +// Assumes a daylight-like spectrum. +// https://www.strollswithmydog.com/effective-quantum-efficiency-of-sensor/#:~:text=11%2C260%20photons/um%5E2/lx-s +constexpr float kPhotonsPerLxSPerUm2 = 11260; + +// Order of magnitude for cameras in the 2010-2020 decade, taking the CFA into +// account. +constexpr float kEffectiveQuantumEfficiency = 0.20; + +// TODO(sboukortt): reevaluate whether these are good defaults, notably whether +// it would be worth making read noise higher at lower ISO settings. +constexpr float kPhotoResponseNonUniformity = 0.005; +constexpr float kInputReferredReadNoise = 3; + +// Assumes a 35mm sensor. +constexpr float kSensorAreaUm2 = 36000.f * 24000; + +template +inline constexpr T Square(const T x) { + return x * x; +} +template +inline constexpr T Cube(const T x) { + return x * x * x; +} + +} // namespace + +NoiseParams SimulatePhotonNoise(const size_t xsize, const size_t ysize, + const float iso) { + const float kOpsinAbsorbanceBiasCbrt = std::cbrt(kOpsinAbsorbanceBias[1]); + + // Focal plane exposure for 18% of kDefaultIntensityTarget, in lx·s. + // (ISO = 10 lx·s ÷ H) + const float h_18 = 10 / iso; + + const float pixel_area_um2 = kSensorAreaUm2 / (xsize * ysize); + + const float electrons_per_pixel_18 = kEffectiveQuantumEfficiency * + kPhotonsPerLxSPerUm2 * h_18 * + pixel_area_um2; + + NoiseParams params; + + for (size_t i = 0; i < NoiseParams::kNumNoisePoints; ++i) { + const float scaled_index = i / (NoiseParams::kNumNoisePoints - 2.f); + // scaled_index is used for XYB = (0, 2·scaled_index, 2·scaled_index) + const float y = 2 * scaled_index; + // 1 = default intensity target + const float linear = std::max( + 0.f, Cube(y - kOpsinAbsorbanceBiasCbrt) + kOpsinAbsorbanceBias[1]); + const float electrons_per_pixel = electrons_per_pixel_18 * (linear / 0.18f); + // Quadrature sum of read noise, photon shot noise (sqrt(S) so simply not + // squared here) and photo response non-uniformity. + // https://doi.org/10.1117/3.725073 + // Units are electrons rms. + const float noise = + std::sqrt(Square(kInputReferredReadNoise) + electrons_per_pixel + + Square(kPhotoResponseNonUniformity * electrons_per_pixel)); + const float linear_noise = noise * (0.18f / electrons_per_pixel_18); + const float opsin_derivative = + (1.f / 3) / Square(std::cbrt(linear - kOpsinAbsorbanceBias[1])); + const float opsin_noise = linear_noise * opsin_derivative; + + // TODO(sboukortt): verify more thoroughly whether the denominator is + // correct. + params.lut[i] = + Clamp1(opsin_noise / + (0.22f // norm_const + * std::sqrt(2.f) // red_noise + green_noise + * 1.13f // standard deviation of a plane of generated noise + ), + 0.f, 1.f); + } + + return params; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_photon_noise.h b/media/libjxl/src/lib/jxl/enc_photon_noise.h new file mode 100644 index 000000000..f43e14d56 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_photon_noise.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_PHOTON_NOISE_H_ +#define LIB_JXL_ENC_PHOTON_NOISE_H_ + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +// Constructs a NoiseParams representing the noise that would be seen at the +// selected nominal exposure on a last-decade (as of 2021) color camera with a +// 36×24mm sensor (“35mm format”). +NoiseParams SimulatePhotonNoise(size_t xsize, size_t ysize, float iso); + +} // namespace jxl + +#endif // LIB_JXL_ENC_PHOTON_NOISE_H_ diff --git a/media/libjxl/src/lib/jxl/enc_photon_noise_test.cc b/media/libjxl/src/lib/jxl/enc_photon_noise_test.cc new file mode 100644 index 000000000..83707255d --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_photon_noise_test.cc @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_photon_noise.h" + +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +using ::testing::FloatNear; +using ::testing::Pointwise; + +MATCHER(AreApproximatelyEqual, "") { + constexpr float kTolerance = 1e-6; + const float actual = std::get<0>(arg); + const float expected = std::get<1>(arg); + return testing::ExplainMatchResult(FloatNear(expected, kTolerance), actual, + result_listener); +} + +TEST(EncPhotonNoiseTest, LUTs) { + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/6000, /*ysize=*/4000, /*iso=*/100).lut, + Pointwise(AreApproximatelyEqual(), + {0.00259652, 0.0139648, 0.00681551, 0.00632582, 0.00694917, + 0.00803922, 0.00934574, 0.0107607})); + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/6000, /*ysize=*/4000, /*iso=*/800).lut, + Pointwise(AreApproximatelyEqual(), + {0.02077220, 0.0420923, 0.01820690, 0.01439020, 0.01293670, + 0.01254030, 0.01277390, 0.0134161})); + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/6000, /*ysize=*/4000, /*iso=*/6400).lut, + Pointwise(AreApproximatelyEqual(), + {0.1661770, 0.1691120, 0.05309080, 0.03963960, 0.03357410, + 0.03001650, 0.02776740, 0.0263478})); + + // Lower when measured on a per-pixel basis as there are fewer of them. + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/4000, /*ysize=*/3000, /*iso=*/6400).lut, + Pointwise(AreApproximatelyEqual(), + {0.0830886, 0.1008720, 0.0367748, 0.0280305, 0.0240236, + 0.0218040, 0.0205771, 0.0200058})); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_quant_weights.cc b/media/libjxl/src/lib/jxl/enc_quant_weights.cc new file mode 100644 index 000000000..d8a9931a5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_quant_weights.cc @@ -0,0 +1,210 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_quant_weights.h" + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace { + +Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) { + JXL_ASSERT(params.num_distance_bands >= 1); + writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands, + params.num_distance_bands - 1); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params.num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer)); + } + } + return true; +} + +Status EncodeQuant(const QuantEncoding& encoding, size_t idx, size_t size_x, + size_t size_y, BitWriter* writer, + ModularFrameEncoder* modular_frame_encoder) { + writer->Write(kLog2NumQuantModes, encoding.mode); + size_x *= kBlockDim; + size_y *= kBlockDim; + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined); + break; + } + case QuantEncoding::kQuantModeID: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer)); + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + encoding.dct2weights[c][i] * (1.0f / 64), writer)); + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.dct4x8multipliers[c], writer)); + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.dct4multipliers[c][i], writer)); + } + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeRAW: { + ModularFrameEncoder::EncodeQuantTable(size_x, size_y, writer, encoding, + idx, modular_frame_encoder); + break; + } + case QuantEncoding::kQuantModeAFV: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer)); + } + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params_afv_4x4, writer)); + break; + } + } + return true; +} + +} // namespace + +Status DequantMatricesEncode(const DequantMatrices* matrices, BitWriter* writer, + size_t layer, AuxOut* aux_out, + ModularFrameEncoder* modular_frame_encoder) { + bool all_default = true; + const std::vector& encodings = matrices->encodings(); + + for (size_t i = 0; i < encodings.size(); i++) { + if (encodings[i].mode != QuantEncoding::kQuantModeLibrary || + encodings[i].predefined != 0) { + all_default = false; + } + } + // TODO(janwas): better bound + BitWriter::Allotment allotment(writer, 512 * 1024); + writer->Write(1, all_default); + if (!all_default) { + for (size_t i = 0; i < encodings.size(); i++) { + JXL_RETURN_IF_ERROR(EncodeQuant( + encodings[i], i, DequantMatrices::required_size_x[i], + DequantMatrices::required_size_y[i], writer, modular_frame_encoder)); + } + } + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return true; +} + +Status DequantMatricesEncodeDC(const DequantMatrices* matrices, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + bool all_default = true; + const float* dc_quant = matrices->DCQuants(); + for (size_t c = 0; c < 3; c++) { + if (dc_quant[c] != kDCQuant[c]) { + all_default = false; + } + } + BitWriter::Allotment allotment(writer, 1 + sizeof(float) * kBitsPerByte * 3); + writer->Write(1, all_default); + if (!all_default) { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer)); + } + } + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return true; +} + +void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) { + matrices->SetDCQuant(dc); + // Roundtrip encode/decode DC to ensure same values as decoder. + BitWriter writer; + JXL_CHECK(DequantMatricesEncodeDC(matrices, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + // Called only in the encoder: should fail only for programmer errors. + JXL_CHECK(matrices->DecodeDC(&br)); + JXL_CHECK(br.Close()); +} + +void DequantMatricesScaleDC(DequantMatrices* matrices, const float scale) { + float dc[3]; + for (size_t c = 0; c < 3; ++c) { + dc[c] = matrices->InvDCQuant(c) * (1.0f / scale); + } + DequantMatricesSetCustomDC(matrices, dc); +} + +void DequantMatricesSetCustom(DequantMatrices* matrices, + const std::vector& encodings, + ModularFrameEncoder* encoder) { + JXL_ASSERT(encodings.size() == DequantMatrices::kNum); + matrices->SetEncodings(encodings); + for (size_t i = 0; i < encodings.size(); i++) { + if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) { + encoder->AddQuantTable(DequantMatrices::required_size_x[i] * kBlockDim, + DequantMatrices::required_size_y[i] * kBlockDim, + encodings[i], i); + } + } + // Roundtrip encode/decode the matrices to ensure same values as decoder. + // Do not pass modular en/decoder, as they only change entropy and not + // values. + BitWriter writer; + JXL_CHECK(DequantMatricesEncode(matrices, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + // Called only in the encoder: should fail only for programmer errors. + JXL_CHECK(matrices->Decode(&br)); + JXL_CHECK(br.Close()); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_quant_weights.h b/media/libjxl/src/lib/jxl/enc_quant_weights.h new file mode 100644 index 000000000..fe5273cf7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_quant_weights.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_QUANT_WEIGHTS_H_ +#define LIB_JXL_ENC_QUANT_WEIGHTS_H_ + +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +Status DequantMatricesEncode( + const DequantMatrices* matrices, BitWriter* writer, size_t layer, + AuxOut* aux_out, ModularFrameEncoder* modular_frame_encoder = nullptr); +Status DequantMatricesEncodeDC(const DequantMatrices* matrices, + BitWriter* writer, size_t layer, + AuxOut* aux_out); +// For consistency with QuantEncoding, higher values correspond to more +// precision. +void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc); + +void DequantMatricesScaleDC(DequantMatrices* matrices, const float scale); + +void DequantMatricesSetCustom(DequantMatrices* matrices, + const std::vector& encodings, + ModularFrameEncoder* encoder); + +} // namespace jxl + +#endif // LIB_JXL_ENC_QUANT_WEIGHTS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_splines.cc b/media/libjxl/src/lib/jxl/enc_splines.cc new file mode 100644 index 000000000..cdb797dc6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_splines.cc @@ -0,0 +1,96 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +class QuantizedSplineEncoder { + public: + // Only call if HasAny(). + static void Tokenize(const QuantizedSpline& spline, + std::vector* const tokens) { + tokens->emplace_back(kNumControlPointsContext, + spline.control_points_.size()); + for (const auto& point : spline.control_points_) { + tokens->emplace_back(kControlPointsContext, PackSigned(point.first)); + tokens->emplace_back(kControlPointsContext, PackSigned(point.second)); + } + const auto encode_dct = [tokens](const int dct[32]) { + for (int i = 0; i < 32; ++i) { + tokens->emplace_back(kDCTContext, PackSigned(dct[i])); + } + }; + for (int c = 0; c < 3; ++c) { + encode_dct(spline.color_dct_[c]); + } + encode_dct(spline.sigma_dct_); + } +}; + +namespace { + +void EncodeAllStartingPoints(const std::vector& points, + std::vector* tokens) { + int64_t last_x = 0; + int64_t last_y = 0; + for (size_t i = 0; i < points.size(); i++) { + const int64_t x = lroundf(points[i].x); + const int64_t y = lroundf(points[i].y); + if (i == 0) { + tokens->emplace_back(kStartingPositionContext, x); + tokens->emplace_back(kStartingPositionContext, y); + } else { + tokens->emplace_back(kStartingPositionContext, PackSigned(x - last_x)); + tokens->emplace_back(kStartingPositionContext, PackSigned(y - last_y)); + } + last_x = x; + last_y = y; + } +} + +} // namespace + +void EncodeSplines(const Splines& splines, BitWriter* writer, + const size_t layer, const HistogramParams& histogram_params, + AuxOut* aux_out) { + JXL_ASSERT(splines.HasAny()); + + const std::vector& quantized_splines = + splines.QuantizedSplines(); + std::vector> tokens(1); + tokens[0].emplace_back(kNumSplinesContext, quantized_splines.size() - 1); + EncodeAllStartingPoints(splines.StartingPoints(), &tokens[0]); + + tokens[0].emplace_back(kQuantizationAdjustmentContext, + PackSigned(splines.GetQuantizationAdjustment())); + + for (const QuantizedSpline& spline : quantized_splines) { + QuantizedSplineEncoder::Tokenize(spline, &tokens[0]); + } + + EntropyEncodingData codes; + std::vector context_map; + BuildAndEncodeHistograms(histogram_params, kNumSplineContexts, tokens, &codes, + &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); +} + +Splines FindSplines(const Image3F& opsin) { + // TODO: implement spline detection. + return {}; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_splines.h b/media/libjxl/src/lib/jxl/enc_splines.h new file mode 100644 index 000000000..732d77ac2 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_splines.h @@ -0,0 +1,39 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_SPLINES_H_ +#define LIB_JXL_ENC_SPLINES_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +// Only call if splines.HasAny(). +void EncodeSplines(const Splines& splines, BitWriter* writer, + const size_t layer, const HistogramParams& histogram_params, + AuxOut* aux_out); + +Splines FindSplines(const Image3F& opsin); + +} // namespace jxl + +#endif // LIB_JXL_ENC_SPLINES_H_ diff --git a/media/libjxl/src/lib/jxl/enc_toc.cc b/media/libjxl/src/lib/jxl/enc_toc.cc new file mode 100644 index 000000000..c877b0c83 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_toc.cc @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_toc.h" + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/toc.h" + +namespace jxl { +Status WriteGroupOffsets(const std::vector& group_codes, + const std::vector* permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, MaxBits(group_codes.size())); + if (permutation && !group_codes.empty()) { + // Don't write a permutation at all for an empty group_codes. + writer->Write(1, 1); // permutation + JXL_DASSERT(permutation->size() == group_codes.size()); + EncodePermutation(permutation->data(), /*skip=*/0, permutation->size(), + writer, /* layer= */ 0, aux_out); + + } else { + writer->Write(1, 0); // no permutation + } + writer->ZeroPadToByte(); // before TOC entries + + for (size_t i = 0; i < group_codes.size(); i++) { + JXL_ASSERT(group_codes[i].BitsWritten() % kBitsPerByte == 0); + const size_t group_size = group_codes[i].BitsWritten() / kBitsPerByte; + JXL_RETURN_IF_ERROR(U32Coder::Write(kTocDist, group_size, writer)); + } + writer->ZeroPadToByte(); // before first group + ReclaimAndCharge(writer, &allotment, kLayerTOC, aux_out); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_toc.h b/media/libjxl/src/lib/jxl/enc_toc.h new file mode 100644 index 000000000..dc81a5d12 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_toc.h @@ -0,0 +1,29 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_TOC_H_ +#define LIB_JXL_ENC_TOC_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Writes the group offsets. If the permutation vector is nullptr, the identity +// permutation will be used. +Status WriteGroupOffsets(const std::vector& group_codes, + const std::vector* permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_TOC_H_ diff --git a/media/libjxl/src/lib/jxl/enc_transforms-inl.h b/media/libjxl/src/lib/jxl/enc_transforms-inl.h new file mode 100644 index 000000000..ef6dc2bbd --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_transforms-inl.h @@ -0,0 +1,827 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_ENC_TRANSFORMS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_ENC_TRANSFORMS_INL_H_ +#undef LIB_JXL_ENC_TRANSFORMS_INL_H_ +#else +#define LIB_JXL_ENC_TRANSFORMS_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_scales.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// Inverse of ReinterpretingDCT. +template +HWY_INLINE void ReinterpretingIDCT(const float* input, + const size_t input_stride, float* output, + const size_t output_stride) { + HWY_ALIGN float block[ROWS * COLS] = {}; + if (ROWS < COLS) { + for (size_t y = 0; y < LF_ROWS; y++) { + for (size_t x = 0; x < LF_COLS; x++) { + block[y * COLS + x] = input[y * input_stride + x] * + DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } else { + for (size_t y = 0; y < LF_COLS; y++) { + for (size_t x = 0; x < LF_ROWS; x++) { + block[y * ROWS + x] = input[y * input_stride + x] * + DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } + + // ROWS, COLS <= 8, so we can put scratch space on the stack. + HWY_ALIGN float scratch_space[ROWS * COLS]; + ComputeScaledIDCT()(block, DCTTo(output, output_stride), + scratch_space); +} + +template +void DCT2TopBlock(const float* block, size_t stride, float* out) { + static_assert(kBlockDim % S == 0, "S should be a divisor of kBlockDim"); + static_assert(S % 2 == 0, "S should be even"); + float temp[kDCTBlockSize]; + constexpr size_t num_2x2 = S / 2; + for (size_t y = 0; y < num_2x2; y++) { + for (size_t x = 0; x < num_2x2; x++) { + float c00 = block[y * 2 * stride + x * 2]; + float c01 = block[y * 2 * stride + x * 2 + 1]; + float c10 = block[(y * 2 + 1) * stride + x * 2]; + float c11 = block[(y * 2 + 1) * stride + x * 2 + 1]; + float r00 = c00 + c01 + c10 + c11; + float r01 = c00 + c01 - c10 - c11; + float r10 = c00 - c01 + c10 - c11; + float r11 = c00 - c01 - c10 + c11; + r00 *= 0.25f; + r01 *= 0.25f; + r10 *= 0.25f; + r11 *= 0.25f; + temp[y * kBlockDim + x] = r00; + temp[y * kBlockDim + num_2x2 + x] = r01; + temp[(y + num_2x2) * kBlockDim + x] = r10; + temp[(y + num_2x2) * kBlockDim + num_2x2 + x] = r11; + } + } + for (size_t y = 0; y < S; y++) { + for (size_t x = 0; x < S; x++) { + out[y * kBlockDim + x] = temp[y * kBlockDim + x]; + } + } +} + +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs) { + HWY_ALIGN static constexpr float k4x4AFVBasisTranspose[16][16] = { + { + 0.2500000000000000, + 0.8769029297991420f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + -0.4105377591765233f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + 0.2206518106944235f, + 0.0000000000000000, + 0.0000000000000000, + -0.7071067811865474f, + 0.6235485373547691f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + 0.4067007583026075f, + -0.2125574805828875f, + 0.0000000000000000, + -0.0643507165794627f, + -0.4517556589999482f, + -0.3046847507248690f, + 0.3017929516615495f, + 0.4082482904638627f, + 0.1747866975480809f, + -0.2110560104933578f, + -0.1426608480880726f, + -0.1381354035075859f, + -0.1743760259965107f, + 0.1135498731499434f, + }, + { + 0.2500000000000000, + -0.1014005039375375f, + 0.4444481661973445f, + 0.3085497062849767f, + 0.0000000000000000f, + -0.0643507165794627f, + 0.1585450355184006f, + 0.5112616136591823f, + 0.2579236279634118f, + 0.0000000000000000, + 0.0812611176717539f, + 0.1856718091610980f, + -0.3416446842253372f, + 0.3302282550303788f, + 0.0702790691196284f, + -0.0741750459581035f, + }, + { + 0.2500000000000000, + 0.2206518106944236f, + 0.0000000000000000, + 0.0000000000000000, + 0.7071067811865476f, + 0.6235485373547694f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + -0.1014005039375378f, + 0.0000000000000000, + 0.4706702258572536f, + 0.0000000000000000, + -0.0643507165794628f, + -0.0403851516082220f, + 0.0000000000000000, + 0.1627234014286620f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.7367497537172237f, + 0.0875511500058708f, + -0.2921026642334881f, + 0.1940289303259434f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + 0.1957439937204294f, + -0.1621205195722993f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0074182263792424f, + -0.2904801297289980f, + 0.0952002265347504f, + 0.0000000000000000, + -0.3675398009862027f, + 0.4921585901373873f, + 0.2462710772207515f, + -0.0794670660590957f, + 0.3623817333531167f, + -0.4351904965232280f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + 0.2929100136981264f, + 0.0000000000000000, + 0.0000000000000000, + -0.0643507165794627f, + 0.3935103426921017f, + -0.0657870154914280f, + 0.0000000000000000, + -0.4082482904638628f, + -0.3078822139579090f, + -0.3852501370925192f, + -0.0857401903551931f, + -0.4613374887461511f, + 0.0000000000000000, + 0.2191868483885747f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.4067007583026072f, + -0.2125574805828705f, + 0.0000000000000000, + -0.0643507165794627f, + -0.4517556589999464f, + 0.3046847507248840f, + 0.3017929516615503f, + -0.4082482904638635f, + -0.1747866975480813f, + 0.2110560104933581f, + -0.1426608480880734f, + -0.1381354035075829f, + -0.1743760259965108f, + 0.1135498731499426f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + -0.1957439937204287f, + -0.1621205195722833f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0074182263792444f, + 0.2904801297290076f, + 0.0952002265347505f, + 0.0000000000000000, + 0.3675398009862011f, + -0.4921585901373891f, + 0.2462710772207514f, + -0.0794670660591026f, + 0.3623817333531165f, + -0.4351904965232251f, + }, + { + 0.2500000000000000, + -0.1014005039375375f, + 0.0000000000000000, + -0.4706702258572528f, + 0.0000000000000000, + -0.0643507165794627f, + 0.1107416575309343f, + 0.0000000000000000, + -0.1627234014286617f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.1488339922711357f, + 0.4972464710953509f, + 0.2921026642334879f, + 0.5550443808910661f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + 0.1137907446044809f, + -0.1464291867126764f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0829816309488205f, + -0.2388977352334460f, + -0.3531238544981630f, + -0.4082482904638630f, + 0.4826689115059883f, + 0.1741941265991622f, + -0.0476868035022925f, + 0.1253805944856366f, + -0.4326608024727445f, + -0.2546827712406646f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + -0.4444481661973438f, + 0.3085497062849487f, + 0.0000000000000000, + -0.0643507165794628f, + 0.1585450355183970f, + -0.5112616136592012f, + 0.2579236279634129f, + 0.0000000000000000, + -0.0812611176717504f, + -0.1856718091610990f, + -0.3416446842253373f, + 0.3302282550303805f, + 0.0702790691196282f, + -0.0741750459581023f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.2929100136981264f, + 0.0000000000000000, + 0.0000000000000000, + -0.0643507165794627f, + 0.3935103426921022f, + 0.0657870154914254f, + 0.0000000000000000, + 0.4082482904638634f, + 0.3078822139579031f, + 0.3852501370925211f, + -0.0857401903551927f, + -0.4613374887461554f, + 0.0000000000000000, + 0.2191868483885728f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.1137907446044814f, + -0.1464291867126654f, + 0.0000000000000000, + -0.0643507165794627f, + 0.0829816309488214f, + 0.2388977352334547f, + -0.3531238544981624f, + 0.4082482904638630f, + -0.4826689115059858f, + -0.1741941265991621f, + -0.0476868035022928f, + 0.1253805944856431f, + -0.4326608024727457f, + -0.2546827712406641f, + }, + { + 0.2500000000000000, + -0.1014005039375374f, + 0.0000000000000000, + 0.4251149611657548f, + 0.0000000000000000, + -0.0643507165794626f, + -0.4517556589999480f, + 0.0000000000000000, + -0.6035859033230976f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + -0.1426608480880724f, + -0.1381354035075845f, + 0.3487520519930227f, + 0.1135498731499429f, + }, + }; + + const HWY_CAPPED(float, 16) d; + for (size_t i = 0; i < 16; i += Lanes(d)) { + auto scalar = Zero(d); + for (size_t j = 0; j < 16; j++) { + auto px = Set(d, pixels[j]); + auto basis = Load(d, k4x4AFVBasisTranspose[j] + i); + scalar = MulAdd(px, basis, scalar); + } + Store(scalar, d, coeffs + i); + } +} + +// Coefficient layout: +// - (even, even) positions hold AFV coefficients +// - (odd, even) positions hold DCT4x4 coefficients +// - (any, odd) positions hold DCT4x8 coefficients +template +void AFVTransformFromPixels(const float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* JXL_RESTRICT coefficients) { + HWY_ALIGN float scratch_space[4 * 8 * 2]; + size_t afv_x = afv_kind & 1; + size_t afv_y = afv_kind / 2; + HWY_ALIGN float block[4 * 8]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + block[(afv_y == 1 ? 3 - iy : iy) * 4 + (afv_x == 1 ? 3 - ix : ix)] = + pixels[(iy + 4 * afv_y) * pixels_stride + ix + 4 * afv_x]; + } + } + // AFV coefficients in (even, even) positions. + HWY_ALIGN float coeff[4 * 4]; + AFVDCT4x4(block, coeff); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + coefficients[iy * 2 * 8 + ix * 2] = coeff[iy * 4 + ix]; + } + } + // 4x4 DCT of the block with same y and different x. + ComputeScaledDCT<4, 4>()( + DCTFrom(pixels + afv_y * 4 * pixels_stride + (afv_x == 1 ? 0 : 4), + pixels_stride), + block, scratch_space); + // ... in (odd, even) positions. + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[iy * 2 * 8 + ix * 2 + 1] = block[iy * 4 + ix]; + } + } + // 4x8 DCT of the other half of the block. + ComputeScaledDCT<4, 8>()( + DCTFrom(pixels + (afv_y == 1 ? 0 : 4) * pixels_stride, pixels_stride), + block, scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[(1 + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + float block00 = coefficients[0] * 0.25f; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + coefficients[0] = (block00 + block01 + 2 * block10) * 0.25f; + coefficients[1] = (block00 - block01) * 0.5f; + coefficients[8] = (block00 + block01 - 2 * block10) * 0.25f; +} + +HWY_MAYBE_UNUSED void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT scratch_space) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::IDENTITY: { + PROFILER_ZONE("DCT Identity"); + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + float block_dc = 0; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + block_dc += pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix]; + } + } + block_dc *= 1.0f / 16; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 1 && iy == 1) continue; + coefficients[(y + iy * 2) * 8 + x + ix * 2] = + pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix] - + pixels[(y * 4 + 1) * pixels_stride + x * 4 + 1]; + } + } + coefficients[(y + 2) * 8 + x + 2] = coefficients[y * 8 + x]; + coefficients[y * 8 + x] = block_dc; + } + } + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + coefficients[0] = (block00 + block01 + block10 + block11) * 0.25f; + coefficients[1] = (block00 + block01 - block10 - block11) * 0.25f; + coefficients[8] = (block00 - block01 + block10 - block11) * 0.25f; + coefficients[9] = (block00 - block01 - block10 + block11) * 0.25f; + break; + } + case Type::DCT8X4: { + PROFILER_ZONE("DCT 8x4"); + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 8]; + ComputeScaledDCT<8, 4>()(DCTFrom(pixels + x * 4, pixels_stride), block, + scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + // Store transposed. + coefficients[(x + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + } + float block0 = coefficients[0]; + float block1 = coefficients[8]; + coefficients[0] = (block0 + block1) * 0.5f; + coefficients[8] = (block0 - block1) * 0.5f; + break; + } + case Type::DCT4X8: { + PROFILER_ZONE("DCT 4x8"); + for (size_t y = 0; y < 2; y++) { + HWY_ALIGN float block[4 * 8]; + ComputeScaledDCT<4, 8>()( + DCTFrom(pixels + y * 4 * pixels_stride, pixels_stride), block, + scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[(y + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + } + float block0 = coefficients[0]; + float block1 = coefficients[8]; + coefficients[0] = (block0 + block1) * 0.5f; + coefficients[8] = (block0 - block1) * 0.5f; + break; + } + case Type::DCT4X4: { + PROFILER_ZONE("DCT 4"); + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 4]; + ComputeScaledDCT<4, 4>()( + DCTFrom(pixels + y * 4 * pixels_stride + x * 4, pixels_stride), + block, scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + coefficients[(y + iy * 2) * 8 + x + ix * 2] = block[iy * 4 + ix]; + } + } + } + } + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + coefficients[0] = (block00 + block01 + block10 + block11) * 0.25f; + coefficients[1] = (block00 + block01 - block10 - block11) * 0.25f; + coefficients[8] = (block00 - block01 + block10 - block11) * 0.25f; + coefficients[9] = (block00 - block01 - block10 + block11) * 0.25f; + break; + } + case Type::DCT2X2: { + PROFILER_ZONE("DCT 2"); + DCT2TopBlock<8>(pixels, pixels_stride, coefficients); + DCT2TopBlock<4>(coefficients, kBlockDim, coefficients); + DCT2TopBlock<2>(coefficients, kBlockDim, coefficients); + break; + } + case Type::DCT16X16: { + PROFILER_ZONE("DCT 16"); + ComputeScaledDCT<16, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT16X8: { + PROFILER_ZONE("DCT 16x8"); + ComputeScaledDCT<16, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT8X16: { + PROFILER_ZONE("DCT 8x16"); + ComputeScaledDCT<8, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X8: { + PROFILER_ZONE("DCT 32x8"); + ComputeScaledDCT<32, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT8X32: { + PROFILER_ZONE("DCT 8x32"); + ComputeScaledDCT<8, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X16: { + PROFILER_ZONE("DCT 32x16"); + ComputeScaledDCT<32, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT16X32: { + PROFILER_ZONE("DCT 16x32"); + ComputeScaledDCT<16, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X32: { + PROFILER_ZONE("DCT 32"); + ComputeScaledDCT<32, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT: { + PROFILER_ZONE("DCT 8"); + ComputeScaledDCT<8, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::AFV0: { + PROFILER_ZONE("AFV0"); + AFVTransformFromPixels<0>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV1: { + PROFILER_ZONE("AFV1"); + AFVTransformFromPixels<1>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV2: { + PROFILER_ZONE("AFV2"); + AFVTransformFromPixels<2>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV3: { + PROFILER_ZONE("AFV3"); + AFVTransformFromPixels<3>(pixels, pixels_stride, coefficients); + break; + } + case Type::DCT64X64: { + PROFILER_ZONE("DCT 64x64"); + ComputeScaledDCT<64, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT64X32: { + PROFILER_ZONE("DCT 64x32"); + ComputeScaledDCT<64, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X64: { + PROFILER_ZONE("DCT 32x64"); + ComputeScaledDCT<32, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X128: { + PROFILER_ZONE("DCT 128x128"); + ComputeScaledDCT<128, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X64: { + PROFILER_ZONE("DCT 128x64"); + ComputeScaledDCT<128, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT64X128: { + PROFILER_ZONE("DCT 64x128"); + ComputeScaledDCT<64, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT256X256: { + PROFILER_ZONE("DCT 256x256"); + ComputeScaledDCT<256, 256>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT256X128: { + PROFILER_ZONE("DCT 256x128"); + ComputeScaledDCT<256, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X256: { + PROFILER_ZONE("DCT 128x256"); + ComputeScaledDCT<128, 256>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + } +} + +HWY_MAYBE_UNUSED void DCFromLowestFrequencies(const AcStrategy::Type strategy, + const float* block, float* dc, + size_t dc_stride) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::DCT16X8: { + ReinterpretingIDCT( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT8X16: { + ReinterpretingIDCT( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT16X16: { + ReinterpretingIDCT( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X8: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT8X32: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X16: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT16X32: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X32: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X32: { + ReinterpretingIDCT( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X64: { + ReinterpretingIDCT( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X64: { + ReinterpretingIDCT( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X64: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/8, /*ROWS=*/16, /*COLS=*/8>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/16, /*ROWS=*/8, /*COLS=*/16>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/16, /*ROWS=*/16, /*COLS=*/16>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT256X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/16, /*ROWS=*/32, /*COLS=*/16>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X256: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/32, /*ROWS=*/16, /*COLS=*/32>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT256X256: { + ReinterpretingIDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/32, /*ROWS=*/32, /*COLS=*/32>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT: + case Type::DCT2X2: + case Type::DCT4X4: + case Type::DCT4X8: + case Type::DCT8X4: + case Type::AFV0: + case Type::AFV1: + case Type::AFV2: + case Type::AFV3: + case Type::IDENTITY: + dc[0] = block[0]; + break; + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_ENC_TRANSFORMS_INL_H_ diff --git a/media/libjxl/src/lib/jxl/enc_transforms.cc b/media/libjxl/src/lib/jxl/enc_transforms.cc new file mode 100644 index 000000000..8978ba1dc --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_transforms.cc @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_transforms.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_transforms.cc" +#include +#include + +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_transforms-inl.h" + +namespace jxl { + +#if HWY_ONCE +HWY_EXPORT(TransformFromPixels); +void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* scratch_space) { + return HWY_DYNAMIC_DISPATCH(TransformFromPixels)( + strategy, pixels, pixels_stride, coefficients, scratch_space); +} + +HWY_EXPORT(DCFromLowestFrequencies); +void DCFromLowestFrequencies(AcStrategy::Type strategy, const float* block, + float* dc, size_t dc_stride) { + return HWY_DYNAMIC_DISPATCH(DCFromLowestFrequencies)(strategy, block, dc, + dc_stride); +} + +HWY_EXPORT(AFVDCT4x4); +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs) { + return HWY_DYNAMIC_DISPATCH(AFVDCT4x4)(pixels, coeffs); +} +#endif // HWY_ONCE + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/enc_transforms.h b/media/libjxl/src/lib/jxl/enc_transforms.h new file mode 100644 index 000000000..039ccc389 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_transforms.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_TRANSFORMS_H_ +#define LIB_JXL_ENC_TRANSFORMS_H_ + +// Facade for (non-inlined) integral transforms. + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT scratch_space); + +// Equivalent of the above for DC image. +void DCFromLowestFrequencies(AcStrategy::Type strategy, const float* block, + float* dc, size_t dc_stride); + +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs); + +} // namespace jxl + +#endif // LIB_JXL_ENC_TRANSFORMS_H_ diff --git a/media/libjxl/src/lib/jxl/enc_xyb.cc b/media/libjxl/src/lib/jxl/enc_xyb.cc new file mode 100644 index 000000000..577e29686 --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_xyb.cc @@ -0,0 +1,381 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_xyb.h" + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_xyb.cc" +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_image_bundle.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// 4x3 matrix * 3x1 SIMD vectors +template +JXL_INLINE void OpsinAbsorbance(const V r, const V g, const V b, + const float* JXL_RESTRICT premul_absorb, + V* JXL_RESTRICT mixed0, V* JXL_RESTRICT mixed1, + V* JXL_RESTRICT mixed2) { + const float* bias = &kOpsinAbsorbanceBias[0]; + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const auto m0 = Load(d, premul_absorb + 0 * N); + const auto m1 = Load(d, premul_absorb + 1 * N); + const auto m2 = Load(d, premul_absorb + 2 * N); + const auto m3 = Load(d, premul_absorb + 3 * N); + const auto m4 = Load(d, premul_absorb + 4 * N); + const auto m5 = Load(d, premul_absorb + 5 * N); + const auto m6 = Load(d, premul_absorb + 6 * N); + const auto m7 = Load(d, premul_absorb + 7 * N); + const auto m8 = Load(d, premul_absorb + 8 * N); + *mixed0 = MulAdd(m0, r, MulAdd(m1, g, MulAdd(m2, b, Set(d, bias[0])))); + *mixed1 = MulAdd(m3, r, MulAdd(m4, g, MulAdd(m5, b, Set(d, bias[1])))); + *mixed2 = MulAdd(m6, r, MulAdd(m7, g, MulAdd(m8, b, Set(d, bias[2])))); +} + +template +void StoreXYB(const V r, V g, const V b, float* JXL_RESTRICT valx, + float* JXL_RESTRICT valy, float* JXL_RESTRICT valz) { + const HWY_FULL(float) d; + const V half = Set(d, 0.5f); + Store(Mul(half, Sub(r, g)), d, valx); + Store(Mul(half, Add(r, g)), d, valy); + Store(b, d, valz); +} + +// Converts one RGB vector to XYB. +template +void LinearRGBToXYB(const V r, const V g, const V b, + const float* JXL_RESTRICT premul_absorb, + float* JXL_RESTRICT valx, float* JXL_RESTRICT valy, + float* JXL_RESTRICT valz) { + V mixed0, mixed1, mixed2; + OpsinAbsorbance(r, g, b, premul_absorb, &mixed0, &mixed1, &mixed2); + + // mixed* should be non-negative even for wide-gamut, so clamp to zero. + mixed0 = ZeroIfNegative(mixed0); + mixed1 = ZeroIfNegative(mixed1); + mixed2 = ZeroIfNegative(mixed2); + + const HWY_FULL(float) d; + const size_t N = Lanes(d); + mixed0 = CubeRootAndAdd(mixed0, Load(d, premul_absorb + 9 * N)); + mixed1 = CubeRootAndAdd(mixed1, Load(d, premul_absorb + 10 * N)); + mixed2 = CubeRootAndAdd(mixed2, Load(d, premul_absorb + 11 * N)); + StoreXYB(mixed0, mixed1, mixed2, valx, valy, valz); + + // For wide-gamut inputs, r/g/b and valx (but not y/z) are often negative. +} + +// Input/output uses the codec.h scaling: nominally 0-1 if in-gamut. +template +V LinearFromSRGB(V encoded) { + return TF_SRGB().DisplayFromEncoded(encoded); +} + +Status LinearSRGBToXYB(const Image3F& linear, + const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT xyb) { + const size_t xsize = linear.xsize(); + + const HWY_FULL(float) d; + return RunOnPool( + pool, 0, static_cast(linear.ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT row_in0 = linear.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_in1 = linear.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_in2 = linear.ConstPlaneRow(2, y); + float* JXL_RESTRICT row_xyb0 = xyb->PlaneRow(0, y); + float* JXL_RESTRICT row_xyb1 = xyb->PlaneRow(1, y); + float* JXL_RESTRICT row_xyb2 = xyb->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = Load(d, row_in0 + x); + const auto in_g = Load(d, row_in1 + x); + const auto in_b = Load(d, row_in2 + x); + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_xyb0 + x, + row_xyb1 + x, row_xyb2 + x); + } + }, + "LinearToXYB"); +} + +Status SRGBToXYB(const Image3F& srgb, const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT xyb) { + const size_t xsize = srgb.xsize(); + + const HWY_FULL(float) d; + return RunOnPool( + pool, 0, static_cast(srgb.ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT row_srgb0 = srgb.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_srgb1 = srgb.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_srgb2 = srgb.ConstPlaneRow(2, y); + float* JXL_RESTRICT row_xyb0 = xyb->PlaneRow(0, y); + float* JXL_RESTRICT row_xyb1 = xyb->PlaneRow(1, y); + float* JXL_RESTRICT row_xyb2 = xyb->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = LinearFromSRGB(Load(d, row_srgb0 + x)); + const auto in_g = LinearFromSRGB(Load(d, row_srgb1 + x)); + const auto in_b = LinearFromSRGB(Load(d, row_srgb2 + x)); + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_xyb0 + x, + row_xyb1 + x, row_xyb2 + x); + } + }, + "SRGBToXYB"); +} + +Status SRGBToXYBAndLinear(const Image3F& srgb, + const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT xyb, + Image3F* JXL_RESTRICT linear) { + const size_t xsize = srgb.xsize(); + + const HWY_FULL(float) d; + return RunOnPool( + pool, 0, static_cast(srgb.ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT row_srgb0 = srgb.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_srgb1 = srgb.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_srgb2 = srgb.ConstPlaneRow(2, y); + + float* JXL_RESTRICT row_linear0 = linear->PlaneRow(0, y); + float* JXL_RESTRICT row_linear1 = linear->PlaneRow(1, y); + float* JXL_RESTRICT row_linear2 = linear->PlaneRow(2, y); + + float* JXL_RESTRICT row_xyb0 = xyb->PlaneRow(0, y); + float* JXL_RESTRICT row_xyb1 = xyb->PlaneRow(1, y); + float* JXL_RESTRICT row_xyb2 = xyb->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = LinearFromSRGB(Load(d, row_srgb0 + x)); + const auto in_g = LinearFromSRGB(Load(d, row_srgb1 + x)); + const auto in_b = LinearFromSRGB(Load(d, row_srgb2 + x)); + + Store(in_r, d, row_linear0 + x); + Store(in_g, d, row_linear1 + x); + Store(in_b, d, row_linear2 + x); + + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_xyb0 + x, + row_xyb1 + x, row_xyb2 + x); + } + }, + "SRGBToXYBAndLinear"); +} + +// This is different from Butteraugli's OpsinDynamicsImage() in the sense that +// it does not contain a sensitivity multiplier based on the blurred image. +const ImageBundle* ToXYB(const ImageBundle& in, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb, const JxlCmsInterface& cms, + ImageBundle* const JXL_RESTRICT linear) { + PROFILER_FUNC; + + const size_t xsize = in.xsize(); + const size_t ysize = in.ysize(); + JXL_ASSERT(SameSize(in, *xyb)); + + const HWY_FULL(float) d; + // Pre-broadcasted constants + HWY_ALIGN float premul_absorb[MaxLanes(d) * 12]; + const size_t N = Lanes(d); + for (size_t i = 0; i < 9; ++i) { + const auto absorb = Set(d, kOpsinAbsorbanceMatrix[i] * + (in.metadata()->IntensityTarget() / 255.0f)); + Store(absorb, d, premul_absorb + i * N); + } + for (size_t i = 0; i < 3; ++i) { + const auto neg_bias_cbrt = Set(d, -cbrtf(kOpsinAbsorbanceBias[i])); + Store(neg_bias_cbrt, d, premul_absorb + (9 + i) * N); + } + + const bool want_linear = linear != nullptr; + + const ColorEncoding& c_linear_srgb = ColorEncoding::LinearSRGB(in.IsGray()); + // Linear sRGB inputs are rare but can be useful for the fastest encoders, for + // which undoing the sRGB transfer function would be a large part of the cost. + if (c_linear_srgb.SameColorEncoding(in.c_current())) { + JXL_CHECK(LinearSRGBToXYB(in.color(), premul_absorb, pool, xyb)); + // This only happens if kitten or slower, moving ImageBundle might be + // possible but the encoder is much slower than this copy. + if (want_linear) { + *linear = in.Copy(); + return linear; + } + return ∈ + } + + // Common case: already sRGB, can avoid the color transform + if (in.IsSRGB()) { + // Common case: can avoid allocating/copying + if (!want_linear) { + JXL_CHECK(SRGBToXYB(in.color(), premul_absorb, pool, xyb)); + return ∈ + } + + // Slow encoder also wants linear sRGB. + linear->SetFromImage(Image3F(xsize, ysize), c_linear_srgb); + JXL_CHECK(SRGBToXYBAndLinear(in.color(), premul_absorb, pool, xyb, + linear->color())); + return linear; + } + + // General case: not sRGB, need color transform. + ImageBundle linear_storage; // Local storage only used if !want_linear. + + ImageBundle* linear_storage_ptr; + if (want_linear) { + // Caller asked for linear, use that storage directly. + linear_storage_ptr = linear; + } else { + // Caller didn't ask for linear, create our own local storage + // OK to reuse metadata, it will not be changed. + linear_storage = ImageBundle(const_cast(in.metadata())); + linear_storage_ptr = &linear_storage; + } + + const ImageBundle* ptr; + JXL_CHECK(TransformIfNeeded(in, c_linear_srgb, cms, pool, linear_storage_ptr, + &ptr)); + // If no transform was necessary, should have taken the above codepath. + JXL_ASSERT(ptr == linear_storage_ptr); + + JXL_CHECK( + LinearSRGBToXYB(*linear_storage_ptr->color(), premul_absorb, pool, xyb)); + return want_linear ? linear : ∈ +} + +// Transform RGB to YCbCr. +// Could be performed in-place (i.e. Y, Cb and Cr could alias R, B and B). +Status RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool) { + const HWY_FULL(float) df; + const size_t S = Lanes(df); // Step. + + const size_t xsize = r_plane.xsize(); + const size_t ysize = r_plane.ysize(); + if ((xsize == 0) || (ysize == 0)) return true; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto k128 = Set(df, 128.0f / 255); + const auto kR = Set(df, 0.299f); // NTSC luma + const auto kG = Set(df, 0.587f); + const auto kB = Set(df, 0.114f); + const auto kAmpR = Set(df, 0.701f); + const auto kAmpB = Set(df, 0.886f); + const auto kDiffR = Add(kAmpR, kR); + const auto kDiffB = Add(kAmpB, kB); + const auto kNormR = Div(Set(df, 1.0f), (Add(kAmpR, Add(kG, kB)))); + const auto kNormB = Div(Set(df, 1.0f), (Add(kR, Add(kG, kAmpB)))); + + constexpr size_t kGroupArea = kGroupDim * kGroupDim; + const size_t lines_per_group = DivCeil(kGroupArea, xsize); + const size_t num_stripes = DivCeil(ysize, lines_per_group); + const auto transform = [&](int idx, int /* thread*/) { + const size_t y0 = idx * lines_per_group; + const size_t y1 = std::min(y0 + lines_per_group, ysize); + for (size_t y = y0; y < y1; ++y) { + const float* r_row = r_plane.ConstRow(y); + const float* g_row = g_plane.ConstRow(y); + const float* b_row = b_plane.ConstRow(y); + float* y_row = y_plane->Row(y); + float* cb_row = cb_plane->Row(y); + float* cr_row = cr_plane->Row(y); + for (size_t x = 0; x < xsize; x += S) { + const auto r = Load(df, r_row + x); + const auto g = Load(df, g_row + x); + const auto b = Load(df, b_row + x); + const auto r_base = Mul(r, kR); + const auto r_diff = Mul(r, kDiffR); + const auto g_base = Mul(g, kG); + const auto b_base = Mul(b, kB); + const auto b_diff = Mul(b, kDiffB); + const auto y_base = Add(r_base, Add(g_base, b_base)); + const auto y_vec = Sub(y_base, k128); + const auto cb_vec = Mul(Sub(b_diff, y_base), kNormB); + const auto cr_vec = Mul(Sub(r_diff, y_base), kNormR); + Store(y_vec, df, y_row + x); + Store(cb_vec, df, cb_row + x); + Store(cr_vec, df, cr_row + x); + } + } + }; + return RunOnPool(pool, 0, static_cast(num_stripes), ThreadPool::NoInit, + transform, "RgbToYcbCr"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ToXYB); +const ImageBundle* ToXYB(const ImageBundle& in, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb, const JxlCmsInterface& cms, + ImageBundle* JXL_RESTRICT linear_storage) { + return HWY_DYNAMIC_DISPATCH(ToXYB)(in, pool, xyb, cms, linear_storage); +} + +HWY_EXPORT(RgbToYcbcr); +Status RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(RgbToYcbcr)(r_plane, g_plane, b_plane, y_plane, + cb_plane, cr_plane, pool); +} + +// DEPRECATED +Image3F OpsinDynamicsImage(const Image3B& srgb8, const JxlCmsInterface& cms) { + ImageMetadata metadata; + metadata.SetUintSamples(8); + metadata.color_encoding = ColorEncoding::SRGB(); + ImageBundle ib(&metadata); + ib.SetFromImage(ConvertToFloat(srgb8), metadata.color_encoding); + JXL_CHECK(ib.TransformTo(ColorEncoding::LinearSRGB(ib.IsGray()), cms)); + ThreadPool* null_pool = nullptr; + Image3F xyb(srgb8.xsize(), srgb8.ysize()); + + ImageBundle linear_storage(&metadata); + (void)ToXYB(ib, null_pool, &xyb, cms, &linear_storage); + return xyb; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/enc_xyb.h b/media/libjxl/src/lib/jxl/enc_xyb.h new file mode 100644 index 000000000..de8f2e3ff --- /dev/null +++ b/media/libjxl/src/lib/jxl/enc_xyb.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_XYB_H_ +#define LIB_JXL_ENC_XYB_H_ + +// Converts to XYB color space. + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Converts any color space to XYB. If `linear` is not null, returns `linear` +// after filling it with a linear sRGB copy of `in`. Otherwise, returns `&in`. +// +// NOTE this return value can avoid an extra color conversion if `in` would +// later be passed to JxlButteraugliComparator. +const ImageBundle* ToXYB(const ImageBundle& in, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb, const JxlCmsInterface& cms, + ImageBundle* JXL_RESTRICT linear = nullptr); + +// Bt.601 to match JPEG/JFIF. Outputs _signed_ YCbCr values suitable for DCT, +// see F.1.1.3 of T.81 (because our data type is float, there is no need to add +// a bias to make the values unsigned). +Status RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool); + +// DEPRECATED, used by opsin_image_wrapper. +Image3F OpsinDynamicsImage(const Image3B& srgb8, const JxlCmsInterface& cms); + +} // namespace jxl + +#endif // LIB_JXL_ENC_XYB_H_ diff --git a/media/libjxl/src/lib/jxl/encode.cc b/media/libjxl/src/lib/jxl/encode.cc new file mode 100644 index 000000000..8e02dd630 --- /dev/null +++ b/media/libjxl/src/lib/jxl/encode.cc @@ -0,0 +1,1822 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/encode.h" + +#include + +#include +#include +#include + +#include "jxl/codestream_header.h" +#include "jxl/types.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/exif.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/sanitizers.h" + +// Debug-printing failure macro similar to JXL_FAILURE, but for the status code +// JXL_ENC_ERROR +#ifdef JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(enc, error_code, format, ...) \ + (enc->error = error_code, \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JXL_ENC_ERROR) +#define JXL_API_ERROR_NOSET(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JXL_ENC_ERROR) +#else // JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(enc, error_code, format, ...) \ + (enc->error = error_code, \ + ((JXL_DEBUG_ON_ERROR) && \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__)), \ + JXL_ENC_ERROR) +#define JXL_API_ERROR_NOSET(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + JXL_ENC_ERROR) +#endif // JXL_CRASH_ON_ERROR + +namespace jxl {} // namespace jxl + +uint32_t JxlEncoderVersion(void) { + return JPEGXL_MAJOR_VERSION * 1000000 + JPEGXL_MINOR_VERSION * 1000 + + JPEGXL_PATCH_VERSION; +} + +namespace { +template +void AppendJxlpBoxCounter(uint32_t counter, bool last, T* output) { + if (last) counter |= 0x80000000; + for (size_t i = 0; i < 4; i++) { + output->push_back(counter >> (8 * (3 - i)) & 0xff); + } +} + +void QueueFrame( + const JxlEncoderFrameSettings* frame_settings, + jxl::MemoryManagerUniquePtr& frame) { + if (frame_settings->values.lossless) { + frame->option_values.cparams.SetLossless(); + } + + jxl::JxlEncoderQueuedInput queued_input(frame_settings->enc->memory_manager); + queued_input.frame = std::move(frame); + frame_settings->enc->input_queue.emplace_back(std::move(queued_input)); + frame_settings->enc->num_queued_frames++; +} + +void QueueBox(JxlEncoder* enc, + jxl::MemoryManagerUniquePtr& box) { + jxl::JxlEncoderQueuedInput queued_input(enc->memory_manager); + queued_input.box = std::move(box); + enc->input_queue.emplace_back(std::move(queued_input)); + enc->num_queued_boxes++; +} + +// TODO(lode): share this code and the Brotli compression code in enc_jpeg_data +JxlEncoderStatus BrotliCompress(int quality, const uint8_t* in, size_t in_size, + jxl::PaddedBytes* out) { + std::unique_ptr + enc(BrotliEncoderCreateInstance(nullptr, nullptr, nullptr), + BrotliEncoderDestroyInstance); + if (!enc) return JXL_API_ERROR_NOSET("BrotliEncoderCreateInstance failed"); + + BrotliEncoderSetParameter(enc.get(), BROTLI_PARAM_QUALITY, quality); + BrotliEncoderSetParameter(enc.get(), BROTLI_PARAM_SIZE_HINT, in_size); + + constexpr size_t kBufferSize = 128 * 1024; + jxl::PaddedBytes temp_buffer(kBufferSize); + + size_t avail_in = in_size; + const uint8_t* next_in = in; + + size_t total_out = 0; + + for (;;) { + size_t avail_out = kBufferSize; + uint8_t* next_out = temp_buffer.data(); + jxl::msan::MemoryIsInitialized(next_in, avail_in); + if (!BrotliEncoderCompressStream(enc.get(), BROTLI_OPERATION_FINISH, + &avail_in, &next_in, &avail_out, &next_out, + &total_out)) { + return JXL_API_ERROR_NOSET("Brotli compression failed"); + } + size_t out_size = next_out - temp_buffer.data(); + jxl::msan::UnpoisonMemory(next_out - out_size, out_size); + out->resize(out->size() + out_size); + memcpy(out->data() + out->size() - out_size, temp_buffer.data(), out_size); + if (BrotliEncoderIsFinished(enc.get())) break; + } + + return JXL_ENC_SUCCESS; +} + +// The JXL codestream can have level 5 or level 10. Levels have certain +// restrictions such as max allowed image dimensions. This function checks the +// level required to support the current encoder settings. The debug_string is +// intended to be used for developer API error messages, and may be set to +// nullptr. +int VerifyLevelSettings(const JxlEncoder* enc, std::string* debug_string) { + const auto& m = enc->metadata.m; + + uint64_t xsize = enc->metadata.size.xsize(); + uint64_t ysize = enc->metadata.size.ysize(); + // The uncompressed ICC size, if it is used. + size_t icc_size = 0; + if (m.color_encoding.WantICC()) { + icc_size = m.color_encoding.ICC().size(); + } + + // Level 10 checks + + if (xsize > (1ull << 30ull) || ysize > (1ull << 30ull) || + xsize * ysize > (1ull << 40ull)) { + if (debug_string) *debug_string = "Too large image dimensions"; + return -1; + } + if (icc_size > (1ull << 28)) { + if (debug_string) *debug_string = "Too large ICC profile size"; + return -1; + } + if (m.num_extra_channels > 256) { + if (debug_string) *debug_string = "Too many extra channels"; + return -1; + } + + // Level 5 checks + + if (!m.modular_16_bit_buffer_sufficient) { + if (debug_string) *debug_string = "Too high modular bit depth"; + return 10; + } + if (xsize > (1ull << 18ull) || ysize > (1ull << 18ull) || + xsize * ysize > (1ull << 28ull)) { + if (debug_string) *debug_string = "Too large image dimensions"; + return 10; + } + if (icc_size > (1ull << 22)) { + if (debug_string) *debug_string = "Too large ICC profile"; + return 10; + } + if (m.num_extra_channels > 4) { + if (debug_string) *debug_string = "Too many extra channels"; + return 10; + } + for (size_t i = 0; i < m.extra_channel_info.size(); ++i) { + if (m.extra_channel_info[i].type == jxl::ExtraChannel::kBlack) { + if (debug_string) *debug_string = "CMYK channel not allowed"; + return 10; + } + } + + // TODO(lode): also need to check if consecutive composite-still frames total + // pixel amount doesn't exceed 2**28 in the case of level 5. This should be + // done when adding frame and requires ability to add composite still frames + // to be added first. + + // TODO(lode): also need to check animation duration of a frame. This should + // be done when adding frame, but first requires implementing setting the + // JxlFrameHeader for a frame. + + // TODO(lode): also need to check properties such as num_splines, num_patches, + // modular_16bit_buffers and multiple properties of modular trees. However + // these are not user-set properties so cannot be checked here, but decisions + // the C++ encoder should be able to make based on the level. + + // All level 5 checks passes, so can return the more compatible level 5 + return 5; +} +JxlEncoderStatus CheckValidBitdepth(uint32_t bits_per_sample, + uint32_t exponent_bits_per_sample) { + if (!exponent_bits_per_sample) { + // The spec allows up to 31 for bits_per_sample here, but + // the code does not (yet) support it. + if (!(bits_per_sample > 0 && bits_per_sample <= 24)) { + return JXL_API_ERROR_NOSET("Invalid value for bits_per_sample"); + } + } else if ((exponent_bits_per_sample > 8) || + (bits_per_sample > 24 + exponent_bits_per_sample) || + (bits_per_sample < 3 + exponent_bits_per_sample)) { + return JXL_API_ERROR_NOSET("Invalid float description"); + } + return JXL_ENC_SUCCESS; +} + +bool EncodeFrameIndexBox(const jxl::JxlEncoderFrameIndexBox& frame_index_box, + jxl::BitWriter& writer) { + bool ok = true; + int NF = 0; + for (size_t i = 0; i < frame_index_box.entries.size(); ++i) { + if (i == 0 || frame_index_box.entries[i].to_be_indexed) { + ++NF; + } + } + // Frame index box contents varint + 8 bytes + // continue with NF * 3 * varint + // varint max length is 10 for 64 bit numbers, and these numbers + // are limited to 63 bits. + static const int kVarintMaxLength = 10; + static const int kFrameIndexBoxHeaderLength = kVarintMaxLength + 8; + static const int kFrameIndexBoxElementLength = 3 * kVarintMaxLength; + const int buffer_size = + kFrameIndexBoxHeaderLength + NF * kFrameIndexBoxElementLength; + std::vector buffer_vec(buffer_size); + uint8_t* buffer = buffer_vec.data(); + size_t output_pos = 0; + ok &= jxl::EncodeVarInt(NF, buffer_vec.size(), &output_pos, buffer); + StoreBE32(frame_index_box.TNUM, &buffer[output_pos]); + output_pos += 4; + StoreBE32(frame_index_box.TDEN, &buffer[output_pos]); + output_pos += 4; + // When we record a frame in the index, the record needs to know + // how many frames until the next indexed frame. That is why + // we store the 'prev' record. That 'prev' record needs to store + // the offset byte position to previously recorded indexed frame, + // that's why we also trace previous to the previous frame. + int prev_prev_ix = -1; // For position offset (OFFi) delta coding. + int prev_ix = 0; + int T_prev = 0; + int T = 0; + for (size_t i = 1; i < frame_index_box.entries.size(); ++i) { + if (frame_index_box.entries[i].to_be_indexed) { + // Now we can record the previous entry, since we need to store + // there how many frames until the next one. + int64_t OFFi = frame_index_box.entries[prev_ix].OFFi; + if (prev_prev_ix != -1) { + // Offi needs to be offset of start byte of this frame compared to start + // byte of previous frame from this index in the JPEG XL codestream. For + // the first frame, this is the offset from the first byte of the JPEG + // XL codestream. + OFFi -= frame_index_box.entries[prev_prev_ix].OFFi; + } + int32_t Ti = T_prev; + int32_t Fi = i - prev_ix; + ok &= jxl::EncodeVarInt(OFFi, buffer_vec.size(), &output_pos, buffer); + ok &= jxl::EncodeVarInt(Ti, buffer_vec.size(), &output_pos, buffer); + ok &= jxl::EncodeVarInt(Fi, buffer_vec.size(), &output_pos, buffer); + prev_prev_ix = prev_ix; + prev_ix = i; + T_prev = T; + T += frame_index_box.entries[i].duration; + } + } + { + // Last frame. + size_t i = frame_index_box.entries.size(); + int64_t OFFi = frame_index_box.entries[prev_ix].OFFi; + if (prev_prev_ix != -1) { + OFFi -= frame_index_box.entries[prev_prev_ix].OFFi; + } + int32_t Ti = T_prev; + int32_t Fi = i - prev_ix; + ok &= jxl::EncodeVarInt(OFFi, buffer_vec.size(), &output_pos, buffer); + ok &= jxl::EncodeVarInt(Ti, buffer_vec.size(), &output_pos, buffer); + ok &= jxl::EncodeVarInt(Fi, buffer_vec.size(), &output_pos, buffer); + } + // Enough buffer has been allocated, this function should never fail in + // writing. + JXL_ASSERT(ok); + return ok; +} + +} // namespace + +JxlEncoderStatus JxlEncoderStruct::RefillOutputByteQueue() { + jxl::PaddedBytes bytes; + + jxl::JxlEncoderQueuedInput& input = input_queue[0]; + + // TODO(lode): split this into 3 functions: for adding the signature and other + // initial headers (jbrd, ...), one for adding frame, and one for adding user + // box. + + if (!wrote_bytes) { + // First time encoding any data, verify the level 5 vs level 10 settings + std::string level_message; + int required_level = VerifyLevelSettings(this, &level_message); + // Only level 5 and 10 are defined, and the function can return -1 to + // indicate full incompatibility. + JXL_ASSERT(required_level == -1 || required_level == 5 || + required_level == 10); + // codestream_level == -1 means auto-set to the required level + if (codestream_level == -1) codestream_level = required_level; + if (codestream_level == 5 && required_level != 5) { + // If the required level is 10, return error rather than automatically + // setting the level to 10, to avoid inadvertently creating a level 10 + // JXL file while intending to target a level 5 decoder. + return JXL_API_ERROR( + this, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level 5 failed: " + level_message) + .c_str()); + } + if (required_level == -1) { + return JXL_API_ERROR( + this, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level 10 failed: " + + level_message) + .c_str()); + } + + jxl::BitWriter writer; + if (!WriteHeaders(&metadata, &writer, nullptr)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Failed to write codestream header"); + } + // Only send ICC (at least several hundred bytes) if fields aren't enough. + if (metadata.m.color_encoding.WantICC()) { + if (!jxl::WriteICC(metadata.m.color_encoding.ICC(), &writer, + jxl::kLayerHeader, nullptr)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Failed to write ICC profile"); + } + } + // TODO(lode): preview should be added here if a preview image is added + + writer.ZeroPadToByte(); + + // Not actually the end of frame, but the end of metadata/ICC, but helps + // the next frame to start here for indexing purposes. + codestream_bytes_written_end_of_frame += + jxl::DivCeil(writer.BitsWritten(), 8); + + bytes = std::move(writer).TakeBytes(); + + if (MustUseContainer()) { + // Add "JXL " and ftyp box. + output_byte_queue.insert( + output_byte_queue.end(), jxl::kContainerHeader, + jxl::kContainerHeader + sizeof(jxl::kContainerHeader)); + if (codestream_level != 5) { + // Add jxll box directly after the ftyp box to indicate the codestream + // level. + output_byte_queue.insert( + output_byte_queue.end(), jxl::kLevelBoxHeader, + jxl::kLevelBoxHeader + sizeof(jxl::kLevelBoxHeader)); + output_byte_queue.push_back(codestream_level); + } + + // Whether to write the basic info and color profile header of the + // codestream into an early separate jxlp box, so that it comes before + // metadata or jpeg reconstruction boxes. In theory this could simply + // always be done, but there's no reason to add an extra box with box + // header overhead if the codestream will already come immediately after + // the signature and level boxes. + bool partial_header = store_jpeg_metadata || (use_boxes && !input.frame); + + if (partial_header) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlp"), bytes.size() + 4, + /*unbounded=*/false, &output_byte_queue); + AppendJxlpBoxCounter(jxlp_counter++, /*last=*/false, + &output_byte_queue); + output_byte_queue.insert(output_byte_queue.end(), bytes.data(), + bytes.data() + bytes.size()); + bytes.clear(); + } + + if (store_jpeg_metadata && !jpeg_metadata.empty()) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_metadata.size(), + false, &output_byte_queue); + output_byte_queue.insert(output_byte_queue.end(), jpeg_metadata.begin(), + jpeg_metadata.end()); + } + } + wrote_bytes = true; + } + + // Choose frame or box processing: exactly one of the two unique pointers (box + // or frame) in the input queue item is non-null. + if (input.frame) { + jxl::MemoryManagerUniquePtr input_frame = + std::move(input.frame); + input_queue.erase(input_queue.begin()); + num_queued_frames--; + for (unsigned idx = 0; idx < input_frame->ec_initialized.size(); idx++) { + if (!input_frame->ec_initialized[idx]) { + return JXL_API_ERROR(this, JXL_ENC_ERR_API_USAGE, + "Extra channel %u is not initialized", idx); + } + } + + // TODO(zond): If the input queue is empty and the frames_closed is true, + // then mark this frame as the last. + + // TODO(zond): Handle progressive mode like EncodeFile does it. + // TODO(zond): Handle animation like EncodeFile does it, by checking if + // JxlEncoderCloseFrames has been called and if the frame queue + // is empty (to see if it's the last animation frame). + + if (metadata.m.xyb_encoded) { + input_frame->option_values.cparams.color_transform = + jxl::ColorTransform::kXYB; + } else { + // TODO(zond): Figure out when to use kYCbCr instead. + input_frame->option_values.cparams.color_transform = + jxl::ColorTransform::kNone; + } + + jxl::BitWriter writer; + jxl::PassesEncoderState enc_state; + + // EncodeFrame creates jxl::FrameHeader object internally based on the + // FrameInfo, imagebundle, cparams and metadata. Copy the information to + // these. + jxl::ImageBundle& ib = input_frame->frame; + ib.name = input_frame->option_values.frame_name; + if (metadata.m.have_animation) { + ib.duration = input_frame->option_values.header.duration; + ib.timecode = input_frame->option_values.header.timecode; + } else { + // If have_animation is false, the encoder should ignore the duration and + // timecode values. However, assigning them to ib will cause the encoder + // to write an invalid frame header that can't be decoded so ensure + // they're the default value of 0 here. + ib.duration = 0; + ib.timecode = 0; + } + frame_index_box.AddFrame(codestream_bytes_written_end_of_frame, ib.duration, + input_frame->option_values.frame_index_box); + ib.blendmode = static_cast( + input_frame->option_values.header.layer_info.blend_info.blendmode); + ib.blend = + input_frame->option_values.header.layer_info.blend_info.blendmode != + JXL_BLEND_REPLACE; + + size_t save_as_reference = + input_frame->option_values.header.layer_info.save_as_reference; + ib.use_for_next_frame = !!save_as_reference; + + jxl::FrameInfo frame_info; + bool last_frame = frames_closed && !num_queued_frames; + frame_info.is_last = last_frame; + frame_info.save_as_reference = save_as_reference; + frame_info.source = + input_frame->option_values.header.layer_info.blend_info.source; + frame_info.clamp = + input_frame->option_values.header.layer_info.blend_info.clamp; + frame_info.alpha_channel = + input_frame->option_values.header.layer_info.blend_info.alpha; + frame_info.extra_channel_blending_info.resize( + metadata.m.num_extra_channels); + // If extra channel blend info has not been set, use the blend mode from the + // layer_info. + JxlBlendInfo default_blend_info = + input_frame->option_values.header.layer_info.blend_info; + for (size_t i = 0; i < metadata.m.num_extra_channels; ++i) { + auto& to = frame_info.extra_channel_blending_info[i]; + const auto& from = + i < input_frame->option_values.extra_channel_blend_info.size() + ? input_frame->option_values.extra_channel_blend_info[i] + : default_blend_info; + to.mode = static_cast(from.blendmode); + to.source = from.source; + to.alpha_channel = from.alpha; + to.clamp = (from.clamp != 0); + } + + if (input_frame->option_values.header.layer_info.have_crop) { + ib.origin.x0 = input_frame->option_values.header.layer_info.crop_x0; + ib.origin.y0 = input_frame->option_values.header.layer_info.crop_y0; + } + JXL_ASSERT(writer.BitsWritten() == 0); + if (!jxl::EncodeFrame(input_frame->option_values.cparams, frame_info, + &metadata, input_frame->frame, &enc_state, cms, + thread_pool.get(), &writer, + /*aux_out=*/nullptr)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, "Failed to encode frame"); + } + codestream_bytes_written_beginning_of_frame = + codestream_bytes_written_end_of_frame; + codestream_bytes_written_end_of_frame += + jxl::DivCeil(writer.BitsWritten(), 8); + + // Possibly bytes already contains the codestream header: in case this is + // the first frame, and the codestream header was not encoded as jxlp above. + bytes.append(std::move(writer).TakeBytes()); + if (MustUseContainer()) { + if (last_frame && jxlp_counter == 0) { + // If this is the last frame and no jxlp boxes were used yet, it's + // slighly more efficient to write a jxlc box since it has 4 bytes less + // overhead. + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlc"), bytes.size(), + /*unbounded=*/false, &output_byte_queue); + } else { + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlp"), bytes.size() + 4, + /*unbounded=*/false, &output_byte_queue); + AppendJxlpBoxCounter(jxlp_counter++, last_frame, &output_byte_queue); + } + } + + output_byte_queue.insert(output_byte_queue.end(), bytes.data(), + bytes.data() + bytes.size()); + + last_used_cparams = input_frame->option_values.cparams; + if (last_frame && frame_index_box.StoreFrameIndexBox()) { + bytes.clear(); + EncodeFrameIndexBox(frame_index_box, writer); + jxl::AppendBoxHeader(jxl::MakeBoxType("jxli"), bytes.size(), + /*unbounded=*/false, &output_byte_queue); + } + } else { + // Not a frame, so is a box instead + jxl::MemoryManagerUniquePtr box = + std::move(input.box); + input_queue.erase(input_queue.begin()); + num_queued_boxes--; + + if (box->compress_box) { + jxl::PaddedBytes compressed(4); + // Prepend the original box type in the brob box contents + for (size_t i = 0; i < 4; i++) { + compressed[i] = static_cast(box->type[i]); + } + if (JXL_ENC_SUCCESS != + BrotliCompress((brotli_effort >= 0 ? brotli_effort : 4), + box->contents.data(), box->contents.size(), + &compressed)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Brotli compression for brob box failed"); + } + jxl::AppendBoxHeader(jxl::MakeBoxType("brob"), compressed.size(), false, + &output_byte_queue); + output_byte_queue.insert(output_byte_queue.end(), compressed.data(), + compressed.data() + compressed.size()); + } else { + jxl::AppendBoxHeader(box->type, box->contents.size(), false, + &output_byte_queue); + output_byte_queue.insert(output_byte_queue.end(), box->contents.data(), + box->contents.data() + box->contents.size()); + } + } + + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetColorEncoding(JxlEncoder* enc, + const JxlColorEncoding* color) { + if (!enc->basic_info_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Basic info not yet set"); + } + if (enc->color_encoding_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Color encoding is already set"); + } + if (!jxl::ConvertExternalToInternalColorEncoding( + *color, &enc->metadata.m.color_encoding)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_GENERIC, "Error in color conversion"); + } + if (enc->metadata.m.color_encoding.GetColorSpace() == + jxl::ColorSpace::kGray) { + if (enc->basic_info.num_color_channels != 1) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "Cannot use grayscale color encoding with num_color_channels != 1"); + } else { + if (enc->basic_info.num_color_channels != 3) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "Cannot use RGB color encoding with num_color_channels != 3"); + } + enc->color_encoding_set = true; + if (!enc->intensity_target_set) { + jxl::SetIntensityTarget(&enc->metadata.m); + } + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetICCProfile(JxlEncoder* enc, + const uint8_t* icc_profile, + size_t size) { + if (!enc->basic_info_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Basic info not yet set"); + } + if (enc->color_encoding_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "ICC profile is already set"); + } + jxl::PaddedBytes icc; + icc.assign(icc_profile, icc_profile + size); + if (!enc->metadata.m.color_encoding.SetICC(std::move(icc))) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_BAD_INPUT, + "ICC profile could not be set"); + } + if (enc->metadata.m.color_encoding.GetColorSpace() == + jxl::ColorSpace::kGray) { + if (enc->basic_info.num_color_channels != 1) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_BAD_INPUT, + "Cannot use grayscale ICC profile with num_color_channels != 1"); + } else { + if (enc->basic_info.num_color_channels != 3) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_BAD_INPUT, + "Cannot use RGB ICC profile with num_color_channels != 3"); + // TODO(jon): also check that a kBlack extra channel is provided in the CMYK + // case + } + enc->color_encoding_set = true; + if (!enc->intensity_target_set) { + jxl::SetIntensityTarget(&enc->metadata.m); + } + + if (!enc->basic_info.uses_original_profile) { + enc->metadata.m.color_encoding.DecideIfWantICC(); + } + + return JXL_ENC_SUCCESS; +} + +void JxlEncoderInitBasicInfo(JxlBasicInfo* info) { + info->have_container = JXL_FALSE; + info->xsize = 0; + info->ysize = 0; + info->bits_per_sample = 8; + info->exponent_bits_per_sample = 0; + info->intensity_target = 0.f; + info->min_nits = 0.f; + info->relative_to_max_display = JXL_FALSE; + info->linear_below = 0.f; + info->uses_original_profile = JXL_FALSE; + info->have_preview = JXL_FALSE; + info->have_animation = JXL_FALSE; + info->orientation = JXL_ORIENT_IDENTITY; + info->num_color_channels = 3; + info->num_extra_channels = 0; + info->alpha_bits = 0; + info->alpha_exponent_bits = 0; + info->alpha_premultiplied = JXL_FALSE; + info->preview.xsize = 0; + info->preview.ysize = 0; + info->intrinsic_xsize = 0; + info->intrinsic_ysize = 0; + info->animation.tps_numerator = 10; + info->animation.tps_denominator = 1; + info->animation.num_loops = 0; + info->animation.have_timecodes = JXL_FALSE; +} + +void JxlEncoderInitFrameHeader(JxlFrameHeader* frame_header) { + // For each field, the default value of the specification is used. Depending + // on wheter an animation frame, or a composite still blending frame, is used, + // different fields have to be set up by the user after initing the frame + // header. + frame_header->duration = 0; + frame_header->timecode = 0; + frame_header->name_length = 0; + // In the specification, the default value of is_last is !frame_type, and the + // default frame_type is kRegularFrame which has value 0, so is_last is true + // by default. However, the encoder does not use this value (the field exists + // for the decoder to set) since last frame is determined by usage of + // JxlEncoderCloseFrames instead. + frame_header->is_last = JXL_TRUE; + frame_header->layer_info.have_crop = JXL_FALSE; + frame_header->layer_info.crop_x0 = 0; + frame_header->layer_info.crop_y0 = 0; + // These must be set if have_crop is enabled, but the default value has + // have_crop false, and these dimensions 0. The user must set these to the + // desired size after enabling have_crop (which is not yet implemented). + frame_header->layer_info.xsize = 0; + frame_header->layer_info.ysize = 0; + JxlEncoderInitBlendInfo(&frame_header->layer_info.blend_info); + frame_header->layer_info.save_as_reference = 0; +} + +void JxlEncoderInitBlendInfo(JxlBlendInfo* blend_info) { + // Default blend mode in the specification is 0. Note that combining + // blend mode of replace with a duration is not useful, but the user has to + // manually set duration in case of animation, or manually change the blend + // mode in case of composite stills, so initing to a combination that is not + // useful on its own is not an issue. + blend_info->blendmode = JXL_BLEND_REPLACE; + blend_info->source = 0; + blend_info->alpha = 0; + blend_info->clamp = 0; +} + +JxlEncoderStatus JxlEncoderSetBasicInfo(JxlEncoder* enc, + const JxlBasicInfo* info) { + if (!enc->metadata.size.Set(info->xsize, info->ysize)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid dimensions"); + } + if (JXL_ENC_SUCCESS != CheckValidBitdepth(info->bits_per_sample, + info->exponent_bits_per_sample)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid bit depth"); + } + enc->metadata.m.bit_depth.bits_per_sample = info->bits_per_sample; + enc->metadata.m.bit_depth.exponent_bits_per_sample = + info->exponent_bits_per_sample; + enc->metadata.m.bit_depth.floating_point_sample = + (info->exponent_bits_per_sample != 0u); + enc->metadata.m.modular_16_bit_buffer_sufficient = + (!info->uses_original_profile || info->bits_per_sample <= 12) && + info->alpha_bits <= 12; + + // The number of extra channels includes the alpha channel, so for example and + // RGBA with no other extra channels, has exactly num_extra_channels == 1 + enc->metadata.m.num_extra_channels = info->num_extra_channels; + enc->metadata.m.extra_channel_info.resize(enc->metadata.m.num_extra_channels); + if (info->num_extra_channels == 0 && info->alpha_bits) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "when alpha_bits is non-zero, the number of channels must be at least " + "1"); + } + // If the user provides non-zero alpha_bits, we make the channel info at index + // zero the appropriate alpha channel. + if (info->alpha_bits) { + JxlExtraChannelInfo channel_info; + JxlEncoderInitExtraChannelInfo(JXL_CHANNEL_ALPHA, &channel_info); + channel_info.bits_per_sample = info->alpha_bits; + channel_info.exponent_bits_per_sample = info->alpha_exponent_bits; + if (JxlEncoderSetExtraChannelInfo(enc, 0, &channel_info)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Problem setting extra channel info for alpha"); + } + } + + enc->metadata.m.xyb_encoded = !info->uses_original_profile; + if (info->orientation > 0 && info->orientation <= 8) { + enc->metadata.m.orientation = info->orientation; + } else { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for orientation field"); + } + if (info->num_color_channels != 1 && info->num_color_channels != 3) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid number of color channels"); + } + if (info->intensity_target != 0) { + enc->metadata.m.SetIntensityTarget(info->intensity_target); + enc->intensity_target_set = true; + } else if (enc->color_encoding_set || enc->metadata.m.xyb_encoded) { + // If both conditions are false, JxlEncoderSetColorEncoding will be called + // later and we will get one more chance to call jxl::SetIntensityTarget, + // after the color encoding is indeed set. + jxl::SetIntensityTarget(&enc->metadata.m); + enc->intensity_target_set = true; + } + enc->metadata.m.tone_mapping.min_nits = info->min_nits; + enc->metadata.m.tone_mapping.relative_to_max_display = + info->relative_to_max_display; + enc->metadata.m.tone_mapping.linear_below = info->linear_below; + enc->basic_info = *info; + enc->basic_info_set = true; + + enc->metadata.m.have_animation = info->have_animation; + if (info->have_animation) { + if (info->animation.tps_denominator < 1) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "If animation is used, tps_denominator must be >= 1"); + } + if (info->animation.tps_numerator < 1) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "If animation is used, tps_numerator must be >= 1"); + } + enc->metadata.m.animation.tps_numerator = info->animation.tps_numerator; + enc->metadata.m.animation.tps_denominator = info->animation.tps_denominator; + enc->metadata.m.animation.num_loops = info->animation.num_loops; + enc->metadata.m.animation.have_timecodes = info->animation.have_timecodes; + } + std::string level_message; + int required_level = VerifyLevelSettings(enc, &level_message); + if (required_level == -1 || + (static_cast(enc->codestream_level) < required_level && + enc->codestream_level != -1)) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level " + + std::to_string(enc->codestream_level) + " failed: " + level_message) + .c_str()); + } + return JXL_ENC_SUCCESS; +} + +void JxlEncoderInitExtraChannelInfo(JxlExtraChannelType type, + JxlExtraChannelInfo* info) { + info->type = type; + info->bits_per_sample = 8; + info->exponent_bits_per_sample = 0; + info->dim_shift = 0; + info->name_length = 0; + info->alpha_premultiplied = JXL_FALSE; + info->spot_color[0] = 0; + info->spot_color[1] = 0; + info->spot_color[2] = 0; + info->spot_color[3] = 0; + info->cfa_channel = 0; +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelInfo( + JxlEncoder* enc, size_t index, const JxlExtraChannelInfo* info) { + if (index >= enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + if (JXL_ENC_SUCCESS != CheckValidBitdepth(info->bits_per_sample, + info->exponent_bits_per_sample)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid bit depth"); + } + + jxl::ExtraChannelInfo& channel = enc->metadata.m.extra_channel_info[index]; + channel.type = static_cast(info->type); + channel.bit_depth.bits_per_sample = info->bits_per_sample; + enc->metadata.m.modular_16_bit_buffer_sufficient &= + info->bits_per_sample <= 12; + channel.bit_depth.exponent_bits_per_sample = info->exponent_bits_per_sample; + channel.bit_depth.floating_point_sample = info->exponent_bits_per_sample != 0; + channel.dim_shift = info->dim_shift; + channel.name = ""; + channel.alpha_associated = (info->alpha_premultiplied != 0); + channel.cfa_channel = info->cfa_channel; + channel.spot_color[0] = info->spot_color[0]; + channel.spot_color[1] = info->spot_color[1]; + channel.spot_color[2] = info->spot_color[2]; + channel.spot_color[3] = info->spot_color[3]; + std::string level_message; + int required_level = VerifyLevelSettings(enc, &level_message); + if (required_level == -1 || + (static_cast(enc->codestream_level) < required_level && + enc->codestream_level != -1)) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level " + + std::to_string(enc->codestream_level) + " failed: " + level_message) + .c_str()); + } + return JXL_ENC_SUCCESS; +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelName(JxlEncoder* enc, + size_t index, + const char* name, + size_t size) { + if (index >= enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + enc->metadata.m.extra_channel_info[index].name = + std::string(name, name + size); + return JXL_ENC_SUCCESS; +} + +JxlEncoderFrameSettings* JxlEncoderFrameSettingsCreate( + JxlEncoder* enc, const JxlEncoderFrameSettings* source) { + auto opts = jxl::MemoryManagerMakeUnique( + &enc->memory_manager); + if (!opts) return nullptr; + opts->enc = enc; + if (source != nullptr) { + opts->values = source->values; + } else { + opts->values.lossless = false; + } + opts->values.cparams.level = enc->codestream_level; + JxlEncoderFrameSettings* ret = opts.get(); + enc->encoder_options.emplace_back(std::move(opts)); + return ret; +} + +JxlEncoderFrameSettings* JxlEncoderOptionsCreate( + JxlEncoder* enc, const JxlEncoderFrameSettings* source) { + // Deprecated function name, call the non-deprecated function + return JxlEncoderFrameSettingsCreate(enc, source); +} + +JxlEncoderStatus JxlEncoderSetFrameLossless( + JxlEncoderFrameSettings* frame_settings, const JXL_BOOL lossless) { + if (lossless && frame_settings->enc->basic_info_set && + frame_settings->enc->metadata.m.xyb_encoded) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Set use_original_profile=true for lossless encoding"); + } + frame_settings->values.lossless = lossless; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderOptionsSetLossless( + JxlEncoderFrameSettings* frame_settings, JXL_BOOL lossless) { + // Deprecated function name, call the non-deprecated function + return JxlEncoderSetFrameLossless(frame_settings, lossless); +} + +JxlEncoderStatus JxlEncoderOptionsSetEffort( + JxlEncoderFrameSettings* frame_settings, const int effort) { + return JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_EFFORT, effort); +} + +JxlEncoderStatus JxlEncoderSetFrameDistance( + JxlEncoderFrameSettings* frame_settings, float distance) { + if (distance < 0.f || distance > 25.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Distance has to be in [0.0..25.0]"); + } + if (distance > 0.f && distance < 0.01f) { + distance = 0.01f; + } + frame_settings->values.cparams.butteraugli_distance = distance; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderOptionsSetDistance( + JxlEncoderFrameSettings* frame_settings, float distance) { + // Deprecated function name, call the non-deprecated function + return JxlEncoderSetFrameDistance(frame_settings, distance); +} + +JxlEncoderStatus JxlEncoderOptionsSetDecodingSpeed( + JxlEncoderFrameSettings* frame_settings, int tier) { + return JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_DECODING_SPEED, tier); +} + +JxlEncoderStatus JxlEncoderFrameSettingsSetOption( + JxlEncoderFrameSettings* frame_settings, JxlEncoderFrameSettingId option, + int64_t value) { + // check if value is -1, 0 or 1 for Override-type options + switch (option) { + case JXL_ENC_FRAME_SETTING_NOISE: + case JXL_ENC_FRAME_SETTING_DOTS: + case JXL_ENC_FRAME_SETTING_PATCHES: + case JXL_ENC_FRAME_SETTING_GABORISH: + case JXL_ENC_FRAME_SETTING_MODULAR: + case JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER: + case JXL_ENC_FRAME_SETTING_RESPONSIVE: + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_LOSSY_PALETTE: + case JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL: + if (value < -1 || value > 1) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be -1 (default), 0 (off) or 1 (on)"); + } + break; + default: + break; + } + + switch (option) { + case JXL_ENC_FRAME_SETTING_EFFORT: + if (value < 1 || value > 9) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Encode effort has to be in [1..9]"); + } + frame_settings->values.cparams.speed_tier = + static_cast(10 - value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_BROTLI_EFFORT: + if (value < -1 || value > 11) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Brotli effort has to be in [-1..11]"); + } + // set cparams for brotli use in JPEG frames + frame_settings->values.cparams.brotli_effort = value; + // set enc option for brotli use in brob boxes + frame_settings->enc->brotli_effort = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_DECODING_SPEED: + if (value < 0 || value > 4) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Decoding speed has to be in [0..4]"); + } + frame_settings->values.cparams.decoding_speed_tier = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_RESAMPLING: + if (value != -1 && value != 1 && value != 2 && value != 4 && value != 8) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Resampling factor has to be 1, 2, 4 or 8"); + } + frame_settings->values.cparams.resampling = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING: + // TOOD(lode): the jxl codestream allows choosing a different resampling + // factor for each extra channel, independently per frame. Move this + // option to a JxlEncoderFrameSettings-option that can be set per extra + // channel, so needs its own function rather than + // JxlEncoderFrameSettingsSetOption due to the extra channel index + // argument required. + if (value != -1 && value != 1 && value != 2 && value != 4 && value != 8) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Resampling factor has to be 1, 2, 4 or 8"); + } + frame_settings->values.cparams.ec_resampling = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED: + if (value < 0 || value > 1) { + return JXL_ENC_ERROR; + } + frame_settings->values.cparams.already_downsampled = (value == 1); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_NOISE: + frame_settings->values.cparams.noise = static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_DOTS: + frame_settings->values.cparams.dots = static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_PATCHES: + frame_settings->values.cparams.patches = + static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_EPF: + if (value < -1 || value > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "EPF value has to be in [-1..3]"); + } + frame_settings->values.cparams.epf = static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_GABORISH: + frame_settings->values.cparams.gaborish = + static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_MODULAR: + frame_settings->values.cparams.modular_mode = (value == 1); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE: + frame_settings->values.cparams.keep_invisible = + static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_GROUP_ORDER: + frame_settings->values.cparams.centerfirst = (value == 1); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X: + if (value < -1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Center x coordinate has to be -1 or positive"); + } + frame_settings->values.cparams.center_x = static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y: + if (value < -1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Center y coordinate has to be -1 or positive"); + } + frame_settings->values.cparams.center_y = static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_RESPONSIVE: + frame_settings->values.cparams.responsive = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC: + frame_settings->values.cparams.progressive_mode = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC: + frame_settings->values.cparams.qprogressive_mode = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC: + if (value < -1 || value > 2) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Progressive DC has to be in [-1..2]"); + } + frame_settings->values.cparams.progressive_dc = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_PALETTE_COLORS: + if (value < -1 || value > 70913) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..70913]"); + } + if (value == -1) { + frame_settings->values.cparams.palette_colors = 1 << 10; + } else { + frame_settings->values.cparams.palette_colors = value; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_LOSSY_PALETTE: + // TODO(lode): the defaults of some palette settings depend on others. + // See the logic in cjxl. Similar for other settings. This should be + // handled in the encoder during JxlEncoderProcessOutput (or, + // alternatively, in the cjxl binary like now) + frame_settings->values.cparams.lossy_palette = (value == 1); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM: + if (value < -1 || value > 2) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..2]"); + } + if (value == -1) { + frame_settings->values.cparams.color_transform = + jxl::ColorTransform::kXYB; + } else { + frame_settings->values.cparams.color_transform = + static_cast(value); + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE: + if (value < -1 || value > 41) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..41]"); + } + frame_settings->values.cparams.colorspace = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE: + if (value < -1 || value > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..3]"); + } + // TODO(lode): the default behavior of this parameter for cjxl is + // to choose 1 or 2 depending on the situation. This behavior needs to be + // implemented either in the C++ library by allowing to set this to -1, or + // kept in cjxl and set it to 1 or 2 using this API. + if (value == -1) { + frame_settings->values.cparams.modular_group_size_shift = 1; + } else { + frame_settings->values.cparams.modular_group_size_shift = value; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR: + if (value < -1 || value > 15) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..15]"); + } + frame_settings->values.cparams.options.predictor = + static_cast(value); + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS: + // The max allowed value can in theory be higher. However, it depends on + // the effort setting. 11 is the highest safe value that doesn't cause + // tree_samples to be >= 64 in the encoder. The specification may allow + // more than this. With more fine tuning higher values could be allowed. + // For N-channel images, the largest useful value is N-1. + if (value < -1 || value > 11) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..11]"); + } + if (value == -1) { + frame_settings->values.cparams.options.max_properties = 0; + } else { + frame_settings->values.cparams.options.max_properties = value; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL: + if (value == -1) { + frame_settings->values.cparams.force_cfl_jpeg_recompression = true; + } else { + frame_settings->values.cparams.force_cfl_jpeg_recompression = value; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_INDEX_BOX: + frame_settings->values.frame_index_box = true; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_PHOTON_NOISE: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Float option, try setting it with " + "JxlEncoderFrameSettingsSetFloatOption"); + default: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Unknown option"); + } +} + +JxlEncoderStatus JxlEncoderFrameSettingsSetFloatOption( + JxlEncoderFrameSettings* frame_settings, JxlEncoderFrameSettingId option, + float value) { + switch (option) { + case JXL_ENC_FRAME_SETTING_PHOTON_NOISE: + if (value < 0) return JXL_ENC_ERROR; + // TODO(lode): add encoder setting to set the 8 floating point values of + // the noise synthesis parameters per frame for more fine grained control. + frame_settings->values.cparams.photon_noise_iso = value; + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT: + if (value < -1.f || value > 100.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be smaller than 100"); + } + // This value is called "iterations" or "nb_repeats" in cjxl, but is in + // fact a fraction in range 0.0-1.0, with the default value 0.5. + // Convert from floating point percentage to floating point fraction here. + if (value < -.5f) { + // TODO(lode): for this and many other settings (also in + // JxlEncoderFrameSettingsSetOption), avoid duplicating the default + // values here and in enc_params.h and options.h, have one location + // where the defaults are specified. + frame_settings->values.cparams.options.nb_repeats = 0.5f; + } else { + frame_settings->values.cparams.options.nb_repeats = value * 0.01f; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT: + if (value < -1.f || value > 100.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..100]"); + } + if (value < -.5f) { + frame_settings->values.cparams.channel_colors_pre_transform_percent = + 95.0f; + } else { + frame_settings->values.cparams.channel_colors_pre_transform_percent = + value; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT: + if (value < -1.f || value > 100.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..100]"); + } + if (value < -.5f) { + frame_settings->values.cparams.channel_colors_percent = 80.0f; + } else { + frame_settings->values.cparams.channel_colors_percent = value; + } + return JXL_ENC_SUCCESS; + case JXL_ENC_FRAME_SETTING_EFFORT: + case JXL_ENC_FRAME_SETTING_DECODING_SPEED: + case JXL_ENC_FRAME_SETTING_RESAMPLING: + case JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING: + case JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED: + case JXL_ENC_FRAME_SETTING_NOISE: + case JXL_ENC_FRAME_SETTING_DOTS: + case JXL_ENC_FRAME_SETTING_PATCHES: + case JXL_ENC_FRAME_SETTING_EPF: + case JXL_ENC_FRAME_SETTING_GABORISH: + case JXL_ENC_FRAME_SETTING_MODULAR: + case JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y: + case JXL_ENC_FRAME_SETTING_RESPONSIVE: + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC: + case JXL_ENC_FRAME_SETTING_PALETTE_COLORS: + case JXL_ENC_FRAME_SETTING_LOSSY_PALETTE: + case JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM: + case JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE: + case JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE: + case JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR: + case JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS: + case JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL: + case JXL_ENC_FRAME_INDEX_BOX: + case JXL_ENC_FRAME_SETTING_BROTLI_EFFORT: + case JXL_ENC_FRAME_SETTING_FILL_ENUM: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Int option, try setting it with " + "JxlEncoderFrameSettingsSetOption"); + default: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Unknown option"); + } +} +JxlEncoder* JxlEncoderCreate(const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) { + return nullptr; + } + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlEncoder)); + if (!alloc) return nullptr; + JxlEncoder* enc = new (alloc) JxlEncoder(); + enc->memory_manager = local_memory_manager; + // TODO(sboukortt): add an API function to set this. + enc->cms = jxl::GetJxlCms(); + + // Initialize all the field values. + JxlEncoderReset(enc); + + return enc; +} + +void JxlEncoderReset(JxlEncoder* enc) { + enc->thread_pool.reset(); + enc->input_queue.clear(); + enc->num_queued_frames = 0; + enc->num_queued_boxes = 0; + enc->encoder_options.clear(); + enc->output_byte_queue.clear(); + enc->codestream_bytes_written_beginning_of_frame = 0; + enc->codestream_bytes_written_end_of_frame = 0; + enc->wrote_bytes = false; + enc->jxlp_counter = 0; + enc->metadata = jxl::CodecMetadata(); + enc->last_used_cparams = jxl::CompressParams(); + enc->frames_closed = false; + enc->boxes_closed = false; + enc->basic_info_set = false; + enc->color_encoding_set = false; + enc->intensity_target_set = false; + enc->use_container = false; + enc->use_boxes = false; + enc->codestream_level = -1; + JxlEncoderInitBasicInfo(&enc->basic_info); +} + +void JxlEncoderDestroy(JxlEncoder* enc) { + if (enc) { + JxlMemoryManager local_memory_manager = enc->memory_manager; + // Call destructor directly since custom free function is used. + enc->~JxlEncoder(); + jxl::MemoryManagerFree(&local_memory_manager, enc); + } +} + +JxlEncoderError JxlEncoderGetError(JxlEncoder* enc) { return enc->error; } + +JxlEncoderStatus JxlEncoderUseContainer(JxlEncoder* enc, + JXL_BOOL use_container) { + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->use_container = static_cast(use_container); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderStoreJPEGMetadata(JxlEncoder* enc, + JXL_BOOL store_jpeg_metadata) { + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->store_jpeg_metadata = static_cast(store_jpeg_metadata); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetCodestreamLevel(JxlEncoder* enc, int level) { + if (level != -1 && level != 5 && level != 10) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_NOT_SUPPORTED, "invalid level"); + } + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->codestream_level = level; + return JXL_ENC_SUCCESS; +} + +int JxlEncoderGetRequiredCodestreamLevel(const JxlEncoder* enc) { + return VerifyLevelSettings(enc, nullptr); +} + +void JxlEncoderSetCms(JxlEncoder* enc, JxlCmsInterface cms) { + jxl::msan::MemoryIsInitialized(&cms, sizeof(cms)); + enc->cms = cms; +} + +JxlEncoderStatus JxlEncoderSetParallelRunner(JxlEncoder* enc, + JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + if (enc->thread_pool) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "parallel runner already set"); + } + enc->thread_pool = jxl::MemoryManagerMakeUnique( + &enc->memory_manager, parallel_runner, parallel_runner_opaque); + if (!enc->thread_pool) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_GENERIC, + "error setting parallel runner"); + } + return JXL_ENC_SUCCESS; +} + +namespace { +JxlEncoderStatus GetCurrentDimensions( + const JxlEncoderFrameSettings* frame_settings, size_t& xsize, + size_t& ysize) { + xsize = frame_settings->enc->metadata.xsize(); + ysize = frame_settings->enc->metadata.ysize(); + if (frame_settings->values.header.layer_info.have_crop) { + xsize = frame_settings->values.header.layer_info.xsize; + ysize = frame_settings->values.header.layer_info.ysize; + } + if (frame_settings->values.cparams.already_downsampled) { + size_t factor = frame_settings->values.cparams.resampling; + xsize = jxl::DivCeil(xsize, factor); + ysize = jxl::DivCeil(ysize, factor); + } + if (xsize == 0 || ysize == 0) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "zero-sized frame is not allowed"); + } + return JXL_ENC_SUCCESS; +} +} // namespace + +JxlEncoderStatus JxlEncoderAddJPEGFrame( + const JxlEncoderFrameSettings* frame_settings, const uint8_t* buffer, + size_t size) { + if (frame_settings->enc->frames_closed) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Frame input is already closed"); + } + + jxl::CodecInOut io; + if (!jxl::jpeg::DecodeImageJPG(jxl::Span(buffer, size), &io)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_BAD_INPUT, + "Error during decode of input JPEG"); + } + + if (!frame_settings->enc->color_encoding_set) { + if (!SetColorEncodingFromJpegData( + *io.Main().jpeg_data, + &frame_settings->enc->metadata.m.color_encoding)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_BAD_INPUT, + "Error in input JPEG color space"); + } + } + + if (!frame_settings->enc->basic_info_set) { + JxlBasicInfo basic_info; + JxlEncoderInitBasicInfo(&basic_info); + basic_info.xsize = io.Main().jpeg_data->width; + basic_info.ysize = io.Main().jpeg_data->height; + basic_info.uses_original_profile = true; + if (JxlEncoderSetBasicInfo(frame_settings->enc, &basic_info) != + JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "Error setting basic info"); + } + } + + if (frame_settings->enc->metadata.m.xyb_encoded) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Can't XYB encode a lossless JPEG"); + } + if (!io.blobs.exif.empty()) { + JxlOrientation orientation = static_cast( + frame_settings->enc->metadata.m.orientation); + jxl::InterpretExif(io.blobs.exif, &orientation); + frame_settings->enc->metadata.m.orientation = orientation; + + size_t exif_size = io.blobs.exif.size(); + // Exif data in JPEG is limited to 64k + if (exif_size > 0xFFFF) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "Exif larger than possible in JPEG?"); + } + exif_size += 4; // prefix 4 zero bytes for tiff offset + std::vector exif(exif_size); + memcpy(exif.data() + 4, io.blobs.exif.data(), io.blobs.exif.size()); + JxlEncoderUseBoxes(frame_settings->enc); + JxlEncoderAddBox(frame_settings->enc, "Exif", exif.data(), exif_size, + /*compress_box=*/JXL_TRUE); + } + if (!io.blobs.xmp.empty()) { + JxlEncoderUseBoxes(frame_settings->enc); + JxlEncoderAddBox(frame_settings->enc, "xml ", io.blobs.xmp.data(), + io.blobs.xmp.size(), /*compress_box=*/JXL_TRUE); + } + if (!io.blobs.jumbf.empty()) { + JxlEncoderUseBoxes(frame_settings->enc); + JxlEncoderAddBox(frame_settings->enc, "jumb", io.blobs.jumbf.data(), + io.blobs.jumbf.size(), /*compress_box=*/JXL_TRUE); + } + if (frame_settings->enc->store_jpeg_metadata) { + jxl::jpeg::JPEGData data_in = *io.Main().jpeg_data; + jxl::PaddedBytes jpeg_data; + if (!jxl::jpeg::EncodeJPEGData(data_in, &jpeg_data, + frame_settings->values.cparams)) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_JBRD, + "JPEG bitstream reconstruction data cannot be encoded"); + } + frame_settings->enc->jpeg_metadata = std::vector( + jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + } + + auto queued_frame = jxl::MemoryManagerMakeUnique( + &frame_settings->enc->memory_manager, + // JxlEncoderQueuedFrame is a struct with no constructors, so we use the + // default move constructor there. + jxl::JxlEncoderQueuedFrame{ + frame_settings->values, + jxl::ImageBundle(&frame_settings->enc->metadata.m), + {}}); + if (!queued_frame) { + // TODO(jon): when can this happen? is this an API usage error? + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "No frame queued?"); + } + queued_frame->frame.SetFromImage(std::move(*io.Main().color()), + io.Main().c_current()); + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "bad dimensions"); + } + if (xsize != static_cast(io.Main().jpeg_data->width) || + ysize != static_cast(io.Main().jpeg_data->height)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "JPEG dimensions don't match frame dimensions"); + } + std::vector extra_channels( + frame_settings->enc->metadata.m.num_extra_channels); + for (auto& extra_channel : extra_channels) { + extra_channel = jxl::ImageF(xsize, ysize); + queued_frame->ec_initialized.push_back(0); + } + queued_frame->frame.SetExtraChannels(std::move(extra_channels)); + queued_frame->frame.jpeg_data = std::move(io.Main().jpeg_data); + queued_frame->frame.color_transform = io.Main().color_transform; + queued_frame->frame.chroma_subsampling = io.Main().chroma_subsampling; + + QueueFrame(frame_settings, queued_frame); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderAddImageFrame( + const JxlEncoderFrameSettings* frame_settings, + const JxlPixelFormat* pixel_format, const void* buffer, size_t size) { + if (!frame_settings->enc->basic_info_set || + (!frame_settings->enc->color_encoding_set && + !frame_settings->enc->metadata.m.xyb_encoded)) { + // Basic Info must be set, and color encoding must be set directly, + // or set to XYB via JxlBasicInfo.uses_original_profile = JXL_FALSE + // Otherwise, this is an API misuse. + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Basic info or color encoding not set yet"); + } + + if (frame_settings->enc->frames_closed) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Frame input already closed"); + } + if (pixel_format->num_channels < 3) { + if (frame_settings->enc->basic_info.num_color_channels != 1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Grayscale pixel format input for an RGB image"); + } + } else { + if (frame_settings->enc->basic_info.num_color_channels != 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "RGB pixel format input for a grayscale image"); + } + } + + auto queued_frame = jxl::MemoryManagerMakeUnique( + &frame_settings->enc->memory_manager, + // JxlEncoderQueuedFrame is a struct with no constructors, so we use the + // default move constructor there. + jxl::JxlEncoderQueuedFrame{ + frame_settings->values, + jxl::ImageBundle(&frame_settings->enc->metadata.m), + {}}); + + if (!queued_frame) { + // TODO(jon): when can this happen? is this an API usage error? + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "No frame queued?"); + } + + jxl::ColorEncoding c_current; + if (!frame_settings->enc->color_encoding_set) { + if ((pixel_format->data_type == JXL_TYPE_FLOAT) || + (pixel_format->data_type == JXL_TYPE_FLOAT16)) { + c_current = + jxl::ColorEncoding::LinearSRGB(pixel_format->num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(pixel_format->num_channels < 3); + } + } else { + c_current = frame_settings->enc->metadata.m.color_encoding; + } + uint32_t num_channels = pixel_format->num_channels; + size_t has_interleaved_alpha = + static_cast(num_channels == 2 || num_channels == 4); + if (has_interleaved_alpha > + frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "number of extra channels mismatch (need 1 extra channel for alpha)"); + } + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "bad dimensions"); + } + std::vector extra_channels( + frame_settings->enc->metadata.m.num_extra_channels); + for (auto& extra_channel : extra_channels) { + extra_channel = jxl::ImageF(xsize, ysize); + } + queued_frame->frame.SetExtraChannels(std::move(extra_channels)); + for (auto& ec_info : frame_settings->enc->metadata.m.extra_channel_info) { + if (has_interleaved_alpha && ec_info.type == jxl::ExtraChannel::kAlpha) { + queued_frame->ec_initialized.push_back(1); + has_interleaved_alpha = 0; // only first Alpha is initialized + } else { + queued_frame->ec_initialized.push_back(0); + } + } + queued_frame->frame.origin.x0 = + frame_settings->values.header.layer_info.crop_x0; + queued_frame->frame.origin.y0 = + frame_settings->values.header.layer_info.crop_y0; + queued_frame->frame.use_for_next_frame = + (frame_settings->values.header.layer_info.save_as_reference != 0u); + queued_frame->frame.blendmode = + frame_settings->values.header.layer_info.blend_info.blendmode == + JXL_BLEND_REPLACE + ? jxl::BlendMode::kReplace + : jxl::BlendMode::kBlend; + queued_frame->frame.blend = + frame_settings->values.header.layer_info.blend_info.source > 0; + + if (!jxl::BufferToImageBundle(*pixel_format, xsize, ysize, buffer, size, + frame_settings->enc->thread_pool.get(), + c_current, &(queued_frame->frame))) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid input buffer"); + } + if (frame_settings->values.lossless && + frame_settings->enc->metadata.m.xyb_encoded) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Set use_original_profile=true for lossless encoding"); + } + queued_frame->option_values.cparams.level = + frame_settings->enc->codestream_level; + + QueueFrame(frame_settings, queued_frame); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderUseBoxes(JxlEncoder* enc) { + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->use_boxes = true; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderAddBox(JxlEncoder* enc, const JxlBoxType type, + const uint8_t* contents, size_t size, + JXL_BOOL compress_box) { + if (!enc->use_boxes) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "must set JxlEncoderUseBoxes at the beginning to add boxes"); + } + if (compress_box) { + if (memcmp("jxl", type, 3) == 0) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "brob box may not contain a type starting with \"jxl\""); + } + if (memcmp("jbrd", type, 4) == 0) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "jbrd box may not be brob compressed"); + } + if (memcmp("brob", type, 4) == 0) { + // The compress_box will compress an existing non-brob box into a brob + // box. If already giving a valid brotli-compressed brob box, set + // compress_box to false since it is already compressed. + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "a brob box cannot contain another brob box"); + } + } + + auto box = jxl::MemoryManagerMakeUnique( + &enc->memory_manager); + + box->type = jxl::MakeBoxType(type); + box->contents.assign(contents, contents + size); + box->compress_box = !!compress_box; + QueueBox(enc, box); + return JXL_ENC_SUCCESS; +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelBuffer( + const JxlEncoderOptions* frame_settings, const JxlPixelFormat* pixel_format, + const void* buffer, size_t size, uint32_t index) { + if (index >= frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + if (!frame_settings->enc->basic_info_set || + !frame_settings->enc->color_encoding_set) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Basic info has to be set first"); + } + if (frame_settings->enc->input_queue.empty()) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "First add image frame, then extra channels"); + } + if (frame_settings->enc->frames_closed) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Frame input already closed"); + } + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "bad dimensions"); + } + if (!jxl::BufferToImageF(*pixel_format, xsize, ysize, buffer, size, + frame_settings->enc->thread_pool.get(), + &frame_settings->enc->input_queue.back() + .frame->frame.extra_channels()[index])) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Failed to set buffer for extra channel"); + } + frame_settings->enc->input_queue.back().frame->ec_initialized[index] = 1; + + return JXL_ENC_SUCCESS; +} + +void JxlEncoderCloseFrames(JxlEncoder* enc) { enc->frames_closed = true; } + +void JxlEncoderCloseBoxes(JxlEncoder* enc) { enc->boxes_closed = true; } + +void JxlEncoderCloseInput(JxlEncoder* enc) { + JxlEncoderCloseFrames(enc); + JxlEncoderCloseBoxes(enc); +} +JxlEncoderStatus JxlEncoderProcessOutput(JxlEncoder* enc, uint8_t** next_out, + size_t* avail_out) { + while (*avail_out > 0 && + (!enc->output_byte_queue.empty() || !enc->input_queue.empty())) { + if (!enc->output_byte_queue.empty()) { + size_t to_copy = std::min(*avail_out, enc->output_byte_queue.size()); + std::copy_n(enc->output_byte_queue.begin(), to_copy, *next_out); + *next_out += to_copy; + *avail_out -= to_copy; + enc->output_byte_queue.erase(enc->output_byte_queue.begin(), + enc->output_byte_queue.begin() + to_copy); + } else if (!enc->input_queue.empty()) { + if (enc->RefillOutputByteQueue() != JXL_ENC_SUCCESS) { + return JXL_ENC_ERROR; + } + } + } + + if (!enc->output_byte_queue.empty() || !enc->input_queue.empty()) { + return JXL_ENC_NEED_MORE_OUTPUT; + } + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetFrameHeader(JxlEncoderOptions* frame_settings, + const JxlFrameHeader* frame_header) { + if (frame_header->layer_info.blend_info.source > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "invalid blending source index"); + } + // If there are no extra channels, it's ok for the value to be 0. + if (frame_header->layer_info.blend_info.alpha != 0 && + frame_header->layer_info.blend_info.alpha >= + frame_settings->enc->metadata.m.extra_channel_info.size()) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "alpha blend channel index out of bounds"); + } + + frame_settings->values.header = *frame_header; + // Setting the frame header resets the frame name, it must be set again with + // JxlEncoderSetFrameName if desired. + frame_settings->values.frame_name = ""; + + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetExtraChannelBlendInfo( + JxlEncoderOptions* frame_settings, size_t index, + const JxlBlendInfo* blend_info) { + if (index >= frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + + if (frame_settings->values.extra_channel_blend_info.size() != + frame_settings->enc->metadata.m.num_extra_channels) { + JxlBlendInfo default_blend_info; + JxlEncoderInitBlendInfo(&default_blend_info); + frame_settings->values.extra_channel_blend_info.resize( + frame_settings->enc->metadata.m.num_extra_channels, default_blend_info); + } + frame_settings->values.extra_channel_blend_info[index] = *blend_info; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetFrameName(JxlEncoderFrameSettings* frame_settings, + const char* frame_name) { + std::string str = frame_name ? frame_name : ""; + if (str.size() > 1071) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "frame name can be max 1071 bytes long"); + } + frame_settings->values.frame_name = str; + frame_settings->values.header.name_length = str.size(); + return JXL_ENC_SUCCESS; +} + +void JxlColorEncodingSetToSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray) { + ConvertInternalToExternalColorEncoding(jxl::ColorEncoding::SRGB(is_gray), + color_encoding); +} + +void JxlColorEncodingSetToLinearSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray) { + ConvertInternalToExternalColorEncoding( + jxl::ColorEncoding::LinearSRGB(is_gray), color_encoding); +} diff --git a/media/libjxl/src/lib/jxl/encode_internal.h b/media/libjxl/src/lib/jxl/encode_internal.h new file mode 100644 index 000000000..9f8254641 --- /dev/null +++ b/media/libjxl/src/lib/jxl/encode_internal.h @@ -0,0 +1,263 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +#ifndef LIB_JXL_ENCODE_INTERNAL_H_ +#define LIB_JXL_ENCODE_INTERNAL_H_ + +#include +#include + +#include "jxl/encode.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" +#include "jxl/types.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/memory_manager_internal.h" + +namespace jxl { + +/* Frame index box 'jxli' will start with Varint() for +NF: has type Varint(): number of frames listed in the index. +TNUM: has type u32: numerator of tick unit. +TDEN: has type u32: denominator of tick unit. Value 0 means the file is +ill-formed. per frame i listed: OFFi: has type Varint(): offset of start byte of +this frame compared to start byte of previous frame from this index in the JPEG +XL codestream. For the first frame, this is the offset from the first byte of +the JPEG XL codestream. Ti: has type Varint(): duration in ticks between the +start of this frame and the start of the next frame in the index. If this is the +last frame in the index, this is the duration in ticks between the start of this +frame and the end of the stream. A tick lasts TNUM / TDEN seconds. Fi: has type +Varint(): amount of frames the next frame in the index occurs after this frame. +If this is the last frame in the index, this is the amount of frames after this +frame in the remainder of the stream. Only frames that are presented by the +decoder are counted for this purpose, this excludes frames that are not intended +for display but for compositing with other frames, such as frames that aren't +the last frame with a duration of 0 ticks. + +All the frames listed in jxli are keyframes and the first frame is +present in the list. +There shall be either zero or one Frame Index boxes in a JPEG XL file. +The offsets OFFi per frame are given as bytes in the codestream, not as +bytes in the file format using the box structure. This means if JPEG XL Partial +Codestream boxes are used, the offset is counted within the concatenated +codestream, bytes from box headers or non-codestream boxes are not counted. +*/ + +typedef struct JxlEncoderFrameIndexBoxEntryStruct { + bool to_be_indexed; + uint32_t duration; + uint64_t OFFi; +} JxlEncoderFrameIndexBoxEntry; + +typedef struct JxlEncoderFrameIndexBoxStruct { + // We always need to record the first frame entry, so presence of the + // first entry alone is not an indication if it was requested to be + // stored. + bool index_box_requested_through_api = false; + + int64_t NF() const { return entries.size(); } + bool StoreFrameIndexBox() { + for (auto e : entries) { + if (e.to_be_indexed) { + return true; + } + } + return false; + } + int32_t TNUM = 1; + int32_t TDEN = 1000; + + std::vector entries; + + // That way we can ensure that every index box will have the first frame. + // If the API user decides to mark it as an indexed frame, we call + // the AddFrame again, this time with requested. + void AddFrame(uint64_t OFFi, uint32_t duration, bool to_be_indexed) { + // We call AddFrame to every frame. + // Recording the first frame is required by the standard. + // Knowing the last frame is required, since the last indexed frame + // needs to know how many frames until the end. + // To be able to tell how many frames there are between each index + // entry we just record every frame here. + if (entries.size() == 1) { + if (OFFi == entries[0].OFFi) { + // API use for the first frame, let's clear the already recorded first + // frame. + entries.clear(); + } + } + JxlEncoderFrameIndexBoxEntry e; + e.to_be_indexed = to_be_indexed; + e.OFFi = OFFi; + e.duration = duration; + entries.push_back(e); + } +} JxlEncoderFrameIndexBox; + +// The encoder options (such as quality, compression speed, ...) for a single +// frame, but not encoder-wide options such as box-related options. +typedef struct JxlEncoderFrameSettingsValuesStruct { + // lossless is a separate setting from cparams because it is a combination + // setting that overrides multiple settings inside of cparams. + bool lossless; + CompressParams cparams; + JxlFrameHeader header; + std::vector extra_channel_blend_info; + std::string frame_name; + bool frame_index_box = false; +} JxlEncoderFrameSettingsValues; + +typedef std::array BoxType; + +// Utility function that makes a BoxType from a string literal. The string must +// have 4 characters, a 5th null termination character is optional. +constexpr BoxType MakeBoxType(const char* type) { + return BoxType( + {{static_cast(type[0]), static_cast(type[1]), + static_cast(type[2]), static_cast(type[3])}}); +} + +constexpr unsigned char kContainerHeader[] = { + 0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xd, 0xa, 0x87, + 0xa, 0, 0, 0, 0x14, 'f', 't', 'y', 'p', 'j', 'x', + 'l', ' ', 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + +constexpr unsigned char kLevelBoxHeader[] = {0, 0, 0, 0x9, 'j', 'x', 'l', 'l'}; + +struct JxlEncoderQueuedFrame { + JxlEncoderFrameSettingsValues option_values; + ImageBundle frame; + std::vector ec_initialized; +}; + +struct JxlEncoderQueuedBox { + BoxType type; + std::vector contents; + bool compress_box; +}; + +// Either a frame, or a box, not both. +struct JxlEncoderQueuedInput { + explicit JxlEncoderQueuedInput(const JxlMemoryManager& memory_manager) + : frame(nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)), + box(nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)) {} + MemoryManagerUniquePtr frame; + MemoryManagerUniquePtr box; +}; + +// Appends a JXL container box header with given type, size, and unbounded +// properties to output. +template +void AppendBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded, + T* output) { + uint64_t box_size = 0; + bool large_size = false; + if (!unbounded) { + box_size = size + 8; + if (box_size >= 0x100000000ull) { + large_size = true; + } + } + + { + const uint64_t store = large_size ? 1 : box_size; + for (size_t i = 0; i < 4; i++) { + output->push_back(store >> (8 * (3 - i)) & 0xff); + } + } + for (size_t i = 0; i < 4; i++) { + output->push_back(type[i]); + } + + if (large_size) { + for (size_t i = 0; i < 8; i++) { + output->push_back(box_size >> (8 * (7 - i)) & 0xff); + } + } +} + +} // namespace jxl + +// Internal use only struct, can only be initialized correctly by +// JxlEncoderCreate. +struct JxlEncoderStruct { + JxlEncoderError error = JxlEncoderError::JXL_ENC_ERR_OK; + JxlMemoryManager memory_manager; + jxl::MemoryManagerUniquePtr thread_pool{ + nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)}; + JxlCmsInterface cms; + std::vector> + encoder_options; + + size_t num_queued_frames; + size_t num_queued_boxes; + std::vector input_queue; + std::deque output_byte_queue; + + // How many codestream bytes have been written, i.e., + // content of jxlc and jxlp boxes. Frame index box jxli + // requires position indices to point to codestream bytes, + // so we need to keep track of the total of flushed or queue + // codestream bytes. These bytes may be in a single jxlc box + // or accross multiple jxlp boxes. + size_t codestream_bytes_written_beginning_of_frame; + size_t codestream_bytes_written_end_of_frame; + jxl::JxlEncoderFrameIndexBox frame_index_box; + + // Force using the container even if not needed + bool use_container; + // User declared they will add metadata boxes + bool use_boxes; + + // TODO(lode): move level into jxl::CompressParams since some C++ + // implementation decisions should be based on it: level 10 allows more + // features to be used. + int32_t codestream_level; + bool store_jpeg_metadata; + jxl::CodecMetadata metadata; + std::vector jpeg_metadata; + + // Wrote any output at all, so wrote the data before the first user added + // frame or box, such as signature, basic info, ICC profile or jpeg + // reconstruction box. + bool wrote_bytes; + jxl::CompressParams last_used_cparams; + JxlBasicInfo basic_info; + + // Encoder wrote a jxlp (partial codestream) box, so any next codestream + // parts must also be written in jxlp boxes, a single jxlc box cannot be + // used. The counter is used for the 4-byte jxlp box index header. + size_t jxlp_counter; + + bool frames_closed; + bool boxes_closed; + bool basic_info_set; + bool color_encoding_set; + bool intensity_target_set; + int brotli_effort = -1; + + // Takes the first frame in the input_queue, encodes it, and appends + // the bytes to the output_byte_queue. + JxlEncoderStatus RefillOutputByteQueue(); + + bool MustUseContainer() const { + return use_container || codestream_level != 5 || store_jpeg_metadata || + use_boxes; + } + + // Appends the bytes of a JXL box header with the provided type and size to + // the end of the output_byte_queue. If unbounded is true, the size won't be + // added to the header and the box will be assumed to continue until EOF. + void AppendBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded); +}; + +struct JxlEncoderFrameSettingsStruct { + JxlEncoder* enc; + jxl::JxlEncoderFrameSettingsValues values; +}; + +#endif // LIB_JXL_ENCODE_INTERNAL_H_ diff --git a/media/libjxl/src/lib/jxl/encode_test.cc b/media/libjxl/src/lib/jxl/encode_test.cc new file mode 100644 index 000000000..4f1ef0b2b --- /dev/null +++ b/media/libjxl/src/lib/jxl/encode_test.cc @@ -0,0 +1,1359 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/encode.h" + +#include "enc_color_management.h" +#include "gtest/gtest.h" +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/encode_cxx.h" +#include "lib/extras/codec.h" +#include "lib/extras/dec/jxl.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +TEST(EncodeTest, AddFrameAfterCloseInputTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderCloseInput(enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); +} + +TEST(EncodeTest, AddJPEGAfterCloseTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderCloseInput(enc.get()); + + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); +} + +TEST(EncodeTest, AddFrameBeforeColorEncodingTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); +} + +TEST(EncodeTest, AddFrameBeforeBasicInfoTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); +} + +TEST(EncodeTest, DefaultAllocTest) { + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + JxlEncoderDestroy(enc); +} + +TEST(EncodeTest, CustomAllocTest) { + struct CalledCounters { + int allocs = 0; + int frees = 0; + } counters; + + JxlMemoryManager mm; + mm.opaque = &counters; + mm.alloc = [](void* opaque, size_t size) { + reinterpret_cast(opaque)->allocs++; + return malloc(size); + }; + mm.free = [](void* opaque, void* address) { + reinterpret_cast(opaque)->frees++; + free(address); + }; + + { + JxlEncoderPtr enc = JxlEncoderMake(&mm); + EXPECT_NE(nullptr, enc.get()); + EXPECT_LE(1, counters.allocs); + EXPECT_EQ(0, counters.frees); + } + EXPECT_LE(1, counters.frees); +} + +TEST(EncodeTest, DefaultParallelRunnerTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetParallelRunner(enc.get(), nullptr, nullptr)); +} + +void VerifyFrameEncoding(size_t xsize, size_t ysize, JxlEncoder* enc, + const JxlEncoderFrameSettings* frame_settings, + size_t max_compressed_size, + bool lossy_use_original_profile) { + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + if (frame_settings->values.lossless || lossy_use_original_profile) { + basic_info.uses_original_profile = true; + } else { + basic_info.uses_original_profile = false; + } + // 16-bit alpha means this requires level 10 + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, true); + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlColorEncodingSetToSRGB(&color_encoding, false); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + pixel_format.num_channels = 1; + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + pixel_format.num_channels = 4; + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_LE(compressed.size(), max_compressed_size); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + jxl::CodecInOut decoded_io; + EXPECT_TRUE(jxl::test::DecodeFile( + {}, jxl::Span(compressed.data(), compressed.size()), + &decoded_io, /*pool=*/nullptr)); + + EXPECT_LE( + ComputeDistance2(input_io.Main(), decoded_io.Main(), jxl::GetJxlCms()), +#if JXL_HIGH_PRECISION + 1.8); +#else + 8.0); +#endif +} + +void VerifyFrameEncoding(JxlEncoder* enc, + const JxlEncoderFrameSettings* frame_settings) { + VerifyFrameEncoding(63, 129, enc, frame_settings, 2600, + /*lossy_use_original_profile=*/false); +} + +TEST(EncodeTest, FrameEncodingTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + VerifyFrameEncoding(enc.get(), + JxlEncoderFrameSettingsCreate(enc.get(), nullptr)); +} + +TEST(EncodeTest, EncoderResetTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + VerifyFrameEncoding(50, 200, enc.get(), + JxlEncoderFrameSettingsCreate(enc.get(), nullptr), 4300, + false); + // Encoder should become reusable for a new image from scratch after using + // reset. + JxlEncoderReset(enc.get()); + VerifyFrameEncoding(157, 77, enc.get(), + JxlEncoderFrameSettingsCreate(enc.get(), nullptr), 2300, + false); +} + +TEST(EncodeTest, CmsTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + bool cms_called = false; + JxlCmsInterface cms = jxl::GetJxlCms(); + struct InitData { + void* original_init_data; + jpegxl_cms_init_func original_init; + bool* cms_called; + }; + InitData init_data = {/*original_init_data=*/cms.init_data, + /*original_init=*/cms.init, + /*cms_called=*/&cms_called}; + cms.init_data = &init_data; + cms.init = +[](void* raw_init_data, size_t num_threads, + size_t pixels_per_thread, const JxlColorProfile* input_profile, + const JxlColorProfile* output_profile, + float intensity_target) { + const InitData* init_data = static_cast(raw_init_data); + *init_data->cms_called = true; + return init_data->original_init(init_data->original_init_data, num_threads, + pixels_per_thread, input_profile, + output_profile, intensity_target); + }; + JxlEncoderSetCms(enc.get(), cms); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), nullptr); + JxlEncoderSetFrameLossless(frame_settings, false); + ASSERT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_EFFORT, 8)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_TRUE(cms_called); +} + +TEST(EncodeTest, frame_settingsTest) { + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 5)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(jxl::SpeedTier::kHare, enc->last_used_cparams.speed_tier); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + // Lower than currently supported values + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 0)); + // Higher than currently supported values + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 10)); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetFrameLossless(frame_settings, JXL_TRUE)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 3600, false); + EXPECT_EQ(true, enc->last_used_cparams.IsLossless()); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetFrameDistance(frame_settings, 0.5)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 3000, false); + EXPECT_EQ(0.5, enc->last_used_cparams.butteraugli_distance); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + // Disallowed negative distance + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetFrameDistance(frame_settings, -1)); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_DECODING_SPEED, 2)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(2u, enc->last_used_cparams.decoding_speed_tier); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_GROUP_ORDER, 100)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_GROUP_ORDER, 1)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X, 5)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(true, enc->last_used_cparams.centerfirst); + EXPECT_EQ(5, enc->last_used_cparams.center_x); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_RESPONSIVE, 0)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC, 1)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC, -1)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, 2)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(false, enc->last_used_cparams.responsive); + EXPECT_EQ(true, enc->last_used_cparams.progressive_mode); + EXPECT_EQ(2, enc->last_used_cparams.progressive_dc); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, JXL_ENC_FRAME_SETTING_PHOTON_NOISE, 1777.777)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_NEAR(1777.777f, enc->last_used_cparams.photon_noise_iso, 1E-6); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT, 55.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, 25.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PALETTE_COLORS, 70000)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_LOSSY_PALETTE, 1)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_NEAR(55.0f, + enc->last_used_cparams.channel_colors_pre_transform_percent, + 1E-6); + EXPECT_NEAR(25.0f, enc->last_used_cparams.channel_colors_percent, 1E-6); + EXPECT_EQ(70000, enc->last_used_cparams.palette_colors); + EXPECT_EQ(true, enc->last_used_cparams.lossy_palette); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE, 30)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE, 2)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, 14)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT, 77.0f)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS, 7)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(30, enc->last_used_cparams.colorspace); + EXPECT_EQ(2, enc->last_used_cparams.modular_group_size_shift); + EXPECT_EQ(jxl::Predictor::Best, enc->last_used_cparams.options.predictor); + EXPECT_NEAR(0.77f, enc->last_used_cparams.options.nb_repeats, 1E-6); + EXPECT_EQ(7, enc->last_used_cparams.options.max_properties); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL, 0)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(false, enc->last_used_cparams.force_cfl_jpeg_recompression); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL, 1)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(true, enc->last_used_cparams.force_cfl_jpeg_recompression); + } +} + +TEST(EncodeTest, LossyEncoderUseOriginalProfileTest) { + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 4100, true); + } + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, 2)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 4500, true); + } + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + ASSERT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 8)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 3700, true); + } +} + +namespace { +// Returns a copy of buf from offset to offset+size, or a new zeroed vector if +// the result would have been out of bounds taking integer overflow into +// account. +const std::vector SliceSpan(const jxl::Span& buf, + size_t offset, size_t size) { + if (offset + size >= buf.size()) { + return std::vector(size, 0); + } + if (offset + size < offset) { + return std::vector(size, 0); + } + return std::vector(buf.data() + offset, buf.data() + offset + size); +} + +struct Box { + // The type of the box. + // If "uuid", use extended_type instead + char type[4] = {0, 0, 0, 0}; + + // The extended_type is only used when type == "uuid". + // Extended types are not used in JXL. However, the box format itself + // supports this so they are handled correctly. + char extended_type[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Box data. + jxl::Span data = jxl::Span(nullptr, 0); + + // If the size is not given, the datasize extends to the end of the file. + // If this field is false, the size field is not encoded when the box is + // serialized. + bool data_size_given = true; + + // If successful, returns true and sets `in` to be the rest data (if any). + // If `in` contains a box with a size larger than `in.size()`, will not + // modify `in`, and will return true but the data `Span` will + // remain set to nullptr. + // If unsuccessful, returns error and doesn't modify `in`. + jxl::Status Decode(jxl::Span* in) { + // Total box_size including this header itself. + uint64_t box_size = LoadBE32(SliceSpan(*in, 0, 4).data()); + size_t pos = 4; + + memcpy(type, SliceSpan(*in, pos, 4).data(), 4); + pos += 4; + + if (box_size == 1) { + // If the size is 1, it indicates extended size read from 64-bit integer. + box_size = LoadBE64(SliceSpan(*in, pos, 8).data()); + pos += 8; + } + + if (!memcmp("uuid", type, 4)) { + memcpy(extended_type, SliceSpan(*in, pos, 16).data(), 16); + pos += 16; + } + + // This is the end of the box header, the box data begins here. Handle + // the data size now. + const size_t header_size = pos; + + if (box_size != 0) { + if (box_size < header_size) { + return JXL_FAILURE("Invalid box size"); + } + if (box_size > in->size()) { + // The box is fine, but the input is too short. + return true; + } + data_size_given = true; + data = jxl::Span(in->data() + header_size, + box_size - header_size); + } else { + data_size_given = false; + data = jxl::Span(in->data() + header_size, + in->size() - header_size); + } + + *in = jxl::Span(in->data() + header_size + data.size(), + in->size() - header_size - data.size()); + return true; + } +}; + +struct Container { + std::vector boxes; + + // If successful, returns true and sets `in` to be the rest data (if any). + // If unsuccessful, returns error and doesn't modify `in`. + jxl::Status Decode(jxl::Span* in) { + boxes.clear(); + + Box signature_box; + JXL_RETURN_IF_ERROR(signature_box.Decode(in)); + if (memcmp("JXL ", signature_box.type, 4) != 0) { + return JXL_FAILURE("Invalid magic signature"); + } + if (signature_box.data.size() != 4) + return JXL_FAILURE("Invalid magic signature"); + if (signature_box.data[0] != 0xd || signature_box.data[1] != 0xa || + signature_box.data[2] != 0x87 || signature_box.data[3] != 0xa) { + return JXL_FAILURE("Invalid magic signature"); + } + + Box ftyp_box; + JXL_RETURN_IF_ERROR(ftyp_box.Decode(in)); + if (memcmp("ftyp", ftyp_box.type, 4) != 0) { + return JXL_FAILURE("Invalid ftyp"); + } + if (ftyp_box.data.size() != 12) return JXL_FAILURE("Invalid ftyp"); + const char* expected = "jxl \0\0\0\0jxl "; + if (memcmp(expected, ftyp_box.data.data(), 12) != 0) + return JXL_FAILURE("Invalid ftyp"); + + while (!in->empty()) { + Box box = {}; + JXL_RETURN_IF_ERROR(box.Decode(in)); + if (box.data.data() == nullptr) { + // The decoding encountered a box, but not enough data yet. + return true; + } + boxes.emplace_back(box); + } + + return true; + } +}; + +} // namespace + +TEST(EncodeTest, SingleFrameBoundedJXLCTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), true)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + size_t xsize = 71; + size_t ysize = 23; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + Container container = {}; + jxl::Span encoded_span = + jxl::Span(compressed.data(), compressed.size()); + EXPECT_TRUE(container.Decode(&encoded_span)); + EXPECT_EQ(0u, encoded_span.size()); + bool found_jxlc = false; + bool found_jxlp = false; + // The encoder is allowed to either emit a jxlc or one or more jxlp. + for (size_t i = 0; i < container.boxes.size(); ++i) { + if (memcmp("jxlc", container.boxes[i].type, 4) == 0) { + EXPECT_EQ(false, found_jxlc); // Max 1 jxlc + EXPECT_EQ(false, found_jxlp); // Can't mix jxlc and jxlp + found_jxlc = true; + } + if (memcmp("jxlp", container.boxes[i].type, 4) == 0) { + EXPECT_EQ(false, found_jxlc); // Can't mix jxlc and jxlp + found_jxlp = true; + } + // The encoder shouldn't create an unbounded box in this case, with the + // single frame it knows the full size in time, so can help make decoding + // more efficient by giving the full box size of the final box. + EXPECT_EQ(true, container.boxes[i].data_size_given); + } + EXPECT_EQ(true, found_jxlc || found_jxlp); +} + +TEST(EncodeTest, CodestreamLevelTest) { + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + Container container = {}; + jxl::Span encoded_span = + jxl::Span(compressed.data(), compressed.size()); + EXPECT_TRUE(container.Decode(&encoded_span)); + EXPECT_EQ(0u, encoded_span.size()); + EXPECT_EQ(0, memcmp("jxll", container.boxes[0].type, 4)); +} + +TEST(EncodeTest, CodestreamLevelVerificationTest) { + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT8, JXL_BIG_ENDIAN, 0}; + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = 64; + basic_info.ysize = 64; + basic_info.uses_original_profile = false; + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + + EXPECT_EQ(5, JxlEncoderGetRequiredCodestreamLevel(enc.get())); + + // Set an image dimension that is too large for level 5, but fits in level 10 + + basic_info.xsize = 1ull << 30ull; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 5)); + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + EXPECT_EQ(10, JxlEncoderGetRequiredCodestreamLevel(enc.get())); + + // Set an image dimension that is too large even for level 10 + + basic_info.xsize = 1ull << 31ull; + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); +} + +TEST(EncodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructionTest)) { + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::extras::JXLDecompressParams dparams; + std::vector decoded_jpeg_bytes; + jxl::extras::PackedPixelFile ppf; + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, &ppf, &decoded_jpeg_bytes)); + + EXPECT_EQ(decoded_jpeg_bytes.size(), orig.size()); + EXPECT_EQ(0, memcmp(decoded_jpeg_bytes.data(), orig.data(), orig.size())); +} + +static void ProcessEncoder(JxlEncoder* enc, std::vector& compressed, + uint8_t*& next_out, size_t& avail_out) { + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + size_t offset = next_out - compressed.data(); + compressed.resize(next_out - compressed.data()); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); +} + +TEST(EncodeTest, BasicInfoTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 1; + size_t ysize = 1; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + basic_info.have_animation = true; + basic_info.intensity_target = 123.4; + basic_info.min_nits = 5.0; + basic_info.linear_below = 12.7; + basic_info.orientation = JXL_ORIENT_ROTATE_90_CW; + basic_info.animation.tps_numerator = 55; + basic_info.animation.tps_denominator = 77; + basic_info.animation.num_loops = 10; + basic_info.animation.have_timecodes = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseFrames(enc.get()); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Decode to verify the boxes, we don't decode to pixels, only the boxes. + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BASIC_INFO)); + // Allow testing the orientation field, without this setting it will be + // overridden to identity. + JxlDecoderSetKeepOrientation(dec.get(), JXL_TRUE); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo basic_info2; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetBasicInfo(dec.get(), &basic_info2)); + EXPECT_EQ(basic_info.xsize, basic_info2.xsize); + EXPECT_EQ(basic_info.ysize, basic_info2.ysize); + EXPECT_EQ(basic_info.bits_per_sample, basic_info2.bits_per_sample); + EXPECT_EQ(basic_info.exponent_bits_per_sample, + basic_info2.exponent_bits_per_sample); + EXPECT_NEAR(basic_info.intensity_target, basic_info2.intensity_target, + 0.5); + EXPECT_NEAR(basic_info.min_nits, basic_info2.min_nits, 0.5); + EXPECT_NEAR(basic_info.linear_below, basic_info2.linear_below, 0.5); + EXPECT_EQ(basic_info.relative_to_max_display, + basic_info2.relative_to_max_display); + EXPECT_EQ(basic_info.uses_original_profile, + basic_info2.uses_original_profile); + EXPECT_EQ(basic_info.orientation, basic_info2.orientation); + EXPECT_EQ(basic_info.num_color_channels, basic_info2.num_color_channels); + // TODO(lode): also test num_extra_channels, but currently there may be a + // mismatch between 0 and 1 if there is alpha, until encoder support for + // extra channels is fully implemented. + EXPECT_EQ(basic_info.alpha_bits, basic_info2.alpha_bits); + EXPECT_EQ(basic_info.alpha_exponent_bits, + basic_info2.alpha_exponent_bits); + EXPECT_EQ(basic_info.alpha_premultiplied, + basic_info2.alpha_premultiplied); + + EXPECT_EQ(basic_info.have_preview, basic_info2.have_preview); + if (basic_info.have_preview) { + EXPECT_EQ(basic_info.preview.xsize, basic_info2.preview.xsize); + EXPECT_EQ(basic_info.preview.ysize, basic_info2.preview.ysize); + } + + EXPECT_EQ(basic_info.have_animation, basic_info2.have_animation); + if (basic_info.have_animation) { + EXPECT_EQ(basic_info.animation.tps_numerator, + basic_info2.animation.tps_numerator); + EXPECT_EQ(basic_info.animation.tps_denominator, + basic_info2.animation.tps_denominator); + EXPECT_EQ(basic_info.animation.num_loops, + basic_info2.animation.num_loops); + EXPECT_EQ(basic_info.animation.have_timecodes, + basic_info2.animation.have_timecodes); + } + } else { + FAIL(); // unexpected status + } + } +} + +TEST(EncodeTest, AnimationHeaderTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 1; + size_t ysize = 1; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.have_animation = true; + basic_info.animation.tps_numerator = 1000; + basic_info.animation.tps_denominator = 1; + basic_info.animation.have_timecodes = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + std::string frame_name = "test frame"; + JxlFrameHeader header; + JxlEncoderInitFrameHeader(&header); + header.duration = 50; + header.timecode = 800; + header.layer_info.blend_info.blendmode = JXL_BLEND_BLEND; + header.layer_info.blend_info.source = 2; + header.layer_info.blend_info.clamp = 1; + JxlBlendInfo extra_channel_blend_info; + JxlEncoderInitBlendInfo(&extra_channel_blend_info); + extra_channel_blend_info.blendmode = JXL_BLEND_MULADD; + JxlEncoderSetFrameHeader(frame_settings, &header); + JxlEncoderSetExtraChannelBlendInfo(frame_settings, 0, + &extra_channel_blend_info); + JxlEncoderSetFrameName(frame_settings, frame_name.c_str()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseFrames(enc.get()); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Decode to verify the boxes, we don't decode to pixels, only the boxes. + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + + // To test the blend_info fields, coalescing must be set to false in the + // decoder. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec.get(), JXL_FALSE)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_FRAME)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + bool seen_frame = false; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_FRAME) { + seen_frame = true; + JxlFrameHeader header2; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec.get(), &header2)); + EXPECT_EQ(header.duration, header2.duration); + EXPECT_EQ(header.timecode, header2.timecode); + EXPECT_EQ(header.layer_info.blend_info.blendmode, + header2.layer_info.blend_info.blendmode); + EXPECT_EQ(header.layer_info.blend_info.clamp, + header2.layer_info.blend_info.clamp); + EXPECT_EQ(header.layer_info.blend_info.source, + header2.layer_info.blend_info.source); + EXPECT_EQ(frame_name.size(), header2.name_length); + JxlBlendInfo extra_channel_blend_info2; + JxlDecoderGetExtraChannelBlendInfo(dec.get(), 0, + &extra_channel_blend_info2); + EXPECT_EQ(extra_channel_blend_info.blendmode, + extra_channel_blend_info2.blendmode); + if (header2.name_length > 0) { + std::string frame_name2(header2.name_length + 1, '\0'); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetFrameName(dec.get(), &frame_name2.front(), + frame_name2.size())); + frame_name2.resize(header2.name_length); + EXPECT_EQ(frame_name, frame_name2); + } + } else { + FAIL(); // unexpected status + } + } + + EXPECT_EQ(true, seen_frame); +} +TEST(EncodeTest, CroppedFrameTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 300; + size_t ysize = 300; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + std::vector pixels2(pixels.size()); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + // Encoding a 300x300 frame in an image that is only 100x100 + basic_info.xsize = 100; + basic_info.ysize = 100; + basic_info.uses_original_profile = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + JxlFrameHeader header; + JxlEncoderInitFrameHeader(&header); + header.layer_info.have_crop = JXL_TRUE; + header.layer_info.xsize = xsize; + header.layer_info.ysize = ysize; + header.layer_info.crop_x0 = -50; + header.layer_info.crop_y0 = -250; + JxlEncoderSetFrameLossless(frame_settings, JXL_TRUE); + JxlEncoderSetFrameHeader(frame_settings, &header); + JxlEncoderFrameSettingsSetOption(frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, + 1); + + std::vector compressed = std::vector(100); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseFrames(enc.get()); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + // Non-coalesced decoding so we can get the full uncropped frame + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec.get(), JXL_FALSE)); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + bool seen_frame = false; + bool checked_frame = false; + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_FRAME) { + seen_frame = true; + JxlFrameHeader header2; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec.get(), &header2)); + EXPECT_EQ(header.layer_info.xsize, header2.layer_info.xsize); + EXPECT_EQ(header.layer_info.ysize, header2.layer_info.ysize); + EXPECT_EQ(header.layer_info.crop_x0, header2.layer_info.crop_x0); + EXPECT_EQ(header.layer_info.crop_y0, header2.layer_info.crop_y0); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec.get(), &pixel_format, + pixels2.data(), pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(0, memcmp(pixels.data(), pixels2.data(), pixels.size())); + checked_frame = true; + } else { + FAIL(); // unexpected status + } + } + EXPECT_EQ(true, checked_frame); + EXPECT_EQ(true, seen_frame); +} + +TEST(EncodeTest, BoxTest) { + // Test with uncompressed boxes and with brob boxes + for (int compress_box = 0; compress_box <= 1; ++compress_box) { + // Tests adding two metadata boxes with the encoder: an exif box before the + // image frame, and an xml box after the image frame. Then verifies the + // decoder can decode them, they are in the expected place, and have the + // correct content after decoding. + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseBoxes(enc.get())); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 50; + size_t ysize = 17; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + + // Add an early metadata box. Also add a valid 4-byte TIFF offset header + // before the fake exif data of these box contents. + constexpr const char* exif_test_string = "\0\0\0\0exif test data"; + const uint8_t* exif_data = + reinterpret_cast(exif_test_string); + // Skip the 4 zeroes for strlen + const size_t exif_size = 4 + strlen(exif_test_string + 4); + JxlEncoderAddBox(enc.get(), "Exif", exif_data, exif_size, compress_box); + + // Write to output + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Add image frame + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + // Indicate this is the last frame + JxlEncoderCloseFrames(enc.get()); + + // Write to output + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Add a late metadata box + constexpr const char* xml_test_string = ""; + const uint8_t* xml_data = reinterpret_cast(xml_test_string); + size_t xml_size = strlen(xml_test_string); + JxlEncoderAddBox(enc.get(), "XML ", xml_data, xml_size, compress_box); + + // Indicate this is the last box + JxlEncoderCloseBoxes(enc.get()); + + // Write to output + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Decode to verify the boxes, we don't decode to pixels, only the boxes. + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + + if (compress_box) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetDecompressBoxes(dec.get(), JXL_TRUE)); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_FRAME | JXL_DEC_BOX)); + + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + std::vector dec_exif_box(exif_size); + std::vector dec_xml_box(xml_size); + + for (bool post_frame = false;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_EQ(0, JxlDecoderReleaseBoxBuffer(dec.get())); + break; + } else if (status == JXL_DEC_FRAME) { + post_frame = true; + } else if (status == JXL_DEC_BOX) { + // Since we gave the exif/xml box output buffer of the exact known + // correct size, 0 bytes should be released. Same when no buffer was + // set. + EXPECT_EQ(0, JxlDecoderReleaseBoxBuffer(dec.get())); + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec.get(), type, true)); + if (!memcmp(type, "Exif", 4)) { + // This box should have been encoded before the image frame + EXPECT_EQ(false, post_frame); + JxlDecoderSetBoxBuffer(dec.get(), dec_exif_box.data(), + dec_exif_box.size()); + } else if (!memcmp(type, "XML ", 4)) { + // This box should have been encoded after the image frame + EXPECT_EQ(true, post_frame); + JxlDecoderSetBoxBuffer(dec.get(), dec_xml_box.data(), + dec_xml_box.size()); + } + } else { + FAIL(); // unexpected status + } + } + + EXPECT_EQ(0, memcmp(exif_data, dec_exif_box.data(), exif_size)); + EXPECT_EQ(0, memcmp(xml_data, dec_xml_box.data(), xml_size)); + } +} + +#if JPEGXL_ENABLE_JPEG // Loading .jpg files requires libjpeg support. +TEST(EncodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGFrameTest)) { + for (int skip_basic_info = 0; skip_basic_info < 2; skip_basic_info++) { + for (int skip_color_encoding = 0; skip_color_encoding < 2; + skip_color_encoding++) { + // cannot set color encoding if basic info is not set + if (skip_basic_info && !skip_color_encoding) continue; + const std::string jpeg_path = "jxl/flower/flower_cropped.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE(SetFromBytes(jxl::Span(orig), &orig_io, + /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_EFFORT, 1); + if (!skip_basic_info) { + JxlBasicInfo basic_info; + JxlEncoderInitBasicInfo(&basic_info); + basic_info.xsize = orig_io.xsize(); + basic_info.ysize = orig_io.ysize(); + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + } + if (!skip_color_encoding) { + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderAddJPEGFrame( + frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = + JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::CodecInOut decoded_io; + EXPECT_TRUE(jxl::test::DecodeFile( + {}, jxl::Span(compressed.data(), compressed.size()), + &decoded_io, /*pool=*/nullptr)); + + EXPECT_LE( + ComputeDistance2(orig_io.Main(), decoded_io.Main(), jxl::GetJxlCms()), + 3.5); + } + } +} +#endif // JPEGXL_ENABLE_JPEG diff --git a/media/libjxl/src/lib/jxl/entropy_coder.cc b/media/libjxl/src/lib/jxl/entropy_coder.cc new file mode 100644 index 000000000..0043c2d31 --- /dev/null +++ b/media/libjxl/src/lib/jxl/entropy_coder.cc @@ -0,0 +1,70 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/entropy_coder.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map) { + auto& dct = block_ctx_map->dc_thresholds; + auto& qft = block_ctx_map->qf_thresholds; + auto& ctx_map = block_ctx_map->ctx_map; + bool is_default = br->ReadFixedBits<1>(); + if (is_default) { + *block_ctx_map = BlockCtxMap(); + return true; + } + block_ctx_map->num_dc_ctxs = 1; + for (int j : {0, 1, 2}) { + dct[j].resize(br->ReadFixedBits<4>()); + block_ctx_map->num_dc_ctxs *= dct[j].size() + 1; + for (int& i : dct[j]) { + i = UnpackSigned(U32Coder::Read(kDCThresholdDist, br)); + } + } + qft.resize(br->ReadFixedBits<4>()); + for (uint32_t& i : qft) { + i = U32Coder::Read(kQFThresholdDist, br) + 1; + } + + if (block_ctx_map->num_dc_ctxs * (qft.size() + 1) > 64) { + return JXL_FAILURE("Invalid block context map: too big"); + } + + ctx_map.resize(3 * kNumOrders * block_ctx_map->num_dc_ctxs * + (qft.size() + 1)); + JXL_RETURN_IF_ERROR(DecodeContextMap(&ctx_map, &block_ctx_map->num_ctxs, br)); + if (block_ctx_map->num_ctxs > 16) { + return JXL_FAILURE("Invalid block context map: too many distinct contexts"); + } + return true; +} + +constexpr uint8_t BlockCtxMap::kDefaultCtxMap[]; // from ac_context.h + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/entropy_coder.h b/media/libjxl/src/lib/jxl/entropy_coder.h new file mode 100644 index 000000000..e4afa7a63 --- /dev/null +++ b/media/libjxl/src/lib/jxl/entropy_coder.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENTROPY_CODER_H_ +#define LIB_JXL_ENTROPY_CODER_H_ + +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +// Entropy coding and context modeling of DC and AC coefficients, as well as AC +// strategy and quantization field. + +namespace jxl { + +static JXL_INLINE int32_t PredictFromTopAndLeft( + const int32_t* const JXL_RESTRICT row_top, + const int32_t* const JXL_RESTRICT row, size_t x, int32_t default_val) { + if (x == 0) { + return row_top == nullptr ? default_val : row_top[x]; + } + if (row_top == nullptr) { + return row[x - 1]; + } + return (row_top[x] + row[x - 1] + 1) / 2; +} + +static constexpr U32Enc kDCThresholdDist(Bits(4), BitsOffset(8, 16), + BitsOffset(16, 272), + BitsOffset(32, 65808)); + +static constexpr U32Enc kQFThresholdDist(Bits(2), BitsOffset(3, 4), + BitsOffset(5, 12), BitsOffset(8, 44)); + +Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map); + +} // namespace jxl + +#endif // LIB_JXL_ENTROPY_CODER_H_ diff --git a/media/libjxl/src/lib/jxl/entropy_coder_test.cc b/media/libjxl/src/lib/jxl/entropy_coder_test.cc new file mode 100644 index 000000000..9b3ffa314 --- /dev/null +++ b/media/libjxl/src/lib/jxl/entropy_coder_test.cc @@ -0,0 +1,68 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(deymo): Move these tests to dec_ans.h and common.h + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" + +namespace jxl { +namespace { + +TEST(EntropyCoderTest, PackUnpack) { + for (int32_t i = -31; i < 32; ++i) { + uint32_t packed = PackSigned(i); + EXPECT_LT(packed, 63u); + int32_t unpacked = UnpackSigned(packed); + EXPECT_EQ(i, unpacked); + } +} + +struct DummyBitReader { + uint32_t nbits, bits; + void Consume(uint32_t nbits) {} + uint32_t PeekBits(uint32_t n) { + EXPECT_EQ(n, nbits); + return bits; + } +}; + +void HybridUintRoundtrip(HybridUintConfig config, size_t limit = 1 << 24) { + Rng rng(0); + constexpr size_t kNumIntegers = 1 << 20; + std::vector integers(kNumIntegers); + std::vector token(kNumIntegers); + std::vector nbits(kNumIntegers); + std::vector bits(kNumIntegers); + for (size_t i = 0; i < kNumIntegers; i++) { + integers[i] = rng.UniformU(0, limit + 1); + config.Encode(integers[i], &token[i], &nbits[i], &bits[i]); + } + for (size_t i = 0; i < kNumIntegers; i++) { + DummyBitReader br{nbits[i], bits[i]}; + EXPECT_EQ(integers[i], + ANSSymbolReader::ReadHybridUintConfig(config, token[i], &br)); + } +} + +TEST(HybridUintTest, Test000) { + HybridUintRoundtrip(HybridUintConfig{0, 0, 0}); +} +TEST(HybridUintTest, Test411) { + HybridUintRoundtrip(HybridUintConfig{4, 1, 1}); +} +TEST(HybridUintTest, Test420) { + HybridUintRoundtrip(HybridUintConfig{4, 2, 0}); +} +TEST(HybridUintTest, Test421) { + HybridUintRoundtrip(HybridUintConfig{4, 2, 1}, 256); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/epf.cc b/media/libjxl/src/lib/jxl/epf.cc new file mode 100644 index 000000000..7288ed9ca --- /dev/null +++ b/media/libjxl/src/lib/jxl/epf.cc @@ -0,0 +1,146 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Edge-preserving smoothing: weighted average based on L1 patch similarity. + +#include "lib/jxl/epf.h" + +#include +#include +#include +#include +#include + +#include +#include +#include // std::accumulate +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +// Mirror n floats starting at *p and store them before p. +JXL_INLINE void LeftMirror(float* p, size_t n) { + for (size_t i = 0; i < n; i++) { + *(p - 1 - i) = p[i]; + } +} + +// Mirror n floats starting at *(p - n) and store them at *p. +JXL_INLINE void RightMirror(float* p, size_t n) { + for (size_t i = 0; i < n; i++) { + p[i] = *(p - 1 - i); + } +} + +void ComputeSigma(const Rect& block_rect, PassesDecoderState* state) { + const LoopFilter& lf = state->shared->frame_header.loop_filter; + JXL_CHECK(lf.epf_iters > 0); + const AcStrategyImage& ac_strategy = state->shared->ac_strategy; + const float quant_scale = state->shared->quantizer.Scale(); + + const size_t sigma_stride = state->sigma.PixelsPerRow(); + const size_t sharpness_stride = state->shared->epf_sharpness.PixelsPerRow(); + + for (size_t by = 0; by < block_rect.ysize(); ++by) { + float* JXL_RESTRICT sigma_row = block_rect.Row(&state->sigma, by); + const uint8_t* JXL_RESTRICT sharpness_row = + block_rect.ConstRow(state->shared->epf_sharpness, by); + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + const int32_t* const JXL_RESTRICT row_quant = + block_rect.ConstRow(state->shared->raw_quant_field, by); + + for (size_t bx = 0; bx < block_rect.xsize(); bx++) { + AcStrategy acs = acs_row[bx]; + size_t llf_x = acs.covered_blocks_x(); + if (!acs.IsFirstBlock()) continue; + // quant_scale is smaller for low quality. + // quant_scale is roughly 0.08 / butteraugli score. + // + // row_quant is smaller for low quality. + // row_quant is a quantization multiplier of form 1.0 / + // row_quant[bx] + // + // lf.epf_quant_mul is a parameter in the format + // kInvSigmaNum is a constant + float sigma_quant = + lf.epf_quant_mul / (quant_scale * row_quant[bx] * kInvSigmaNum); + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + float sigma = + sigma_quant * + lf.epf_sharp_lut[sharpness_row[bx + ix + iy * sharpness_stride]]; + // Avoid infinities. + sigma = std::min(-1e-4f, sigma); // TODO(veluca): remove this. + sigma_row[bx + ix + kSigmaPadding + + (iy + kSigmaPadding) * sigma_stride] = 1.0f / sigma; + } + } + // TODO(veluca): remove this padding. + // Left padding with mirroring. + if (bx + block_rect.x0() == 0) { + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + LeftMirror( + sigma_row + kSigmaPadding + (iy + kSigmaPadding) * sigma_stride, + kSigmaBorder); + } + } + // Right padding with mirroring. + if (bx + block_rect.x0() + llf_x == + state->shared->frame_dim.xsize_blocks) { + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + RightMirror(sigma_row + kSigmaPadding + bx + llf_x + + (iy + kSigmaPadding) * sigma_stride, + kSigmaBorder); + } + } + // Offsets for row copying, in blocks. + size_t offset_before = bx + block_rect.x0() == 0 ? 1 : bx + kSigmaPadding; + size_t offset_after = + bx + block_rect.x0() + llf_x == state->shared->frame_dim.xsize_blocks + ? kSigmaPadding + llf_x + bx + kSigmaBorder + : kSigmaPadding + llf_x + bx; + size_t num = offset_after - offset_before; + // Above + if (by + block_rect.y0() == 0) { + for (size_t iy = 0; iy < kSigmaBorder; iy++) { + memcpy( + sigma_row + offset_before + + (kSigmaPadding - 1 - iy) * sigma_stride, + sigma_row + offset_before + (kSigmaPadding + iy) * sigma_stride, + num * sizeof(*sigma_row)); + } + } + // Below + if (by + block_rect.y0() + acs.covered_blocks_y() == + state->shared->frame_dim.ysize_blocks) { + for (size_t iy = 0; iy < kSigmaBorder; iy++) { + memcpy( + sigma_row + offset_before + + sigma_stride * (acs.covered_blocks_y() + kSigmaPadding + iy), + sigma_row + offset_before + + sigma_stride * + (acs.covered_blocks_y() + kSigmaPadding - 1 - iy), + num * sizeof(*sigma_row)); + } + } + } + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/epf.h b/media/libjxl/src/lib/jxl/epf.h new file mode 100644 index 000000000..7a0834ed9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/epf.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_EPF_H_ +#define LIB_JXL_EPF_H_ + +// Fast SIMD "in-loop" edge preserving filter (adaptive, nonlinear). + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/passes_state.h" + +namespace jxl { + +// 4 * (sqrt(0.5)-1), so that Weight(sigma) = 0.5. +static constexpr float kInvSigmaNum = -1.1715728752538099024f; + +// kInvSigmaNum / 0.3 +constexpr float kMinSigma = -3.90524291751269967465540850526868f; + +// Fills the `state->filter_weights.sigma` image with the precomputed sigma +// values in the area inside `block_rect`. Accesses the AC strategy, quant field +// and epf_sharpness fields in the corresponding positions. +void ComputeSigma(const Rect& block_rect, PassesDecoderState* state); + +} // namespace jxl + +#endif // LIB_JXL_EPF_H_ diff --git a/media/libjxl/src/lib/jxl/exif.h b/media/libjxl/src/lib/jxl/exif.h new file mode 100644 index 000000000..06652fc73 --- /dev/null +++ b/media/libjxl/src/lib/jxl/exif.h @@ -0,0 +1,85 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_EXIF_H_ +#define LIB_JXL_EXIF_H_ + +// Basic parsing of Exif (just enough for the render-impacting things +// like orientation) + +#include "jxl/codestream_header.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +constexpr uint16_t kExifOrientationTag = 274; + +// Checks if a blob looks like Exif, and if so, sets bigendian +// according to the tiff endianness +inline bool IsExif(const std::vector& exif, bool* bigendian) { + if (exif.size() < 12) return false; // not enough bytes for a valid exif blob + const uint8_t* t = exif.data(); + if (LoadLE32(t) == 0x2A004D4D) { + *bigendian = true; + return true; + } else if (LoadLE32(t) == 0x002A4949) { + *bigendian = false; + return true; + } + return false; // not a valid tiff header +} + +// Finds the position of an Exif tag, or 0 if it is not found +inline size_t FindExifTagPosition(const std::vector& exif, + uint16_t tagname) { + bool bigendian; + if (!IsExif(exif, &bigendian)) return 0; + const uint8_t* t = exif.data() + 4; + uint32_t offset = (bigendian ? LoadBE32(t) : LoadLE32(t)); + if (exif.size() < 12 + offset + 2 || offset < 8) return 0; + t += offset - 4; + uint16_t nb_tags = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + while (nb_tags > 0) { + if (t + 12 >= exif.data() + exif.size()) return 0; + uint16_t tag = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + if (tag == tagname) return static_cast(t - exif.data()); + t += 10; + nb_tags--; + } + return 0; +} + +// TODO (jon): tag 1 can be used to represent Adobe RGB 1998 if it has value +// "R03" +// TODO (jon): set intrinsic dimensions according to +// https://discourse.wicg.io/t/proposal-exif-image-resolution-auto-and-from-image/4326/24 +// Parses the Exif data just enough to extract any render-impacting info. +// If the Exif data is invalid or could not be parsed, then it is treated +// as a no-op. +inline void InterpretExif(const std::vector& exif, + JxlOrientation* orientation) { + bool bigendian; + if (!IsExif(exif, &bigendian)) return; + size_t o_pos = FindExifTagPosition(exif, kExifOrientationTag); + if (o_pos) { + const uint8_t* t = exif.data() + o_pos; + uint16_t type = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + uint32_t count = (bigendian ? LoadBE32(t) : LoadLE32(t)); + t += 4; + uint16_t value = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 4; + if (type == 3 && count == 1 && value >= 1 && value <= 8) { + *orientation = static_cast(value); + } + } +} + +} // namespace jxl + +#endif // LIB_JXL_EXIF_H_ diff --git a/media/libjxl/src/lib/jxl/fake_parallel_runner_testonly.h b/media/libjxl/src/lib/jxl/fake_parallel_runner_testonly.h new file mode 100644 index 000000000..3b5c16b54 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fake_parallel_runner_testonly.h @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FAKE_PARALLEL_RUNNER_TESTONLY_H_ +#define LIB_JXL_FAKE_PARALLEL_RUNNER_TESTONLY_H_ + +#include + +#include + +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/random.h" + +namespace jxl { + +// A parallel runner implementation that runs all the jobs in a single thread +// (the caller thread) but runs them pretending to use multiple threads and +// potentially out of order. This is useful for testing conditions that only +// occur under heavy load where the order of operations is different. +class FakeParallelRunner { + public: + FakeParallelRunner(uint32_t order_seed, uint32_t num_threads) + : order_seed_(order_seed), rng_(order_seed), num_threads_(num_threads) { + if (num_threads_ < 1) num_threads_ = 1; + } + + JxlParallelRetCode Run(void* jxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start, + uint32_t end) { + JxlParallelRetCode ret = init(jxl_opaque, num_threads_); + if (ret != 0) return ret; + + if (order_seed_ == 0) { + for (uint32_t i = start; i < end; i++) { + func(jxl_opaque, i, i % num_threads_); + } + } else { + std::vector order(end - start); + for (uint32_t i = start; i < end; i++) { + order[i - start] = i; + } + rng_.Shuffle(order.data(), order.size()); + for (uint32_t i = start; i < end; i++) { + func(jxl_opaque, order[i - start], i % num_threads_); + } + } + return ret; + } + + private: + // Seed for the RNG for defining the execution order. A value of 0 means + // sequential order from start to end. + uint32_t order_seed_; + + // The PRNG object, initialized with the order_seed_. Only used if the seed is + // not 0. + Rng rng_; + + // Number of fake threads. All the tasks are run on the same thread, but using + // different thread_id values based on this num_threads. + uint32_t num_threads_; +}; + +} // namespace jxl + +extern "C" { +// Function to pass as the parallel runner. +JXL_INLINE JxlParallelRetCode JxlFakeParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + return static_cast(runner_opaque) + ->Run(jpegxl_opaque, init, func, start_range, end_range); +} +} + +#endif // LIB_JXL_FAKE_PARALLEL_RUNNER_TESTONLY_H_ diff --git a/media/libjxl/src/lib/jxl/fast_dct-inl.h b/media/libjxl/src/lib/jxl/fast_dct-inl.h new file mode 100644 index 000000000..defdfcd12 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct-inl.h @@ -0,0 +1,236 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_FAST_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_FAST_DCT_INL_H_ +#undef LIB_JXL_FAST_DCT_INL_H_ +#else +#define LIB_JXL_FAST_DCT_INL_H_ +#endif + +#include +#include + +#include "lib/jxl/base/status.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#if HWY_TARGET == HWY_NEON +HWY_NOINLINE void FastTransposeBlock(const int16_t* JXL_RESTRICT data_in, + size_t stride_in, size_t N, size_t M, + int16_t* JXL_RESTRICT data_out, + size_t stride_out) { + JXL_DASSERT(N % 8 == 0); + JXL_DASSERT(M % 8 == 0); + for (size_t i = 0; i < N; i += 8) { + for (size_t j = 0; j < M; j += 8) { + // TODO(veluca): one could optimize the M==8, stride_in==8 case further + // with vld4. + // This code is about 40% faster for N == M == stride_in == + // stride_out == 8 + // Using loads + stores to reshuffle things to be able to + // use vld4 doesn't help. + /* + auto a0 = vld4q_s16(data_in); auto a1 = vld4q_s16(data_in + 32); + int16x8x4_t out0; + int16x8x4_t out1; + out0.val[0] = vuzp1q_s16(a0.val[0], a1.val[0]); + out0.val[1] = vuzp1q_s16(a0.val[1], a1.val[1]); + out0.val[2] = vuzp1q_s16(a0.val[2], a1.val[2]); + out0.val[3] = vuzp1q_s16(a0.val[3], a1.val[3]); + out1.val[0] = vuzp2q_s16(a0.val[0], a1.val[0]); + out1.val[1] = vuzp2q_s16(a0.val[1], a1.val[1]); + out1.val[2] = vuzp2q_s16(a0.val[2], a1.val[2]); + out1.val[3] = vuzp2q_s16(a0.val[3], a1.val[3]); + vst1q_s16_x4(data_out, out0); + vst1q_s16_x4(data_out + 32, out1); + */ + auto a0 = vld1q_s16(data_in + i * stride_in + j); + auto a1 = vld1q_s16(data_in + (i + 1) * stride_in + j); + auto a2 = vld1q_s16(data_in + (i + 2) * stride_in + j); + auto a3 = vld1q_s16(data_in + (i + 3) * stride_in + j); + + auto a01 = vtrnq_s16(a0, a1); + auto a23 = vtrnq_s16(a2, a3); + + auto four0 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[0]), + vreinterpretq_s32_s16(a23.val[0])); + auto four1 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[1]), + vreinterpretq_s32_s16(a23.val[1])); + + auto a4 = vld1q_s16(data_in + (i + 4) * stride_in + j); + auto a5 = vld1q_s16(data_in + (i + 5) * stride_in + j); + auto a6 = vld1q_s16(data_in + (i + 6) * stride_in + j); + auto a7 = vld1q_s16(data_in + (i + 7) * stride_in + j); + + auto a45 = vtrnq_s16(a4, a5); + auto a67 = vtrnq_s16(a6, a7); + + auto four2 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[0]), + vreinterpretq_s32_s16(a67.val[0])); + auto four3 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[1]), + vreinterpretq_s32_s16(a67.val[1])); + + auto out0 = + vcombine_s32(vget_low_s32(four0.val[0]), vget_low_s32(four2.val[0])); + auto out1 = + vcombine_s32(vget_low_s32(four1.val[0]), vget_low_s32(four3.val[0])); + auto out2 = + vcombine_s32(vget_low_s32(four0.val[1]), vget_low_s32(four2.val[1])); + auto out3 = + vcombine_s32(vget_low_s32(four1.val[1]), vget_low_s32(four3.val[1])); + auto out4 = vcombine_s32(vget_high_s32(four0.val[0]), + vget_high_s32(four2.val[0])); + auto out5 = vcombine_s32(vget_high_s32(four1.val[0]), + vget_high_s32(four3.val[0])); + auto out6 = vcombine_s32(vget_high_s32(four0.val[1]), + vget_high_s32(four2.val[1])); + auto out7 = vcombine_s32(vget_high_s32(four1.val[1]), + vget_high_s32(four3.val[1])); + vst1q_s16(data_out + j * stride_out + i, vreinterpretq_s16_s32(out0)); + vst1q_s16(data_out + (j + 1) * stride_out + i, + vreinterpretq_s16_s32(out1)); + vst1q_s16(data_out + (j + 2) * stride_out + i, + vreinterpretq_s16_s32(out2)); + vst1q_s16(data_out + (j + 3) * stride_out + i, + vreinterpretq_s16_s32(out3)); + vst1q_s16(data_out + (j + 4) * stride_out + i, + vreinterpretq_s16_s32(out4)); + vst1q_s16(data_out + (j + 5) * stride_out + i, + vreinterpretq_s16_s32(out5)); + vst1q_s16(data_out + (j + 6) * stride_out + i, + vreinterpretq_s16_s32(out6)); + vst1q_s16(data_out + (j + 7) * stride_out + i, + vreinterpretq_s16_s32(out7)); + } + } +} + +template +struct FastDCTTag {}; + +#include "lib/jxl/fast_dct128-inl.h" +#include "lib/jxl/fast_dct16-inl.h" +#include "lib/jxl/fast_dct256-inl.h" +#include "lib/jxl/fast_dct32-inl.h" +#include "lib/jxl/fast_dct64-inl.h" +#include "lib/jxl/fast_dct8-inl.h" + +template +struct ComputeFastScaledIDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // int16_ts. + HWY_MAYBE_UNUSED void operator()(int16_t* JXL_RESTRICT from, int16_t* to, + size_t to_stride, + int16_t* JXL_RESTRICT scratch_space) { + // Reverse the steps done in ComputeScaledDCT. + if (ROWS < COLS) { + FastTransposeBlock(from, COLS, ROWS, COLS, scratch_space, ROWS); + FastIDCT(FastDCTTag(), scratch_space, ROWS, from, ROWS, ROWS); + FastTransposeBlock(from, ROWS, COLS, ROWS, scratch_space, COLS); + FastIDCT(FastDCTTag(), scratch_space, COLS, to, to_stride, COLS); + } else { + FastIDCT(FastDCTTag(), from, ROWS, scratch_space, ROWS, ROWS); + FastTransposeBlock(scratch_space, ROWS, COLS, ROWS, from, COLS); + FastIDCT(FastDCTTag(), from, COLS, to, to_stride, COLS); + } + } +}; +#endif + +template +HWY_NOINLINE void TestFastIDCT() { +#if HWY_TARGET == HWY_NEON + auto pixels_mem = hwy::AllocateAligned(N * M); + float* pixels = pixels_mem.get(); + auto dct_mem = hwy::AllocateAligned(N * M); + float* dct = dct_mem.get(); + auto dct_i_mem = hwy::AllocateAligned(N * M); + int16_t* dct_i = dct_i_mem.get(); + auto dct_in_mem = hwy::AllocateAligned(N * M); + int16_t* dct_in = dct_in_mem.get(); + auto idct_mem = hwy::AllocateAligned(N * M); + int16_t* idct = idct_mem.get(); + + auto scratch_space_mem = hwy::AllocateAligned(N * M * 2); + float* scratch_space = scratch_space_mem.get(); + auto scratch_space_i_mem = hwy::AllocateAligned(N * M * 2); + int16_t* scratch_space_i = scratch_space_i_mem.get(); + + Rng rng(0); + for (size_t i = 0; i < N * M; i++) { + pixels[i] = rng.UniformF(-1, 1); + } + ComputeScaledDCT()(DCTFrom(pixels, N), dct, scratch_space); + size_t integer_bits = std::max(FastIDCTIntegerBits(FastDCTTag()), + FastIDCTIntegerBits(FastDCTTag())); + // Enough range for [-2, 2] output values. + JXL_ASSERT(integer_bits <= 14); + float scale = (1 << (14 - integer_bits)); + for (size_t i = 0; i < N * M; i++) { + dct_i[i] = std::round(dct[i] * scale); + } + + for (size_t j = 0; j < 40000000 / (M * N); j++) { + memcpy(dct_in, dct_i, sizeof(*dct_i) * N * M); + ComputeFastScaledIDCT()(dct_in, idct, N, scratch_space_i); + } + float max_error = 0; + for (size_t i = 0; i < M * N; i++) { + float err = std::abs(idct[i] * (1.0f / scale) - pixels[i]); + if (std::abs(err) > max_error) { + max_error = std::abs(err); + } + } + printf("max error: %f mantissa bits: %d\n", max_error, + 14 - (int)integer_bits); +#endif +} + +template +HWY_NOINLINE void TestFloatIDCT() { + auto pixels_mem = hwy::AllocateAligned(N * M); + float* pixels = pixels_mem.get(); + auto dct_mem = hwy::AllocateAligned(N * M); + float* dct = dct_mem.get(); + auto idct_mem = hwy::AllocateAligned(N * M); + float* idct = idct_mem.get(); + + auto dct_in_mem = hwy::AllocateAligned(N * M); + float* dct_in = dct_mem.get(); + + auto scratch_space_mem = hwy::AllocateAligned(N * M * 2); + float* scratch_space = scratch_space_mem.get(); + + Rng rng(0); + for (size_t i = 0; i < N * M; i++) { + pixels[i] = rng.UniformF(-1, 1); + } + ComputeScaledDCT()(DCTFrom(pixels, N), dct, scratch_space); + + for (size_t j = 0; j < 40000000 / (M * N); j++) { + memcpy(dct_in, dct, sizeof(*dct) * N * M); + ComputeScaledIDCT()(dct_in, DCTTo(idct, N), scratch_space); + } + float max_error = 0; + for (size_t i = 0; i < M * N; i++) { + float err = std::abs(idct[i] - pixels[i]); + if (std::abs(err) > max_error) { + max_error = std::abs(err); + } + } + printf("max error: %e\n", max_error); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_FAST_DCT_INL_H_ diff --git a/media/libjxl/src/lib/jxl/fast_dct.cc b/media/libjxl/src/lib/jxl/fast_dct.cc new file mode 100644 index 000000000..d796018fd --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct.cc @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_dct.cc" +#include +#include + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/fast_dct-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { +void BenchmarkFloatIDCT32x32() { TestFloatIDCT<32, 32>(); } +void BenchmarkFastIDCT32x32() { TestFastIDCT<32, 32>(); } +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(BenchmarkFloatIDCT32x32); +HWY_EXPORT(BenchmarkFastIDCT32x32); +void BenchmarkFloatIDCT32x32() { + HWY_DYNAMIC_DISPATCH(BenchmarkFloatIDCT32x32)(); +} +void BenchmarkFastIDCT32x32() { + HWY_DYNAMIC_DISPATCH(BenchmarkFastIDCT32x32)(); +} +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/fast_dct.h b/media/libjxl/src/lib/jxl/fast_dct.h new file mode 100644 index 000000000..641933d8a --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct.h @@ -0,0 +1,9 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +namespace jxl { +void BenchmarkFloatIDCT32x32(); +void BenchmarkFastIDCT32x32(); +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/fast_dct128-inl.h b/media/libjxl/src/lib/jxl/fast_dct128-inl.h new file mode 100644 index 000000000..1a94d3ee9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct128-inl.h @@ -0,0 +1,2137 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<128>) { return 2; } + +void FastIDCT(FastDCTTag<128>, const int16_t* in, size_t in_stride, + int16_t* out, size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 64 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 32 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 96 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 80 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 48 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17_tmp = vqrdmulhq_n_s16(v16, 13573); + int16x8_t v17 = vaddq_s16(v17_tmp, v16); + int16x8_t v18 = vld1q_s16(in + in_stride * 112 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v19, v16); + int16x8_t v21 = vaddq_s16(v17, v20); + int16x8_t v22 = vqrdmulhq_n_s16(v21, 17734); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 72 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 56 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 40 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35_tmp = vqrdmulhq_n_s16(v34, 13573); + int16x8_t v35 = vaddq_s16(v35_tmp, v34); + int16x8_t v36 = vld1q_s16(in + in_stride * 104 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 88 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vaddq_s16(v35, v39); + int16x8_t v41 = vqrdmulhq_n_s16(v40, 17734); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v37, v28); + int16x8_t v46 = vaddq_s16(v29, v32); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vaddq_s16(v46, v43); + int16x8_t v50_tmp = vqrdmulhq_n_s16(v49, 13573); + int16x8_t v50 = vaddq_s16(v50_tmp, v49); + int16x8_t v51 = vld1q_s16(in + in_stride * 120 + i); + int16x8_t v52 = vaddq_s16(v51, v36); + int16x8_t v53 = vaddq_s16(v52, v45); + int16x8_t v54 = vaddq_s16(v53, v49); + int16x8_t v55 = vaddq_s16(v50, v54); + int16x8_t v56 = vqrdmulhq_n_s16(v55, 17734); + int16x8_t v57 = vaddq_s16(v48, v56); + int16x8_t v58 = vqrdmulhq_n_s16(v57, 16705); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v63_tmp = vqrdmulhq_n_s16(v62, 13573); + int16x8_t v63 = vaddq_s16(v63_tmp, v62); + int16x8_t v64 = vld1q_s16(in + in_stride * 68 + i); + int16x8_t v65 = vld1q_s16(in + in_stride * 60 + i); + int16x8_t v66 = vaddq_s16(v64, v65); + int16x8_t v67 = vaddq_s16(v63, v66); + int16x8_t v68 = vld1q_s16(in + in_stride * 36 + i); + int16x8_t v69 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v70 = vaddq_s16(v68, v69); + int16x8_t v71_tmp = vqrdmulhq_n_s16(v70, 13573); + int16x8_t v71 = vaddq_s16(v71_tmp, v70); + int16x8_t v72 = vld1q_s16(in + in_stride * 100 + i); + int16x8_t v73 = vld1q_s16(in + in_stride * 92 + i); + int16x8_t v74 = vaddq_s16(v72, v73); + int16x8_t v75 = vaddq_s16(v74, v70); + int16x8_t v76 = vaddq_s16(v71, v75); + int16x8_t v77 = vqrdmulhq_n_s16(v76, 17734); + int16x8_t v78 = vaddq_s16(v67, v77); + int16x8_t v79 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v80 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v81 = vaddq_s16(v79, v80); + int16x8_t v82_tmp = vqrdmulhq_n_s16(v81, 13573); + int16x8_t v82 = vaddq_s16(v82_tmp, v81); + int16x8_t v83 = vld1q_s16(in + in_stride * 84 + i); + int16x8_t v84 = vld1q_s16(in + in_stride * 76 + i); + int16x8_t v85 = vaddq_s16(v83, v84); + int16x8_t v86 = vld1q_s16(in + in_stride * 52 + i); + int16x8_t v87 = vld1q_s16(in + in_stride * 44 + i); + int16x8_t v88 = vaddq_s16(v86, v87); + int16x8_t v89 = vaddq_s16(v85, v88); + int16x8_t v90 = vaddq_s16(v82, v89); + int16x8_t v91 = vaddq_s16(v88, v81); + int16x8_t v92_tmp = vqrdmulhq_n_s16(v91, 13573); + int16x8_t v92 = vaddq_s16(v92_tmp, v91); + int16x8_t v93 = vld1q_s16(in + in_stride * 116 + i); + int16x8_t v94 = vld1q_s16(in + in_stride * 108 + i); + int16x8_t v95 = vaddq_s16(v93, v94); + int16x8_t v96 = vaddq_s16(v95, v85); + int16x8_t v97 = vaddq_s16(v96, v91); + int16x8_t v98 = vaddq_s16(v92, v97); + int16x8_t v99 = vqrdmulhq_n_s16(v98, 17734); + int16x8_t v100 = vaddq_s16(v90, v99); + int16x8_t v101 = vqrdmulhq_n_s16(v100, 16705); + int16x8_t v102 = vaddq_s16(v78, v101); + int16x8_t v103 = vaddq_s16(v80, v62); + int16x8_t v104_tmp = vqrdmulhq_n_s16(v103, 13573); + int16x8_t v104 = vaddq_s16(v104_tmp, v103); + int16x8_t v105 = vaddq_s16(v84, v64); + int16x8_t v106 = vaddq_s16(v65, v86); + int16x8_t v107 = vaddq_s16(v105, v106); + int16x8_t v108 = vaddq_s16(v104, v107); + int16x8_t v109 = vaddq_s16(v87, v68); + int16x8_t v110 = vaddq_s16(v69, v79); + int16x8_t v111 = vaddq_s16(v109, v110); + int16x8_t v112_tmp = vqrdmulhq_n_s16(v111, 13573); + int16x8_t v112 = vaddq_s16(v112_tmp, v111); + int16x8_t v113 = vaddq_s16(v94, v72); + int16x8_t v114 = vaddq_s16(v73, v83); + int16x8_t v115 = vaddq_s16(v113, v114); + int16x8_t v116 = vaddq_s16(v115, v111); + int16x8_t v117 = vaddq_s16(v112, v116); + int16x8_t v118 = vqrdmulhq_n_s16(v117, 17734); + int16x8_t v119 = vaddq_s16(v108, v118); + int16x8_t v120 = vaddq_s16(v110, v103); + int16x8_t v121_tmp = vqrdmulhq_n_s16(v120, 13573); + int16x8_t v121 = vaddq_s16(v121_tmp, v120); + int16x8_t v122 = vaddq_s16(v114, v105); + int16x8_t v123 = vaddq_s16(v106, v109); + int16x8_t v124 = vaddq_s16(v122, v123); + int16x8_t v125 = vaddq_s16(v121, v124); + int16x8_t v126 = vaddq_s16(v123, v120); + int16x8_t v127_tmp = vqrdmulhq_n_s16(v126, 13573); + int16x8_t v127 = vaddq_s16(v127_tmp, v126); + int16x8_t v128 = vld1q_s16(in + in_stride * 124 + i); + int16x8_t v129 = vaddq_s16(v128, v93); + int16x8_t v130 = vaddq_s16(v129, v113); + int16x8_t v131 = vaddq_s16(v130, v122); + int16x8_t v132 = vaddq_s16(v131, v126); + int16x8_t v133 = vaddq_s16(v127, v132); + int16x8_t v134 = vqrdmulhq_n_s16(v133, 17734); + int16x8_t v135 = vaddq_s16(v125, v134); + int16x8_t v136 = vqrdmulhq_n_s16(v135, 16705); + int16x8_t v137 = vaddq_s16(v119, v136); + int16x8_t v138 = vqrdmulhq_n_s16(v137, 16463); + int16x8_t v139 = vaddq_s16(v102, v138); + int16x8_t v140 = vqrdmulhq_n_s16(v139, 16404); + int16x8_t v141 = vaddq_s16(v61, v140); + int16x8_t v142 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v143_tmp = vqrdmulhq_n_s16(v142, 13573); + int16x8_t v143 = vaddq_s16(v143_tmp, v142); + int16x8_t v144 = vld1q_s16(in + in_stride * 66 + i); + int16x8_t v145 = vld1q_s16(in + in_stride * 62 + i); + int16x8_t v146 = vaddq_s16(v144, v145); + int16x8_t v147 = vaddq_s16(v143, v146); + int16x8_t v148 = vld1q_s16(in + in_stride * 34 + i); + int16x8_t v149 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v150 = vaddq_s16(v148, v149); + int16x8_t v151_tmp = vqrdmulhq_n_s16(v150, 13573); + int16x8_t v151 = vaddq_s16(v151_tmp, v150); + int16x8_t v152 = vld1q_s16(in + in_stride * 98 + i); + int16x8_t v153 = vld1q_s16(in + in_stride * 94 + i); + int16x8_t v154 = vaddq_s16(v152, v153); + int16x8_t v155 = vaddq_s16(v154, v150); + int16x8_t v156 = vaddq_s16(v151, v155); + int16x8_t v157 = vqrdmulhq_n_s16(v156, 17734); + int16x8_t v158 = vaddq_s16(v147, v157); + int16x8_t v159 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v160 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v161 = vaddq_s16(v159, v160); + int16x8_t v162_tmp = vqrdmulhq_n_s16(v161, 13573); + int16x8_t v162 = vaddq_s16(v162_tmp, v161); + int16x8_t v163 = vld1q_s16(in + in_stride * 82 + i); + int16x8_t v164 = vld1q_s16(in + in_stride * 78 + i); + int16x8_t v165 = vaddq_s16(v163, v164); + int16x8_t v166 = vld1q_s16(in + in_stride * 50 + i); + int16x8_t v167 = vld1q_s16(in + in_stride * 46 + i); + int16x8_t v168 = vaddq_s16(v166, v167); + int16x8_t v169 = vaddq_s16(v165, v168); + int16x8_t v170 = vaddq_s16(v162, v169); + int16x8_t v171 = vaddq_s16(v168, v161); + int16x8_t v172_tmp = vqrdmulhq_n_s16(v171, 13573); + int16x8_t v172 = vaddq_s16(v172_tmp, v171); + int16x8_t v173 = vld1q_s16(in + in_stride * 114 + i); + int16x8_t v174 = vld1q_s16(in + in_stride * 110 + i); + int16x8_t v175 = vaddq_s16(v173, v174); + int16x8_t v176 = vaddq_s16(v175, v165); + int16x8_t v177 = vaddq_s16(v176, v171); + int16x8_t v178 = vaddq_s16(v172, v177); + int16x8_t v179 = vqrdmulhq_n_s16(v178, 17734); + int16x8_t v180 = vaddq_s16(v170, v179); + int16x8_t v181 = vqrdmulhq_n_s16(v180, 16705); + int16x8_t v182 = vaddq_s16(v158, v181); + int16x8_t v183 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v184 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v185 = vaddq_s16(v183, v184); + int16x8_t v186_tmp = vqrdmulhq_n_s16(v185, 13573); + int16x8_t v186 = vaddq_s16(v186_tmp, v185); + int16x8_t v187 = vld1q_s16(in + in_stride * 74 + i); + int16x8_t v188 = vld1q_s16(in + in_stride * 70 + i); + int16x8_t v189 = vaddq_s16(v187, v188); + int16x8_t v190 = vld1q_s16(in + in_stride * 58 + i); + int16x8_t v191 = vld1q_s16(in + in_stride * 54 + i); + int16x8_t v192 = vaddq_s16(v190, v191); + int16x8_t v193 = vaddq_s16(v189, v192); + int16x8_t v194 = vaddq_s16(v186, v193); + int16x8_t v195 = vld1q_s16(in + in_stride * 42 + i); + int16x8_t v196 = vld1q_s16(in + in_stride * 38 + i); + int16x8_t v197 = vaddq_s16(v195, v196); + int16x8_t v198 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v199 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v200 = vaddq_s16(v198, v199); + int16x8_t v201 = vaddq_s16(v197, v200); + int16x8_t v202_tmp = vqrdmulhq_n_s16(v201, 13573); + int16x8_t v202 = vaddq_s16(v202_tmp, v201); + int16x8_t v203 = vld1q_s16(in + in_stride * 106 + i); + int16x8_t v204 = vld1q_s16(in + in_stride * 102 + i); + int16x8_t v205 = vaddq_s16(v203, v204); + int16x8_t v206 = vld1q_s16(in + in_stride * 90 + i); + int16x8_t v207 = vld1q_s16(in + in_stride * 86 + i); + int16x8_t v208 = vaddq_s16(v206, v207); + int16x8_t v209 = vaddq_s16(v205, v208); + int16x8_t v210 = vaddq_s16(v209, v201); + int16x8_t v211 = vaddq_s16(v202, v210); + int16x8_t v212 = vqrdmulhq_n_s16(v211, 17734); + int16x8_t v213 = vaddq_s16(v194, v212); + int16x8_t v214 = vaddq_s16(v200, v185); + int16x8_t v215_tmp = vqrdmulhq_n_s16(v214, 13573); + int16x8_t v215 = vaddq_s16(v215_tmp, v214); + int16x8_t v216 = vaddq_s16(v208, v189); + int16x8_t v217 = vaddq_s16(v192, v197); + int16x8_t v218 = vaddq_s16(v216, v217); + int16x8_t v219 = vaddq_s16(v215, v218); + int16x8_t v220 = vaddq_s16(v217, v214); + int16x8_t v221_tmp = vqrdmulhq_n_s16(v220, 13573); + int16x8_t v221 = vaddq_s16(v221_tmp, v220); + int16x8_t v222 = vld1q_s16(in + in_stride * 122 + i); + int16x8_t v223 = vld1q_s16(in + in_stride * 118 + i); + int16x8_t v224 = vaddq_s16(v222, v223); + int16x8_t v225 = vaddq_s16(v224, v205); + int16x8_t v226 = vaddq_s16(v225, v216); + int16x8_t v227 = vaddq_s16(v226, v220); + int16x8_t v228 = vaddq_s16(v221, v227); + int16x8_t v229 = vqrdmulhq_n_s16(v228, 17734); + int16x8_t v230 = vaddq_s16(v219, v229); + int16x8_t v231 = vqrdmulhq_n_s16(v230, 16705); + int16x8_t v232 = vaddq_s16(v213, v231); + int16x8_t v233 = vqrdmulhq_n_s16(v232, 16463); + int16x8_t v234 = vaddq_s16(v182, v233); + int16x8_t v235 = vaddq_s16(v184, v142); + int16x8_t v236_tmp = vqrdmulhq_n_s16(v235, 13573); + int16x8_t v236 = vaddq_s16(v236_tmp, v235); + int16x8_t v237 = vaddq_s16(v188, v144); + int16x8_t v238 = vaddq_s16(v145, v190); + int16x8_t v239 = vaddq_s16(v237, v238); + int16x8_t v240 = vaddq_s16(v236, v239); + int16x8_t v241 = vaddq_s16(v196, v148); + int16x8_t v242 = vaddq_s16(v149, v198); + int16x8_t v243 = vaddq_s16(v241, v242); + int16x8_t v244_tmp = vqrdmulhq_n_s16(v243, 13573); + int16x8_t v244 = vaddq_s16(v244_tmp, v243); + int16x8_t v245 = vaddq_s16(v204, v152); + int16x8_t v246 = vaddq_s16(v153, v206); + int16x8_t v247 = vaddq_s16(v245, v246); + int16x8_t v248 = vaddq_s16(v247, v243); + int16x8_t v249 = vaddq_s16(v244, v248); + int16x8_t v250 = vqrdmulhq_n_s16(v249, 17734); + int16x8_t v251 = vaddq_s16(v240, v250); + int16x8_t v252 = vaddq_s16(v199, v159); + int16x8_t v253 = vaddq_s16(v160, v183); + int16x8_t v254 = vaddq_s16(v252, v253); + int16x8_t v255_tmp = vqrdmulhq_n_s16(v254, 13573); + int16x8_t v255 = vaddq_s16(v255_tmp, v254); + int16x8_t v256 = vaddq_s16(v207, v163); + int16x8_t v257 = vaddq_s16(v164, v187); + int16x8_t v258 = vaddq_s16(v256, v257); + int16x8_t v259 = vaddq_s16(v191, v166); + int16x8_t v260 = vaddq_s16(v167, v195); + int16x8_t v261 = vaddq_s16(v259, v260); + int16x8_t v262 = vaddq_s16(v258, v261); + int16x8_t v263 = vaddq_s16(v255, v262); + int16x8_t v264 = vaddq_s16(v261, v254); + int16x8_t v265_tmp = vqrdmulhq_n_s16(v264, 13573); + int16x8_t v265 = vaddq_s16(v265_tmp, v264); + int16x8_t v266 = vaddq_s16(v223, v173); + int16x8_t v267 = vaddq_s16(v174, v203); + int16x8_t v268 = vaddq_s16(v266, v267); + int16x8_t v269 = vaddq_s16(v268, v258); + int16x8_t v270 = vaddq_s16(v269, v264); + int16x8_t v271 = vaddq_s16(v265, v270); + int16x8_t v272 = vqrdmulhq_n_s16(v271, 17734); + int16x8_t v273 = vaddq_s16(v263, v272); + int16x8_t v274 = vqrdmulhq_n_s16(v273, 16705); + int16x8_t v275 = vaddq_s16(v251, v274); + int16x8_t v276 = vaddq_s16(v253, v235); + int16x8_t v277_tmp = vqrdmulhq_n_s16(v276, 13573); + int16x8_t v277 = vaddq_s16(v277_tmp, v276); + int16x8_t v278 = vaddq_s16(v257, v237); + int16x8_t v279 = vaddq_s16(v238, v259); + int16x8_t v280 = vaddq_s16(v278, v279); + int16x8_t v281 = vaddq_s16(v277, v280); + int16x8_t v282 = vaddq_s16(v260, v241); + int16x8_t v283 = vaddq_s16(v242, v252); + int16x8_t v284 = vaddq_s16(v282, v283); + int16x8_t v285_tmp = vqrdmulhq_n_s16(v284, 13573); + int16x8_t v285 = vaddq_s16(v285_tmp, v284); + int16x8_t v286 = vaddq_s16(v267, v245); + int16x8_t v287 = vaddq_s16(v246, v256); + int16x8_t v288 = vaddq_s16(v286, v287); + int16x8_t v289 = vaddq_s16(v288, v284); + int16x8_t v290 = vaddq_s16(v285, v289); + int16x8_t v291 = vqrdmulhq_n_s16(v290, 17734); + int16x8_t v292 = vaddq_s16(v281, v291); + int16x8_t v293 = vaddq_s16(v283, v276); + int16x8_t v294_tmp = vqrdmulhq_n_s16(v293, 13573); + int16x8_t v294 = vaddq_s16(v294_tmp, v293); + int16x8_t v295 = vaddq_s16(v287, v278); + int16x8_t v296 = vaddq_s16(v279, v282); + int16x8_t v297 = vaddq_s16(v295, v296); + int16x8_t v298 = vaddq_s16(v294, v297); + int16x8_t v299 = vaddq_s16(v296, v293); + int16x8_t v300_tmp = vqrdmulhq_n_s16(v299, 13573); + int16x8_t v300 = vaddq_s16(v300_tmp, v299); + int16x8_t v301 = vld1q_s16(in + in_stride * 126 + i); + int16x8_t v302 = vaddq_s16(v301, v222); + int16x8_t v303 = vaddq_s16(v302, v266); + int16x8_t v304 = vaddq_s16(v303, v286); + int16x8_t v305 = vaddq_s16(v304, v295); + int16x8_t v306 = vaddq_s16(v305, v299); + int16x8_t v307 = vaddq_s16(v300, v306); + int16x8_t v308 = vqrdmulhq_n_s16(v307, 17734); + int16x8_t v309 = vaddq_s16(v298, v308); + int16x8_t v310 = vqrdmulhq_n_s16(v309, 16705); + int16x8_t v311 = vaddq_s16(v292, v310); + int16x8_t v312 = vqrdmulhq_n_s16(v311, 16463); + int16x8_t v313 = vaddq_s16(v275, v312); + int16x8_t v314 = vqrdmulhq_n_s16(v313, 16404); + int16x8_t v315 = vaddq_s16(v234, v314); + int16x8_t v316 = vqrdmulhq_n_s16(v315, 16389); + int16x8_t v317 = vaddq_s16(v141, v316); + int16x8_t v318 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v319_tmp = vqrdmulhq_n_s16(v318, 13573); + int16x8_t v319 = vaddq_s16(v319_tmp, v318); + int16x8_t v320 = vld1q_s16(in + in_stride * 65 + i); + int16x8_t v321 = vld1q_s16(in + in_stride * 63 + i); + int16x8_t v322 = vaddq_s16(v320, v321); + int16x8_t v323 = vaddq_s16(v319, v322); + int16x8_t v324 = vld1q_s16(in + in_stride * 33 + i); + int16x8_t v325 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v326 = vaddq_s16(v324, v325); + int16x8_t v327_tmp = vqrdmulhq_n_s16(v326, 13573); + int16x8_t v327 = vaddq_s16(v327_tmp, v326); + int16x8_t v328 = vld1q_s16(in + in_stride * 97 + i); + int16x8_t v329 = vld1q_s16(in + in_stride * 95 + i); + int16x8_t v330 = vaddq_s16(v328, v329); + int16x8_t v331 = vaddq_s16(v330, v326); + int16x8_t v332 = vaddq_s16(v327, v331); + int16x8_t v333 = vqrdmulhq_n_s16(v332, 17734); + int16x8_t v334 = vaddq_s16(v323, v333); + int16x8_t v335 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v336 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v337 = vaddq_s16(v335, v336); + int16x8_t v338_tmp = vqrdmulhq_n_s16(v337, 13573); + int16x8_t v338 = vaddq_s16(v338_tmp, v337); + int16x8_t v339 = vld1q_s16(in + in_stride * 81 + i); + int16x8_t v340 = vld1q_s16(in + in_stride * 79 + i); + int16x8_t v341 = vaddq_s16(v339, v340); + int16x8_t v342 = vld1q_s16(in + in_stride * 49 + i); + int16x8_t v343 = vld1q_s16(in + in_stride * 47 + i); + int16x8_t v344 = vaddq_s16(v342, v343); + int16x8_t v345 = vaddq_s16(v341, v344); + int16x8_t v346 = vaddq_s16(v338, v345); + int16x8_t v347 = vaddq_s16(v344, v337); + int16x8_t v348_tmp = vqrdmulhq_n_s16(v347, 13573); + int16x8_t v348 = vaddq_s16(v348_tmp, v347); + int16x8_t v349 = vld1q_s16(in + in_stride * 113 + i); + int16x8_t v350 = vld1q_s16(in + in_stride * 111 + i); + int16x8_t v351 = vaddq_s16(v349, v350); + int16x8_t v352 = vaddq_s16(v351, v341); + int16x8_t v353 = vaddq_s16(v352, v347); + int16x8_t v354 = vaddq_s16(v348, v353); + int16x8_t v355 = vqrdmulhq_n_s16(v354, 17734); + int16x8_t v356 = vaddq_s16(v346, v355); + int16x8_t v357 = vqrdmulhq_n_s16(v356, 16705); + int16x8_t v358 = vaddq_s16(v334, v357); + int16x8_t v359 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v360 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v361 = vaddq_s16(v359, v360); + int16x8_t v362_tmp = vqrdmulhq_n_s16(v361, 13573); + int16x8_t v362 = vaddq_s16(v362_tmp, v361); + int16x8_t v363 = vld1q_s16(in + in_stride * 73 + i); + int16x8_t v364 = vld1q_s16(in + in_stride * 71 + i); + int16x8_t v365 = vaddq_s16(v363, v364); + int16x8_t v366 = vld1q_s16(in + in_stride * 57 + i); + int16x8_t v367 = vld1q_s16(in + in_stride * 55 + i); + int16x8_t v368 = vaddq_s16(v366, v367); + int16x8_t v369 = vaddq_s16(v365, v368); + int16x8_t v370 = vaddq_s16(v362, v369); + int16x8_t v371 = vld1q_s16(in + in_stride * 41 + i); + int16x8_t v372 = vld1q_s16(in + in_stride * 39 + i); + int16x8_t v373 = vaddq_s16(v371, v372); + int16x8_t v374 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v375 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v376 = vaddq_s16(v374, v375); + int16x8_t v377 = vaddq_s16(v373, v376); + int16x8_t v378_tmp = vqrdmulhq_n_s16(v377, 13573); + int16x8_t v378 = vaddq_s16(v378_tmp, v377); + int16x8_t v379 = vld1q_s16(in + in_stride * 105 + i); + int16x8_t v380 = vld1q_s16(in + in_stride * 103 + i); + int16x8_t v381 = vaddq_s16(v379, v380); + int16x8_t v382 = vld1q_s16(in + in_stride * 89 + i); + int16x8_t v383 = vld1q_s16(in + in_stride * 87 + i); + int16x8_t v384 = vaddq_s16(v382, v383); + int16x8_t v385 = vaddq_s16(v381, v384); + int16x8_t v386 = vaddq_s16(v385, v377); + int16x8_t v387 = vaddq_s16(v378, v386); + int16x8_t v388 = vqrdmulhq_n_s16(v387, 17734); + int16x8_t v389 = vaddq_s16(v370, v388); + int16x8_t v390 = vaddq_s16(v376, v361); + int16x8_t v391_tmp = vqrdmulhq_n_s16(v390, 13573); + int16x8_t v391 = vaddq_s16(v391_tmp, v390); + int16x8_t v392 = vaddq_s16(v384, v365); + int16x8_t v393 = vaddq_s16(v368, v373); + int16x8_t v394 = vaddq_s16(v392, v393); + int16x8_t v395 = vaddq_s16(v391, v394); + int16x8_t v396 = vaddq_s16(v393, v390); + int16x8_t v397_tmp = vqrdmulhq_n_s16(v396, 13573); + int16x8_t v397 = vaddq_s16(v397_tmp, v396); + int16x8_t v398 = vld1q_s16(in + in_stride * 121 + i); + int16x8_t v399 = vld1q_s16(in + in_stride * 119 + i); + int16x8_t v400 = vaddq_s16(v398, v399); + int16x8_t v401 = vaddq_s16(v400, v381); + int16x8_t v402 = vaddq_s16(v401, v392); + int16x8_t v403 = vaddq_s16(v402, v396); + int16x8_t v404 = vaddq_s16(v397, v403); + int16x8_t v405 = vqrdmulhq_n_s16(v404, 17734); + int16x8_t v406 = vaddq_s16(v395, v405); + int16x8_t v407 = vqrdmulhq_n_s16(v406, 16705); + int16x8_t v408 = vaddq_s16(v389, v407); + int16x8_t v409 = vqrdmulhq_n_s16(v408, 16463); + int16x8_t v410 = vaddq_s16(v358, v409); + int16x8_t v411 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v412 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v413 = vaddq_s16(v411, v412); + int16x8_t v414_tmp = vqrdmulhq_n_s16(v413, 13573); + int16x8_t v414 = vaddq_s16(v414_tmp, v413); + int16x8_t v415 = vld1q_s16(in + in_stride * 69 + i); + int16x8_t v416 = vld1q_s16(in + in_stride * 67 + i); + int16x8_t v417 = vaddq_s16(v415, v416); + int16x8_t v418 = vld1q_s16(in + in_stride * 61 + i); + int16x8_t v419 = vld1q_s16(in + in_stride * 59 + i); + int16x8_t v420 = vaddq_s16(v418, v419); + int16x8_t v421 = vaddq_s16(v417, v420); + int16x8_t v422 = vaddq_s16(v414, v421); + int16x8_t v423 = vld1q_s16(in + in_stride * 37 + i); + int16x8_t v424 = vld1q_s16(in + in_stride * 35 + i); + int16x8_t v425 = vaddq_s16(v423, v424); + int16x8_t v426 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v427 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v428 = vaddq_s16(v426, v427); + int16x8_t v429 = vaddq_s16(v425, v428); + int16x8_t v430_tmp = vqrdmulhq_n_s16(v429, 13573); + int16x8_t v430 = vaddq_s16(v430_tmp, v429); + int16x8_t v431 = vld1q_s16(in + in_stride * 101 + i); + int16x8_t v432 = vld1q_s16(in + in_stride * 99 + i); + int16x8_t v433 = vaddq_s16(v431, v432); + int16x8_t v434 = vld1q_s16(in + in_stride * 93 + i); + int16x8_t v435 = vld1q_s16(in + in_stride * 91 + i); + int16x8_t v436 = vaddq_s16(v434, v435); + int16x8_t v437 = vaddq_s16(v433, v436); + int16x8_t v438 = vaddq_s16(v437, v429); + int16x8_t v439 = vaddq_s16(v430, v438); + int16x8_t v440 = vqrdmulhq_n_s16(v439, 17734); + int16x8_t v441 = vaddq_s16(v422, v440); + int16x8_t v442 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v443 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v444 = vaddq_s16(v442, v443); + int16x8_t v445 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v446 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v447 = vaddq_s16(v445, v446); + int16x8_t v448 = vaddq_s16(v444, v447); + int16x8_t v449_tmp = vqrdmulhq_n_s16(v448, 13573); + int16x8_t v449 = vaddq_s16(v449_tmp, v448); + int16x8_t v450 = vld1q_s16(in + in_stride * 85 + i); + int16x8_t v451 = vld1q_s16(in + in_stride * 83 + i); + int16x8_t v452 = vaddq_s16(v450, v451); + int16x8_t v453 = vld1q_s16(in + in_stride * 77 + i); + int16x8_t v454 = vld1q_s16(in + in_stride * 75 + i); + int16x8_t v455 = vaddq_s16(v453, v454); + int16x8_t v456 = vaddq_s16(v452, v455); + int16x8_t v457 = vld1q_s16(in + in_stride * 53 + i); + int16x8_t v458 = vld1q_s16(in + in_stride * 51 + i); + int16x8_t v459 = vaddq_s16(v457, v458); + int16x8_t v460 = vld1q_s16(in + in_stride * 45 + i); + int16x8_t v461 = vld1q_s16(in + in_stride * 43 + i); + int16x8_t v462 = vaddq_s16(v460, v461); + int16x8_t v463 = vaddq_s16(v459, v462); + int16x8_t v464 = vaddq_s16(v456, v463); + int16x8_t v465 = vaddq_s16(v449, v464); + int16x8_t v466 = vaddq_s16(v463, v448); + int16x8_t v467_tmp = vqrdmulhq_n_s16(v466, 13573); + int16x8_t v467 = vaddq_s16(v467_tmp, v466); + int16x8_t v468 = vld1q_s16(in + in_stride * 117 + i); + int16x8_t v469 = vld1q_s16(in + in_stride * 115 + i); + int16x8_t v470 = vaddq_s16(v468, v469); + int16x8_t v471 = vld1q_s16(in + in_stride * 109 + i); + int16x8_t v472 = vld1q_s16(in + in_stride * 107 + i); + int16x8_t v473 = vaddq_s16(v471, v472); + int16x8_t v474 = vaddq_s16(v470, v473); + int16x8_t v475 = vaddq_s16(v474, v456); + int16x8_t v476 = vaddq_s16(v475, v466); + int16x8_t v477 = vaddq_s16(v467, v476); + int16x8_t v478 = vqrdmulhq_n_s16(v477, 17734); + int16x8_t v479 = vaddq_s16(v465, v478); + int16x8_t v480 = vqrdmulhq_n_s16(v479, 16705); + int16x8_t v481 = vaddq_s16(v441, v480); + int16x8_t v482 = vaddq_s16(v447, v413); + int16x8_t v483_tmp = vqrdmulhq_n_s16(v482, 13573); + int16x8_t v483 = vaddq_s16(v483_tmp, v482); + int16x8_t v484 = vaddq_s16(v455, v417); + int16x8_t v485 = vaddq_s16(v420, v459); + int16x8_t v486 = vaddq_s16(v484, v485); + int16x8_t v487 = vaddq_s16(v483, v486); + int16x8_t v488 = vaddq_s16(v462, v425); + int16x8_t v489 = vaddq_s16(v428, v444); + int16x8_t v490 = vaddq_s16(v488, v489); + int16x8_t v491_tmp = vqrdmulhq_n_s16(v490, 13573); + int16x8_t v491 = vaddq_s16(v491_tmp, v490); + int16x8_t v492 = vaddq_s16(v473, v433); + int16x8_t v493 = vaddq_s16(v436, v452); + int16x8_t v494 = vaddq_s16(v492, v493); + int16x8_t v495 = vaddq_s16(v494, v490); + int16x8_t v496 = vaddq_s16(v491, v495); + int16x8_t v497 = vqrdmulhq_n_s16(v496, 17734); + int16x8_t v498 = vaddq_s16(v487, v497); + int16x8_t v499 = vaddq_s16(v489, v482); + int16x8_t v500_tmp = vqrdmulhq_n_s16(v499, 13573); + int16x8_t v500 = vaddq_s16(v500_tmp, v499); + int16x8_t v501 = vaddq_s16(v493, v484); + int16x8_t v502 = vaddq_s16(v485, v488); + int16x8_t v503 = vaddq_s16(v501, v502); + int16x8_t v504 = vaddq_s16(v500, v503); + int16x8_t v505 = vaddq_s16(v502, v499); + int16x8_t v506_tmp = vqrdmulhq_n_s16(v505, 13573); + int16x8_t v506 = vaddq_s16(v506_tmp, v505); + int16x8_t v507 = vld1q_s16(in + in_stride * 125 + i); + int16x8_t v508 = vld1q_s16(in + in_stride * 123 + i); + int16x8_t v509 = vaddq_s16(v507, v508); + int16x8_t v510 = vaddq_s16(v509, v470); + int16x8_t v511 = vaddq_s16(v510, v492); + int16x8_t v512 = vaddq_s16(v511, v501); + int16x8_t v513 = vaddq_s16(v512, v505); + int16x8_t v514 = vaddq_s16(v506, v513); + int16x8_t v515 = vqrdmulhq_n_s16(v514, 17734); + int16x8_t v516 = vaddq_s16(v504, v515); + int16x8_t v517 = vqrdmulhq_n_s16(v516, 16705); + int16x8_t v518 = vaddq_s16(v498, v517); + int16x8_t v519 = vqrdmulhq_n_s16(v518, 16463); + int16x8_t v520 = vaddq_s16(v481, v519); + int16x8_t v521 = vqrdmulhq_n_s16(v520, 16404); + int16x8_t v522 = vaddq_s16(v410, v521); + int16x8_t v523 = vaddq_s16(v412, v318); + int16x8_t v524_tmp = vqrdmulhq_n_s16(v523, 13573); + int16x8_t v524 = vaddq_s16(v524_tmp, v523); + int16x8_t v525 = vaddq_s16(v416, v320); + int16x8_t v526 = vaddq_s16(v321, v418); + int16x8_t v527 = vaddq_s16(v525, v526); + int16x8_t v528 = vaddq_s16(v524, v527); + int16x8_t v529 = vaddq_s16(v424, v324); + int16x8_t v530 = vaddq_s16(v325, v426); + int16x8_t v531 = vaddq_s16(v529, v530); + int16x8_t v532_tmp = vqrdmulhq_n_s16(v531, 13573); + int16x8_t v532 = vaddq_s16(v532_tmp, v531); + int16x8_t v533 = vaddq_s16(v432, v328); + int16x8_t v534 = vaddq_s16(v329, v434); + int16x8_t v535 = vaddq_s16(v533, v534); + int16x8_t v536 = vaddq_s16(v535, v531); + int16x8_t v537 = vaddq_s16(v532, v536); + int16x8_t v538 = vqrdmulhq_n_s16(v537, 17734); + int16x8_t v539 = vaddq_s16(v528, v538); + int16x8_t v540 = vaddq_s16(v443, v335); + int16x8_t v541 = vaddq_s16(v336, v445); + int16x8_t v542 = vaddq_s16(v540, v541); + int16x8_t v543_tmp = vqrdmulhq_n_s16(v542, 13573); + int16x8_t v543 = vaddq_s16(v543_tmp, v542); + int16x8_t v544 = vaddq_s16(v451, v339); + int16x8_t v545 = vaddq_s16(v340, v453); + int16x8_t v546 = vaddq_s16(v544, v545); + int16x8_t v547 = vaddq_s16(v458, v342); + int16x8_t v548 = vaddq_s16(v343, v460); + int16x8_t v549 = vaddq_s16(v547, v548); + int16x8_t v550 = vaddq_s16(v546, v549); + int16x8_t v551 = vaddq_s16(v543, v550); + int16x8_t v552 = vaddq_s16(v549, v542); + int16x8_t v553_tmp = vqrdmulhq_n_s16(v552, 13573); + int16x8_t v553 = vaddq_s16(v553_tmp, v552); + int16x8_t v554 = vaddq_s16(v469, v349); + int16x8_t v555 = vaddq_s16(v350, v471); + int16x8_t v556 = vaddq_s16(v554, v555); + int16x8_t v557 = vaddq_s16(v556, v546); + int16x8_t v558 = vaddq_s16(v557, v552); + int16x8_t v559 = vaddq_s16(v553, v558); + int16x8_t v560 = vqrdmulhq_n_s16(v559, 17734); + int16x8_t v561 = vaddq_s16(v551, v560); + int16x8_t v562 = vqrdmulhq_n_s16(v561, 16705); + int16x8_t v563 = vaddq_s16(v539, v562); + int16x8_t v564 = vaddq_s16(v446, v359); + int16x8_t v565 = vaddq_s16(v360, v411); + int16x8_t v566 = vaddq_s16(v564, v565); + int16x8_t v567_tmp = vqrdmulhq_n_s16(v566, 13573); + int16x8_t v567 = vaddq_s16(v567_tmp, v566); + int16x8_t v568 = vaddq_s16(v454, v363); + int16x8_t v569 = vaddq_s16(v364, v415); + int16x8_t v570 = vaddq_s16(v568, v569); + int16x8_t v571 = vaddq_s16(v419, v366); + int16x8_t v572 = vaddq_s16(v367, v457); + int16x8_t v573 = vaddq_s16(v571, v572); + int16x8_t v574 = vaddq_s16(v570, v573); + int16x8_t v575 = vaddq_s16(v567, v574); + int16x8_t v576 = vaddq_s16(v461, v371); + int16x8_t v577 = vaddq_s16(v372, v423); + int16x8_t v578 = vaddq_s16(v576, v577); + int16x8_t v579 = vaddq_s16(v427, v374); + int16x8_t v580 = vaddq_s16(v375, v442); + int16x8_t v581 = vaddq_s16(v579, v580); + int16x8_t v582 = vaddq_s16(v578, v581); + int16x8_t v583_tmp = vqrdmulhq_n_s16(v582, 13573); + int16x8_t v583 = vaddq_s16(v583_tmp, v582); + int16x8_t v584 = vaddq_s16(v472, v379); + int16x8_t v585 = vaddq_s16(v380, v431); + int16x8_t v586 = vaddq_s16(v584, v585); + int16x8_t v587 = vaddq_s16(v435, v382); + int16x8_t v588 = vaddq_s16(v383, v450); + int16x8_t v589 = vaddq_s16(v587, v588); + int16x8_t v590 = vaddq_s16(v586, v589); + int16x8_t v591 = vaddq_s16(v590, v582); + int16x8_t v592 = vaddq_s16(v583, v591); + int16x8_t v593 = vqrdmulhq_n_s16(v592, 17734); + int16x8_t v594 = vaddq_s16(v575, v593); + int16x8_t v595 = vaddq_s16(v581, v566); + int16x8_t v596_tmp = vqrdmulhq_n_s16(v595, 13573); + int16x8_t v596 = vaddq_s16(v596_tmp, v595); + int16x8_t v597 = vaddq_s16(v589, v570); + int16x8_t v598 = vaddq_s16(v573, v578); + int16x8_t v599 = vaddq_s16(v597, v598); + int16x8_t v600 = vaddq_s16(v596, v599); + int16x8_t v601 = vaddq_s16(v598, v595); + int16x8_t v602_tmp = vqrdmulhq_n_s16(v601, 13573); + int16x8_t v602 = vaddq_s16(v602_tmp, v601); + int16x8_t v603 = vaddq_s16(v508, v398); + int16x8_t v604 = vaddq_s16(v399, v468); + int16x8_t v605 = vaddq_s16(v603, v604); + int16x8_t v606 = vaddq_s16(v605, v586); + int16x8_t v607 = vaddq_s16(v606, v597); + int16x8_t v608 = vaddq_s16(v607, v601); + int16x8_t v609 = vaddq_s16(v602, v608); + int16x8_t v610 = vqrdmulhq_n_s16(v609, 17734); + int16x8_t v611 = vaddq_s16(v600, v610); + int16x8_t v612 = vqrdmulhq_n_s16(v611, 16705); + int16x8_t v613 = vaddq_s16(v594, v612); + int16x8_t v614 = vqrdmulhq_n_s16(v613, 16463); + int16x8_t v615 = vaddq_s16(v563, v614); + int16x8_t v616 = vaddq_s16(v565, v523); + int16x8_t v617_tmp = vqrdmulhq_n_s16(v616, 13573); + int16x8_t v617 = vaddq_s16(v617_tmp, v616); + int16x8_t v618 = vaddq_s16(v569, v525); + int16x8_t v619 = vaddq_s16(v526, v571); + int16x8_t v620 = vaddq_s16(v618, v619); + int16x8_t v621 = vaddq_s16(v617, v620); + int16x8_t v622 = vaddq_s16(v577, v529); + int16x8_t v623 = vaddq_s16(v530, v579); + int16x8_t v624 = vaddq_s16(v622, v623); + int16x8_t v625_tmp = vqrdmulhq_n_s16(v624, 13573); + int16x8_t v625 = vaddq_s16(v625_tmp, v624); + int16x8_t v626 = vaddq_s16(v585, v533); + int16x8_t v627 = vaddq_s16(v534, v587); + int16x8_t v628 = vaddq_s16(v626, v627); + int16x8_t v629 = vaddq_s16(v628, v624); + int16x8_t v630 = vaddq_s16(v625, v629); + int16x8_t v631 = vqrdmulhq_n_s16(v630, 17734); + int16x8_t v632 = vaddq_s16(v621, v631); + int16x8_t v633 = vaddq_s16(v580, v540); + int16x8_t v634 = vaddq_s16(v541, v564); + int16x8_t v635 = vaddq_s16(v633, v634); + int16x8_t v636_tmp = vqrdmulhq_n_s16(v635, 13573); + int16x8_t v636 = vaddq_s16(v636_tmp, v635); + int16x8_t v637 = vaddq_s16(v588, v544); + int16x8_t v638 = vaddq_s16(v545, v568); + int16x8_t v639 = vaddq_s16(v637, v638); + int16x8_t v640 = vaddq_s16(v572, v547); + int16x8_t v641 = vaddq_s16(v548, v576); + int16x8_t v642 = vaddq_s16(v640, v641); + int16x8_t v643 = vaddq_s16(v639, v642); + int16x8_t v644 = vaddq_s16(v636, v643); + int16x8_t v645 = vaddq_s16(v642, v635); + int16x8_t v646_tmp = vqrdmulhq_n_s16(v645, 13573); + int16x8_t v646 = vaddq_s16(v646_tmp, v645); + int16x8_t v647 = vaddq_s16(v604, v554); + int16x8_t v648 = vaddq_s16(v555, v584); + int16x8_t v649 = vaddq_s16(v647, v648); + int16x8_t v650 = vaddq_s16(v649, v639); + int16x8_t v651 = vaddq_s16(v650, v645); + int16x8_t v652 = vaddq_s16(v646, v651); + int16x8_t v653 = vqrdmulhq_n_s16(v652, 17734); + int16x8_t v654 = vaddq_s16(v644, v653); + int16x8_t v655 = vqrdmulhq_n_s16(v654, 16705); + int16x8_t v656 = vaddq_s16(v632, v655); + int16x8_t v657 = vaddq_s16(v634, v616); + int16x8_t v658_tmp = vqrdmulhq_n_s16(v657, 13573); + int16x8_t v658 = vaddq_s16(v658_tmp, v657); + int16x8_t v659 = vaddq_s16(v638, v618); + int16x8_t v660 = vaddq_s16(v619, v640); + int16x8_t v661 = vaddq_s16(v659, v660); + int16x8_t v662 = vaddq_s16(v658, v661); + int16x8_t v663 = vaddq_s16(v641, v622); + int16x8_t v664 = vaddq_s16(v623, v633); + int16x8_t v665 = vaddq_s16(v663, v664); + int16x8_t v666_tmp = vqrdmulhq_n_s16(v665, 13573); + int16x8_t v666 = vaddq_s16(v666_tmp, v665); + int16x8_t v667 = vaddq_s16(v648, v626); + int16x8_t v668 = vaddq_s16(v627, v637); + int16x8_t v669 = vaddq_s16(v667, v668); + int16x8_t v670 = vaddq_s16(v669, v665); + int16x8_t v671 = vaddq_s16(v666, v670); + int16x8_t v672 = vqrdmulhq_n_s16(v671, 17734); + int16x8_t v673 = vaddq_s16(v662, v672); + int16x8_t v674 = vaddq_s16(v664, v657); + int16x8_t v675_tmp = vqrdmulhq_n_s16(v674, 13573); + int16x8_t v675 = vaddq_s16(v675_tmp, v674); + int16x8_t v676 = vaddq_s16(v668, v659); + int16x8_t v677 = vaddq_s16(v660, v663); + int16x8_t v678 = vaddq_s16(v676, v677); + int16x8_t v679 = vaddq_s16(v675, v678); + int16x8_t v680 = vaddq_s16(v677, v674); + int16x8_t v681_tmp = vqrdmulhq_n_s16(v680, 13573); + int16x8_t v681 = vaddq_s16(v681_tmp, v680); + int16x8_t v682 = vld1q_s16(in + in_stride * 127 + i); + int16x8_t v683 = vaddq_s16(v682, v507); + int16x8_t v684 = vaddq_s16(v683, v603); + int16x8_t v685 = vaddq_s16(v684, v647); + int16x8_t v686 = vaddq_s16(v685, v667); + int16x8_t v687 = vaddq_s16(v686, v676); + int16x8_t v688 = vaddq_s16(v687, v680); + int16x8_t v689 = vaddq_s16(v681, v688); + int16x8_t v690 = vqrdmulhq_n_s16(v689, 17734); + int16x8_t v691 = vaddq_s16(v679, v690); + int16x8_t v692 = vqrdmulhq_n_s16(v691, 16705); + int16x8_t v693 = vaddq_s16(v673, v692); + int16x8_t v694 = vqrdmulhq_n_s16(v693, 16463); + int16x8_t v695 = vaddq_s16(v656, v694); + int16x8_t v696 = vqrdmulhq_n_s16(v695, 16404); + int16x8_t v697 = vaddq_s16(v615, v696); + int16x8_t v698 = vqrdmulhq_n_s16(v697, 16389); + int16x8_t v699 = vaddq_s16(v522, v698); + int16x8_t v700 = vqrdmulhq_n_s16(v699, 16385); + int16x8_t v701 = vaddq_s16(v317, v700); + int16x8_t v702 = vsubq_s16(v0, v1); + int16x8_t v703 = vsubq_s16(v4, v6); + int16x8_t v704_tmp = vqrdmulhq_n_s16(v703, 10045); + int16x8_t v704 = vaddq_s16(v704_tmp, v703); + int16x8_t v705 = vaddq_s16(v702, v704); + int16x8_t v706 = vsubq_s16(v11, v14); + int16x8_t v707 = vsubq_s16(v17, v20); + int16x8_t v708_tmp = vqrdmulhq_n_s16(v707, 10045); + int16x8_t v708 = vaddq_s16(v708_tmp, v707); + int16x8_t v709 = vaddq_s16(v706, v708); + int16x8_t v710 = vqrdmulhq_n_s16(v709, 19705); + int16x8_t v711 = vaddq_s16(v705, v710); + int16x8_t v712 = vsubq_s16(v27, v30); + int16x8_t v713 = vsubq_s16(v35, v39); + int16x8_t v714_tmp = vqrdmulhq_n_s16(v713, 10045); + int16x8_t v714 = vaddq_s16(v714_tmp, v713); + int16x8_t v715 = vaddq_s16(v712, v714); + int16x8_t v716 = vsubq_s16(v44, v47); + int16x8_t v717 = vsubq_s16(v50, v54); + int16x8_t v718_tmp = vqrdmulhq_n_s16(v717, 10045); + int16x8_t v718 = vaddq_s16(v718_tmp, v717); + int16x8_t v719 = vaddq_s16(v716, v718); + int16x8_t v720 = vqrdmulhq_n_s16(v719, 19705); + int16x8_t v721 = vaddq_s16(v715, v720); + int16x8_t v722 = vqrdmulhq_n_s16(v721, 17121); + int16x8_t v723 = vaddq_s16(v711, v722); + int16x8_t v724 = vsubq_s16(v63, v66); + int16x8_t v725 = vsubq_s16(v71, v75); + int16x8_t v726_tmp = vqrdmulhq_n_s16(v725, 10045); + int16x8_t v726 = vaddq_s16(v726_tmp, v725); + int16x8_t v727 = vaddq_s16(v724, v726); + int16x8_t v728 = vsubq_s16(v82, v89); + int16x8_t v729 = vsubq_s16(v92, v97); + int16x8_t v730_tmp = vqrdmulhq_n_s16(v729, 10045); + int16x8_t v730 = vaddq_s16(v730_tmp, v729); + int16x8_t v731 = vaddq_s16(v728, v730); + int16x8_t v732 = vqrdmulhq_n_s16(v731, 19705); + int16x8_t v733 = vaddq_s16(v727, v732); + int16x8_t v734 = vsubq_s16(v104, v107); + int16x8_t v735 = vsubq_s16(v112, v116); + int16x8_t v736_tmp = vqrdmulhq_n_s16(v735, 10045); + int16x8_t v736 = vaddq_s16(v736_tmp, v735); + int16x8_t v737 = vaddq_s16(v734, v736); + int16x8_t v738 = vsubq_s16(v121, v124); + int16x8_t v739 = vsubq_s16(v127, v132); + int16x8_t v740_tmp = vqrdmulhq_n_s16(v739, 10045); + int16x8_t v740 = vaddq_s16(v740_tmp, v739); + int16x8_t v741 = vaddq_s16(v738, v740); + int16x8_t v742 = vqrdmulhq_n_s16(v741, 19705); + int16x8_t v743 = vaddq_s16(v737, v742); + int16x8_t v744 = vqrdmulhq_n_s16(v743, 17121); + int16x8_t v745 = vaddq_s16(v733, v744); + int16x8_t v746 = vqrdmulhq_n_s16(v745, 16563); + int16x8_t v747 = vaddq_s16(v723, v746); + int16x8_t v748 = vsubq_s16(v143, v146); + int16x8_t v749 = vsubq_s16(v151, v155); + int16x8_t v750_tmp = vqrdmulhq_n_s16(v749, 10045); + int16x8_t v750 = vaddq_s16(v750_tmp, v749); + int16x8_t v751 = vaddq_s16(v748, v750); + int16x8_t v752 = vsubq_s16(v162, v169); + int16x8_t v753 = vqrdmulhq_n_s16(v752, 19705); + int16x8_t v754 = vsubq_s16(v172, v177); + int16x8_t v755 = vqrdmulhq_n_s16(v754, 25746); + int16x8_t v756 = vaddq_s16(v753, v755); + int16x8_t v757 = vaddq_s16(v751, v756); + int16x8_t v758 = vsubq_s16(v186, v193); + int16x8_t v759 = vsubq_s16(v202, v210); + int16x8_t v760_tmp = vqrdmulhq_n_s16(v759, 10045); + int16x8_t v760 = vaddq_s16(v760_tmp, v759); + int16x8_t v761 = vaddq_s16(v758, v760); + int16x8_t v762 = vsubq_s16(v215, v218); + int16x8_t v763 = vsubq_s16(v221, v227); + int16x8_t v764_tmp = vqrdmulhq_n_s16(v763, 10045); + int16x8_t v764 = vaddq_s16(v764_tmp, v763); + int16x8_t v765 = vaddq_s16(v762, v764); + int16x8_t v766 = vqrdmulhq_n_s16(v765, 19705); + int16x8_t v767 = vaddq_s16(v761, v766); + int16x8_t v768 = vqrdmulhq_n_s16(v767, 17121); + int16x8_t v769 = vaddq_s16(v757, v768); + int16x8_t v770 = vsubq_s16(v236, v239); + int16x8_t v771 = vsubq_s16(v244, v248); + int16x8_t v772_tmp = vqrdmulhq_n_s16(v771, 10045); + int16x8_t v772 = vaddq_s16(v772_tmp, v771); + int16x8_t v773 = vaddq_s16(v770, v772); + int16x8_t v774 = vsubq_s16(v255, v262); + int16x8_t v775 = vsubq_s16(v265, v270); + int16x8_t v776_tmp = vqrdmulhq_n_s16(v775, 10045); + int16x8_t v776 = vaddq_s16(v776_tmp, v775); + int16x8_t v777 = vaddq_s16(v774, v776); + int16x8_t v778 = vqrdmulhq_n_s16(v777, 19705); + int16x8_t v779 = vaddq_s16(v773, v778); + int16x8_t v780 = vsubq_s16(v277, v280); + int16x8_t v781 = vsubq_s16(v285, v289); + int16x8_t v782_tmp = vqrdmulhq_n_s16(v781, 10045); + int16x8_t v782 = vaddq_s16(v782_tmp, v781); + int16x8_t v783 = vaddq_s16(v780, v782); + int16x8_t v784 = vsubq_s16(v294, v297); + int16x8_t v785 = vsubq_s16(v300, v306); + int16x8_t v786_tmp = vqrdmulhq_n_s16(v785, 10045); + int16x8_t v786 = vaddq_s16(v786_tmp, v785); + int16x8_t v787 = vaddq_s16(v784, v786); + int16x8_t v788 = vqrdmulhq_n_s16(v787, 19705); + int16x8_t v789 = vaddq_s16(v783, v788); + int16x8_t v790 = vqrdmulhq_n_s16(v789, 17121); + int16x8_t v791 = vaddq_s16(v779, v790); + int16x8_t v792 = vqrdmulhq_n_s16(v791, 16563); + int16x8_t v793 = vaddq_s16(v769, v792); + int16x8_t v794 = vqrdmulhq_n_s16(v793, 16429); + int16x8_t v795 = vaddq_s16(v747, v794); + int16x8_t v796 = vsubq_s16(v319, v322); + int16x8_t v797 = vsubq_s16(v327, v331); + int16x8_t v798_tmp = vqrdmulhq_n_s16(v797, 10045); + int16x8_t v798 = vaddq_s16(v798_tmp, v797); + int16x8_t v799 = vaddq_s16(v796, v798); + int16x8_t v800 = vsubq_s16(v338, v345); + int16x8_t v801 = vsubq_s16(v348, v353); + int16x8_t v802_tmp = vqrdmulhq_n_s16(v801, 10045); + int16x8_t v802 = vaddq_s16(v802_tmp, v801); + int16x8_t v803 = vaddq_s16(v800, v802); + int16x8_t v804 = vqrdmulhq_n_s16(v803, 19705); + int16x8_t v805 = vaddq_s16(v799, v804); + int16x8_t v806 = vsubq_s16(v362, v369); + int16x8_t v807 = vsubq_s16(v378, v386); + int16x8_t v808_tmp = vqrdmulhq_n_s16(v807, 10045); + int16x8_t v808 = vaddq_s16(v808_tmp, v807); + int16x8_t v809 = vaddq_s16(v806, v808); + int16x8_t v810 = vsubq_s16(v391, v394); + int16x8_t v811 = vsubq_s16(v397, v403); + int16x8_t v812_tmp = vqrdmulhq_n_s16(v811, 10045); + int16x8_t v812 = vaddq_s16(v812_tmp, v811); + int16x8_t v813 = vaddq_s16(v810, v812); + int16x8_t v814 = vqrdmulhq_n_s16(v813, 19705); + int16x8_t v815 = vaddq_s16(v809, v814); + int16x8_t v816 = vqrdmulhq_n_s16(v815, 17121); + int16x8_t v817 = vaddq_s16(v805, v816); + int16x8_t v818 = vsubq_s16(v414, v421); + int16x8_t v819 = vsubq_s16(v430, v438); + int16x8_t v820_tmp = vqrdmulhq_n_s16(v819, 10045); + int16x8_t v820 = vaddq_s16(v820_tmp, v819); + int16x8_t v821 = vaddq_s16(v818, v820); + int16x8_t v822 = vsubq_s16(v449, v464); + int16x8_t v823 = vsubq_s16(v467, v476); + int16x8_t v824_tmp = vqrdmulhq_n_s16(v823, 10045); + int16x8_t v824 = vaddq_s16(v824_tmp, v823); + int16x8_t v825 = vaddq_s16(v822, v824); + int16x8_t v826 = vqrdmulhq_n_s16(v825, 19705); + int16x8_t v827 = vaddq_s16(v821, v826); + int16x8_t v828 = vsubq_s16(v483, v486); + int16x8_t v829 = vsubq_s16(v491, v495); + int16x8_t v830_tmp = vqrdmulhq_n_s16(v829, 10045); + int16x8_t v830 = vaddq_s16(v830_tmp, v829); + int16x8_t v831 = vaddq_s16(v828, v830); + int16x8_t v832 = vsubq_s16(v500, v503); + int16x8_t v833 = vsubq_s16(v506, v513); + int16x8_t v834_tmp = vqrdmulhq_n_s16(v833, 10045); + int16x8_t v834 = vaddq_s16(v834_tmp, v833); + int16x8_t v835 = vaddq_s16(v832, v834); + int16x8_t v836 = vqrdmulhq_n_s16(v835, 19705); + int16x8_t v837 = vaddq_s16(v831, v836); + int16x8_t v838 = vqrdmulhq_n_s16(v837, 17121); + int16x8_t v839 = vaddq_s16(v827, v838); + int16x8_t v840 = vqrdmulhq_n_s16(v839, 16563); + int16x8_t v841 = vaddq_s16(v817, v840); + int16x8_t v842 = vsubq_s16(v524, v527); + int16x8_t v843 = vsubq_s16(v532, v536); + int16x8_t v844_tmp = vqrdmulhq_n_s16(v843, 10045); + int16x8_t v844 = vaddq_s16(v844_tmp, v843); + int16x8_t v845 = vaddq_s16(v842, v844); + int16x8_t v846 = vsubq_s16(v543, v550); + int16x8_t v847 = vsubq_s16(v553, v558); + int16x8_t v848_tmp = vqrdmulhq_n_s16(v847, 10045); + int16x8_t v848 = vaddq_s16(v848_tmp, v847); + int16x8_t v849 = vaddq_s16(v846, v848); + int16x8_t v850 = vqrdmulhq_n_s16(v849, 19705); + int16x8_t v851 = vaddq_s16(v845, v850); + int16x8_t v852 = vsubq_s16(v567, v574); + int16x8_t v853 = vsubq_s16(v583, v591); + int16x8_t v854_tmp = vqrdmulhq_n_s16(v853, 10045); + int16x8_t v854 = vaddq_s16(v854_tmp, v853); + int16x8_t v855 = vaddq_s16(v852, v854); + int16x8_t v856 = vsubq_s16(v596, v599); + int16x8_t v857 = vsubq_s16(v602, v608); + int16x8_t v858_tmp = vqrdmulhq_n_s16(v857, 10045); + int16x8_t v858 = vaddq_s16(v858_tmp, v857); + int16x8_t v859 = vaddq_s16(v856, v858); + int16x8_t v860 = vqrdmulhq_n_s16(v859, 19705); + int16x8_t v861 = vaddq_s16(v855, v860); + int16x8_t v862 = vqrdmulhq_n_s16(v861, 17121); + int16x8_t v863 = vaddq_s16(v851, v862); + int16x8_t v864 = vsubq_s16(v617, v620); + int16x8_t v865 = vsubq_s16(v625, v629); + int16x8_t v866_tmp = vqrdmulhq_n_s16(v865, 10045); + int16x8_t v866 = vaddq_s16(v866_tmp, v865); + int16x8_t v867 = vaddq_s16(v864, v866); + int16x8_t v868 = vsubq_s16(v636, v643); + int16x8_t v869 = vsubq_s16(v646, v651); + int16x8_t v870_tmp = vqrdmulhq_n_s16(v869, 10045); + int16x8_t v870 = vaddq_s16(v870_tmp, v869); + int16x8_t v871 = vaddq_s16(v868, v870); + int16x8_t v872 = vqrdmulhq_n_s16(v871, 19705); + int16x8_t v873 = vaddq_s16(v867, v872); + int16x8_t v874 = vsubq_s16(v658, v661); + int16x8_t v875 = vsubq_s16(v666, v670); + int16x8_t v876_tmp = vqrdmulhq_n_s16(v875, 10045); + int16x8_t v876 = vaddq_s16(v876_tmp, v875); + int16x8_t v877 = vaddq_s16(v874, v876); + int16x8_t v878 = vsubq_s16(v675, v678); + int16x8_t v879 = vsubq_s16(v681, v688); + int16x8_t v880_tmp = vqrdmulhq_n_s16(v879, 10045); + int16x8_t v880 = vaddq_s16(v880_tmp, v879); + int16x8_t v881 = vaddq_s16(v878, v880); + int16x8_t v882 = vqrdmulhq_n_s16(v881, 19705); + int16x8_t v883 = vaddq_s16(v877, v882); + int16x8_t v884 = vqrdmulhq_n_s16(v883, 17121); + int16x8_t v885 = vaddq_s16(v873, v884); + int16x8_t v886 = vqrdmulhq_n_s16(v885, 16563); + int16x8_t v887 = vaddq_s16(v863, v886); + int16x8_t v888 = vqrdmulhq_n_s16(v887, 16429); + int16x8_t v889 = vaddq_s16(v841, v888); + int16x8_t v890 = vqrdmulhq_n_s16(v889, 16395); + int16x8_t v891 = vaddq_s16(v795, v890); + int16x8_t v892 = vsubq_s16(v702, v704); + int16x8_t v893 = vsubq_s16(v706, v708); + int16x8_t v894 = vqrdmulhq_n_s16(v893, 29490); + int16x8_t v895 = vaddq_s16(v892, v894); + int16x8_t v896 = vsubq_s16(v712, v714); + int16x8_t v897 = vsubq_s16(v716, v718); + int16x8_t v898 = vqrdmulhq_n_s16(v897, 29490); + int16x8_t v899 = vaddq_s16(v896, v898); + int16x8_t v900 = vqrdmulhq_n_s16(v899, 18578); + int16x8_t v901 = vaddq_s16(v895, v900); + int16x8_t v902 = vsubq_s16(v724, v726); + int16x8_t v903 = vsubq_s16(v728, v730); + int16x8_t v904 = vqrdmulhq_n_s16(v903, 29490); + int16x8_t v905 = vaddq_s16(v902, v904); + int16x8_t v906 = vsubq_s16(v734, v736); + int16x8_t v907 = vsubq_s16(v738, v740); + int16x8_t v908 = vqrdmulhq_n_s16(v907, 29490); + int16x8_t v909 = vaddq_s16(v906, v908); + int16x8_t v910 = vqrdmulhq_n_s16(v909, 18578); + int16x8_t v911 = vaddq_s16(v905, v910); + int16x8_t v912 = vqrdmulhq_n_s16(v911, 16890); + int16x8_t v913 = vaddq_s16(v901, v912); + int16x8_t v914 = vsubq_s16(v748, v750); + int16x8_t v915_tmp = vqrdmulhq_n_s16(v754, 10045); + int16x8_t v915 = vaddq_s16(v915_tmp, v754); + int16x8_t v916 = vsubq_s16(v752, v915); + int16x8_t v917 = vqrdmulhq_n_s16(v916, 29490); + int16x8_t v918 = vaddq_s16(v914, v917); + int16x8_t v919 = vsubq_s16(v758, v760); + int16x8_t v920 = vsubq_s16(v762, v764); + int16x8_t v921 = vqrdmulhq_n_s16(v920, 29490); + int16x8_t v922 = vaddq_s16(v919, v921); + int16x8_t v923 = vqrdmulhq_n_s16(v922, 18578); + int16x8_t v924 = vaddq_s16(v918, v923); + int16x8_t v925 = vsubq_s16(v770, v772); + int16x8_t v926 = vsubq_s16(v774, v776); + int16x8_t v927 = vqrdmulhq_n_s16(v926, 29490); + int16x8_t v928 = vaddq_s16(v925, v927); + int16x8_t v929 = vsubq_s16(v780, v782); + int16x8_t v930 = vsubq_s16(v784, v786); + int16x8_t v931 = vqrdmulhq_n_s16(v930, 29490); + int16x8_t v932 = vaddq_s16(v929, v931); + int16x8_t v933 = vqrdmulhq_n_s16(v932, 18578); + int16x8_t v934 = vaddq_s16(v928, v933); + int16x8_t v935 = vqrdmulhq_n_s16(v934, 16890); + int16x8_t v936 = vaddq_s16(v924, v935); + int16x8_t v937 = vqrdmulhq_n_s16(v936, 16508); + int16x8_t v938 = vaddq_s16(v913, v937); + int16x8_t v939 = vsubq_s16(v796, v798); + int16x8_t v940 = vsubq_s16(v800, v802); + int16x8_t v941 = vqrdmulhq_n_s16(v940, 29490); + int16x8_t v942 = vaddq_s16(v939, v941); + int16x8_t v943 = vsubq_s16(v806, v808); + int16x8_t v944 = vsubq_s16(v810, v812); + int16x8_t v945 = vqrdmulhq_n_s16(v944, 29490); + int16x8_t v946 = vaddq_s16(v943, v945); + int16x8_t v947 = vqrdmulhq_n_s16(v946, 18578); + int16x8_t v948 = vaddq_s16(v942, v947); + int16x8_t v949 = vsubq_s16(v818, v820); + int16x8_t v950 = vsubq_s16(v822, v824); + int16x8_t v951 = vqrdmulhq_n_s16(v950, 29490); + int16x8_t v952 = vaddq_s16(v949, v951); + int16x8_t v953 = vsubq_s16(v828, v830); + int16x8_t v954 = vsubq_s16(v832, v834); + int16x8_t v955 = vqrdmulhq_n_s16(v954, 29490); + int16x8_t v956 = vaddq_s16(v953, v955); + int16x8_t v957 = vqrdmulhq_n_s16(v956, 18578); + int16x8_t v958 = vaddq_s16(v952, v957); + int16x8_t v959 = vqrdmulhq_n_s16(v958, 16890); + int16x8_t v960 = vaddq_s16(v948, v959); + int16x8_t v961 = vsubq_s16(v842, v844); + int16x8_t v962 = vsubq_s16(v846, v848); + int16x8_t v963 = vqrdmulhq_n_s16(v962, 29490); + int16x8_t v964 = vaddq_s16(v961, v963); + int16x8_t v965 = vsubq_s16(v852, v854); + int16x8_t v966 = vsubq_s16(v856, v858); + int16x8_t v967 = vqrdmulhq_n_s16(v966, 29490); + int16x8_t v968 = vaddq_s16(v965, v967); + int16x8_t v969 = vqrdmulhq_n_s16(v968, 18578); + int16x8_t v970 = vaddq_s16(v964, v969); + int16x8_t v971 = vsubq_s16(v864, v866); + int16x8_t v972 = vsubq_s16(v868, v870); + int16x8_t v973 = vqrdmulhq_n_s16(v972, 29490); + int16x8_t v974 = vaddq_s16(v971, v973); + int16x8_t v975 = vsubq_s16(v874, v876); + int16x8_t v976 = vsubq_s16(v878, v880); + int16x8_t v977 = vqrdmulhq_n_s16(v976, 29490); + int16x8_t v978 = vaddq_s16(v975, v977); + int16x8_t v979 = vqrdmulhq_n_s16(v978, 18578); + int16x8_t v980 = vaddq_s16(v974, v979); + int16x8_t v981 = vqrdmulhq_n_s16(v980, 16890); + int16x8_t v982 = vaddq_s16(v970, v981); + int16x8_t v983 = vqrdmulhq_n_s16(v982, 16508); + int16x8_t v984 = vaddq_s16(v960, v983); + int16x8_t v985 = vqrdmulhq_n_s16(v984, 16415); + int16x8_t v986 = vaddq_s16(v938, v985); + int16x8_t v987 = vsubq_s16(v2, v8); + int16x8_t v988 = vsubq_s16(v15, v22); + int16x8_t v989_tmp = vqrdmulhq_n_s16(v988, 18446); + int16x8_t v989 = vmlaq_n_s16(v989_tmp, v988, 2); + int16x8_t v990 = vaddq_s16(v987, v989); + int16x8_t v991 = vsubq_s16(v31, v41); + int16x8_t v992 = vsubq_s16(v48, v56); + int16x8_t v993_tmp = vqrdmulhq_n_s16(v992, 18446); + int16x8_t v993 = vmlaq_n_s16(v993_tmp, v992, 2); + int16x8_t v994 = vaddq_s16(v991, v993); + int16x8_t v995 = vqrdmulhq_n_s16(v994, 21195); + int16x8_t v996 = vaddq_s16(v990, v995); + int16x8_t v997 = vsubq_s16(v67, v77); + int16x8_t v998 = vsubq_s16(v90, v99); + int16x8_t v999_tmp = vqrdmulhq_n_s16(v998, 18446); + int16x8_t v999 = vmlaq_n_s16(v999_tmp, v998, 2); + int16x8_t v1000 = vaddq_s16(v997, v999); + int16x8_t v1001 = vsubq_s16(v108, v118); + int16x8_t v1002 = vsubq_s16(v125, v134); + int16x8_t v1003_tmp = vqrdmulhq_n_s16(v1002, 18446); + int16x8_t v1003 = vmlaq_n_s16(v1003_tmp, v1002, 2); + int16x8_t v1004 = vaddq_s16(v1001, v1003); + int16x8_t v1005 = vqrdmulhq_n_s16(v1004, 21195); + int16x8_t v1006 = vaddq_s16(v1000, v1005); + int16x8_t v1007 = vqrdmulhq_n_s16(v1006, 17401); + int16x8_t v1008 = vaddq_s16(v996, v1007); + int16x8_t v1009 = vsubq_s16(v147, v157); + int16x8_t v1010 = vsubq_s16(v170, v179); + int16x8_t v1011_tmp = vqrdmulhq_n_s16(v1010, 18446); + int16x8_t v1011 = vmlaq_n_s16(v1011_tmp, v1010, 2); + int16x8_t v1012 = vaddq_s16(v1009, v1011); + int16x8_t v1013 = vsubq_s16(v194, v212); + int16x8_t v1014 = vsubq_s16(v219, v229); + int16x8_t v1015_tmp = vqrdmulhq_n_s16(v1014, 18446); + int16x8_t v1015 = vmlaq_n_s16(v1015_tmp, v1014, 2); + int16x8_t v1016 = vaddq_s16(v1013, v1015); + int16x8_t v1017 = vqrdmulhq_n_s16(v1016, 21195); + int16x8_t v1018 = vaddq_s16(v1012, v1017); + int16x8_t v1019 = vsubq_s16(v240, v250); + int16x8_t v1020 = vsubq_s16(v263, v272); + int16x8_t v1021_tmp = vqrdmulhq_n_s16(v1020, 18446); + int16x8_t v1021 = vmlaq_n_s16(v1021_tmp, v1020, 2); + int16x8_t v1022 = vaddq_s16(v1019, v1021); + int16x8_t v1023 = vsubq_s16(v281, v291); + int16x8_t v1024 = vsubq_s16(v298, v308); + int16x8_t v1025_tmp = vqrdmulhq_n_s16(v1024, 18446); + int16x8_t v1025 = vmlaq_n_s16(v1025_tmp, v1024, 2); + int16x8_t v1026 = vaddq_s16(v1023, v1025); + int16x8_t v1027 = vqrdmulhq_n_s16(v1026, 21195); + int16x8_t v1028 = vaddq_s16(v1022, v1027); + int16x8_t v1029 = vqrdmulhq_n_s16(v1028, 17401); + int16x8_t v1030 = vaddq_s16(v1018, v1029); + int16x8_t v1031 = vqrdmulhq_n_s16(v1030, 16629); + int16x8_t v1032 = vaddq_s16(v1008, v1031); + int16x8_t v1033 = vsubq_s16(v323, v333); + int16x8_t v1034 = vsubq_s16(v346, v355); + int16x8_t v1035_tmp = vqrdmulhq_n_s16(v1034, 18446); + int16x8_t v1035 = vmlaq_n_s16(v1035_tmp, v1034, 2); + int16x8_t v1036 = vaddq_s16(v1033, v1035); + int16x8_t v1037 = vsubq_s16(v370, v388); + int16x8_t v1038 = vsubq_s16(v395, v405); + int16x8_t v1039_tmp = vqrdmulhq_n_s16(v1038, 18446); + int16x8_t v1039 = vmlaq_n_s16(v1039_tmp, v1038, 2); + int16x8_t v1040 = vaddq_s16(v1037, v1039); + int16x8_t v1041 = vqrdmulhq_n_s16(v1040, 21195); + int16x8_t v1042 = vaddq_s16(v1036, v1041); + int16x8_t v1043 = vsubq_s16(v422, v440); + int16x8_t v1044 = vsubq_s16(v465, v478); + int16x8_t v1045_tmp = vqrdmulhq_n_s16(v1044, 18446); + int16x8_t v1045 = vmlaq_n_s16(v1045_tmp, v1044, 2); + int16x8_t v1046 = vaddq_s16(v1043, v1045); + int16x8_t v1047 = vsubq_s16(v487, v497); + int16x8_t v1048 = vsubq_s16(v504, v515); + int16x8_t v1049_tmp = vqrdmulhq_n_s16(v1048, 18446); + int16x8_t v1049 = vmlaq_n_s16(v1049_tmp, v1048, 2); + int16x8_t v1050 = vaddq_s16(v1047, v1049); + int16x8_t v1051 = vqrdmulhq_n_s16(v1050, 21195); + int16x8_t v1052 = vaddq_s16(v1046, v1051); + int16x8_t v1053 = vqrdmulhq_n_s16(v1052, 17401); + int16x8_t v1054 = vaddq_s16(v1042, v1053); + int16x8_t v1055 = vsubq_s16(v528, v538); + int16x8_t v1056 = vsubq_s16(v551, v560); + int16x8_t v1057_tmp = vqrdmulhq_n_s16(v1056, 18446); + int16x8_t v1057 = vmlaq_n_s16(v1057_tmp, v1056, 2); + int16x8_t v1058 = vaddq_s16(v1055, v1057); + int16x8_t v1059 = vsubq_s16(v575, v593); + int16x8_t v1060 = vsubq_s16(v600, v610); + int16x8_t v1061_tmp = vqrdmulhq_n_s16(v1060, 18446); + int16x8_t v1061 = vmlaq_n_s16(v1061_tmp, v1060, 2); + int16x8_t v1062 = vaddq_s16(v1059, v1061); + int16x8_t v1063 = vqrdmulhq_n_s16(v1062, 21195); + int16x8_t v1064 = vaddq_s16(v1058, v1063); + int16x8_t v1065 = vsubq_s16(v621, v631); + int16x8_t v1066 = vsubq_s16(v644, v653); + int16x8_t v1067_tmp = vqrdmulhq_n_s16(v1066, 18446); + int16x8_t v1067 = vmlaq_n_s16(v1067_tmp, v1066, 2); + int16x8_t v1068 = vaddq_s16(v1065, v1067); + int16x8_t v1069 = vsubq_s16(v662, v672); + int16x8_t v1070 = vsubq_s16(v679, v690); + int16x8_t v1071_tmp = vqrdmulhq_n_s16(v1070, 18446); + int16x8_t v1071 = vmlaq_n_s16(v1071_tmp, v1070, 2); + int16x8_t v1072 = vaddq_s16(v1069, v1071); + int16x8_t v1073 = vqrdmulhq_n_s16(v1072, 21195); + int16x8_t v1074 = vaddq_s16(v1068, v1073); + int16x8_t v1075 = vqrdmulhq_n_s16(v1074, 17401); + int16x8_t v1076 = vaddq_s16(v1064, v1075); + int16x8_t v1077 = vqrdmulhq_n_s16(v1076, 16629); + int16x8_t v1078 = vaddq_s16(v1054, v1077); + int16x8_t v1079 = vqrdmulhq_n_s16(v1078, 16445); + int16x8_t v1080 = vaddq_s16(v1032, v1079); + int16x8_t v1081 = vsubq_s16(v987, v989); + int16x8_t v1082 = vsubq_s16(v991, v993); + int16x8_t v1083 = vqrdmulhq_n_s16(v1082, 25826); + int16x8_t v1084 = vaddq_s16(v1081, v1083); + int16x8_t v1085 = vsubq_s16(v997, v999); + int16x8_t v1086 = vsubq_s16(v1001, v1003); + int16x8_t v1087 = vqrdmulhq_n_s16(v1086, 25826); + int16x8_t v1088 = vaddq_s16(v1085, v1087); + int16x8_t v1089 = vqrdmulhq_n_s16(v1088, 18124); + int16x8_t v1090 = vaddq_s16(v1084, v1089); + int16x8_t v1091 = vsubq_s16(v1009, v1011); + int16x8_t v1092 = vsubq_s16(v1013, v1015); + int16x8_t v1093 = vqrdmulhq_n_s16(v1092, 25826); + int16x8_t v1094 = vaddq_s16(v1091, v1093); + int16x8_t v1095 = vsubq_s16(v1019, v1021); + int16x8_t v1096 = vsubq_s16(v1023, v1025); + int16x8_t v1097 = vqrdmulhq_n_s16(v1096, 25826); + int16x8_t v1098 = vaddq_s16(v1095, v1097); + int16x8_t v1099 = vqrdmulhq_n_s16(v1098, 18124); + int16x8_t v1100 = vaddq_s16(v1094, v1099); + int16x8_t v1101 = vqrdmulhq_n_s16(v1100, 16792); + int16x8_t v1102 = vaddq_s16(v1090, v1101); + int16x8_t v1103 = vsubq_s16(v1033, v1035); + int16x8_t v1104 = vsubq_s16(v1037, v1039); + int16x8_t v1105 = vqrdmulhq_n_s16(v1104, 25826); + int16x8_t v1106 = vaddq_s16(v1103, v1105); + int16x8_t v1107 = vsubq_s16(v1043, v1045); + int16x8_t v1108 = vsubq_s16(v1047, v1049); + int16x8_t v1109 = vqrdmulhq_n_s16(v1108, 25826); + int16x8_t v1110 = vaddq_s16(v1107, v1109); + int16x8_t v1111 = vqrdmulhq_n_s16(v1110, 18124); + int16x8_t v1112 = vaddq_s16(v1106, v1111); + int16x8_t v1113 = vsubq_s16(v1055, v1057); + int16x8_t v1114 = vsubq_s16(v1059, v1061); + int16x8_t v1115 = vqrdmulhq_n_s16(v1114, 25826); + int16x8_t v1116 = vaddq_s16(v1113, v1115); + int16x8_t v1117 = vsubq_s16(v1065, v1067); + int16x8_t v1118 = vsubq_s16(v1069, v1071); + int16x8_t v1119 = vqrdmulhq_n_s16(v1118, 25826); + int16x8_t v1120 = vaddq_s16(v1117, v1119); + int16x8_t v1121 = vqrdmulhq_n_s16(v1120, 18124); + int16x8_t v1122 = vaddq_s16(v1116, v1121); + int16x8_t v1123 = vqrdmulhq_n_s16(v1122, 16792); + int16x8_t v1124 = vaddq_s16(v1112, v1123); + int16x8_t v1125 = vqrdmulhq_n_s16(v1124, 16484); + int16x8_t v1126 = vaddq_s16(v1102, v1125); + int16x8_t v1127 = vsubq_s16(v892, v894); + int16x8_t v1128 = vsubq_s16(v896, v898); + int16x8_t v1129_tmp = vqrdmulhq_n_s16(v1128, 1988); + int16x8_t v1129 = vaddq_s16(v1129_tmp, v1128); + int16x8_t v1130 = vaddq_s16(v1127, v1129); + int16x8_t v1131 = vsubq_s16(v902, v904); + int16x8_t v1132 = vsubq_s16(v906, v908); + int16x8_t v1133_tmp = vqrdmulhq_n_s16(v1132, 1988); + int16x8_t v1133 = vaddq_s16(v1133_tmp, v1132); + int16x8_t v1134 = vaddq_s16(v1131, v1133); + int16x8_t v1135 = vqrdmulhq_n_s16(v1134, 19102); + int16x8_t v1136 = vaddq_s16(v1130, v1135); + int16x8_t v1137 = vsubq_s16(v914, v917); + int16x8_t v1138 = vsubq_s16(v919, v921); + int16x8_t v1139_tmp = vqrdmulhq_n_s16(v1138, 1988); + int16x8_t v1139 = vaddq_s16(v1139_tmp, v1138); + int16x8_t v1140 = vaddq_s16(v1137, v1139); + int16x8_t v1141 = vsubq_s16(v925, v927); + int16x8_t v1142 = vsubq_s16(v929, v931); + int16x8_t v1143_tmp = vqrdmulhq_n_s16(v1142, 1988); + int16x8_t v1143 = vaddq_s16(v1143_tmp, v1142); + int16x8_t v1144 = vaddq_s16(v1141, v1143); + int16x8_t v1145 = vqrdmulhq_n_s16(v1144, 19102); + int16x8_t v1146 = vaddq_s16(v1140, v1145); + int16x8_t v1147 = vqrdmulhq_n_s16(v1146, 17000); + int16x8_t v1148 = vaddq_s16(v1136, v1147); + int16x8_t v1149 = vsubq_s16(v939, v941); + int16x8_t v1150 = vsubq_s16(v943, v945); + int16x8_t v1151_tmp = vqrdmulhq_n_s16(v1150, 1988); + int16x8_t v1151 = vaddq_s16(v1151_tmp, v1150); + int16x8_t v1152 = vaddq_s16(v1149, v1151); + int16x8_t v1153 = vsubq_s16(v949, v951); + int16x8_t v1154 = vsubq_s16(v953, v955); + int16x8_t v1155_tmp = vqrdmulhq_n_s16(v1154, 1988); + int16x8_t v1155 = vaddq_s16(v1155_tmp, v1154); + int16x8_t v1156 = vaddq_s16(v1153, v1155); + int16x8_t v1157 = vqrdmulhq_n_s16(v1156, 19102); + int16x8_t v1158 = vaddq_s16(v1152, v1157); + int16x8_t v1159 = vsubq_s16(v961, v963); + int16x8_t v1160 = vsubq_s16(v965, v967); + int16x8_t v1161_tmp = vqrdmulhq_n_s16(v1160, 1988); + int16x8_t v1161 = vaddq_s16(v1161_tmp, v1160); + int16x8_t v1162 = vaddq_s16(v1159, v1161); + int16x8_t v1163 = vsubq_s16(v971, v973); + int16x8_t v1164 = vsubq_s16(v975, v977); + int16x8_t v1165_tmp = vqrdmulhq_n_s16(v1164, 1988); + int16x8_t v1165 = vaddq_s16(v1165_tmp, v1164); + int16x8_t v1166 = vaddq_s16(v1163, v1165); + int16x8_t v1167 = vqrdmulhq_n_s16(v1166, 19102); + int16x8_t v1168 = vaddq_s16(v1162, v1167); + int16x8_t v1169 = vqrdmulhq_n_s16(v1168, 17000); + int16x8_t v1170 = vaddq_s16(v1158, v1169); + int16x8_t v1171 = vqrdmulhq_n_s16(v1170, 16534); + int16x8_t v1172 = vaddq_s16(v1148, v1171); + int16x8_t v1173 = vsubq_s16(v705, v710); + int16x8_t v1174 = vsubq_s16(v715, v720); + int16x8_t v1175_tmp = vqrdmulhq_n_s16(v1174, 23673); + int16x8_t v1175 = vaddq_s16(v1175_tmp, v1174); + int16x8_t v1176 = vaddq_s16(v1173, v1175); + int16x8_t v1177 = vsubq_s16(v727, v732); + int16x8_t v1178 = vsubq_s16(v737, v742); + int16x8_t v1179_tmp = vqrdmulhq_n_s16(v1178, 23673); + int16x8_t v1179 = vaddq_s16(v1179_tmp, v1178); + int16x8_t v1180 = vaddq_s16(v1177, v1179); + int16x8_t v1181 = vqrdmulhq_n_s16(v1180, 20398); + int16x8_t v1182 = vaddq_s16(v1176, v1181); + int16x8_t v1183 = vsubq_s16(v751, v756); + int16x8_t v1184 = vsubq_s16(v761, v766); + int16x8_t v1185_tmp = vqrdmulhq_n_s16(v1184, 23673); + int16x8_t v1185 = vaddq_s16(v1185_tmp, v1184); + int16x8_t v1186 = vaddq_s16(v1183, v1185); + int16x8_t v1187 = vsubq_s16(v773, v778); + int16x8_t v1188 = vsubq_s16(v783, v788); + int16x8_t v1189_tmp = vqrdmulhq_n_s16(v1188, 23673); + int16x8_t v1189 = vaddq_s16(v1189_tmp, v1188); + int16x8_t v1190 = vaddq_s16(v1187, v1189); + int16x8_t v1191 = vqrdmulhq_n_s16(v1190, 20398); + int16x8_t v1192 = vaddq_s16(v1186, v1191); + int16x8_t v1193 = vqrdmulhq_n_s16(v1192, 17255); + int16x8_t v1194 = vaddq_s16(v1182, v1193); + int16x8_t v1195 = vsubq_s16(v799, v804); + int16x8_t v1196 = vsubq_s16(v809, v814); + int16x8_t v1197_tmp = vqrdmulhq_n_s16(v1196, 23673); + int16x8_t v1197 = vaddq_s16(v1197_tmp, v1196); + int16x8_t v1198 = vaddq_s16(v1195, v1197); + int16x8_t v1199 = vsubq_s16(v821, v826); + int16x8_t v1200 = vsubq_s16(v831, v836); + int16x8_t v1201_tmp = vqrdmulhq_n_s16(v1200, 23673); + int16x8_t v1201 = vaddq_s16(v1201_tmp, v1200); + int16x8_t v1202 = vaddq_s16(v1199, v1201); + int16x8_t v1203 = vqrdmulhq_n_s16(v1202, 20398); + int16x8_t v1204 = vaddq_s16(v1198, v1203); + int16x8_t v1205 = vsubq_s16(v845, v850); + int16x8_t v1206 = vsubq_s16(v855, v860); + int16x8_t v1207_tmp = vqrdmulhq_n_s16(v1206, 23673); + int16x8_t v1207 = vaddq_s16(v1207_tmp, v1206); + int16x8_t v1208 = vaddq_s16(v1205, v1207); + int16x8_t v1209 = vsubq_s16(v867, v872); + int16x8_t v1210 = vsubq_s16(v877, v882); + int16x8_t v1211_tmp = vqrdmulhq_n_s16(v1210, 23673); + int16x8_t v1211 = vaddq_s16(v1211_tmp, v1210); + int16x8_t v1212 = vaddq_s16(v1209, v1211); + int16x8_t v1213 = vqrdmulhq_n_s16(v1212, 20398); + int16x8_t v1214 = vaddq_s16(v1208, v1213); + int16x8_t v1215 = vqrdmulhq_n_s16(v1214, 17255); + int16x8_t v1216 = vaddq_s16(v1204, v1215); + int16x8_t v1217 = vqrdmulhq_n_s16(v1216, 16595); + int16x8_t v1218 = vaddq_s16(v1194, v1217); + int16x8_t v1219 = vsubq_s16(v9, v24); + int16x8_t v1220 = vsubq_s16(v42, v58); + int16x8_t v1221_tmp = vqrdmulhq_n_s16(v1220, 3314); + int16x8_t v1221 = vmlaq_n_s16(v1221_tmp, v1220, 5); + int16x8_t v1222 = vaddq_s16(v1219, v1221); + int16x8_t v1223 = vsubq_s16(v78, v101); + int16x8_t v1224 = vsubq_s16(v119, v136); + int16x8_t v1225_tmp = vqrdmulhq_n_s16(v1224, 3314); + int16x8_t v1225 = vmlaq_n_s16(v1225_tmp, v1224, 5); + int16x8_t v1226 = vaddq_s16(v1223, v1225); + int16x8_t v1227 = vqrdmulhq_n_s16(v1226, 22112); + int16x8_t v1228 = vaddq_s16(v1222, v1227); + int16x8_t v1229 = vsubq_s16(v158, v181); + int16x8_t v1230 = vsubq_s16(v213, v231); + int16x8_t v1231_tmp = vqrdmulhq_n_s16(v1230, 3314); + int16x8_t v1231 = vmlaq_n_s16(v1231_tmp, v1230, 5); + int16x8_t v1232 = vaddq_s16(v1229, v1231); + int16x8_t v1233 = vsubq_s16(v251, v274); + int16x8_t v1234 = vsubq_s16(v292, v310); + int16x8_t v1235_tmp = vqrdmulhq_n_s16(v1234, 3314); + int16x8_t v1235 = vmlaq_n_s16(v1235_tmp, v1234, 5); + int16x8_t v1236 = vaddq_s16(v1233, v1235); + int16x8_t v1237 = vqrdmulhq_n_s16(v1236, 22112); + int16x8_t v1238 = vaddq_s16(v1232, v1237); + int16x8_t v1239 = vqrdmulhq_n_s16(v1238, 17561); + int16x8_t v1240 = vaddq_s16(v1228, v1239); + int16x8_t v1241 = vsubq_s16(v334, v357); + int16x8_t v1242 = vsubq_s16(v389, v407); + int16x8_t v1243_tmp = vqrdmulhq_n_s16(v1242, 3314); + int16x8_t v1243 = vmlaq_n_s16(v1243_tmp, v1242, 5); + int16x8_t v1244 = vaddq_s16(v1241, v1243); + int16x8_t v1245 = vsubq_s16(v441, v480); + int16x8_t v1246 = vsubq_s16(v498, v517); + int16x8_t v1247_tmp = vqrdmulhq_n_s16(v1246, 3314); + int16x8_t v1247 = vmlaq_n_s16(v1247_tmp, v1246, 5); + int16x8_t v1248 = vaddq_s16(v1245, v1247); + int16x8_t v1249 = vqrdmulhq_n_s16(v1248, 22112); + int16x8_t v1250 = vaddq_s16(v1244, v1249); + int16x8_t v1251 = vsubq_s16(v539, v562); + int16x8_t v1252 = vsubq_s16(v594, v612); + int16x8_t v1253_tmp = vqrdmulhq_n_s16(v1252, 3314); + int16x8_t v1253 = vmlaq_n_s16(v1253_tmp, v1252, 5); + int16x8_t v1254 = vaddq_s16(v1251, v1253); + int16x8_t v1255 = vsubq_s16(v632, v655); + int16x8_t v1256 = vsubq_s16(v673, v692); + int16x8_t v1257_tmp = vqrdmulhq_n_s16(v1256, 3314); + int16x8_t v1257 = vmlaq_n_s16(v1257_tmp, v1256, 5); + int16x8_t v1258 = vaddq_s16(v1255, v1257); + int16x8_t v1259 = vqrdmulhq_n_s16(v1258, 22112); + int16x8_t v1260 = vaddq_s16(v1254, v1259); + int16x8_t v1261 = vqrdmulhq_n_s16(v1260, 17561); + int16x8_t v1262 = vaddq_s16(v1250, v1261); + int16x8_t v1263 = vqrdmulhq_n_s16(v1262, 16666); + int16x8_t v1264 = vaddq_s16(v1240, v1263); + int16x8_t v1265 = vsubq_s16(v1219, v1221); + int16x8_t v1266 = vsubq_s16(v1223, v1225); + int16x8_t v1267 = vqrdmulhq_n_s16(v1266, 24397); + int16x8_t v1268 = vaddq_s16(v1265, v1267); + int16x8_t v1269 = vsubq_s16(v1229, v1231); + int16x8_t v1270 = vsubq_s16(v1233, v1235); + int16x8_t v1271 = vqrdmulhq_n_s16(v1270, 24397); + int16x8_t v1272 = vaddq_s16(v1269, v1271); + int16x8_t v1273 = vqrdmulhq_n_s16(v1272, 17921); + int16x8_t v1274 = vaddq_s16(v1268, v1273); + int16x8_t v1275 = vsubq_s16(v1241, v1243); + int16x8_t v1276 = vsubq_s16(v1245, v1247); + int16x8_t v1277 = vqrdmulhq_n_s16(v1276, 24397); + int16x8_t v1278 = vaddq_s16(v1275, v1277); + int16x8_t v1279 = vsubq_s16(v1251, v1253); + int16x8_t v1280 = vsubq_s16(v1255, v1257); + int16x8_t v1281 = vqrdmulhq_n_s16(v1280, 24397); + int16x8_t v1282 = vaddq_s16(v1279, v1281); + int16x8_t v1283 = vqrdmulhq_n_s16(v1282, 17921); + int16x8_t v1284 = vaddq_s16(v1278, v1283); + int16x8_t v1285 = vqrdmulhq_n_s16(v1284, 16747); + int16x8_t v1286 = vaddq_s16(v1274, v1285); + int16x8_t v1287 = vsubq_s16(v1173, v1175); + int16x8_t v1288 = vsubq_s16(v1177, v1179); + int16x8_t v1289 = vqrdmulhq_n_s16(v1288, 27504); + int16x8_t v1290 = vaddq_s16(v1287, v1289); + int16x8_t v1291 = vsubq_s16(v1183, v1185); + int16x8_t v1292 = vsubq_s16(v1187, v1189); + int16x8_t v1293 = vqrdmulhq_n_s16(v1292, 27504); + int16x8_t v1294 = vaddq_s16(v1291, v1293); + int16x8_t v1295 = vqrdmulhq_n_s16(v1294, 18343); + int16x8_t v1296 = vaddq_s16(v1290, v1295); + int16x8_t v1297 = vsubq_s16(v1195, v1197); + int16x8_t v1298 = vsubq_s16(v1199, v1201); + int16x8_t v1299 = vqrdmulhq_n_s16(v1298, 27504); + int16x8_t v1300 = vaddq_s16(v1297, v1299); + int16x8_t v1301 = vsubq_s16(v1205, v1207); + int16x8_t v1302 = vsubq_s16(v1209, v1211); + int16x8_t v1303 = vqrdmulhq_n_s16(v1302, 27504); + int16x8_t v1304 = vaddq_s16(v1301, v1303); + int16x8_t v1305 = vqrdmulhq_n_s16(v1304, 18343); + int16x8_t v1306 = vaddq_s16(v1300, v1305); + int16x8_t v1307 = vqrdmulhq_n_s16(v1306, 16840); + int16x8_t v1308 = vaddq_s16(v1296, v1307); + int16x8_t v1309 = vsubq_s16(v1127, v1129); + int16x8_t v1310 = vsubq_s16(v1131, v1133); + int16x8_t v1311 = vqrdmulhq_n_s16(v1310, 31869); + int16x8_t v1312 = vaddq_s16(v1309, v1311); + int16x8_t v1313 = vsubq_s16(v1137, v1139); + int16x8_t v1314 = vsubq_s16(v1141, v1143); + int16x8_t v1315 = vqrdmulhq_n_s16(v1314, 31869); + int16x8_t v1316 = vaddq_s16(v1313, v1315); + int16x8_t v1317 = vqrdmulhq_n_s16(v1316, 18830); + int16x8_t v1318 = vaddq_s16(v1312, v1317); + int16x8_t v1319 = vsubq_s16(v1149, v1151); + int16x8_t v1320 = vsubq_s16(v1153, v1155); + int16x8_t v1321 = vqrdmulhq_n_s16(v1320, 31869); + int16x8_t v1322 = vaddq_s16(v1319, v1321); + int16x8_t v1323 = vsubq_s16(v1159, v1161); + int16x8_t v1324 = vsubq_s16(v1163, v1165); + int16x8_t v1325 = vqrdmulhq_n_s16(v1324, 31869); + int16x8_t v1326 = vaddq_s16(v1323, v1325); + int16x8_t v1327 = vqrdmulhq_n_s16(v1326, 18830); + int16x8_t v1328 = vaddq_s16(v1322, v1327); + int16x8_t v1329 = vqrdmulhq_n_s16(v1328, 16944); + int16x8_t v1330 = vaddq_s16(v1318, v1329); + int16x8_t v1331 = vsubq_s16(v1081, v1083); + int16x8_t v1332 = vsubq_s16(v1085, v1087); + int16x8_t v1333_tmp = vqrdmulhq_n_s16(v1332, 5552); + int16x8_t v1333 = vaddq_s16(v1333_tmp, v1332); + int16x8_t v1334 = vaddq_s16(v1331, v1333); + int16x8_t v1335 = vsubq_s16(v1091, v1093); + int16x8_t v1336 = vsubq_s16(v1095, v1097); + int16x8_t v1337_tmp = vqrdmulhq_n_s16(v1336, 5552); + int16x8_t v1337 = vaddq_s16(v1337_tmp, v1336); + int16x8_t v1338 = vaddq_s16(v1335, v1337); + int16x8_t v1339 = vqrdmulhq_n_s16(v1338, 19393); + int16x8_t v1340 = vaddq_s16(v1334, v1339); + int16x8_t v1341 = vsubq_s16(v1103, v1105); + int16x8_t v1342 = vsubq_s16(v1107, v1109); + int16x8_t v1343_tmp = vqrdmulhq_n_s16(v1342, 5552); + int16x8_t v1343 = vaddq_s16(v1343_tmp, v1342); + int16x8_t v1344 = vaddq_s16(v1341, v1343); + int16x8_t v1345 = vsubq_s16(v1113, v1115); + int16x8_t v1346 = vsubq_s16(v1117, v1119); + int16x8_t v1347_tmp = vqrdmulhq_n_s16(v1346, 5552); + int16x8_t v1347 = vaddq_s16(v1347_tmp, v1346); + int16x8_t v1348 = vaddq_s16(v1345, v1347); + int16x8_t v1349 = vqrdmulhq_n_s16(v1348, 19393); + int16x8_t v1350 = vaddq_s16(v1344, v1349); + int16x8_t v1351 = vqrdmulhq_n_s16(v1350, 17059); + int16x8_t v1352 = vaddq_s16(v1340, v1351); + int16x8_t v1353 = vsubq_s16(v990, v995); + int16x8_t v1354 = vsubq_s16(v1000, v1005); + int16x8_t v1355_tmp = vqrdmulhq_n_s16(v1354, 15865); + int16x8_t v1355 = vaddq_s16(v1355_tmp, v1354); + int16x8_t v1356 = vaddq_s16(v1353, v1355); + int16x8_t v1357 = vsubq_s16(v1012, v1017); + int16x8_t v1358 = vsubq_s16(v1022, v1027); + int16x8_t v1359_tmp = vqrdmulhq_n_s16(v1358, 15865); + int16x8_t v1359 = vaddq_s16(v1359_tmp, v1358); + int16x8_t v1360 = vaddq_s16(v1357, v1359); + int16x8_t v1361 = vqrdmulhq_n_s16(v1360, 20040); + int16x8_t v1362 = vaddq_s16(v1356, v1361); + int16x8_t v1363 = vsubq_s16(v1036, v1041); + int16x8_t v1364 = vsubq_s16(v1046, v1051); + int16x8_t v1365_tmp = vqrdmulhq_n_s16(v1364, 15865); + int16x8_t v1365 = vaddq_s16(v1365_tmp, v1364); + int16x8_t v1366 = vaddq_s16(v1363, v1365); + int16x8_t v1367 = vsubq_s16(v1058, v1063); + int16x8_t v1368 = vsubq_s16(v1068, v1073); + int16x8_t v1369_tmp = vqrdmulhq_n_s16(v1368, 15865); + int16x8_t v1369 = vaddq_s16(v1369_tmp, v1368); + int16x8_t v1370 = vaddq_s16(v1367, v1369); + int16x8_t v1371 = vqrdmulhq_n_s16(v1370, 20040); + int16x8_t v1372 = vaddq_s16(v1366, v1371); + int16x8_t v1373 = vqrdmulhq_n_s16(v1372, 17187); + int16x8_t v1374 = vaddq_s16(v1362, v1373); + int16x8_t v1375 = vsubq_s16(v895, v900); + int16x8_t v1376 = vsubq_s16(v905, v910); + int16x8_t v1377_tmp = vqrdmulhq_n_s16(v1376, 1893); + int16x8_t v1377 = vmlaq_n_s16(v1377_tmp, v1376, 2); + int16x8_t v1378 = vaddq_s16(v1375, v1377); + int16x8_t v1379 = vsubq_s16(v918, v923); + int16x8_t v1380 = vsubq_s16(v928, v933); + int16x8_t v1381_tmp = vqrdmulhq_n_s16(v1380, 1893); + int16x8_t v1381 = vmlaq_n_s16(v1381_tmp, v1380, 2); + int16x8_t v1382 = vaddq_s16(v1379, v1381); + int16x8_t v1383 = vqrdmulhq_n_s16(v1382, 20783); + int16x8_t v1384 = vaddq_s16(v1378, v1383); + int16x8_t v1385 = vsubq_s16(v942, v947); + int16x8_t v1386 = vsubq_s16(v952, v957); + int16x8_t v1387_tmp = vqrdmulhq_n_s16(v1386, 1893); + int16x8_t v1387 = vmlaq_n_s16(v1387_tmp, v1386, 2); + int16x8_t v1388 = vaddq_s16(v1385, v1387); + int16x8_t v1389 = vsubq_s16(v964, v969); + int16x8_t v1390 = vsubq_s16(v974, v979); + int16x8_t v1391_tmp = vqrdmulhq_n_s16(v1390, 1893); + int16x8_t v1391 = vmlaq_n_s16(v1391_tmp, v1390, 2); + int16x8_t v1392 = vaddq_s16(v1389, v1391); + int16x8_t v1393 = vqrdmulhq_n_s16(v1392, 20783); + int16x8_t v1394 = vaddq_s16(v1388, v1393); + int16x8_t v1395 = vqrdmulhq_n_s16(v1394, 17326); + int16x8_t v1396 = vaddq_s16(v1384, v1395); + int16x8_t v1397 = vsubq_s16(v711, v722); + int16x8_t v1398 = vsubq_s16(v733, v744); + int16x8_t v1399_tmp = vqrdmulhq_n_s16(v1398, 13357); + int16x8_t v1399 = vmlaq_n_s16(v1399_tmp, v1398, 3); + int16x8_t v1400 = vaddq_s16(v1397, v1399); + int16x8_t v1401 = vsubq_s16(v757, v768); + int16x8_t v1402 = vsubq_s16(v779, v790); + int16x8_t v1403_tmp = vqrdmulhq_n_s16(v1402, 13357); + int16x8_t v1403 = vmlaq_n_s16(v1403_tmp, v1402, 3); + int16x8_t v1404 = vaddq_s16(v1401, v1403); + int16x8_t v1405 = vqrdmulhq_n_s16(v1404, 21637); + int16x8_t v1406 = vaddq_s16(v1400, v1405); + int16x8_t v1407 = vsubq_s16(v805, v816); + int16x8_t v1408 = vsubq_s16(v827, v838); + int16x8_t v1409_tmp = vqrdmulhq_n_s16(v1408, 13357); + int16x8_t v1409 = vmlaq_n_s16(v1409_tmp, v1408, 3); + int16x8_t v1410 = vaddq_s16(v1407, v1409); + int16x8_t v1411 = vsubq_s16(v851, v862); + int16x8_t v1412 = vsubq_s16(v873, v884); + int16x8_t v1413_tmp = vqrdmulhq_n_s16(v1412, 13357); + int16x8_t v1413 = vmlaq_n_s16(v1413_tmp, v1412, 3); + int16x8_t v1414 = vaddq_s16(v1411, v1413); + int16x8_t v1415 = vqrdmulhq_n_s16(v1414, 21637); + int16x8_t v1416 = vaddq_s16(v1410, v1415); + int16x8_t v1417 = vqrdmulhq_n_s16(v1416, 17479); + int16x8_t v1418 = vaddq_s16(v1406, v1417); + int16x8_t v1419 = vsubq_s16(v25, v60); + int16x8_t v1420 = vsubq_s16(v102, v138); + int16x8_t v1421_tmp = vqrdmulhq_n_s16(v1420, 6226); + int16x8_t v1421 = vmlaq_n_s16(v1421_tmp, v1420, 10); + int16x8_t v1422 = vaddq_s16(v1419, v1421); + int16x8_t v1423 = vsubq_s16(v182, v233); + int16x8_t v1424 = vsubq_s16(v275, v312); + int16x8_t v1425_tmp = vqrdmulhq_n_s16(v1424, 6226); + int16x8_t v1425 = vmlaq_n_s16(v1425_tmp, v1424, 10); + int16x8_t v1426 = vaddq_s16(v1423, v1425); + int16x8_t v1427 = vqrdmulhq_n_s16(v1426, 22622); + int16x8_t v1428 = vaddq_s16(v1422, v1427); + int16x8_t v1429 = vsubq_s16(v358, v409); + int16x8_t v1430 = vsubq_s16(v481, v519); + int16x8_t v1431_tmp = vqrdmulhq_n_s16(v1430, 6226); + int16x8_t v1431 = vmlaq_n_s16(v1431_tmp, v1430, 10); + int16x8_t v1432 = vaddq_s16(v1429, v1431); + int16x8_t v1433 = vsubq_s16(v563, v614); + int16x8_t v1434 = vsubq_s16(v656, v694); + int16x8_t v1435_tmp = vqrdmulhq_n_s16(v1434, 6226); + int16x8_t v1435 = vmlaq_n_s16(v1435_tmp, v1434, 10); + int16x8_t v1436 = vaddq_s16(v1433, v1435); + int16x8_t v1437 = vqrdmulhq_n_s16(v1436, 22622); + int16x8_t v1438 = vaddq_s16(v1432, v1437); + int16x8_t v1439 = vqrdmulhq_n_s16(v1438, 17646); + int16x8_t v1440 = vaddq_s16(v1428, v1439); + int16x8_t v1441 = vsubq_s16(v1419, v1421); + int16x8_t v1442 = vsubq_s16(v1423, v1425); + int16x8_t v1443 = vqrdmulhq_n_s16(v1442, 23761); + int16x8_t v1444 = vaddq_s16(v1441, v1443); + int16x8_t v1445 = vsubq_s16(v1429, v1431); + int16x8_t v1446 = vsubq_s16(v1433, v1435); + int16x8_t v1447 = vqrdmulhq_n_s16(v1446, 23761); + int16x8_t v1448 = vaddq_s16(v1445, v1447); + int16x8_t v1449 = vqrdmulhq_n_s16(v1448, 17826); + int16x8_t v1450 = vaddq_s16(v1444, v1449); + int16x8_t v1451 = vsubq_s16(v1397, v1399); + int16x8_t v1452 = vsubq_s16(v1401, v1403); + int16x8_t v1453 = vqrdmulhq_n_s16(v1452, 25084); + int16x8_t v1454 = vaddq_s16(v1451, v1453); + int16x8_t v1455 = vsubq_s16(v1407, v1409); + int16x8_t v1456 = vsubq_s16(v1411, v1413); + int16x8_t v1457 = vqrdmulhq_n_s16(v1456, 25084); + int16x8_t v1458 = vaddq_s16(v1455, v1457); + int16x8_t v1459 = vqrdmulhq_n_s16(v1458, 18021); + int16x8_t v1460 = vaddq_s16(v1454, v1459); + int16x8_t v1461 = vsubq_s16(v1375, v1377); + int16x8_t v1462 = vsubq_s16(v1379, v1381); + int16x8_t v1463 = vqrdmulhq_n_s16(v1462, 26631); + int16x8_t v1464 = vaddq_s16(v1461, v1463); + int16x8_t v1465 = vsubq_s16(v1385, v1387); + int16x8_t v1466 = vsubq_s16(v1389, v1391); + int16x8_t v1467 = vqrdmulhq_n_s16(v1466, 26631); + int16x8_t v1468 = vaddq_s16(v1465, v1467); + int16x8_t v1469 = vqrdmulhq_n_s16(v1468, 18231); + int16x8_t v1470 = vaddq_s16(v1464, v1469); + int16x8_t v1471 = vsubq_s16(v1353, v1355); + int16x8_t v1472 = vsubq_s16(v1357, v1359); + int16x8_t v1473 = vqrdmulhq_n_s16(v1472, 28454); + int16x8_t v1474 = vaddq_s16(v1471, v1473); + int16x8_t v1475 = vsubq_s16(v1363, v1365); + int16x8_t v1476 = vsubq_s16(v1367, v1369); + int16x8_t v1477 = vqrdmulhq_n_s16(v1476, 28454); + int16x8_t v1478 = vaddq_s16(v1475, v1477); + int16x8_t v1479 = vqrdmulhq_n_s16(v1478, 18458); + int16x8_t v1480 = vaddq_s16(v1474, v1479); + int16x8_t v1481 = vsubq_s16(v1331, v1333); + int16x8_t v1482 = vsubq_s16(v1335, v1337); + int16x8_t v1483 = vqrdmulhq_n_s16(v1482, 30624); + int16x8_t v1484 = vaddq_s16(v1481, v1483); + int16x8_t v1485 = vsubq_s16(v1341, v1343); + int16x8_t v1486 = vsubq_s16(v1345, v1347); + int16x8_t v1487 = vqrdmulhq_n_s16(v1486, 30624); + int16x8_t v1488 = vaddq_s16(v1485, v1487); + int16x8_t v1489 = vqrdmulhq_n_s16(v1488, 18702); + int16x8_t v1490 = vaddq_s16(v1484, v1489); + int16x8_t v1491 = vsubq_s16(v1309, v1311); + int16x8_t v1492 = vsubq_s16(v1313, v1315); + int16x8_t v1493_tmp = vqrdmulhq_n_s16(v1492, 472); + int16x8_t v1493 = vaddq_s16(v1493_tmp, v1492); + int16x8_t v1494 = vaddq_s16(v1491, v1493); + int16x8_t v1495 = vsubq_s16(v1319, v1321); + int16x8_t v1496 = vsubq_s16(v1323, v1325); + int16x8_t v1497_tmp = vqrdmulhq_n_s16(v1496, 472); + int16x8_t v1497 = vaddq_s16(v1497_tmp, v1496); + int16x8_t v1498 = vaddq_s16(v1495, v1497); + int16x8_t v1499 = vqrdmulhq_n_s16(v1498, 18964); + int16x8_t v1500 = vaddq_s16(v1494, v1499); + int16x8_t v1501 = vsubq_s16(v1287, v1289); + int16x8_t v1502 = vsubq_s16(v1291, v1293); + int16x8_t v1503_tmp = vqrdmulhq_n_s16(v1502, 3672); + int16x8_t v1503 = vaddq_s16(v1503_tmp, v1502); + int16x8_t v1504 = vaddq_s16(v1501, v1503); + int16x8_t v1505 = vsubq_s16(v1297, v1299); + int16x8_t v1506 = vsubq_s16(v1301, v1303); + int16x8_t v1507_tmp = vqrdmulhq_n_s16(v1506, 3672); + int16x8_t v1507 = vaddq_s16(v1507_tmp, v1506); + int16x8_t v1508 = vaddq_s16(v1505, v1507); + int16x8_t v1509 = vqrdmulhq_n_s16(v1508, 19245); + int16x8_t v1510 = vaddq_s16(v1504, v1509); + int16x8_t v1511 = vsubq_s16(v1265, v1267); + int16x8_t v1512 = vsubq_s16(v1269, v1271); + int16x8_t v1513_tmp = vqrdmulhq_n_s16(v1512, 7662); + int16x8_t v1513 = vaddq_s16(v1513_tmp, v1512); + int16x8_t v1514 = vaddq_s16(v1511, v1513); + int16x8_t v1515 = vsubq_s16(v1275, v1277); + int16x8_t v1516 = vsubq_s16(v1279, v1281); + int16x8_t v1517_tmp = vqrdmulhq_n_s16(v1516, 7662); + int16x8_t v1517 = vaddq_s16(v1517_tmp, v1516); + int16x8_t v1518 = vaddq_s16(v1515, v1517); + int16x8_t v1519 = vqrdmulhq_n_s16(v1518, 19546); + int16x8_t v1520 = vaddq_s16(v1514, v1519); + int16x8_t v1521 = vsubq_s16(v1222, v1227); + int16x8_t v1522 = vsubq_s16(v1232, v1237); + int16x8_t v1523_tmp = vqrdmulhq_n_s16(v1522, 12756); + int16x8_t v1523 = vaddq_s16(v1523_tmp, v1522); + int16x8_t v1524 = vaddq_s16(v1521, v1523); + int16x8_t v1525 = vsubq_s16(v1244, v1249); + int16x8_t v1526 = vsubq_s16(v1254, v1259); + int16x8_t v1527_tmp = vqrdmulhq_n_s16(v1526, 12756); + int16x8_t v1527 = vaddq_s16(v1527_tmp, v1526); + int16x8_t v1528 = vaddq_s16(v1525, v1527); + int16x8_t v1529 = vqrdmulhq_n_s16(v1528, 19869); + int16x8_t v1530 = vaddq_s16(v1524, v1529); + int16x8_t v1531 = vsubq_s16(v1176, v1181); + int16x8_t v1532 = vsubq_s16(v1186, v1191); + int16x8_t v1533_tmp = vqrdmulhq_n_s16(v1532, 19463); + int16x8_t v1533 = vaddq_s16(v1533_tmp, v1532); + int16x8_t v1534 = vaddq_s16(v1531, v1533); + int16x8_t v1535 = vsubq_s16(v1198, v1203); + int16x8_t v1536 = vsubq_s16(v1208, v1213); + int16x8_t v1537_tmp = vqrdmulhq_n_s16(v1536, 19463); + int16x8_t v1537 = vaddq_s16(v1537_tmp, v1536); + int16x8_t v1538 = vaddq_s16(v1535, v1537); + int16x8_t v1539 = vqrdmulhq_n_s16(v1538, 20216); + int16x8_t v1540 = vaddq_s16(v1534, v1539); + int16x8_t v1541 = vsubq_s16(v1130, v1135); + int16x8_t v1542 = vsubq_s16(v1140, v1145); + int16x8_t v1543_tmp = vqrdmulhq_n_s16(v1542, 28661); + int16x8_t v1543 = vaddq_s16(v1543_tmp, v1542); + int16x8_t v1544 = vaddq_s16(v1541, v1543); + int16x8_t v1545 = vsubq_s16(v1152, v1157); + int16x8_t v1546 = vsubq_s16(v1162, v1167); + int16x8_t v1547_tmp = vqrdmulhq_n_s16(v1546, 28661); + int16x8_t v1547 = vaddq_s16(v1547_tmp, v1546); + int16x8_t v1548 = vaddq_s16(v1545, v1547); + int16x8_t v1549 = vqrdmulhq_n_s16(v1548, 20587); + int16x8_t v1550 = vaddq_s16(v1544, v1549); + int16x8_t v1551 = vsubq_s16(v1084, v1089); + int16x8_t v1552 = vsubq_s16(v1094, v1099); + int16x8_t v1553_tmp = vqrdmulhq_n_s16(v1552, 9242); + int16x8_t v1553 = vmlaq_n_s16(v1553_tmp, v1552, 2); + int16x8_t v1554 = vaddq_s16(v1551, v1553); + int16x8_t v1555 = vsubq_s16(v1106, v1111); + int16x8_t v1556 = vsubq_s16(v1116, v1121); + int16x8_t v1557_tmp = vqrdmulhq_n_s16(v1556, 9242); + int16x8_t v1557 = vmlaq_n_s16(v1557_tmp, v1556, 2); + int16x8_t v1558 = vaddq_s16(v1555, v1557); + int16x8_t v1559 = vqrdmulhq_n_s16(v1558, 20985); + int16x8_t v1560 = vaddq_s16(v1554, v1559); + int16x8_t v1561 = vsubq_s16(v996, v1007); + int16x8_t v1562 = vsubq_s16(v1018, v1029); + int16x8_t v1563_tmp = vqrdmulhq_n_s16(v1562, 30298); + int16x8_t v1563 = vmlaq_n_s16(v1563_tmp, v1562, 2); + int16x8_t v1564 = vaddq_s16(v1561, v1563); + int16x8_t v1565 = vsubq_s16(v1042, v1053); + int16x8_t v1566 = vsubq_s16(v1064, v1075); + int16x8_t v1567_tmp = vqrdmulhq_n_s16(v1566, 30298); + int16x8_t v1567 = vmlaq_n_s16(v1567_tmp, v1566, 2); + int16x8_t v1568 = vaddq_s16(v1565, v1567); + int16x8_t v1569 = vqrdmulhq_n_s16(v1568, 21412); + int16x8_t v1570 = vaddq_s16(v1564, v1569); + int16x8_t v1571 = vsubq_s16(v901, v912); + int16x8_t v1572 = vsubq_s16(v924, v935); + int16x8_t v1573_tmp = vqrdmulhq_n_s16(v1572, 2773); + int16x8_t v1573 = vmlaq_n_s16(v1573_tmp, v1572, 4); + int16x8_t v1574 = vaddq_s16(v1571, v1573); + int16x8_t v1575 = vsubq_s16(v948, v959); + int16x8_t v1576 = vsubq_s16(v970, v981); + int16x8_t v1577_tmp = vqrdmulhq_n_s16(v1576, 2773); + int16x8_t v1577 = vmlaq_n_s16(v1577_tmp, v1576, 4); + int16x8_t v1578 = vaddq_s16(v1575, v1577); + int16x8_t v1579 = vqrdmulhq_n_s16(v1578, 21871); + int16x8_t v1580 = vaddq_s16(v1574, v1579); + int16x8_t v1581 = vsubq_s16(v723, v746); + int16x8_t v1582 = vsubq_s16(v769, v792); + int16x8_t v1583_tmp = vqrdmulhq_n_s16(v1582, 26108); + int16x8_t v1583 = vmlaq_n_s16(v1583_tmp, v1582, 6); + int16x8_t v1584 = vaddq_s16(v1581, v1583); + int16x8_t v1585 = vsubq_s16(v817, v840); + int16x8_t v1586 = vsubq_s16(v863, v886); + int16x8_t v1587_tmp = vqrdmulhq_n_s16(v1586, 26108); + int16x8_t v1587 = vmlaq_n_s16(v1587_tmp, v1586, 6); + int16x8_t v1588 = vaddq_s16(v1585, v1587); + int16x8_t v1589 = vqrdmulhq_n_s16(v1588, 22363); + int16x8_t v1590 = vaddq_s16(v1584, v1589); + int16x8_t v1591 = vsubq_s16(v61, v140); + int16x8_t v1592 = vsubq_s16(v234, v314); + int16x8_t v1593_tmp = vqrdmulhq_n_s16(v1592, 12251); + int16x8_t v1593 = vmlaq_n_s16(v1593_tmp, v1592, 20); + int16x8_t v1594 = vaddq_s16(v1591, v1593); + int16x8_t v1595 = vsubq_s16(v410, v521); + int16x8_t v1596 = vsubq_s16(v615, v696); + int16x8_t v1597_tmp = vqrdmulhq_n_s16(v1596, 12251); + int16x8_t v1597 = vmlaq_n_s16(v1597_tmp, v1596, 20); + int16x8_t v1598 = vaddq_s16(v1595, v1597); + int16x8_t v1599 = vqrdmulhq_n_s16(v1598, 22891); + int16x8_t v1600 = vaddq_s16(v1594, v1599); + int16x8_t v1601 = vsubq_s16(v1591, v1593); + int16x8_t v1602 = vsubq_s16(v1595, v1597); + int16x8_t v1603 = vqrdmulhq_n_s16(v1602, 23460); + int16x8_t v1604 = vaddq_s16(v1601, v1603); + int16x8_t v1605 = vsubq_s16(v1581, v1583); + int16x8_t v1606 = vsubq_s16(v1585, v1587); + int16x8_t v1607 = vqrdmulhq_n_s16(v1606, 24073); + int16x8_t v1608 = vaddq_s16(v1605, v1607); + int16x8_t v1609 = vsubq_s16(v1571, v1573); + int16x8_t v1610 = vsubq_s16(v1575, v1577); + int16x8_t v1611 = vqrdmulhq_n_s16(v1610, 24734); + int16x8_t v1612 = vaddq_s16(v1609, v1611); + int16x8_t v1613 = vsubq_s16(v1561, v1563); + int16x8_t v1614 = vsubq_s16(v1565, v1567); + int16x8_t v1615 = vqrdmulhq_n_s16(v1614, 25448); + int16x8_t v1616 = vaddq_s16(v1613, v1615); + int16x8_t v1617 = vsubq_s16(v1551, v1553); + int16x8_t v1618 = vsubq_s16(v1555, v1557); + int16x8_t v1619 = vqrdmulhq_n_s16(v1618, 26220); + int16x8_t v1620 = vaddq_s16(v1617, v1619); + int16x8_t v1621 = vsubq_s16(v1541, v1543); + int16x8_t v1622 = vsubq_s16(v1545, v1547); + int16x8_t v1623 = vqrdmulhq_n_s16(v1622, 27058); + int16x8_t v1624 = vaddq_s16(v1621, v1623); + int16x8_t v1625 = vsubq_s16(v1531, v1533); + int16x8_t v1626 = vsubq_s16(v1535, v1537); + int16x8_t v1627 = vqrdmulhq_n_s16(v1626, 27969); + int16x8_t v1628 = vaddq_s16(v1625, v1627); + int16x8_t v1629 = vsubq_s16(v1521, v1523); + int16x8_t v1630 = vsubq_s16(v1525, v1527); + int16x8_t v1631 = vqrdmulhq_n_s16(v1630, 28961); + int16x8_t v1632 = vaddq_s16(v1629, v1631); + int16x8_t v1633 = vsubq_s16(v1511, v1513); + int16x8_t v1634 = vsubq_s16(v1515, v1517); + int16x8_t v1635 = vqrdmulhq_n_s16(v1634, 30044); + int16x8_t v1636 = vaddq_s16(v1633, v1635); + int16x8_t v1637 = vsubq_s16(v1501, v1503); + int16x8_t v1638 = vsubq_s16(v1505, v1507); + int16x8_t v1639 = vqrdmulhq_n_s16(v1638, 31232); + int16x8_t v1640 = vaddq_s16(v1637, v1639); + int16x8_t v1641 = vsubq_s16(v1491, v1493); + int16x8_t v1642 = vsubq_s16(v1495, v1497); + int16x8_t v1643 = vqrdmulhq_n_s16(v1642, 32538); + int16x8_t v1644 = vaddq_s16(v1641, v1643); + int16x8_t v1645 = vsubq_s16(v1481, v1483); + int16x8_t v1646 = vsubq_s16(v1485, v1487); + int16x8_t v1647_tmp = vqrdmulhq_n_s16(v1646, 1211); + int16x8_t v1647 = vaddq_s16(v1647_tmp, v1646); + int16x8_t v1648 = vaddq_s16(v1645, v1647); + int16x8_t v1649 = vsubq_s16(v1471, v1473); + int16x8_t v1650 = vsubq_s16(v1475, v1477); + int16x8_t v1651_tmp = vqrdmulhq_n_s16(v1650, 2808); + int16x8_t v1651 = vaddq_s16(v1651_tmp, v1650); + int16x8_t v1652 = vaddq_s16(v1649, v1651); + int16x8_t v1653 = vsubq_s16(v1461, v1463); + int16x8_t v1654 = vsubq_s16(v1465, v1467); + int16x8_t v1655_tmp = vqrdmulhq_n_s16(v1654, 4586); + int16x8_t v1655 = vaddq_s16(v1655_tmp, v1654); + int16x8_t v1656 = vaddq_s16(v1653, v1655); + int16x8_t v1657 = vsubq_s16(v1451, v1453); + int16x8_t v1658 = vsubq_s16(v1455, v1457); + int16x8_t v1659_tmp = vqrdmulhq_n_s16(v1658, 6576); + int16x8_t v1659 = vaddq_s16(v1659_tmp, v1658); + int16x8_t v1660 = vaddq_s16(v1657, v1659); + int16x8_t v1661 = vsubq_s16(v1441, v1443); + int16x8_t v1662 = vsubq_s16(v1445, v1447); + int16x8_t v1663_tmp = vqrdmulhq_n_s16(v1662, 8817); + int16x8_t v1663 = vaddq_s16(v1663_tmp, v1662); + int16x8_t v1664 = vaddq_s16(v1661, v1663); + int16x8_t v1665 = vsubq_s16(v1422, v1427); + int16x8_t v1666 = vsubq_s16(v1432, v1437); + int16x8_t v1667_tmp = vqrdmulhq_n_s16(v1666, 11356); + int16x8_t v1667 = vaddq_s16(v1667_tmp, v1666); + int16x8_t v1668 = vaddq_s16(v1665, v1667); + int16x8_t v1669 = vsubq_s16(v1400, v1405); + int16x8_t v1670 = vsubq_s16(v1410, v1415); + int16x8_t v1671_tmp = vqrdmulhq_n_s16(v1670, 14256); + int16x8_t v1671 = vaddq_s16(v1671_tmp, v1670); + int16x8_t v1672 = vaddq_s16(v1669, v1671); + int16x8_t v1673 = vsubq_s16(v1378, v1383); + int16x8_t v1674 = vsubq_s16(v1388, v1393); + int16x8_t v1675_tmp = vqrdmulhq_n_s16(v1674, 17596); + int16x8_t v1675 = vaddq_s16(v1675_tmp, v1674); + int16x8_t v1676 = vaddq_s16(v1673, v1675); + int16x8_t v1677 = vsubq_s16(v1356, v1361); + int16x8_t v1678 = vsubq_s16(v1366, v1371); + int16x8_t v1679_tmp = vqrdmulhq_n_s16(v1678, 21483); + int16x8_t v1679 = vaddq_s16(v1679_tmp, v1678); + int16x8_t v1680 = vaddq_s16(v1677, v1679); + int16x8_t v1681 = vsubq_s16(v1334, v1339); + int16x8_t v1682 = vsubq_s16(v1344, v1349); + int16x8_t v1683_tmp = vqrdmulhq_n_s16(v1682, 26057); + int16x8_t v1683 = vaddq_s16(v1683_tmp, v1682); + int16x8_t v1684 = vaddq_s16(v1681, v1683); + int16x8_t v1685 = vsubq_s16(v1312, v1317); + int16x8_t v1686 = vsubq_s16(v1322, v1327); + int16x8_t v1687_tmp = vqrdmulhq_n_s16(v1686, 31517); + int16x8_t v1687 = vaddq_s16(v1687_tmp, v1686); + int16x8_t v1688 = vaddq_s16(v1685, v1687); + int16x8_t v1689 = vsubq_s16(v1290, v1295); + int16x8_t v1690 = vsubq_s16(v1300, v1305); + int16x8_t v1691_tmp = vqrdmulhq_n_s16(v1690, 5373); + int16x8_t v1691 = vmlaq_n_s16(v1691_tmp, v1690, 2); + int16x8_t v1692 = vaddq_s16(v1689, v1691); + int16x8_t v1693 = vsubq_s16(v1268, v1273); + int16x8_t v1694 = vsubq_s16(v1278, v1283); + int16x8_t v1695_tmp = vqrdmulhq_n_s16(v1694, 13571); + int16x8_t v1695 = vmlaq_n_s16(v1695_tmp, v1694, 2); + int16x8_t v1696 = vaddq_s16(v1693, v1695); + int16x8_t v1697 = vsubq_s16(v1228, v1239); + int16x8_t v1698 = vsubq_s16(v1250, v1261); + int16x8_t v1699_tmp = vqrdmulhq_n_s16(v1698, 23975); + int16x8_t v1699 = vmlaq_n_s16(v1699_tmp, v1698, 2); + int16x8_t v1700 = vaddq_s16(v1697, v1699); + int16x8_t v1701 = vsubq_s16(v1182, v1193); + int16x8_t v1702 = vsubq_s16(v1204, v1215); + int16x8_t v1703_tmp = vqrdmulhq_n_s16(v1702, 4832); + int16x8_t v1703 = vmlaq_n_s16(v1703_tmp, v1702, 3); + int16x8_t v1704 = vaddq_s16(v1701, v1703); + int16x8_t v1705 = vsubq_s16(v1136, v1147); + int16x8_t v1706 = vsubq_s16(v1158, v1169); + int16x8_t v1707_tmp = vqrdmulhq_n_s16(v1706, 23437); + int16x8_t v1707 = vmlaq_n_s16(v1707_tmp, v1706, 3); + int16x8_t v1708 = vaddq_s16(v1705, v1707); + int16x8_t v1709 = vsubq_s16(v1090, v1101); + int16x8_t v1710 = vsubq_s16(v1112, v1123); + int16x8_t v1711_tmp = vqrdmulhq_n_s16(v1710, 17573); + int16x8_t v1711 = vmlaq_n_s16(v1711_tmp, v1710, 4); + int16x8_t v1712 = vaddq_s16(v1709, v1711); + int16x8_t v1713 = vsubq_s16(v1008, v1031); + int16x8_t v1714 = vsubq_s16(v1054, v1077); + int16x8_t v1715_tmp = vqrdmulhq_n_s16(v1714, 27122); + int16x8_t v1715 = vmlaq_n_s16(v1715_tmp, v1714, 5); + int16x8_t v1716 = vaddq_s16(v1713, v1715); + int16x8_t v1717 = vsubq_s16(v913, v937); + int16x8_t v1718 = vsubq_s16(v960, v983); + int16x8_t v1719_tmp = vqrdmulhq_n_s16(v1718, 5041); + int16x8_t v1719 = vmlaq_n_s16(v1719_tmp, v1718, 8); + int16x8_t v1720 = vaddq_s16(v1717, v1719); + int16x8_t v1721 = vsubq_s16(v747, v794); + int16x8_t v1722 = vsubq_s16(v841, v888); + int16x8_t v1723_tmp = vqrdmulhq_n_s16(v1722, 19146); + int16x8_t v1723 = vmlaq_n_s16(v1723_tmp, v1722, 13); + int16x8_t v1724 = vaddq_s16(v1721, v1723); + int16x8_t v1725 = vsubq_s16(v141, v316); + int16x8_t v1726 = vsubq_s16(v522, v698); + int16x8_t v1727_tmp = vqrdmulhq_n_s16(v1726, 24402); + int16x8_t v1727 = vmlaq_n_s16(v1727_tmp, v1726, 40); + int16x8_t v1728 = vaddq_s16(v1725, v1727); + int16x8_t v1729 = vsubq_s16(v1725, v1727); + int16x8_t v1730 = vsubq_s16(v1721, v1723); + int16x8_t v1731 = vsubq_s16(v1717, v1719); + int16x8_t v1732 = vsubq_s16(v1713, v1715); + int16x8_t v1733 = vsubq_s16(v1709, v1711); + int16x8_t v1734 = vsubq_s16(v1705, v1707); + int16x8_t v1735 = vsubq_s16(v1701, v1703); + int16x8_t v1736 = vsubq_s16(v1697, v1699); + int16x8_t v1737 = vsubq_s16(v1693, v1695); + int16x8_t v1738 = vsubq_s16(v1689, v1691); + int16x8_t v1739 = vsubq_s16(v1685, v1687); + int16x8_t v1740 = vsubq_s16(v1681, v1683); + int16x8_t v1741 = vsubq_s16(v1677, v1679); + int16x8_t v1742 = vsubq_s16(v1673, v1675); + int16x8_t v1743 = vsubq_s16(v1669, v1671); + int16x8_t v1744 = vsubq_s16(v1665, v1667); + int16x8_t v1745 = vsubq_s16(v1661, v1663); + int16x8_t v1746 = vsubq_s16(v1657, v1659); + int16x8_t v1747 = vsubq_s16(v1653, v1655); + int16x8_t v1748 = vsubq_s16(v1649, v1651); + int16x8_t v1749 = vsubq_s16(v1645, v1647); + int16x8_t v1750 = vsubq_s16(v1641, v1643); + int16x8_t v1751 = vsubq_s16(v1637, v1639); + int16x8_t v1752 = vsubq_s16(v1633, v1635); + int16x8_t v1753 = vsubq_s16(v1629, v1631); + int16x8_t v1754 = vsubq_s16(v1625, v1627); + int16x8_t v1755 = vsubq_s16(v1621, v1623); + int16x8_t v1756 = vsubq_s16(v1617, v1619); + int16x8_t v1757 = vsubq_s16(v1613, v1615); + int16x8_t v1758 = vsubq_s16(v1609, v1611); + int16x8_t v1759 = vsubq_s16(v1605, v1607); + int16x8_t v1760 = vsubq_s16(v1601, v1603); + int16x8_t v1761 = vsubq_s16(v1594, v1599); + int16x8_t v1762 = vsubq_s16(v1584, v1589); + int16x8_t v1763 = vsubq_s16(v1574, v1579); + int16x8_t v1764 = vsubq_s16(v1564, v1569); + int16x8_t v1765 = vsubq_s16(v1554, v1559); + int16x8_t v1766 = vsubq_s16(v1544, v1549); + int16x8_t v1767 = vsubq_s16(v1534, v1539); + int16x8_t v1768 = vsubq_s16(v1524, v1529); + int16x8_t v1769 = vsubq_s16(v1514, v1519); + int16x8_t v1770 = vsubq_s16(v1504, v1509); + int16x8_t v1771 = vsubq_s16(v1494, v1499); + int16x8_t v1772 = vsubq_s16(v1484, v1489); + int16x8_t v1773 = vsubq_s16(v1474, v1479); + int16x8_t v1774 = vsubq_s16(v1464, v1469); + int16x8_t v1775 = vsubq_s16(v1454, v1459); + int16x8_t v1776 = vsubq_s16(v1444, v1449); + int16x8_t v1777 = vsubq_s16(v1428, v1439); + int16x8_t v1778 = vsubq_s16(v1406, v1417); + int16x8_t v1779 = vsubq_s16(v1384, v1395); + int16x8_t v1780 = vsubq_s16(v1362, v1373); + int16x8_t v1781 = vsubq_s16(v1340, v1351); + int16x8_t v1782 = vsubq_s16(v1318, v1329); + int16x8_t v1783 = vsubq_s16(v1296, v1307); + int16x8_t v1784 = vsubq_s16(v1274, v1285); + int16x8_t v1785 = vsubq_s16(v1240, v1263); + int16x8_t v1786 = vsubq_s16(v1194, v1217); + int16x8_t v1787 = vsubq_s16(v1148, v1171); + int16x8_t v1788 = vsubq_s16(v1102, v1125); + int16x8_t v1789 = vsubq_s16(v1032, v1079); + int16x8_t v1790 = vsubq_s16(v938, v985); + int16x8_t v1791 = vsubq_s16(v795, v890); + int16x8_t v1792 = vsubq_s16(v317, v700); + vst1q_s16(out + out_stride * 0 + i, v701); + vst1q_s16(out + out_stride * 1 + i, v891); + vst1q_s16(out + out_stride * 2 + i, v986); + vst1q_s16(out + out_stride * 3 + i, v1080); + vst1q_s16(out + out_stride * 4 + i, v1126); + vst1q_s16(out + out_stride * 5 + i, v1172); + vst1q_s16(out + out_stride * 6 + i, v1218); + vst1q_s16(out + out_stride * 7 + i, v1264); + vst1q_s16(out + out_stride * 8 + i, v1286); + vst1q_s16(out + out_stride * 9 + i, v1308); + vst1q_s16(out + out_stride * 10 + i, v1330); + vst1q_s16(out + out_stride * 11 + i, v1352); + vst1q_s16(out + out_stride * 12 + i, v1374); + vst1q_s16(out + out_stride * 13 + i, v1396); + vst1q_s16(out + out_stride * 14 + i, v1418); + vst1q_s16(out + out_stride * 15 + i, v1440); + vst1q_s16(out + out_stride * 16 + i, v1450); + vst1q_s16(out + out_stride * 17 + i, v1460); + vst1q_s16(out + out_stride * 18 + i, v1470); + vst1q_s16(out + out_stride * 19 + i, v1480); + vst1q_s16(out + out_stride * 20 + i, v1490); + vst1q_s16(out + out_stride * 21 + i, v1500); + vst1q_s16(out + out_stride * 22 + i, v1510); + vst1q_s16(out + out_stride * 23 + i, v1520); + vst1q_s16(out + out_stride * 24 + i, v1530); + vst1q_s16(out + out_stride * 25 + i, v1540); + vst1q_s16(out + out_stride * 26 + i, v1550); + vst1q_s16(out + out_stride * 27 + i, v1560); + vst1q_s16(out + out_stride * 28 + i, v1570); + vst1q_s16(out + out_stride * 29 + i, v1580); + vst1q_s16(out + out_stride * 30 + i, v1590); + vst1q_s16(out + out_stride * 31 + i, v1600); + vst1q_s16(out + out_stride * 32 + i, v1604); + vst1q_s16(out + out_stride * 33 + i, v1608); + vst1q_s16(out + out_stride * 34 + i, v1612); + vst1q_s16(out + out_stride * 35 + i, v1616); + vst1q_s16(out + out_stride * 36 + i, v1620); + vst1q_s16(out + out_stride * 37 + i, v1624); + vst1q_s16(out + out_stride * 38 + i, v1628); + vst1q_s16(out + out_stride * 39 + i, v1632); + vst1q_s16(out + out_stride * 40 + i, v1636); + vst1q_s16(out + out_stride * 41 + i, v1640); + vst1q_s16(out + out_stride * 42 + i, v1644); + vst1q_s16(out + out_stride * 43 + i, v1648); + vst1q_s16(out + out_stride * 44 + i, v1652); + vst1q_s16(out + out_stride * 45 + i, v1656); + vst1q_s16(out + out_stride * 46 + i, v1660); + vst1q_s16(out + out_stride * 47 + i, v1664); + vst1q_s16(out + out_stride * 48 + i, v1668); + vst1q_s16(out + out_stride * 49 + i, v1672); + vst1q_s16(out + out_stride * 50 + i, v1676); + vst1q_s16(out + out_stride * 51 + i, v1680); + vst1q_s16(out + out_stride * 52 + i, v1684); + vst1q_s16(out + out_stride * 53 + i, v1688); + vst1q_s16(out + out_stride * 54 + i, v1692); + vst1q_s16(out + out_stride * 55 + i, v1696); + vst1q_s16(out + out_stride * 56 + i, v1700); + vst1q_s16(out + out_stride * 57 + i, v1704); + vst1q_s16(out + out_stride * 58 + i, v1708); + vst1q_s16(out + out_stride * 59 + i, v1712); + vst1q_s16(out + out_stride * 60 + i, v1716); + vst1q_s16(out + out_stride * 61 + i, v1720); + vst1q_s16(out + out_stride * 62 + i, v1724); + vst1q_s16(out + out_stride * 63 + i, v1728); + vst1q_s16(out + out_stride * 64 + i, v1729); + vst1q_s16(out + out_stride * 65 + i, v1730); + vst1q_s16(out + out_stride * 66 + i, v1731); + vst1q_s16(out + out_stride * 67 + i, v1732); + vst1q_s16(out + out_stride * 68 + i, v1733); + vst1q_s16(out + out_stride * 69 + i, v1734); + vst1q_s16(out + out_stride * 70 + i, v1735); + vst1q_s16(out + out_stride * 71 + i, v1736); + vst1q_s16(out + out_stride * 72 + i, v1737); + vst1q_s16(out + out_stride * 73 + i, v1738); + vst1q_s16(out + out_stride * 74 + i, v1739); + vst1q_s16(out + out_stride * 75 + i, v1740); + vst1q_s16(out + out_stride * 76 + i, v1741); + vst1q_s16(out + out_stride * 77 + i, v1742); + vst1q_s16(out + out_stride * 78 + i, v1743); + vst1q_s16(out + out_stride * 79 + i, v1744); + vst1q_s16(out + out_stride * 80 + i, v1745); + vst1q_s16(out + out_stride * 81 + i, v1746); + vst1q_s16(out + out_stride * 82 + i, v1747); + vst1q_s16(out + out_stride * 83 + i, v1748); + vst1q_s16(out + out_stride * 84 + i, v1749); + vst1q_s16(out + out_stride * 85 + i, v1750); + vst1q_s16(out + out_stride * 86 + i, v1751); + vst1q_s16(out + out_stride * 87 + i, v1752); + vst1q_s16(out + out_stride * 88 + i, v1753); + vst1q_s16(out + out_stride * 89 + i, v1754); + vst1q_s16(out + out_stride * 90 + i, v1755); + vst1q_s16(out + out_stride * 91 + i, v1756); + vst1q_s16(out + out_stride * 92 + i, v1757); + vst1q_s16(out + out_stride * 93 + i, v1758); + vst1q_s16(out + out_stride * 94 + i, v1759); + vst1q_s16(out + out_stride * 95 + i, v1760); + vst1q_s16(out + out_stride * 96 + i, v1761); + vst1q_s16(out + out_stride * 97 + i, v1762); + vst1q_s16(out + out_stride * 98 + i, v1763); + vst1q_s16(out + out_stride * 99 + i, v1764); + vst1q_s16(out + out_stride * 100 + i, v1765); + vst1q_s16(out + out_stride * 101 + i, v1766); + vst1q_s16(out + out_stride * 102 + i, v1767); + vst1q_s16(out + out_stride * 103 + i, v1768); + vst1q_s16(out + out_stride * 104 + i, v1769); + vst1q_s16(out + out_stride * 105 + i, v1770); + vst1q_s16(out + out_stride * 106 + i, v1771); + vst1q_s16(out + out_stride * 107 + i, v1772); + vst1q_s16(out + out_stride * 108 + i, v1773); + vst1q_s16(out + out_stride * 109 + i, v1774); + vst1q_s16(out + out_stride * 110 + i, v1775); + vst1q_s16(out + out_stride * 111 + i, v1776); + vst1q_s16(out + out_stride * 112 + i, v1777); + vst1q_s16(out + out_stride * 113 + i, v1778); + vst1q_s16(out + out_stride * 114 + i, v1779); + vst1q_s16(out + out_stride * 115 + i, v1780); + vst1q_s16(out + out_stride * 116 + i, v1781); + vst1q_s16(out + out_stride * 117 + i, v1782); + vst1q_s16(out + out_stride * 118 + i, v1783); + vst1q_s16(out + out_stride * 119 + i, v1784); + vst1q_s16(out + out_stride * 120 + i, v1785); + vst1q_s16(out + out_stride * 121 + i, v1786); + vst1q_s16(out + out_stride * 122 + i, v1787); + vst1q_s16(out + out_stride * 123 + i, v1788); + vst1q_s16(out + out_stride * 124 + i, v1789); + vst1q_s16(out + out_stride * 125 + i, v1790); + vst1q_s16(out + out_stride * 126 + i, v1791); + vst1q_s16(out + out_stride * 127 + i, v1792); + } +} diff --git a/media/libjxl/src/lib/jxl/fast_dct16-inl.h b/media/libjxl/src/lib/jxl/fast_dct16-inl.h new file mode 100644 index 000000000..472ec20d4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct16-inl.h @@ -0,0 +1,180 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<16>) { return 1; } + +void FastIDCT(FastDCTTag<16>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17 = vqrdmulhq_n_s16(v16, 25080); + int16x8_t v18 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v16, v19); + int16x8_t v21 = vqrdmulhq_n_s16(v20, 17734); + int16x8_t v22 = vaddq_s16(v17, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v27 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v28 = vaddq_s16(v26, v27); + int16x8_t v29 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v30 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v31 = vaddq_s16(v29, v30); + int16x8_t v32 = vaddq_s16(v28, v31); + int16x8_t v33 = vqrdmulhq_n_s16(v32, 17734); + int16x8_t v34 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v35 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v36 = vaddq_s16(v34, v35); + int16x8_t v37 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v38 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v39 = vaddq_s16(v37, v38); + int16x8_t v40 = vaddq_s16(v36, v39); + int16x8_t v41_tmp = vqrdmulhq_n_s16(v40, 10045); + int16x8_t v41 = vaddq_s16(v41_tmp, v40); + int16x8_t v42 = vaddq_s16(v33, v41); + int16x8_t v43 = vqrdmulhq_n_s16(v42, 16705); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v36, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v36); + int16x8_t v45 = vaddq_s16(v39, v31); + int16x8_t v46 = vaddq_s16(v44, v45); + int16x8_t v47 = vqrdmulhq_n_s16(v46, 16705); + int16x8_t v48 = vaddq_s16(v43, v47); + int16x8_t v49_tmp = vqrdmulhq_n_s16(v35, 13573); + int16x8_t v49 = vaddq_s16(v49_tmp, v35); + int16x8_t v50 = vaddq_s16(v30, v37); + int16x8_t v51 = vaddq_s16(v49, v50); + int16x8_t v52 = vaddq_s16(v38, v34); + int16x8_t v53 = vaddq_s16(v27, v29); + int16x8_t v54 = vaddq_s16(v52, v53); + int16x8_t v55 = vqrdmulhq_n_s16(v54, 17734); + int16x8_t v56 = vqrdmulhq_n_s16(v52, 25080); + int16x8_t v57 = vaddq_s16(v55, v56); + int16x8_t v58 = vaddq_s16(v51, v57); + int16x8_t v59 = vaddq_s16(v48, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vsubq_s16(v0, v1); + int16x8_t v63 = vsubq_s16(v4, v6); + int16x8_t v64_tmp = vqrdmulhq_n_s16(v63, 10045); + int16x8_t v64 = vaddq_s16(v64_tmp, v63); + int16x8_t v65 = vaddq_s16(v62, v64); + int16x8_t v66 = vsubq_s16(v11, v14); + int16x8_t v67 = vqrdmulhq_n_s16(v16, 17734); + int16x8_t v68_tmp = vqrdmulhq_n_s16(v19, 10045); + int16x8_t v68 = vaddq_s16(v68_tmp, v19); + int16x8_t v69 = vsubq_s16(v67, v68); + int16x8_t v70 = vaddq_s16(v66, v69); + int16x8_t v71 = vqrdmulhq_n_s16(v70, 19705); + int16x8_t v72 = vaddq_s16(v65, v71); + int16x8_t v73 = vsubq_s16(v49, v50); + int16x8_t v74 = vqrdmulhq_n_s16(v52, 17734); + int16x8_t v75_tmp = vqrdmulhq_n_s16(v53, 10045); + int16x8_t v75 = vaddq_s16(v75_tmp, v53); + int16x8_t v76 = vsubq_s16(v74, v75); + int16x8_t v77 = vaddq_s16(v73, v76); + int16x8_t v78 = vsubq_s16(v44, v45); + int16x8_t v79 = vqrdmulhq_n_s16(v78, 19705); + int16x8_t v80 = vqrdmulhq_n_s16(v40, 13573); + int16x8_t v81 = vsubq_s16(v80, v32); + int16x8_t v82 = vqrdmulhq_n_s16(v81, 25746); + int16x8_t v83 = vaddq_s16(v79, v82); + int16x8_t v84 = vaddq_s16(v77, v83); + int16x8_t v85 = vqrdmulhq_n_s16(v84, 17121); + int16x8_t v86 = vaddq_s16(v72, v85); + int16x8_t v87 = vsubq_s16(v62, v64); + int16x8_t v88 = vsubq_s16(v66, v69); + int16x8_t v89 = vqrdmulhq_n_s16(v88, 29490); + int16x8_t v90 = vaddq_s16(v87, v89); + int16x8_t v91 = vsubq_s16(v73, v76); + int16x8_t v92 = vqrdmulhq_n_s16(v78, 29490); + int16x8_t v93_tmp = vqrdmulhq_n_s16(v81, 5763); + int16x8_t v93 = vaddq_s16(v93_tmp, v81); + int16x8_t v94 = vsubq_s16(v92, v93); + int16x8_t v95 = vaddq_s16(v91, v94); + int16x8_t v96 = vqrdmulhq_n_s16(v95, 18578); + int16x8_t v97 = vaddq_s16(v90, v96); + int16x8_t v98 = vsubq_s16(v46, v42); + int16x8_t v99_tmp = vqrdmulhq_n_s16(v98, 18446); + int16x8_t v99 = vmlaq_n_s16(v99_tmp, v98, 2); + int16x8_t v100 = vsubq_s16(v51, v57); + int16x8_t v101 = vaddq_s16(v99, v100); + int16x8_t v102 = vqrdmulhq_n_s16(v101, 21195); + int16x8_t v103 = vsubq_s16(v2, v8); + int16x8_t v104 = vsubq_s16(v15, v22); + int16x8_t v105_tmp = vqrdmulhq_n_s16(v104, 18446); + int16x8_t v105 = vmlaq_n_s16(v105_tmp, v104, 2); + int16x8_t v106 = vaddq_s16(v103, v105); + int16x8_t v107 = vaddq_s16(v102, v106); + int16x8_t v108 = vsubq_s16(v103, v105); + int16x8_t v109 = vsubq_s16(v100, v99); + int16x8_t v110 = vqrdmulhq_n_s16(v109, 25826); + int16x8_t v111 = vaddq_s16(v108, v110); + int16x8_t v112 = vsubq_s16(v87, v89); + int16x8_t v113 = vsubq_s16(v91, v94); + int16x8_t v114_tmp = vqrdmulhq_n_s16(v113, 1988); + int16x8_t v114 = vaddq_s16(v114_tmp, v113); + int16x8_t v115 = vaddq_s16(v112, v114); + int16x8_t v116 = vsubq_s16(v65, v71); + int16x8_t v117 = vsubq_s16(v77, v83); + int16x8_t v118_tmp = vqrdmulhq_n_s16(v117, 23673); + int16x8_t v118 = vaddq_s16(v118_tmp, v117); + int16x8_t v119 = vaddq_s16(v116, v118); + int16x8_t v120 = vsubq_s16(v58, v48); + int16x8_t v121_tmp = vqrdmulhq_n_s16(v120, 3314); + int16x8_t v121 = vmlaq_n_s16(v121_tmp, v120, 5); + int16x8_t v122 = vsubq_s16(v9, v24); + int16x8_t v123 = vaddq_s16(v121, v122); + int16x8_t v124 = vsubq_s16(v122, v121); + int16x8_t v125 = vsubq_s16(v116, v118); + int16x8_t v126 = vsubq_s16(v112, v114); + int16x8_t v127 = vsubq_s16(v108, v110); + int16x8_t v128 = vsubq_s16(v106, v102); + int16x8_t v129 = vsubq_s16(v90, v96); + int16x8_t v130 = vsubq_s16(v72, v85); + int16x8_t v131 = vsubq_s16(v25, v60); + vst1q_s16(out + out_stride * 0 + i, v61); + vst1q_s16(out + out_stride * 1 + i, v86); + vst1q_s16(out + out_stride * 2 + i, v97); + vst1q_s16(out + out_stride * 3 + i, v107); + vst1q_s16(out + out_stride * 4 + i, v111); + vst1q_s16(out + out_stride * 5 + i, v115); + vst1q_s16(out + out_stride * 6 + i, v119); + vst1q_s16(out + out_stride * 7 + i, v123); + vst1q_s16(out + out_stride * 8 + i, v124); + vst1q_s16(out + out_stride * 9 + i, v125); + vst1q_s16(out + out_stride * 10 + i, v126); + vst1q_s16(out + out_stride * 11 + i, v127); + vst1q_s16(out + out_stride * 12 + i, v128); + vst1q_s16(out + out_stride * 13 + i, v129); + vst1q_s16(out + out_stride * 14 + i, v130); + vst1q_s16(out + out_stride * 15 + i, v131); + } +} diff --git a/media/libjxl/src/lib/jxl/fast_dct256-inl.h b/media/libjxl/src/lib/jxl/fast_dct256-inl.h new file mode 100644 index 000000000..a823440af --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct256-inl.h @@ -0,0 +1,4811 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<256>) { return 3; } + +void FastIDCT(FastDCTTag<256>, const int16_t* in, size_t in_stride, + int16_t* out, size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 128 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 64 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 192 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 32 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 160 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 96 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17_tmp = vqrdmulhq_n_s16(v16, 13573); + int16x8_t v17 = vaddq_s16(v17_tmp, v16); + int16x8_t v18 = vld1q_s16(in + in_stride * 224 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v19, v16); + int16x8_t v21 = vaddq_s16(v17, v20); + int16x8_t v22 = vqrdmulhq_n_s16(v21, 17734); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 144 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 112 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 80 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 48 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35_tmp = vqrdmulhq_n_s16(v34, 13573); + int16x8_t v35 = vaddq_s16(v35_tmp, v34); + int16x8_t v36 = vld1q_s16(in + in_stride * 208 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 176 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vaddq_s16(v35, v39); + int16x8_t v41 = vqrdmulhq_n_s16(v40, 17734); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v37, v28); + int16x8_t v46 = vaddq_s16(v29, v32); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vaddq_s16(v46, v43); + int16x8_t v50_tmp = vqrdmulhq_n_s16(v49, 13573); + int16x8_t v50 = vaddq_s16(v50_tmp, v49); + int16x8_t v51 = vld1q_s16(in + in_stride * 240 + i); + int16x8_t v52 = vaddq_s16(v51, v36); + int16x8_t v53 = vaddq_s16(v52, v45); + int16x8_t v54 = vaddq_s16(v53, v49); + int16x8_t v55 = vaddq_s16(v50, v54); + int16x8_t v56 = vqrdmulhq_n_s16(v55, 17734); + int16x8_t v57 = vaddq_s16(v48, v56); + int16x8_t v58 = vqrdmulhq_n_s16(v57, 16705); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v63_tmp = vqrdmulhq_n_s16(v62, 13573); + int16x8_t v63 = vaddq_s16(v63_tmp, v62); + int16x8_t v64 = vld1q_s16(in + in_stride * 136 + i); + int16x8_t v65 = vld1q_s16(in + in_stride * 120 + i); + int16x8_t v66 = vaddq_s16(v64, v65); + int16x8_t v67 = vaddq_s16(v63, v66); + int16x8_t v68 = vld1q_s16(in + in_stride * 72 + i); + int16x8_t v69 = vld1q_s16(in + in_stride * 56 + i); + int16x8_t v70 = vaddq_s16(v68, v69); + int16x8_t v71_tmp = vqrdmulhq_n_s16(v70, 13573); + int16x8_t v71 = vaddq_s16(v71_tmp, v70); + int16x8_t v72 = vld1q_s16(in + in_stride * 200 + i); + int16x8_t v73 = vld1q_s16(in + in_stride * 184 + i); + int16x8_t v74 = vaddq_s16(v72, v73); + int16x8_t v75 = vaddq_s16(v74, v70); + int16x8_t v76 = vaddq_s16(v71, v75); + int16x8_t v77 = vqrdmulhq_n_s16(v76, 17734); + int16x8_t v78 = vaddq_s16(v67, v77); + int16x8_t v79 = vld1q_s16(in + in_stride * 40 + i); + int16x8_t v80 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v81 = vaddq_s16(v79, v80); + int16x8_t v82_tmp = vqrdmulhq_n_s16(v81, 13573); + int16x8_t v82 = vaddq_s16(v82_tmp, v81); + int16x8_t v83 = vld1q_s16(in + in_stride * 168 + i); + int16x8_t v84 = vld1q_s16(in + in_stride * 152 + i); + int16x8_t v85 = vaddq_s16(v83, v84); + int16x8_t v86 = vld1q_s16(in + in_stride * 104 + i); + int16x8_t v87 = vld1q_s16(in + in_stride * 88 + i); + int16x8_t v88 = vaddq_s16(v86, v87); + int16x8_t v89 = vaddq_s16(v85, v88); + int16x8_t v90 = vaddq_s16(v82, v89); + int16x8_t v91 = vaddq_s16(v88, v81); + int16x8_t v92_tmp = vqrdmulhq_n_s16(v91, 13573); + int16x8_t v92 = vaddq_s16(v92_tmp, v91); + int16x8_t v93 = vld1q_s16(in + in_stride * 232 + i); + int16x8_t v94 = vld1q_s16(in + in_stride * 216 + i); + int16x8_t v95 = vaddq_s16(v93, v94); + int16x8_t v96 = vaddq_s16(v95, v85); + int16x8_t v97 = vaddq_s16(v96, v91); + int16x8_t v98 = vaddq_s16(v92, v97); + int16x8_t v99 = vqrdmulhq_n_s16(v98, 17734); + int16x8_t v100 = vaddq_s16(v90, v99); + int16x8_t v101 = vqrdmulhq_n_s16(v100, 16705); + int16x8_t v102 = vaddq_s16(v78, v101); + int16x8_t v103 = vaddq_s16(v80, v62); + int16x8_t v104_tmp = vqrdmulhq_n_s16(v103, 13573); + int16x8_t v104 = vaddq_s16(v104_tmp, v103); + int16x8_t v105 = vaddq_s16(v84, v64); + int16x8_t v106 = vaddq_s16(v65, v86); + int16x8_t v107 = vaddq_s16(v105, v106); + int16x8_t v108 = vaddq_s16(v104, v107); + int16x8_t v109 = vaddq_s16(v87, v68); + int16x8_t v110 = vaddq_s16(v69, v79); + int16x8_t v111 = vaddq_s16(v109, v110); + int16x8_t v112_tmp = vqrdmulhq_n_s16(v111, 13573); + int16x8_t v112 = vaddq_s16(v112_tmp, v111); + int16x8_t v113 = vaddq_s16(v94, v72); + int16x8_t v114 = vaddq_s16(v73, v83); + int16x8_t v115 = vaddq_s16(v113, v114); + int16x8_t v116 = vaddq_s16(v115, v111); + int16x8_t v117 = vaddq_s16(v112, v116); + int16x8_t v118 = vqrdmulhq_n_s16(v117, 17734); + int16x8_t v119 = vaddq_s16(v108, v118); + int16x8_t v120 = vaddq_s16(v110, v103); + int16x8_t v121_tmp = vqrdmulhq_n_s16(v120, 13573); + int16x8_t v121 = vaddq_s16(v121_tmp, v120); + int16x8_t v122 = vaddq_s16(v114, v105); + int16x8_t v123 = vaddq_s16(v106, v109); + int16x8_t v124 = vaddq_s16(v122, v123); + int16x8_t v125 = vaddq_s16(v121, v124); + int16x8_t v126 = vaddq_s16(v123, v120); + int16x8_t v127_tmp = vqrdmulhq_n_s16(v126, 13573); + int16x8_t v127 = vaddq_s16(v127_tmp, v126); + int16x8_t v128 = vld1q_s16(in + in_stride * 248 + i); + int16x8_t v129 = vaddq_s16(v128, v93); + int16x8_t v130 = vaddq_s16(v129, v113); + int16x8_t v131 = vaddq_s16(v130, v122); + int16x8_t v132 = vaddq_s16(v131, v126); + int16x8_t v133 = vaddq_s16(v127, v132); + int16x8_t v134 = vqrdmulhq_n_s16(v133, 17734); + int16x8_t v135 = vaddq_s16(v125, v134); + int16x8_t v136 = vqrdmulhq_n_s16(v135, 16705); + int16x8_t v137 = vaddq_s16(v119, v136); + int16x8_t v138 = vqrdmulhq_n_s16(v137, 16463); + int16x8_t v139 = vaddq_s16(v102, v138); + int16x8_t v140 = vqrdmulhq_n_s16(v139, 16404); + int16x8_t v141 = vaddq_s16(v61, v140); + int16x8_t v142 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v143_tmp = vqrdmulhq_n_s16(v142, 13573); + int16x8_t v143 = vaddq_s16(v143_tmp, v142); + int16x8_t v144 = vld1q_s16(in + in_stride * 132 + i); + int16x8_t v145 = vld1q_s16(in + in_stride * 124 + i); + int16x8_t v146 = vaddq_s16(v144, v145); + int16x8_t v147 = vaddq_s16(v143, v146); + int16x8_t v148 = vld1q_s16(in + in_stride * 68 + i); + int16x8_t v149 = vld1q_s16(in + in_stride * 60 + i); + int16x8_t v150 = vaddq_s16(v148, v149); + int16x8_t v151_tmp = vqrdmulhq_n_s16(v150, 13573); + int16x8_t v151 = vaddq_s16(v151_tmp, v150); + int16x8_t v152 = vld1q_s16(in + in_stride * 196 + i); + int16x8_t v153 = vld1q_s16(in + in_stride * 188 + i); + int16x8_t v154 = vaddq_s16(v152, v153); + int16x8_t v155 = vaddq_s16(v154, v150); + int16x8_t v156 = vaddq_s16(v151, v155); + int16x8_t v157 = vqrdmulhq_n_s16(v156, 17734); + int16x8_t v158 = vaddq_s16(v147, v157); + int16x8_t v159 = vld1q_s16(in + in_stride * 36 + i); + int16x8_t v160 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v161 = vaddq_s16(v159, v160); + int16x8_t v162_tmp = vqrdmulhq_n_s16(v161, 13573); + int16x8_t v162 = vaddq_s16(v162_tmp, v161); + int16x8_t v163 = vld1q_s16(in + in_stride * 164 + i); + int16x8_t v164 = vld1q_s16(in + in_stride * 156 + i); + int16x8_t v165 = vaddq_s16(v163, v164); + int16x8_t v166 = vld1q_s16(in + in_stride * 100 + i); + int16x8_t v167 = vld1q_s16(in + in_stride * 92 + i); + int16x8_t v168 = vaddq_s16(v166, v167); + int16x8_t v169 = vaddq_s16(v165, v168); + int16x8_t v170 = vaddq_s16(v162, v169); + int16x8_t v171 = vaddq_s16(v168, v161); + int16x8_t v172_tmp = vqrdmulhq_n_s16(v171, 13573); + int16x8_t v172 = vaddq_s16(v172_tmp, v171); + int16x8_t v173 = vld1q_s16(in + in_stride * 228 + i); + int16x8_t v174 = vld1q_s16(in + in_stride * 220 + i); + int16x8_t v175 = vaddq_s16(v173, v174); + int16x8_t v176 = vaddq_s16(v175, v165); + int16x8_t v177 = vaddq_s16(v176, v171); + int16x8_t v178 = vaddq_s16(v172, v177); + int16x8_t v179 = vqrdmulhq_n_s16(v178, 17734); + int16x8_t v180 = vaddq_s16(v170, v179); + int16x8_t v181 = vqrdmulhq_n_s16(v180, 16705); + int16x8_t v182 = vaddq_s16(v158, v181); + int16x8_t v183 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v184 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v185 = vaddq_s16(v183, v184); + int16x8_t v186_tmp = vqrdmulhq_n_s16(v185, 13573); + int16x8_t v186 = vaddq_s16(v186_tmp, v185); + int16x8_t v187 = vld1q_s16(in + in_stride * 148 + i); + int16x8_t v188 = vld1q_s16(in + in_stride * 140 + i); + int16x8_t v189 = vaddq_s16(v187, v188); + int16x8_t v190 = vld1q_s16(in + in_stride * 116 + i); + int16x8_t v191 = vld1q_s16(in + in_stride * 108 + i); + int16x8_t v192 = vaddq_s16(v190, v191); + int16x8_t v193 = vaddq_s16(v189, v192); + int16x8_t v194 = vaddq_s16(v186, v193); + int16x8_t v195 = vld1q_s16(in + in_stride * 84 + i); + int16x8_t v196 = vld1q_s16(in + in_stride * 76 + i); + int16x8_t v197 = vaddq_s16(v195, v196); + int16x8_t v198 = vld1q_s16(in + in_stride * 52 + i); + int16x8_t v199 = vld1q_s16(in + in_stride * 44 + i); + int16x8_t v200 = vaddq_s16(v198, v199); + int16x8_t v201 = vaddq_s16(v197, v200); + int16x8_t v202_tmp = vqrdmulhq_n_s16(v201, 13573); + int16x8_t v202 = vaddq_s16(v202_tmp, v201); + int16x8_t v203 = vld1q_s16(in + in_stride * 212 + i); + int16x8_t v204 = vld1q_s16(in + in_stride * 204 + i); + int16x8_t v205 = vaddq_s16(v203, v204); + int16x8_t v206 = vld1q_s16(in + in_stride * 180 + i); + int16x8_t v207 = vld1q_s16(in + in_stride * 172 + i); + int16x8_t v208 = vaddq_s16(v206, v207); + int16x8_t v209 = vaddq_s16(v205, v208); + int16x8_t v210 = vaddq_s16(v209, v201); + int16x8_t v211 = vaddq_s16(v202, v210); + int16x8_t v212 = vqrdmulhq_n_s16(v211, 17734); + int16x8_t v213 = vaddq_s16(v194, v212); + int16x8_t v214 = vaddq_s16(v200, v185); + int16x8_t v215_tmp = vqrdmulhq_n_s16(v214, 13573); + int16x8_t v215 = vaddq_s16(v215_tmp, v214); + int16x8_t v216 = vaddq_s16(v208, v189); + int16x8_t v217 = vaddq_s16(v192, v197); + int16x8_t v218 = vaddq_s16(v216, v217); + int16x8_t v219 = vaddq_s16(v215, v218); + int16x8_t v220 = vaddq_s16(v217, v214); + int16x8_t v221_tmp = vqrdmulhq_n_s16(v220, 13573); + int16x8_t v221 = vaddq_s16(v221_tmp, v220); + int16x8_t v222 = vld1q_s16(in + in_stride * 244 + i); + int16x8_t v223 = vld1q_s16(in + in_stride * 236 + i); + int16x8_t v224 = vaddq_s16(v222, v223); + int16x8_t v225 = vaddq_s16(v224, v205); + int16x8_t v226 = vaddq_s16(v225, v216); + int16x8_t v227 = vaddq_s16(v226, v220); + int16x8_t v228 = vaddq_s16(v221, v227); + int16x8_t v229 = vqrdmulhq_n_s16(v228, 17734); + int16x8_t v230 = vaddq_s16(v219, v229); + int16x8_t v231 = vqrdmulhq_n_s16(v230, 16705); + int16x8_t v232 = vaddq_s16(v213, v231); + int16x8_t v233 = vqrdmulhq_n_s16(v232, 16463); + int16x8_t v234 = vaddq_s16(v182, v233); + int16x8_t v235 = vaddq_s16(v184, v142); + int16x8_t v236_tmp = vqrdmulhq_n_s16(v235, 13573); + int16x8_t v236 = vaddq_s16(v236_tmp, v235); + int16x8_t v237 = vaddq_s16(v188, v144); + int16x8_t v238 = vaddq_s16(v145, v190); + int16x8_t v239 = vaddq_s16(v237, v238); + int16x8_t v240 = vaddq_s16(v236, v239); + int16x8_t v241 = vaddq_s16(v196, v148); + int16x8_t v242 = vaddq_s16(v149, v198); + int16x8_t v243 = vaddq_s16(v241, v242); + int16x8_t v244_tmp = vqrdmulhq_n_s16(v243, 13573); + int16x8_t v244 = vaddq_s16(v244_tmp, v243); + int16x8_t v245 = vaddq_s16(v204, v152); + int16x8_t v246 = vaddq_s16(v153, v206); + int16x8_t v247 = vaddq_s16(v245, v246); + int16x8_t v248 = vaddq_s16(v247, v243); + int16x8_t v249 = vaddq_s16(v244, v248); + int16x8_t v250 = vqrdmulhq_n_s16(v249, 17734); + int16x8_t v251 = vaddq_s16(v240, v250); + int16x8_t v252 = vaddq_s16(v199, v159); + int16x8_t v253 = vaddq_s16(v160, v183); + int16x8_t v254 = vaddq_s16(v252, v253); + int16x8_t v255_tmp = vqrdmulhq_n_s16(v254, 13573); + int16x8_t v255 = vaddq_s16(v255_tmp, v254); + int16x8_t v256 = vaddq_s16(v207, v163); + int16x8_t v257 = vaddq_s16(v164, v187); + int16x8_t v258 = vaddq_s16(v256, v257); + int16x8_t v259 = vaddq_s16(v191, v166); + int16x8_t v260 = vaddq_s16(v167, v195); + int16x8_t v261 = vaddq_s16(v259, v260); + int16x8_t v262 = vaddq_s16(v258, v261); + int16x8_t v263 = vaddq_s16(v255, v262); + int16x8_t v264 = vaddq_s16(v261, v254); + int16x8_t v265_tmp = vqrdmulhq_n_s16(v264, 13573); + int16x8_t v265 = vaddq_s16(v265_tmp, v264); + int16x8_t v266 = vaddq_s16(v223, v173); + int16x8_t v267 = vaddq_s16(v174, v203); + int16x8_t v268 = vaddq_s16(v266, v267); + int16x8_t v269 = vaddq_s16(v268, v258); + int16x8_t v270 = vaddq_s16(v269, v264); + int16x8_t v271 = vaddq_s16(v265, v270); + int16x8_t v272 = vqrdmulhq_n_s16(v271, 17734); + int16x8_t v273 = vaddq_s16(v263, v272); + int16x8_t v274 = vqrdmulhq_n_s16(v273, 16705); + int16x8_t v275 = vaddq_s16(v251, v274); + int16x8_t v276 = vaddq_s16(v253, v235); + int16x8_t v277_tmp = vqrdmulhq_n_s16(v276, 13573); + int16x8_t v277 = vaddq_s16(v277_tmp, v276); + int16x8_t v278 = vaddq_s16(v257, v237); + int16x8_t v279 = vaddq_s16(v238, v259); + int16x8_t v280 = vaddq_s16(v278, v279); + int16x8_t v281 = vaddq_s16(v277, v280); + int16x8_t v282 = vaddq_s16(v260, v241); + int16x8_t v283 = vaddq_s16(v242, v252); + int16x8_t v284 = vaddq_s16(v282, v283); + int16x8_t v285_tmp = vqrdmulhq_n_s16(v284, 13573); + int16x8_t v285 = vaddq_s16(v285_tmp, v284); + int16x8_t v286 = vaddq_s16(v267, v245); + int16x8_t v287 = vaddq_s16(v246, v256); + int16x8_t v288 = vaddq_s16(v286, v287); + int16x8_t v289 = vaddq_s16(v288, v284); + int16x8_t v290 = vaddq_s16(v285, v289); + int16x8_t v291 = vqrdmulhq_n_s16(v290, 17734); + int16x8_t v292 = vaddq_s16(v281, v291); + int16x8_t v293 = vaddq_s16(v283, v276); + int16x8_t v294_tmp = vqrdmulhq_n_s16(v293, 13573); + int16x8_t v294 = vaddq_s16(v294_tmp, v293); + int16x8_t v295 = vaddq_s16(v287, v278); + int16x8_t v296 = vaddq_s16(v279, v282); + int16x8_t v297 = vaddq_s16(v295, v296); + int16x8_t v298 = vaddq_s16(v294, v297); + int16x8_t v299 = vaddq_s16(v296, v293); + int16x8_t v300_tmp = vqrdmulhq_n_s16(v299, 13573); + int16x8_t v300 = vaddq_s16(v300_tmp, v299); + int16x8_t v301 = vld1q_s16(in + in_stride * 252 + i); + int16x8_t v302 = vaddq_s16(v301, v222); + int16x8_t v303 = vaddq_s16(v302, v266); + int16x8_t v304 = vaddq_s16(v303, v286); + int16x8_t v305 = vaddq_s16(v304, v295); + int16x8_t v306 = vaddq_s16(v305, v299); + int16x8_t v307 = vaddq_s16(v300, v306); + int16x8_t v308 = vqrdmulhq_n_s16(v307, 17734); + int16x8_t v309 = vaddq_s16(v298, v308); + int16x8_t v310 = vqrdmulhq_n_s16(v309, 16705); + int16x8_t v311 = vaddq_s16(v292, v310); + int16x8_t v312 = vqrdmulhq_n_s16(v311, 16463); + int16x8_t v313 = vaddq_s16(v275, v312); + int16x8_t v314 = vqrdmulhq_n_s16(v313, 16404); + int16x8_t v315 = vaddq_s16(v234, v314); + int16x8_t v316 = vqrdmulhq_n_s16(v315, 16389); + int16x8_t v317 = vaddq_s16(v141, v316); + int16x8_t v318 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v319_tmp = vqrdmulhq_n_s16(v318, 13573); + int16x8_t v319 = vaddq_s16(v319_tmp, v318); + int16x8_t v320 = vld1q_s16(in + in_stride * 130 + i); + int16x8_t v321 = vld1q_s16(in + in_stride * 126 + i); + int16x8_t v322 = vaddq_s16(v320, v321); + int16x8_t v323 = vaddq_s16(v319, v322); + int16x8_t v324 = vld1q_s16(in + in_stride * 66 + i); + int16x8_t v325 = vld1q_s16(in + in_stride * 62 + i); + int16x8_t v326 = vaddq_s16(v324, v325); + int16x8_t v327_tmp = vqrdmulhq_n_s16(v326, 13573); + int16x8_t v327 = vaddq_s16(v327_tmp, v326); + int16x8_t v328 = vld1q_s16(in + in_stride * 194 + i); + int16x8_t v329 = vld1q_s16(in + in_stride * 190 + i); + int16x8_t v330 = vaddq_s16(v328, v329); + int16x8_t v331 = vaddq_s16(v330, v326); + int16x8_t v332 = vaddq_s16(v327, v331); + int16x8_t v333 = vqrdmulhq_n_s16(v332, 17734); + int16x8_t v334 = vaddq_s16(v323, v333); + int16x8_t v335 = vld1q_s16(in + in_stride * 34 + i); + int16x8_t v336 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v337 = vaddq_s16(v335, v336); + int16x8_t v338_tmp = vqrdmulhq_n_s16(v337, 13573); + int16x8_t v338 = vaddq_s16(v338_tmp, v337); + int16x8_t v339 = vld1q_s16(in + in_stride * 162 + i); + int16x8_t v340 = vld1q_s16(in + in_stride * 158 + i); + int16x8_t v341 = vaddq_s16(v339, v340); + int16x8_t v342 = vld1q_s16(in + in_stride * 98 + i); + int16x8_t v343 = vld1q_s16(in + in_stride * 94 + i); + int16x8_t v344 = vaddq_s16(v342, v343); + int16x8_t v345 = vaddq_s16(v341, v344); + int16x8_t v346 = vaddq_s16(v338, v345); + int16x8_t v347 = vaddq_s16(v344, v337); + int16x8_t v348_tmp = vqrdmulhq_n_s16(v347, 13573); + int16x8_t v348 = vaddq_s16(v348_tmp, v347); + int16x8_t v349 = vld1q_s16(in + in_stride * 226 + i); + int16x8_t v350 = vld1q_s16(in + in_stride * 222 + i); + int16x8_t v351 = vaddq_s16(v349, v350); + int16x8_t v352 = vaddq_s16(v351, v341); + int16x8_t v353 = vaddq_s16(v352, v347); + int16x8_t v354 = vaddq_s16(v348, v353); + int16x8_t v355 = vqrdmulhq_n_s16(v354, 17734); + int16x8_t v356 = vaddq_s16(v346, v355); + int16x8_t v357 = vqrdmulhq_n_s16(v356, 16705); + int16x8_t v358 = vaddq_s16(v334, v357); + int16x8_t v359 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v360 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v361 = vaddq_s16(v359, v360); + int16x8_t v362_tmp = vqrdmulhq_n_s16(v361, 13573); + int16x8_t v362 = vaddq_s16(v362_tmp, v361); + int16x8_t v363 = vld1q_s16(in + in_stride * 146 + i); + int16x8_t v364 = vld1q_s16(in + in_stride * 142 + i); + int16x8_t v365 = vaddq_s16(v363, v364); + int16x8_t v366 = vld1q_s16(in + in_stride * 114 + i); + int16x8_t v367 = vld1q_s16(in + in_stride * 110 + i); + int16x8_t v368 = vaddq_s16(v366, v367); + int16x8_t v369 = vaddq_s16(v365, v368); + int16x8_t v370 = vaddq_s16(v362, v369); + int16x8_t v371 = vld1q_s16(in + in_stride * 82 + i); + int16x8_t v372 = vld1q_s16(in + in_stride * 78 + i); + int16x8_t v373 = vaddq_s16(v371, v372); + int16x8_t v374 = vld1q_s16(in + in_stride * 50 + i); + int16x8_t v375 = vld1q_s16(in + in_stride * 46 + i); + int16x8_t v376 = vaddq_s16(v374, v375); + int16x8_t v377 = vaddq_s16(v373, v376); + int16x8_t v378_tmp = vqrdmulhq_n_s16(v377, 13573); + int16x8_t v378 = vaddq_s16(v378_tmp, v377); + int16x8_t v379 = vld1q_s16(in + in_stride * 210 + i); + int16x8_t v380 = vld1q_s16(in + in_stride * 206 + i); + int16x8_t v381 = vaddq_s16(v379, v380); + int16x8_t v382 = vld1q_s16(in + in_stride * 178 + i); + int16x8_t v383 = vld1q_s16(in + in_stride * 174 + i); + int16x8_t v384 = vaddq_s16(v382, v383); + int16x8_t v385 = vaddq_s16(v381, v384); + int16x8_t v386 = vaddq_s16(v385, v377); + int16x8_t v387 = vaddq_s16(v378, v386); + int16x8_t v388 = vqrdmulhq_n_s16(v387, 17734); + int16x8_t v389 = vaddq_s16(v370, v388); + int16x8_t v390 = vaddq_s16(v376, v361); + int16x8_t v391_tmp = vqrdmulhq_n_s16(v390, 13573); + int16x8_t v391 = vaddq_s16(v391_tmp, v390); + int16x8_t v392 = vaddq_s16(v384, v365); + int16x8_t v393 = vaddq_s16(v368, v373); + int16x8_t v394 = vaddq_s16(v392, v393); + int16x8_t v395 = vaddq_s16(v391, v394); + int16x8_t v396 = vaddq_s16(v393, v390); + int16x8_t v397_tmp = vqrdmulhq_n_s16(v396, 13573); + int16x8_t v397 = vaddq_s16(v397_tmp, v396); + int16x8_t v398 = vld1q_s16(in + in_stride * 242 + i); + int16x8_t v399 = vld1q_s16(in + in_stride * 238 + i); + int16x8_t v400 = vaddq_s16(v398, v399); + int16x8_t v401 = vaddq_s16(v400, v381); + int16x8_t v402 = vaddq_s16(v401, v392); + int16x8_t v403 = vaddq_s16(v402, v396); + int16x8_t v404 = vaddq_s16(v397, v403); + int16x8_t v405 = vqrdmulhq_n_s16(v404, 17734); + int16x8_t v406 = vaddq_s16(v395, v405); + int16x8_t v407 = vqrdmulhq_n_s16(v406, 16705); + int16x8_t v408 = vaddq_s16(v389, v407); + int16x8_t v409 = vqrdmulhq_n_s16(v408, 16463); + int16x8_t v410 = vaddq_s16(v358, v409); + int16x8_t v411 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v412 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v413 = vaddq_s16(v411, v412); + int16x8_t v414_tmp = vqrdmulhq_n_s16(v413, 13573); + int16x8_t v414 = vaddq_s16(v414_tmp, v413); + int16x8_t v415 = vld1q_s16(in + in_stride * 138 + i); + int16x8_t v416 = vld1q_s16(in + in_stride * 134 + i); + int16x8_t v417 = vaddq_s16(v415, v416); + int16x8_t v418 = vld1q_s16(in + in_stride * 122 + i); + int16x8_t v419 = vld1q_s16(in + in_stride * 118 + i); + int16x8_t v420 = vaddq_s16(v418, v419); + int16x8_t v421 = vaddq_s16(v417, v420); + int16x8_t v422 = vaddq_s16(v414, v421); + int16x8_t v423 = vld1q_s16(in + in_stride * 74 + i); + int16x8_t v424 = vld1q_s16(in + in_stride * 70 + i); + int16x8_t v425 = vaddq_s16(v423, v424); + int16x8_t v426 = vld1q_s16(in + in_stride * 58 + i); + int16x8_t v427 = vld1q_s16(in + in_stride * 54 + i); + int16x8_t v428 = vaddq_s16(v426, v427); + int16x8_t v429 = vaddq_s16(v425, v428); + int16x8_t v430_tmp = vqrdmulhq_n_s16(v429, 13573); + int16x8_t v430 = vaddq_s16(v430_tmp, v429); + int16x8_t v431 = vld1q_s16(in + in_stride * 202 + i); + int16x8_t v432 = vld1q_s16(in + in_stride * 198 + i); + int16x8_t v433 = vaddq_s16(v431, v432); + int16x8_t v434 = vld1q_s16(in + in_stride * 186 + i); + int16x8_t v435 = vld1q_s16(in + in_stride * 182 + i); + int16x8_t v436 = vaddq_s16(v434, v435); + int16x8_t v437 = vaddq_s16(v433, v436); + int16x8_t v438 = vaddq_s16(v437, v429); + int16x8_t v439 = vaddq_s16(v430, v438); + int16x8_t v440 = vqrdmulhq_n_s16(v439, 17734); + int16x8_t v441 = vaddq_s16(v422, v440); + int16x8_t v442 = vld1q_s16(in + in_stride * 42 + i); + int16x8_t v443 = vld1q_s16(in + in_stride * 38 + i); + int16x8_t v444 = vaddq_s16(v442, v443); + int16x8_t v445 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v446 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v447 = vaddq_s16(v445, v446); + int16x8_t v448 = vaddq_s16(v444, v447); + int16x8_t v449_tmp = vqrdmulhq_n_s16(v448, 13573); + int16x8_t v449 = vaddq_s16(v449_tmp, v448); + int16x8_t v450 = vld1q_s16(in + in_stride * 170 + i); + int16x8_t v451 = vld1q_s16(in + in_stride * 166 + i); + int16x8_t v452 = vaddq_s16(v450, v451); + int16x8_t v453 = vld1q_s16(in + in_stride * 154 + i); + int16x8_t v454 = vld1q_s16(in + in_stride * 150 + i); + int16x8_t v455 = vaddq_s16(v453, v454); + int16x8_t v456 = vaddq_s16(v452, v455); + int16x8_t v457 = vld1q_s16(in + in_stride * 106 + i); + int16x8_t v458 = vld1q_s16(in + in_stride * 102 + i); + int16x8_t v459 = vaddq_s16(v457, v458); + int16x8_t v460 = vld1q_s16(in + in_stride * 90 + i); + int16x8_t v461 = vld1q_s16(in + in_stride * 86 + i); + int16x8_t v462 = vaddq_s16(v460, v461); + int16x8_t v463 = vaddq_s16(v459, v462); + int16x8_t v464 = vaddq_s16(v456, v463); + int16x8_t v465 = vaddq_s16(v449, v464); + int16x8_t v466 = vaddq_s16(v463, v448); + int16x8_t v467_tmp = vqrdmulhq_n_s16(v466, 13573); + int16x8_t v467 = vaddq_s16(v467_tmp, v466); + int16x8_t v468 = vld1q_s16(in + in_stride * 234 + i); + int16x8_t v469 = vld1q_s16(in + in_stride * 230 + i); + int16x8_t v470 = vaddq_s16(v468, v469); + int16x8_t v471 = vld1q_s16(in + in_stride * 218 + i); + int16x8_t v472 = vld1q_s16(in + in_stride * 214 + i); + int16x8_t v473 = vaddq_s16(v471, v472); + int16x8_t v474 = vaddq_s16(v470, v473); + int16x8_t v475 = vaddq_s16(v474, v456); + int16x8_t v476 = vaddq_s16(v475, v466); + int16x8_t v477 = vaddq_s16(v467, v476); + int16x8_t v478 = vqrdmulhq_n_s16(v477, 17734); + int16x8_t v479 = vaddq_s16(v465, v478); + int16x8_t v480 = vqrdmulhq_n_s16(v479, 16705); + int16x8_t v481 = vaddq_s16(v441, v480); + int16x8_t v482 = vaddq_s16(v447, v413); + int16x8_t v483_tmp = vqrdmulhq_n_s16(v482, 13573); + int16x8_t v483 = vaddq_s16(v483_tmp, v482); + int16x8_t v484 = vaddq_s16(v455, v417); + int16x8_t v485 = vaddq_s16(v420, v459); + int16x8_t v486 = vaddq_s16(v484, v485); + int16x8_t v487 = vaddq_s16(v483, v486); + int16x8_t v488 = vaddq_s16(v462, v425); + int16x8_t v489 = vaddq_s16(v428, v444); + int16x8_t v490 = vaddq_s16(v488, v489); + int16x8_t v491_tmp = vqrdmulhq_n_s16(v490, 13573); + int16x8_t v491 = vaddq_s16(v491_tmp, v490); + int16x8_t v492 = vaddq_s16(v473, v433); + int16x8_t v493 = vaddq_s16(v436, v452); + int16x8_t v494 = vaddq_s16(v492, v493); + int16x8_t v495 = vaddq_s16(v494, v490); + int16x8_t v496 = vaddq_s16(v491, v495); + int16x8_t v497 = vqrdmulhq_n_s16(v496, 17734); + int16x8_t v498 = vaddq_s16(v487, v497); + int16x8_t v499 = vaddq_s16(v489, v482); + int16x8_t v500_tmp = vqrdmulhq_n_s16(v499, 13573); + int16x8_t v500 = vaddq_s16(v500_tmp, v499); + int16x8_t v501 = vaddq_s16(v493, v484); + int16x8_t v502 = vaddq_s16(v485, v488); + int16x8_t v503 = vaddq_s16(v501, v502); + int16x8_t v504 = vaddq_s16(v500, v503); + int16x8_t v505 = vaddq_s16(v502, v499); + int16x8_t v506_tmp = vqrdmulhq_n_s16(v505, 13573); + int16x8_t v506 = vaddq_s16(v506_tmp, v505); + int16x8_t v507 = vld1q_s16(in + in_stride * 250 + i); + int16x8_t v508 = vld1q_s16(in + in_stride * 246 + i); + int16x8_t v509 = vaddq_s16(v507, v508); + int16x8_t v510 = vaddq_s16(v509, v470); + int16x8_t v511 = vaddq_s16(v510, v492); + int16x8_t v512 = vaddq_s16(v511, v501); + int16x8_t v513 = vaddq_s16(v512, v505); + int16x8_t v514 = vaddq_s16(v506, v513); + int16x8_t v515 = vqrdmulhq_n_s16(v514, 17734); + int16x8_t v516 = vaddq_s16(v504, v515); + int16x8_t v517 = vqrdmulhq_n_s16(v516, 16705); + int16x8_t v518 = vaddq_s16(v498, v517); + int16x8_t v519 = vqrdmulhq_n_s16(v518, 16463); + int16x8_t v520 = vaddq_s16(v481, v519); + int16x8_t v521 = vqrdmulhq_n_s16(v520, 16404); + int16x8_t v522 = vaddq_s16(v410, v521); + int16x8_t v523 = vaddq_s16(v412, v318); + int16x8_t v524_tmp = vqrdmulhq_n_s16(v523, 13573); + int16x8_t v524 = vaddq_s16(v524_tmp, v523); + int16x8_t v525 = vaddq_s16(v416, v320); + int16x8_t v526 = vaddq_s16(v321, v418); + int16x8_t v527 = vaddq_s16(v525, v526); + int16x8_t v528 = vaddq_s16(v524, v527); + int16x8_t v529 = vaddq_s16(v424, v324); + int16x8_t v530 = vaddq_s16(v325, v426); + int16x8_t v531 = vaddq_s16(v529, v530); + int16x8_t v532_tmp = vqrdmulhq_n_s16(v531, 13573); + int16x8_t v532 = vaddq_s16(v532_tmp, v531); + int16x8_t v533 = vaddq_s16(v432, v328); + int16x8_t v534 = vaddq_s16(v329, v434); + int16x8_t v535 = vaddq_s16(v533, v534); + int16x8_t v536 = vaddq_s16(v535, v531); + int16x8_t v537 = vaddq_s16(v532, v536); + int16x8_t v538 = vqrdmulhq_n_s16(v537, 17734); + int16x8_t v539 = vaddq_s16(v528, v538); + int16x8_t v540 = vaddq_s16(v443, v335); + int16x8_t v541 = vaddq_s16(v336, v445); + int16x8_t v542 = vaddq_s16(v540, v541); + int16x8_t v543_tmp = vqrdmulhq_n_s16(v542, 13573); + int16x8_t v543 = vaddq_s16(v543_tmp, v542); + int16x8_t v544 = vaddq_s16(v451, v339); + int16x8_t v545 = vaddq_s16(v340, v453); + int16x8_t v546 = vaddq_s16(v544, v545); + int16x8_t v547 = vaddq_s16(v458, v342); + int16x8_t v548 = vaddq_s16(v343, v460); + int16x8_t v549 = vaddq_s16(v547, v548); + int16x8_t v550 = vaddq_s16(v546, v549); + int16x8_t v551 = vaddq_s16(v543, v550); + int16x8_t v552 = vaddq_s16(v549, v542); + int16x8_t v553_tmp = vqrdmulhq_n_s16(v552, 13573); + int16x8_t v553 = vaddq_s16(v553_tmp, v552); + int16x8_t v554 = vaddq_s16(v469, v349); + int16x8_t v555 = vaddq_s16(v350, v471); + int16x8_t v556 = vaddq_s16(v554, v555); + int16x8_t v557 = vaddq_s16(v556, v546); + int16x8_t v558 = vaddq_s16(v557, v552); + int16x8_t v559 = vaddq_s16(v553, v558); + int16x8_t v560 = vqrdmulhq_n_s16(v559, 17734); + int16x8_t v561 = vaddq_s16(v551, v560); + int16x8_t v562 = vqrdmulhq_n_s16(v561, 16705); + int16x8_t v563 = vaddq_s16(v539, v562); + int16x8_t v564 = vaddq_s16(v446, v359); + int16x8_t v565 = vaddq_s16(v360, v411); + int16x8_t v566 = vaddq_s16(v564, v565); + int16x8_t v567_tmp = vqrdmulhq_n_s16(v566, 13573); + int16x8_t v567 = vaddq_s16(v567_tmp, v566); + int16x8_t v568 = vaddq_s16(v454, v363); + int16x8_t v569 = vaddq_s16(v364, v415); + int16x8_t v570 = vaddq_s16(v568, v569); + int16x8_t v571 = vaddq_s16(v419, v366); + int16x8_t v572 = vaddq_s16(v367, v457); + int16x8_t v573 = vaddq_s16(v571, v572); + int16x8_t v574 = vaddq_s16(v570, v573); + int16x8_t v575 = vaddq_s16(v567, v574); + int16x8_t v576 = vaddq_s16(v461, v371); + int16x8_t v577 = vaddq_s16(v372, v423); + int16x8_t v578 = vaddq_s16(v576, v577); + int16x8_t v579 = vaddq_s16(v427, v374); + int16x8_t v580 = vaddq_s16(v375, v442); + int16x8_t v581 = vaddq_s16(v579, v580); + int16x8_t v582 = vaddq_s16(v578, v581); + int16x8_t v583_tmp = vqrdmulhq_n_s16(v582, 13573); + int16x8_t v583 = vaddq_s16(v583_tmp, v582); + int16x8_t v584 = vaddq_s16(v472, v379); + int16x8_t v585 = vaddq_s16(v380, v431); + int16x8_t v586 = vaddq_s16(v584, v585); + int16x8_t v587 = vaddq_s16(v435, v382); + int16x8_t v588 = vaddq_s16(v383, v450); + int16x8_t v589 = vaddq_s16(v587, v588); + int16x8_t v590 = vaddq_s16(v586, v589); + int16x8_t v591 = vaddq_s16(v590, v582); + int16x8_t v592 = vaddq_s16(v583, v591); + int16x8_t v593 = vqrdmulhq_n_s16(v592, 17734); + int16x8_t v594 = vaddq_s16(v575, v593); + int16x8_t v595 = vaddq_s16(v581, v566); + int16x8_t v596_tmp = vqrdmulhq_n_s16(v595, 13573); + int16x8_t v596 = vaddq_s16(v596_tmp, v595); + int16x8_t v597 = vaddq_s16(v589, v570); + int16x8_t v598 = vaddq_s16(v573, v578); + int16x8_t v599 = vaddq_s16(v597, v598); + int16x8_t v600 = vaddq_s16(v596, v599); + int16x8_t v601 = vaddq_s16(v598, v595); + int16x8_t v602_tmp = vqrdmulhq_n_s16(v601, 13573); + int16x8_t v602 = vaddq_s16(v602_tmp, v601); + int16x8_t v603 = vaddq_s16(v508, v398); + int16x8_t v604 = vaddq_s16(v399, v468); + int16x8_t v605 = vaddq_s16(v603, v604); + int16x8_t v606 = vaddq_s16(v605, v586); + int16x8_t v607 = vaddq_s16(v606, v597); + int16x8_t v608 = vaddq_s16(v607, v601); + int16x8_t v609 = vaddq_s16(v602, v608); + int16x8_t v610 = vqrdmulhq_n_s16(v609, 17734); + int16x8_t v611 = vaddq_s16(v600, v610); + int16x8_t v612 = vqrdmulhq_n_s16(v611, 16705); + int16x8_t v613 = vaddq_s16(v594, v612); + int16x8_t v614 = vqrdmulhq_n_s16(v613, 16463); + int16x8_t v615 = vaddq_s16(v563, v614); + int16x8_t v616 = vaddq_s16(v565, v523); + int16x8_t v617_tmp = vqrdmulhq_n_s16(v616, 13573); + int16x8_t v617 = vaddq_s16(v617_tmp, v616); + int16x8_t v618 = vaddq_s16(v569, v525); + int16x8_t v619 = vaddq_s16(v526, v571); + int16x8_t v620 = vaddq_s16(v618, v619); + int16x8_t v621 = vaddq_s16(v617, v620); + int16x8_t v622 = vaddq_s16(v577, v529); + int16x8_t v623 = vaddq_s16(v530, v579); + int16x8_t v624 = vaddq_s16(v622, v623); + int16x8_t v625_tmp = vqrdmulhq_n_s16(v624, 13573); + int16x8_t v625 = vaddq_s16(v625_tmp, v624); + int16x8_t v626 = vaddq_s16(v585, v533); + int16x8_t v627 = vaddq_s16(v534, v587); + int16x8_t v628 = vaddq_s16(v626, v627); + int16x8_t v629 = vaddq_s16(v628, v624); + int16x8_t v630 = vaddq_s16(v625, v629); + int16x8_t v631 = vqrdmulhq_n_s16(v630, 17734); + int16x8_t v632 = vaddq_s16(v621, v631); + int16x8_t v633 = vaddq_s16(v580, v540); + int16x8_t v634 = vaddq_s16(v541, v564); + int16x8_t v635 = vaddq_s16(v633, v634); + int16x8_t v636_tmp = vqrdmulhq_n_s16(v635, 13573); + int16x8_t v636 = vaddq_s16(v636_tmp, v635); + int16x8_t v637 = vaddq_s16(v588, v544); + int16x8_t v638 = vaddq_s16(v545, v568); + int16x8_t v639 = vaddq_s16(v637, v638); + int16x8_t v640 = vaddq_s16(v572, v547); + int16x8_t v641 = vaddq_s16(v548, v576); + int16x8_t v642 = vaddq_s16(v640, v641); + int16x8_t v643 = vaddq_s16(v639, v642); + int16x8_t v644 = vaddq_s16(v636, v643); + int16x8_t v645 = vaddq_s16(v642, v635); + int16x8_t v646_tmp = vqrdmulhq_n_s16(v645, 13573); + int16x8_t v646 = vaddq_s16(v646_tmp, v645); + int16x8_t v647 = vaddq_s16(v604, v554); + int16x8_t v648 = vaddq_s16(v555, v584); + int16x8_t v649 = vaddq_s16(v647, v648); + int16x8_t v650 = vaddq_s16(v649, v639); + int16x8_t v651 = vaddq_s16(v650, v645); + int16x8_t v652 = vaddq_s16(v646, v651); + int16x8_t v653 = vqrdmulhq_n_s16(v652, 17734); + int16x8_t v654 = vaddq_s16(v644, v653); + int16x8_t v655 = vqrdmulhq_n_s16(v654, 16705); + int16x8_t v656 = vaddq_s16(v632, v655); + int16x8_t v657 = vaddq_s16(v634, v616); + int16x8_t v658_tmp = vqrdmulhq_n_s16(v657, 13573); + int16x8_t v658 = vaddq_s16(v658_tmp, v657); + int16x8_t v659 = vaddq_s16(v638, v618); + int16x8_t v660 = vaddq_s16(v619, v640); + int16x8_t v661 = vaddq_s16(v659, v660); + int16x8_t v662 = vaddq_s16(v658, v661); + int16x8_t v663 = vaddq_s16(v641, v622); + int16x8_t v664 = vaddq_s16(v623, v633); + int16x8_t v665 = vaddq_s16(v663, v664); + int16x8_t v666_tmp = vqrdmulhq_n_s16(v665, 13573); + int16x8_t v666 = vaddq_s16(v666_tmp, v665); + int16x8_t v667 = vaddq_s16(v648, v626); + int16x8_t v668 = vaddq_s16(v627, v637); + int16x8_t v669 = vaddq_s16(v667, v668); + int16x8_t v670 = vaddq_s16(v669, v665); + int16x8_t v671 = vaddq_s16(v666, v670); + int16x8_t v672 = vqrdmulhq_n_s16(v671, 17734); + int16x8_t v673 = vaddq_s16(v662, v672); + int16x8_t v674 = vaddq_s16(v664, v657); + int16x8_t v675_tmp = vqrdmulhq_n_s16(v674, 13573); + int16x8_t v675 = vaddq_s16(v675_tmp, v674); + int16x8_t v676 = vaddq_s16(v668, v659); + int16x8_t v677 = vaddq_s16(v660, v663); + int16x8_t v678 = vaddq_s16(v676, v677); + int16x8_t v679 = vaddq_s16(v675, v678); + int16x8_t v680 = vaddq_s16(v677, v674); + int16x8_t v681_tmp = vqrdmulhq_n_s16(v680, 13573); + int16x8_t v681 = vaddq_s16(v681_tmp, v680); + int16x8_t v682 = vld1q_s16(in + in_stride * 254 + i); + int16x8_t v683 = vaddq_s16(v682, v507); + int16x8_t v684 = vaddq_s16(v683, v603); + int16x8_t v685 = vaddq_s16(v684, v647); + int16x8_t v686 = vaddq_s16(v685, v667); + int16x8_t v687 = vaddq_s16(v686, v676); + int16x8_t v688 = vaddq_s16(v687, v680); + int16x8_t v689 = vaddq_s16(v681, v688); + int16x8_t v690 = vqrdmulhq_n_s16(v689, 17734); + int16x8_t v691 = vaddq_s16(v679, v690); + int16x8_t v692 = vqrdmulhq_n_s16(v691, 16705); + int16x8_t v693 = vaddq_s16(v673, v692); + int16x8_t v694 = vqrdmulhq_n_s16(v693, 16463); + int16x8_t v695 = vaddq_s16(v656, v694); + int16x8_t v696 = vqrdmulhq_n_s16(v695, 16404); + int16x8_t v697 = vaddq_s16(v615, v696); + int16x8_t v698 = vqrdmulhq_n_s16(v697, 16389); + int16x8_t v699 = vaddq_s16(v522, v698); + int16x8_t v700 = vqrdmulhq_n_s16(v699, 16385); + int16x8_t v701 = vaddq_s16(v317, v700); + int16x8_t v702 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v703_tmp = vqrdmulhq_n_s16(v702, 13573); + int16x8_t v703 = vaddq_s16(v703_tmp, v702); + int16x8_t v704 = vld1q_s16(in + in_stride * 129 + i); + int16x8_t v705 = vld1q_s16(in + in_stride * 127 + i); + int16x8_t v706 = vaddq_s16(v704, v705); + int16x8_t v707 = vaddq_s16(v703, v706); + int16x8_t v708 = vld1q_s16(in + in_stride * 65 + i); + int16x8_t v709 = vld1q_s16(in + in_stride * 63 + i); + int16x8_t v710 = vaddq_s16(v708, v709); + int16x8_t v711_tmp = vqrdmulhq_n_s16(v710, 13573); + int16x8_t v711 = vaddq_s16(v711_tmp, v710); + int16x8_t v712 = vld1q_s16(in + in_stride * 193 + i); + int16x8_t v713 = vld1q_s16(in + in_stride * 191 + i); + int16x8_t v714 = vaddq_s16(v712, v713); + int16x8_t v715 = vaddq_s16(v714, v710); + int16x8_t v716 = vaddq_s16(v711, v715); + int16x8_t v717 = vqrdmulhq_n_s16(v716, 17734); + int16x8_t v718 = vaddq_s16(v707, v717); + int16x8_t v719 = vld1q_s16(in + in_stride * 33 + i); + int16x8_t v720 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v721 = vaddq_s16(v719, v720); + int16x8_t v722_tmp = vqrdmulhq_n_s16(v721, 13573); + int16x8_t v722 = vaddq_s16(v722_tmp, v721); + int16x8_t v723 = vld1q_s16(in + in_stride * 161 + i); + int16x8_t v724 = vld1q_s16(in + in_stride * 159 + i); + int16x8_t v725 = vaddq_s16(v723, v724); + int16x8_t v726 = vld1q_s16(in + in_stride * 97 + i); + int16x8_t v727 = vld1q_s16(in + in_stride * 95 + i); + int16x8_t v728 = vaddq_s16(v726, v727); + int16x8_t v729 = vaddq_s16(v725, v728); + int16x8_t v730 = vaddq_s16(v722, v729); + int16x8_t v731 = vaddq_s16(v728, v721); + int16x8_t v732_tmp = vqrdmulhq_n_s16(v731, 13573); + int16x8_t v732 = vaddq_s16(v732_tmp, v731); + int16x8_t v733 = vld1q_s16(in + in_stride * 225 + i); + int16x8_t v734 = vld1q_s16(in + in_stride * 223 + i); + int16x8_t v735 = vaddq_s16(v733, v734); + int16x8_t v736 = vaddq_s16(v735, v725); + int16x8_t v737 = vaddq_s16(v736, v731); + int16x8_t v738 = vaddq_s16(v732, v737); + int16x8_t v739 = vqrdmulhq_n_s16(v738, 17734); + int16x8_t v740 = vaddq_s16(v730, v739); + int16x8_t v741 = vqrdmulhq_n_s16(v740, 16705); + int16x8_t v742 = vaddq_s16(v718, v741); + int16x8_t v743 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v744 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v745 = vaddq_s16(v743, v744); + int16x8_t v746_tmp = vqrdmulhq_n_s16(v745, 13573); + int16x8_t v746 = vaddq_s16(v746_tmp, v745); + int16x8_t v747 = vld1q_s16(in + in_stride * 145 + i); + int16x8_t v748 = vld1q_s16(in + in_stride * 143 + i); + int16x8_t v749 = vaddq_s16(v747, v748); + int16x8_t v750 = vld1q_s16(in + in_stride * 113 + i); + int16x8_t v751 = vld1q_s16(in + in_stride * 111 + i); + int16x8_t v752 = vaddq_s16(v750, v751); + int16x8_t v753 = vaddq_s16(v749, v752); + int16x8_t v754 = vaddq_s16(v746, v753); + int16x8_t v755 = vld1q_s16(in + in_stride * 81 + i); + int16x8_t v756 = vld1q_s16(in + in_stride * 79 + i); + int16x8_t v757 = vaddq_s16(v755, v756); + int16x8_t v758 = vld1q_s16(in + in_stride * 49 + i); + int16x8_t v759 = vld1q_s16(in + in_stride * 47 + i); + int16x8_t v760 = vaddq_s16(v758, v759); + int16x8_t v761 = vaddq_s16(v757, v760); + int16x8_t v762_tmp = vqrdmulhq_n_s16(v761, 13573); + int16x8_t v762 = vaddq_s16(v762_tmp, v761); + int16x8_t v763 = vld1q_s16(in + in_stride * 209 + i); + int16x8_t v764 = vld1q_s16(in + in_stride * 207 + i); + int16x8_t v765 = vaddq_s16(v763, v764); + int16x8_t v766 = vld1q_s16(in + in_stride * 177 + i); + int16x8_t v767 = vld1q_s16(in + in_stride * 175 + i); + int16x8_t v768 = vaddq_s16(v766, v767); + int16x8_t v769 = vaddq_s16(v765, v768); + int16x8_t v770 = vaddq_s16(v769, v761); + int16x8_t v771 = vaddq_s16(v762, v770); + int16x8_t v772 = vqrdmulhq_n_s16(v771, 17734); + int16x8_t v773 = vaddq_s16(v754, v772); + int16x8_t v774 = vaddq_s16(v760, v745); + int16x8_t v775_tmp = vqrdmulhq_n_s16(v774, 13573); + int16x8_t v775 = vaddq_s16(v775_tmp, v774); + int16x8_t v776 = vaddq_s16(v768, v749); + int16x8_t v777 = vaddq_s16(v752, v757); + int16x8_t v778 = vaddq_s16(v776, v777); + int16x8_t v779 = vaddq_s16(v775, v778); + int16x8_t v780 = vaddq_s16(v777, v774); + int16x8_t v781_tmp = vqrdmulhq_n_s16(v780, 13573); + int16x8_t v781 = vaddq_s16(v781_tmp, v780); + int16x8_t v782 = vld1q_s16(in + in_stride * 241 + i); + int16x8_t v783 = vld1q_s16(in + in_stride * 239 + i); + int16x8_t v784 = vaddq_s16(v782, v783); + int16x8_t v785 = vaddq_s16(v784, v765); + int16x8_t v786 = vaddq_s16(v785, v776); + int16x8_t v787 = vaddq_s16(v786, v780); + int16x8_t v788 = vaddq_s16(v781, v787); + int16x8_t v789 = vqrdmulhq_n_s16(v788, 17734); + int16x8_t v790 = vaddq_s16(v779, v789); + int16x8_t v791 = vqrdmulhq_n_s16(v790, 16705); + int16x8_t v792 = vaddq_s16(v773, v791); + int16x8_t v793 = vqrdmulhq_n_s16(v792, 16463); + int16x8_t v794 = vaddq_s16(v742, v793); + int16x8_t v795 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v796 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v797 = vaddq_s16(v795, v796); + int16x8_t v798_tmp = vqrdmulhq_n_s16(v797, 13573); + int16x8_t v798 = vaddq_s16(v798_tmp, v797); + int16x8_t v799 = vld1q_s16(in + in_stride * 137 + i); + int16x8_t v800 = vld1q_s16(in + in_stride * 135 + i); + int16x8_t v801 = vaddq_s16(v799, v800); + int16x8_t v802 = vld1q_s16(in + in_stride * 121 + i); + int16x8_t v803 = vld1q_s16(in + in_stride * 119 + i); + int16x8_t v804 = vaddq_s16(v802, v803); + int16x8_t v805 = vaddq_s16(v801, v804); + int16x8_t v806 = vaddq_s16(v798, v805); + int16x8_t v807 = vld1q_s16(in + in_stride * 73 + i); + int16x8_t v808 = vld1q_s16(in + in_stride * 71 + i); + int16x8_t v809 = vaddq_s16(v807, v808); + int16x8_t v810 = vld1q_s16(in + in_stride * 57 + i); + int16x8_t v811 = vld1q_s16(in + in_stride * 55 + i); + int16x8_t v812 = vaddq_s16(v810, v811); + int16x8_t v813 = vaddq_s16(v809, v812); + int16x8_t v814_tmp = vqrdmulhq_n_s16(v813, 13573); + int16x8_t v814 = vaddq_s16(v814_tmp, v813); + int16x8_t v815 = vld1q_s16(in + in_stride * 201 + i); + int16x8_t v816 = vld1q_s16(in + in_stride * 199 + i); + int16x8_t v817 = vaddq_s16(v815, v816); + int16x8_t v818 = vld1q_s16(in + in_stride * 185 + i); + int16x8_t v819 = vld1q_s16(in + in_stride * 183 + i); + int16x8_t v820 = vaddq_s16(v818, v819); + int16x8_t v821 = vaddq_s16(v817, v820); + int16x8_t v822 = vaddq_s16(v821, v813); + int16x8_t v823 = vaddq_s16(v814, v822); + int16x8_t v824 = vqrdmulhq_n_s16(v823, 17734); + int16x8_t v825 = vaddq_s16(v806, v824); + int16x8_t v826 = vld1q_s16(in + in_stride * 41 + i); + int16x8_t v827 = vld1q_s16(in + in_stride * 39 + i); + int16x8_t v828 = vaddq_s16(v826, v827); + int16x8_t v829 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v830 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v831 = vaddq_s16(v829, v830); + int16x8_t v832 = vaddq_s16(v828, v831); + int16x8_t v833_tmp = vqrdmulhq_n_s16(v832, 13573); + int16x8_t v833 = vaddq_s16(v833_tmp, v832); + int16x8_t v834 = vld1q_s16(in + in_stride * 169 + i); + int16x8_t v835 = vld1q_s16(in + in_stride * 167 + i); + int16x8_t v836 = vaddq_s16(v834, v835); + int16x8_t v837 = vld1q_s16(in + in_stride * 153 + i); + int16x8_t v838 = vld1q_s16(in + in_stride * 151 + i); + int16x8_t v839 = vaddq_s16(v837, v838); + int16x8_t v840 = vaddq_s16(v836, v839); + int16x8_t v841 = vld1q_s16(in + in_stride * 105 + i); + int16x8_t v842 = vld1q_s16(in + in_stride * 103 + i); + int16x8_t v843 = vaddq_s16(v841, v842); + int16x8_t v844 = vld1q_s16(in + in_stride * 89 + i); + int16x8_t v845 = vld1q_s16(in + in_stride * 87 + i); + int16x8_t v846 = vaddq_s16(v844, v845); + int16x8_t v847 = vaddq_s16(v843, v846); + int16x8_t v848 = vaddq_s16(v840, v847); + int16x8_t v849 = vaddq_s16(v833, v848); + int16x8_t v850 = vaddq_s16(v847, v832); + int16x8_t v851_tmp = vqrdmulhq_n_s16(v850, 13573); + int16x8_t v851 = vaddq_s16(v851_tmp, v850); + int16x8_t v852 = vld1q_s16(in + in_stride * 233 + i); + int16x8_t v853 = vld1q_s16(in + in_stride * 231 + i); + int16x8_t v854 = vaddq_s16(v852, v853); + int16x8_t v855 = vld1q_s16(in + in_stride * 217 + i); + int16x8_t v856 = vld1q_s16(in + in_stride * 215 + i); + int16x8_t v857 = vaddq_s16(v855, v856); + int16x8_t v858 = vaddq_s16(v854, v857); + int16x8_t v859 = vaddq_s16(v858, v840); + int16x8_t v860 = vaddq_s16(v859, v850); + int16x8_t v861 = vaddq_s16(v851, v860); + int16x8_t v862 = vqrdmulhq_n_s16(v861, 17734); + int16x8_t v863 = vaddq_s16(v849, v862); + int16x8_t v864 = vqrdmulhq_n_s16(v863, 16705); + int16x8_t v865 = vaddq_s16(v825, v864); + int16x8_t v866 = vaddq_s16(v831, v797); + int16x8_t v867_tmp = vqrdmulhq_n_s16(v866, 13573); + int16x8_t v867 = vaddq_s16(v867_tmp, v866); + int16x8_t v868 = vaddq_s16(v839, v801); + int16x8_t v869 = vaddq_s16(v804, v843); + int16x8_t v870 = vaddq_s16(v868, v869); + int16x8_t v871 = vaddq_s16(v867, v870); + int16x8_t v872 = vaddq_s16(v846, v809); + int16x8_t v873 = vaddq_s16(v812, v828); + int16x8_t v874 = vaddq_s16(v872, v873); + int16x8_t v875_tmp = vqrdmulhq_n_s16(v874, 13573); + int16x8_t v875 = vaddq_s16(v875_tmp, v874); + int16x8_t v876 = vaddq_s16(v857, v817); + int16x8_t v877 = vaddq_s16(v820, v836); + int16x8_t v878 = vaddq_s16(v876, v877); + int16x8_t v879 = vaddq_s16(v878, v874); + int16x8_t v880 = vaddq_s16(v875, v879); + int16x8_t v881 = vqrdmulhq_n_s16(v880, 17734); + int16x8_t v882 = vaddq_s16(v871, v881); + int16x8_t v883 = vaddq_s16(v873, v866); + int16x8_t v884_tmp = vqrdmulhq_n_s16(v883, 13573); + int16x8_t v884 = vaddq_s16(v884_tmp, v883); + int16x8_t v885 = vaddq_s16(v877, v868); + int16x8_t v886 = vaddq_s16(v869, v872); + int16x8_t v887 = vaddq_s16(v885, v886); + int16x8_t v888 = vaddq_s16(v884, v887); + int16x8_t v889 = vaddq_s16(v886, v883); + int16x8_t v890_tmp = vqrdmulhq_n_s16(v889, 13573); + int16x8_t v890 = vaddq_s16(v890_tmp, v889); + int16x8_t v891 = vld1q_s16(in + in_stride * 249 + i); + int16x8_t v892 = vld1q_s16(in + in_stride * 247 + i); + int16x8_t v893 = vaddq_s16(v891, v892); + int16x8_t v894 = vaddq_s16(v893, v854); + int16x8_t v895 = vaddq_s16(v894, v876); + int16x8_t v896 = vaddq_s16(v895, v885); + int16x8_t v897 = vaddq_s16(v896, v889); + int16x8_t v898 = vaddq_s16(v890, v897); + int16x8_t v899 = vqrdmulhq_n_s16(v898, 17734); + int16x8_t v900 = vaddq_s16(v888, v899); + int16x8_t v901 = vqrdmulhq_n_s16(v900, 16705); + int16x8_t v902 = vaddq_s16(v882, v901); + int16x8_t v903 = vqrdmulhq_n_s16(v902, 16463); + int16x8_t v904 = vaddq_s16(v865, v903); + int16x8_t v905 = vqrdmulhq_n_s16(v904, 16404); + int16x8_t v906 = vaddq_s16(v794, v905); + int16x8_t v907 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v908 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v909 = vaddq_s16(v907, v908); + int16x8_t v910_tmp = vqrdmulhq_n_s16(v909, 13573); + int16x8_t v910 = vaddq_s16(v910_tmp, v909); + int16x8_t v911 = vld1q_s16(in + in_stride * 133 + i); + int16x8_t v912 = vld1q_s16(in + in_stride * 131 + i); + int16x8_t v913 = vaddq_s16(v911, v912); + int16x8_t v914 = vld1q_s16(in + in_stride * 125 + i); + int16x8_t v915 = vld1q_s16(in + in_stride * 123 + i); + int16x8_t v916 = vaddq_s16(v914, v915); + int16x8_t v917 = vaddq_s16(v913, v916); + int16x8_t v918 = vaddq_s16(v910, v917); + int16x8_t v919 = vld1q_s16(in + in_stride * 69 + i); + int16x8_t v920 = vld1q_s16(in + in_stride * 67 + i); + int16x8_t v921 = vaddq_s16(v919, v920); + int16x8_t v922 = vld1q_s16(in + in_stride * 61 + i); + int16x8_t v923 = vld1q_s16(in + in_stride * 59 + i); + int16x8_t v924 = vaddq_s16(v922, v923); + int16x8_t v925 = vaddq_s16(v921, v924); + int16x8_t v926_tmp = vqrdmulhq_n_s16(v925, 13573); + int16x8_t v926 = vaddq_s16(v926_tmp, v925); + int16x8_t v927 = vld1q_s16(in + in_stride * 197 + i); + int16x8_t v928 = vld1q_s16(in + in_stride * 195 + i); + int16x8_t v929 = vaddq_s16(v927, v928); + int16x8_t v930 = vld1q_s16(in + in_stride * 189 + i); + int16x8_t v931 = vld1q_s16(in + in_stride * 187 + i); + int16x8_t v932 = vaddq_s16(v930, v931); + int16x8_t v933 = vaddq_s16(v929, v932); + int16x8_t v934 = vaddq_s16(v933, v925); + int16x8_t v935 = vaddq_s16(v926, v934); + int16x8_t v936 = vqrdmulhq_n_s16(v935, 17734); + int16x8_t v937 = vaddq_s16(v918, v936); + int16x8_t v938 = vld1q_s16(in + in_stride * 37 + i); + int16x8_t v939 = vld1q_s16(in + in_stride * 35 + i); + int16x8_t v940 = vaddq_s16(v938, v939); + int16x8_t v941 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v942 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v943 = vaddq_s16(v941, v942); + int16x8_t v944 = vaddq_s16(v940, v943); + int16x8_t v945_tmp = vqrdmulhq_n_s16(v944, 13573); + int16x8_t v945 = vaddq_s16(v945_tmp, v944); + int16x8_t v946 = vld1q_s16(in + in_stride * 165 + i); + int16x8_t v947 = vld1q_s16(in + in_stride * 163 + i); + int16x8_t v948 = vaddq_s16(v946, v947); + int16x8_t v949 = vld1q_s16(in + in_stride * 157 + i); + int16x8_t v950 = vld1q_s16(in + in_stride * 155 + i); + int16x8_t v951 = vaddq_s16(v949, v950); + int16x8_t v952 = vaddq_s16(v948, v951); + int16x8_t v953 = vld1q_s16(in + in_stride * 101 + i); + int16x8_t v954 = vld1q_s16(in + in_stride * 99 + i); + int16x8_t v955 = vaddq_s16(v953, v954); + int16x8_t v956 = vld1q_s16(in + in_stride * 93 + i); + int16x8_t v957 = vld1q_s16(in + in_stride * 91 + i); + int16x8_t v958 = vaddq_s16(v956, v957); + int16x8_t v959 = vaddq_s16(v955, v958); + int16x8_t v960 = vaddq_s16(v952, v959); + int16x8_t v961 = vaddq_s16(v945, v960); + int16x8_t v962 = vaddq_s16(v959, v944); + int16x8_t v963_tmp = vqrdmulhq_n_s16(v962, 13573); + int16x8_t v963 = vaddq_s16(v963_tmp, v962); + int16x8_t v964 = vld1q_s16(in + in_stride * 229 + i); + int16x8_t v965 = vld1q_s16(in + in_stride * 227 + i); + int16x8_t v966 = vaddq_s16(v964, v965); + int16x8_t v967 = vld1q_s16(in + in_stride * 221 + i); + int16x8_t v968 = vld1q_s16(in + in_stride * 219 + i); + int16x8_t v969 = vaddq_s16(v967, v968); + int16x8_t v970 = vaddq_s16(v966, v969); + int16x8_t v971 = vaddq_s16(v970, v952); + int16x8_t v972 = vaddq_s16(v971, v962); + int16x8_t v973 = vaddq_s16(v963, v972); + int16x8_t v974 = vqrdmulhq_n_s16(v973, 17734); + int16x8_t v975 = vaddq_s16(v961, v974); + int16x8_t v976 = vqrdmulhq_n_s16(v975, 16705); + int16x8_t v977 = vaddq_s16(v937, v976); + int16x8_t v978 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v979 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v980 = vaddq_s16(v978, v979); + int16x8_t v981 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v982 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v983 = vaddq_s16(v981, v982); + int16x8_t v984 = vaddq_s16(v980, v983); + int16x8_t v985_tmp = vqrdmulhq_n_s16(v984, 13573); + int16x8_t v985 = vaddq_s16(v985_tmp, v984); + int16x8_t v986 = vld1q_s16(in + in_stride * 149 + i); + int16x8_t v987 = vld1q_s16(in + in_stride * 147 + i); + int16x8_t v988 = vaddq_s16(v986, v987); + int16x8_t v989 = vld1q_s16(in + in_stride * 141 + i); + int16x8_t v990 = vld1q_s16(in + in_stride * 139 + i); + int16x8_t v991 = vaddq_s16(v989, v990); + int16x8_t v992 = vaddq_s16(v988, v991); + int16x8_t v993 = vld1q_s16(in + in_stride * 117 + i); + int16x8_t v994 = vld1q_s16(in + in_stride * 115 + i); + int16x8_t v995 = vaddq_s16(v993, v994); + int16x8_t v996 = vld1q_s16(in + in_stride * 109 + i); + int16x8_t v997 = vld1q_s16(in + in_stride * 107 + i); + int16x8_t v998 = vaddq_s16(v996, v997); + int16x8_t v999 = vaddq_s16(v995, v998); + int16x8_t v1000 = vaddq_s16(v992, v999); + int16x8_t v1001 = vaddq_s16(v985, v1000); + int16x8_t v1002 = vld1q_s16(in + in_stride * 85 + i); + int16x8_t v1003 = vld1q_s16(in + in_stride * 83 + i); + int16x8_t v1004 = vaddq_s16(v1002, v1003); + int16x8_t v1005 = vld1q_s16(in + in_stride * 77 + i); + int16x8_t v1006 = vld1q_s16(in + in_stride * 75 + i); + int16x8_t v1007 = vaddq_s16(v1005, v1006); + int16x8_t v1008 = vaddq_s16(v1004, v1007); + int16x8_t v1009 = vld1q_s16(in + in_stride * 53 + i); + int16x8_t v1010 = vld1q_s16(in + in_stride * 51 + i); + int16x8_t v1011 = vaddq_s16(v1009, v1010); + int16x8_t v1012 = vld1q_s16(in + in_stride * 45 + i); + int16x8_t v1013 = vld1q_s16(in + in_stride * 43 + i); + int16x8_t v1014 = vaddq_s16(v1012, v1013); + int16x8_t v1015 = vaddq_s16(v1011, v1014); + int16x8_t v1016 = vaddq_s16(v1008, v1015); + int16x8_t v1017_tmp = vqrdmulhq_n_s16(v1016, 13573); + int16x8_t v1017 = vaddq_s16(v1017_tmp, v1016); + int16x8_t v1018 = vld1q_s16(in + in_stride * 213 + i); + int16x8_t v1019 = vld1q_s16(in + in_stride * 211 + i); + int16x8_t v1020 = vaddq_s16(v1018, v1019); + int16x8_t v1021 = vld1q_s16(in + in_stride * 205 + i); + int16x8_t v1022 = vld1q_s16(in + in_stride * 203 + i); + int16x8_t v1023 = vaddq_s16(v1021, v1022); + int16x8_t v1024 = vaddq_s16(v1020, v1023); + int16x8_t v1025 = vld1q_s16(in + in_stride * 181 + i); + int16x8_t v1026 = vld1q_s16(in + in_stride * 179 + i); + int16x8_t v1027 = vaddq_s16(v1025, v1026); + int16x8_t v1028 = vld1q_s16(in + in_stride * 173 + i); + int16x8_t v1029 = vld1q_s16(in + in_stride * 171 + i); + int16x8_t v1030 = vaddq_s16(v1028, v1029); + int16x8_t v1031 = vaddq_s16(v1027, v1030); + int16x8_t v1032 = vaddq_s16(v1024, v1031); + int16x8_t v1033 = vaddq_s16(v1032, v1016); + int16x8_t v1034 = vaddq_s16(v1017, v1033); + int16x8_t v1035 = vqrdmulhq_n_s16(v1034, 17734); + int16x8_t v1036 = vaddq_s16(v1001, v1035); + int16x8_t v1037 = vaddq_s16(v1015, v984); + int16x8_t v1038_tmp = vqrdmulhq_n_s16(v1037, 13573); + int16x8_t v1038 = vaddq_s16(v1038_tmp, v1037); + int16x8_t v1039 = vaddq_s16(v1031, v992); + int16x8_t v1040 = vaddq_s16(v999, v1008); + int16x8_t v1041 = vaddq_s16(v1039, v1040); + int16x8_t v1042 = vaddq_s16(v1038, v1041); + int16x8_t v1043 = vaddq_s16(v1040, v1037); + int16x8_t v1044_tmp = vqrdmulhq_n_s16(v1043, 13573); + int16x8_t v1044 = vaddq_s16(v1044_tmp, v1043); + int16x8_t v1045 = vld1q_s16(in + in_stride * 245 + i); + int16x8_t v1046 = vld1q_s16(in + in_stride * 243 + i); + int16x8_t v1047 = vaddq_s16(v1045, v1046); + int16x8_t v1048 = vld1q_s16(in + in_stride * 237 + i); + int16x8_t v1049 = vld1q_s16(in + in_stride * 235 + i); + int16x8_t v1050 = vaddq_s16(v1048, v1049); + int16x8_t v1051 = vaddq_s16(v1047, v1050); + int16x8_t v1052 = vaddq_s16(v1051, v1024); + int16x8_t v1053 = vaddq_s16(v1052, v1039); + int16x8_t v1054 = vaddq_s16(v1053, v1043); + int16x8_t v1055 = vaddq_s16(v1044, v1054); + int16x8_t v1056 = vqrdmulhq_n_s16(v1055, 17734); + int16x8_t v1057 = vaddq_s16(v1042, v1056); + int16x8_t v1058 = vqrdmulhq_n_s16(v1057, 16705); + int16x8_t v1059 = vaddq_s16(v1036, v1058); + int16x8_t v1060 = vqrdmulhq_n_s16(v1059, 16463); + int16x8_t v1061 = vaddq_s16(v977, v1060); + int16x8_t v1062 = vaddq_s16(v983, v909); + int16x8_t v1063_tmp = vqrdmulhq_n_s16(v1062, 13573); + int16x8_t v1063 = vaddq_s16(v1063_tmp, v1062); + int16x8_t v1064 = vaddq_s16(v991, v913); + int16x8_t v1065 = vaddq_s16(v916, v995); + int16x8_t v1066 = vaddq_s16(v1064, v1065); + int16x8_t v1067 = vaddq_s16(v1063, v1066); + int16x8_t v1068 = vaddq_s16(v1007, v921); + int16x8_t v1069 = vaddq_s16(v924, v1011); + int16x8_t v1070 = vaddq_s16(v1068, v1069); + int16x8_t v1071_tmp = vqrdmulhq_n_s16(v1070, 13573); + int16x8_t v1071 = vaddq_s16(v1071_tmp, v1070); + int16x8_t v1072 = vaddq_s16(v1023, v929); + int16x8_t v1073 = vaddq_s16(v932, v1027); + int16x8_t v1074 = vaddq_s16(v1072, v1073); + int16x8_t v1075 = vaddq_s16(v1074, v1070); + int16x8_t v1076 = vaddq_s16(v1071, v1075); + int16x8_t v1077 = vqrdmulhq_n_s16(v1076, 17734); + int16x8_t v1078 = vaddq_s16(v1067, v1077); + int16x8_t v1079 = vaddq_s16(v1014, v940); + int16x8_t v1080 = vaddq_s16(v943, v980); + int16x8_t v1081 = vaddq_s16(v1079, v1080); + int16x8_t v1082_tmp = vqrdmulhq_n_s16(v1081, 13573); + int16x8_t v1082 = vaddq_s16(v1082_tmp, v1081); + int16x8_t v1083 = vaddq_s16(v1030, v948); + int16x8_t v1084 = vaddq_s16(v951, v988); + int16x8_t v1085 = vaddq_s16(v1083, v1084); + int16x8_t v1086 = vaddq_s16(v998, v955); + int16x8_t v1087 = vaddq_s16(v958, v1004); + int16x8_t v1088 = vaddq_s16(v1086, v1087); + int16x8_t v1089 = vaddq_s16(v1085, v1088); + int16x8_t v1090 = vaddq_s16(v1082, v1089); + int16x8_t v1091 = vaddq_s16(v1088, v1081); + int16x8_t v1092_tmp = vqrdmulhq_n_s16(v1091, 13573); + int16x8_t v1092 = vaddq_s16(v1092_tmp, v1091); + int16x8_t v1093 = vaddq_s16(v1050, v966); + int16x8_t v1094 = vaddq_s16(v969, v1020); + int16x8_t v1095 = vaddq_s16(v1093, v1094); + int16x8_t v1096 = vaddq_s16(v1095, v1085); + int16x8_t v1097 = vaddq_s16(v1096, v1091); + int16x8_t v1098 = vaddq_s16(v1092, v1097); + int16x8_t v1099 = vqrdmulhq_n_s16(v1098, 17734); + int16x8_t v1100 = vaddq_s16(v1090, v1099); + int16x8_t v1101 = vqrdmulhq_n_s16(v1100, 16705); + int16x8_t v1102 = vaddq_s16(v1078, v1101); + int16x8_t v1103 = vaddq_s16(v1080, v1062); + int16x8_t v1104_tmp = vqrdmulhq_n_s16(v1103, 13573); + int16x8_t v1104 = vaddq_s16(v1104_tmp, v1103); + int16x8_t v1105 = vaddq_s16(v1084, v1064); + int16x8_t v1106 = vaddq_s16(v1065, v1086); + int16x8_t v1107 = vaddq_s16(v1105, v1106); + int16x8_t v1108 = vaddq_s16(v1104, v1107); + int16x8_t v1109 = vaddq_s16(v1087, v1068); + int16x8_t v1110 = vaddq_s16(v1069, v1079); + int16x8_t v1111 = vaddq_s16(v1109, v1110); + int16x8_t v1112_tmp = vqrdmulhq_n_s16(v1111, 13573); + int16x8_t v1112 = vaddq_s16(v1112_tmp, v1111); + int16x8_t v1113 = vaddq_s16(v1094, v1072); + int16x8_t v1114 = vaddq_s16(v1073, v1083); + int16x8_t v1115 = vaddq_s16(v1113, v1114); + int16x8_t v1116 = vaddq_s16(v1115, v1111); + int16x8_t v1117 = vaddq_s16(v1112, v1116); + int16x8_t v1118 = vqrdmulhq_n_s16(v1117, 17734); + int16x8_t v1119 = vaddq_s16(v1108, v1118); + int16x8_t v1120 = vaddq_s16(v1110, v1103); + int16x8_t v1121_tmp = vqrdmulhq_n_s16(v1120, 13573); + int16x8_t v1121 = vaddq_s16(v1121_tmp, v1120); + int16x8_t v1122 = vaddq_s16(v1114, v1105); + int16x8_t v1123 = vaddq_s16(v1106, v1109); + int16x8_t v1124 = vaddq_s16(v1122, v1123); + int16x8_t v1125 = vaddq_s16(v1121, v1124); + int16x8_t v1126 = vaddq_s16(v1123, v1120); + int16x8_t v1127_tmp = vqrdmulhq_n_s16(v1126, 13573); + int16x8_t v1127 = vaddq_s16(v1127_tmp, v1126); + int16x8_t v1128 = vld1q_s16(in + in_stride * 253 + i); + int16x8_t v1129 = vld1q_s16(in + in_stride * 251 + i); + int16x8_t v1130 = vaddq_s16(v1128, v1129); + int16x8_t v1131 = vaddq_s16(v1130, v1047); + int16x8_t v1132 = vaddq_s16(v1131, v1093); + int16x8_t v1133 = vaddq_s16(v1132, v1113); + int16x8_t v1134 = vaddq_s16(v1133, v1122); + int16x8_t v1135 = vaddq_s16(v1134, v1126); + int16x8_t v1136 = vaddq_s16(v1127, v1135); + int16x8_t v1137 = vqrdmulhq_n_s16(v1136, 17734); + int16x8_t v1138 = vaddq_s16(v1125, v1137); + int16x8_t v1139 = vqrdmulhq_n_s16(v1138, 16705); + int16x8_t v1140 = vaddq_s16(v1119, v1139); + int16x8_t v1141 = vqrdmulhq_n_s16(v1140, 16463); + int16x8_t v1142 = vaddq_s16(v1102, v1141); + int16x8_t v1143 = vqrdmulhq_n_s16(v1142, 16404); + int16x8_t v1144 = vaddq_s16(v1061, v1143); + int16x8_t v1145 = vqrdmulhq_n_s16(v1144, 16389); + int16x8_t v1146 = vaddq_s16(v906, v1145); + int16x8_t v1147 = vaddq_s16(v908, v702); + int16x8_t v1148_tmp = vqrdmulhq_n_s16(v1147, 13573); + int16x8_t v1148 = vaddq_s16(v1148_tmp, v1147); + int16x8_t v1149 = vaddq_s16(v912, v704); + int16x8_t v1150 = vaddq_s16(v705, v914); + int16x8_t v1151 = vaddq_s16(v1149, v1150); + int16x8_t v1152 = vaddq_s16(v1148, v1151); + int16x8_t v1153 = vaddq_s16(v920, v708); + int16x8_t v1154 = vaddq_s16(v709, v922); + int16x8_t v1155 = vaddq_s16(v1153, v1154); + int16x8_t v1156_tmp = vqrdmulhq_n_s16(v1155, 13573); + int16x8_t v1156 = vaddq_s16(v1156_tmp, v1155); + int16x8_t v1157 = vaddq_s16(v928, v712); + int16x8_t v1158 = vaddq_s16(v713, v930); + int16x8_t v1159 = vaddq_s16(v1157, v1158); + int16x8_t v1160 = vaddq_s16(v1159, v1155); + int16x8_t v1161 = vaddq_s16(v1156, v1160); + int16x8_t v1162 = vqrdmulhq_n_s16(v1161, 17734); + int16x8_t v1163 = vaddq_s16(v1152, v1162); + int16x8_t v1164 = vaddq_s16(v939, v719); + int16x8_t v1165 = vaddq_s16(v720, v941); + int16x8_t v1166 = vaddq_s16(v1164, v1165); + int16x8_t v1167_tmp = vqrdmulhq_n_s16(v1166, 13573); + int16x8_t v1167 = vaddq_s16(v1167_tmp, v1166); + int16x8_t v1168 = vaddq_s16(v947, v723); + int16x8_t v1169 = vaddq_s16(v724, v949); + int16x8_t v1170 = vaddq_s16(v1168, v1169); + int16x8_t v1171 = vaddq_s16(v954, v726); + int16x8_t v1172 = vaddq_s16(v727, v956); + int16x8_t v1173 = vaddq_s16(v1171, v1172); + int16x8_t v1174 = vaddq_s16(v1170, v1173); + int16x8_t v1175 = vaddq_s16(v1167, v1174); + int16x8_t v1176 = vaddq_s16(v1173, v1166); + int16x8_t v1177_tmp = vqrdmulhq_n_s16(v1176, 13573); + int16x8_t v1177 = vaddq_s16(v1177_tmp, v1176); + int16x8_t v1178 = vaddq_s16(v965, v733); + int16x8_t v1179 = vaddq_s16(v734, v967); + int16x8_t v1180 = vaddq_s16(v1178, v1179); + int16x8_t v1181 = vaddq_s16(v1180, v1170); + int16x8_t v1182 = vaddq_s16(v1181, v1176); + int16x8_t v1183 = vaddq_s16(v1177, v1182); + int16x8_t v1184 = vqrdmulhq_n_s16(v1183, 17734); + int16x8_t v1185 = vaddq_s16(v1175, v1184); + int16x8_t v1186 = vqrdmulhq_n_s16(v1185, 16705); + int16x8_t v1187 = vaddq_s16(v1163, v1186); + int16x8_t v1188 = vaddq_s16(v979, v743); + int16x8_t v1189 = vaddq_s16(v744, v981); + int16x8_t v1190 = vaddq_s16(v1188, v1189); + int16x8_t v1191_tmp = vqrdmulhq_n_s16(v1190, 13573); + int16x8_t v1191 = vaddq_s16(v1191_tmp, v1190); + int16x8_t v1192 = vaddq_s16(v987, v747); + int16x8_t v1193 = vaddq_s16(v748, v989); + int16x8_t v1194 = vaddq_s16(v1192, v1193); + int16x8_t v1195 = vaddq_s16(v994, v750); + int16x8_t v1196 = vaddq_s16(v751, v996); + int16x8_t v1197 = vaddq_s16(v1195, v1196); + int16x8_t v1198 = vaddq_s16(v1194, v1197); + int16x8_t v1199 = vaddq_s16(v1191, v1198); + int16x8_t v1200 = vaddq_s16(v1003, v755); + int16x8_t v1201 = vaddq_s16(v756, v1005); + int16x8_t v1202 = vaddq_s16(v1200, v1201); + int16x8_t v1203 = vaddq_s16(v1010, v758); + int16x8_t v1204 = vaddq_s16(v759, v1012); + int16x8_t v1205 = vaddq_s16(v1203, v1204); + int16x8_t v1206 = vaddq_s16(v1202, v1205); + int16x8_t v1207_tmp = vqrdmulhq_n_s16(v1206, 13573); + int16x8_t v1207 = vaddq_s16(v1207_tmp, v1206); + int16x8_t v1208 = vaddq_s16(v1019, v763); + int16x8_t v1209 = vaddq_s16(v764, v1021); + int16x8_t v1210 = vaddq_s16(v1208, v1209); + int16x8_t v1211 = vaddq_s16(v1026, v766); + int16x8_t v1212 = vaddq_s16(v767, v1028); + int16x8_t v1213 = vaddq_s16(v1211, v1212); + int16x8_t v1214 = vaddq_s16(v1210, v1213); + int16x8_t v1215 = vaddq_s16(v1214, v1206); + int16x8_t v1216 = vaddq_s16(v1207, v1215); + int16x8_t v1217 = vqrdmulhq_n_s16(v1216, 17734); + int16x8_t v1218 = vaddq_s16(v1199, v1217); + int16x8_t v1219 = vaddq_s16(v1205, v1190); + int16x8_t v1220_tmp = vqrdmulhq_n_s16(v1219, 13573); + int16x8_t v1220 = vaddq_s16(v1220_tmp, v1219); + int16x8_t v1221 = vaddq_s16(v1213, v1194); + int16x8_t v1222 = vaddq_s16(v1197, v1202); + int16x8_t v1223 = vaddq_s16(v1221, v1222); + int16x8_t v1224 = vaddq_s16(v1220, v1223); + int16x8_t v1225 = vaddq_s16(v1222, v1219); + int16x8_t v1226_tmp = vqrdmulhq_n_s16(v1225, 13573); + int16x8_t v1226 = vaddq_s16(v1226_tmp, v1225); + int16x8_t v1227 = vaddq_s16(v1046, v782); + int16x8_t v1228 = vaddq_s16(v783, v1048); + int16x8_t v1229 = vaddq_s16(v1227, v1228); + int16x8_t v1230 = vaddq_s16(v1229, v1210); + int16x8_t v1231 = vaddq_s16(v1230, v1221); + int16x8_t v1232 = vaddq_s16(v1231, v1225); + int16x8_t v1233 = vaddq_s16(v1226, v1232); + int16x8_t v1234 = vqrdmulhq_n_s16(v1233, 17734); + int16x8_t v1235 = vaddq_s16(v1224, v1234); + int16x8_t v1236 = vqrdmulhq_n_s16(v1235, 16705); + int16x8_t v1237 = vaddq_s16(v1218, v1236); + int16x8_t v1238 = vqrdmulhq_n_s16(v1237, 16463); + int16x8_t v1239 = vaddq_s16(v1187, v1238); + int16x8_t v1240 = vaddq_s16(v982, v795); + int16x8_t v1241 = vaddq_s16(v796, v907); + int16x8_t v1242 = vaddq_s16(v1240, v1241); + int16x8_t v1243_tmp = vqrdmulhq_n_s16(v1242, 13573); + int16x8_t v1243 = vaddq_s16(v1243_tmp, v1242); + int16x8_t v1244 = vaddq_s16(v990, v799); + int16x8_t v1245 = vaddq_s16(v800, v911); + int16x8_t v1246 = vaddq_s16(v1244, v1245); + int16x8_t v1247 = vaddq_s16(v915, v802); + int16x8_t v1248 = vaddq_s16(v803, v993); + int16x8_t v1249 = vaddq_s16(v1247, v1248); + int16x8_t v1250 = vaddq_s16(v1246, v1249); + int16x8_t v1251 = vaddq_s16(v1243, v1250); + int16x8_t v1252 = vaddq_s16(v1006, v807); + int16x8_t v1253 = vaddq_s16(v808, v919); + int16x8_t v1254 = vaddq_s16(v1252, v1253); + int16x8_t v1255 = vaddq_s16(v923, v810); + int16x8_t v1256 = vaddq_s16(v811, v1009); + int16x8_t v1257 = vaddq_s16(v1255, v1256); + int16x8_t v1258 = vaddq_s16(v1254, v1257); + int16x8_t v1259_tmp = vqrdmulhq_n_s16(v1258, 13573); + int16x8_t v1259 = vaddq_s16(v1259_tmp, v1258); + int16x8_t v1260 = vaddq_s16(v1022, v815); + int16x8_t v1261 = vaddq_s16(v816, v927); + int16x8_t v1262 = vaddq_s16(v1260, v1261); + int16x8_t v1263 = vaddq_s16(v931, v818); + int16x8_t v1264 = vaddq_s16(v819, v1025); + int16x8_t v1265 = vaddq_s16(v1263, v1264); + int16x8_t v1266 = vaddq_s16(v1262, v1265); + int16x8_t v1267 = vaddq_s16(v1266, v1258); + int16x8_t v1268 = vaddq_s16(v1259, v1267); + int16x8_t v1269 = vqrdmulhq_n_s16(v1268, 17734); + int16x8_t v1270 = vaddq_s16(v1251, v1269); + int16x8_t v1271 = vaddq_s16(v1013, v826); + int16x8_t v1272 = vaddq_s16(v827, v938); + int16x8_t v1273 = vaddq_s16(v1271, v1272); + int16x8_t v1274 = vaddq_s16(v942, v829); + int16x8_t v1275 = vaddq_s16(v830, v978); + int16x8_t v1276 = vaddq_s16(v1274, v1275); + int16x8_t v1277 = vaddq_s16(v1273, v1276); + int16x8_t v1278_tmp = vqrdmulhq_n_s16(v1277, 13573); + int16x8_t v1278 = vaddq_s16(v1278_tmp, v1277); + int16x8_t v1279 = vaddq_s16(v1029, v834); + int16x8_t v1280 = vaddq_s16(v835, v946); + int16x8_t v1281 = vaddq_s16(v1279, v1280); + int16x8_t v1282 = vaddq_s16(v950, v837); + int16x8_t v1283 = vaddq_s16(v838, v986); + int16x8_t v1284 = vaddq_s16(v1282, v1283); + int16x8_t v1285 = vaddq_s16(v1281, v1284); + int16x8_t v1286 = vaddq_s16(v997, v841); + int16x8_t v1287 = vaddq_s16(v842, v953); + int16x8_t v1288 = vaddq_s16(v1286, v1287); + int16x8_t v1289 = vaddq_s16(v957, v844); + int16x8_t v1290 = vaddq_s16(v845, v1002); + int16x8_t v1291 = vaddq_s16(v1289, v1290); + int16x8_t v1292 = vaddq_s16(v1288, v1291); + int16x8_t v1293 = vaddq_s16(v1285, v1292); + int16x8_t v1294 = vaddq_s16(v1278, v1293); + int16x8_t v1295 = vaddq_s16(v1292, v1277); + int16x8_t v1296_tmp = vqrdmulhq_n_s16(v1295, 13573); + int16x8_t v1296 = vaddq_s16(v1296_tmp, v1295); + int16x8_t v1297 = vaddq_s16(v1049, v852); + int16x8_t v1298 = vaddq_s16(v853, v964); + int16x8_t v1299 = vaddq_s16(v1297, v1298); + int16x8_t v1300 = vaddq_s16(v968, v855); + int16x8_t v1301 = vaddq_s16(v856, v1018); + int16x8_t v1302 = vaddq_s16(v1300, v1301); + int16x8_t v1303 = vaddq_s16(v1299, v1302); + int16x8_t v1304 = vaddq_s16(v1303, v1285); + int16x8_t v1305 = vaddq_s16(v1304, v1295); + int16x8_t v1306 = vaddq_s16(v1296, v1305); + int16x8_t v1307 = vqrdmulhq_n_s16(v1306, 17734); + int16x8_t v1308 = vaddq_s16(v1294, v1307); + int16x8_t v1309 = vqrdmulhq_n_s16(v1308, 16705); + int16x8_t v1310 = vaddq_s16(v1270, v1309); + int16x8_t v1311 = vaddq_s16(v1276, v1242); + int16x8_t v1312_tmp = vqrdmulhq_n_s16(v1311, 13573); + int16x8_t v1312 = vaddq_s16(v1312_tmp, v1311); + int16x8_t v1313 = vaddq_s16(v1284, v1246); + int16x8_t v1314 = vaddq_s16(v1249, v1288); + int16x8_t v1315 = vaddq_s16(v1313, v1314); + int16x8_t v1316 = vaddq_s16(v1312, v1315); + int16x8_t v1317 = vaddq_s16(v1291, v1254); + int16x8_t v1318 = vaddq_s16(v1257, v1273); + int16x8_t v1319 = vaddq_s16(v1317, v1318); + int16x8_t v1320_tmp = vqrdmulhq_n_s16(v1319, 13573); + int16x8_t v1320 = vaddq_s16(v1320_tmp, v1319); + int16x8_t v1321 = vaddq_s16(v1302, v1262); + int16x8_t v1322 = vaddq_s16(v1265, v1281); + int16x8_t v1323 = vaddq_s16(v1321, v1322); + int16x8_t v1324 = vaddq_s16(v1323, v1319); + int16x8_t v1325 = vaddq_s16(v1320, v1324); + int16x8_t v1326 = vqrdmulhq_n_s16(v1325, 17734); + int16x8_t v1327 = vaddq_s16(v1316, v1326); + int16x8_t v1328 = vaddq_s16(v1318, v1311); + int16x8_t v1329_tmp = vqrdmulhq_n_s16(v1328, 13573); + int16x8_t v1329 = vaddq_s16(v1329_tmp, v1328); + int16x8_t v1330 = vaddq_s16(v1322, v1313); + int16x8_t v1331 = vaddq_s16(v1314, v1317); + int16x8_t v1332 = vaddq_s16(v1330, v1331); + int16x8_t v1333 = vaddq_s16(v1329, v1332); + int16x8_t v1334 = vaddq_s16(v1331, v1328); + int16x8_t v1335_tmp = vqrdmulhq_n_s16(v1334, 13573); + int16x8_t v1335 = vaddq_s16(v1335_tmp, v1334); + int16x8_t v1336 = vaddq_s16(v1129, v891); + int16x8_t v1337 = vaddq_s16(v892, v1045); + int16x8_t v1338 = vaddq_s16(v1336, v1337); + int16x8_t v1339 = vaddq_s16(v1338, v1299); + int16x8_t v1340 = vaddq_s16(v1339, v1321); + int16x8_t v1341 = vaddq_s16(v1340, v1330); + int16x8_t v1342 = vaddq_s16(v1341, v1334); + int16x8_t v1343 = vaddq_s16(v1335, v1342); + int16x8_t v1344 = vqrdmulhq_n_s16(v1343, 17734); + int16x8_t v1345 = vaddq_s16(v1333, v1344); + int16x8_t v1346 = vqrdmulhq_n_s16(v1345, 16705); + int16x8_t v1347 = vaddq_s16(v1327, v1346); + int16x8_t v1348 = vqrdmulhq_n_s16(v1347, 16463); + int16x8_t v1349 = vaddq_s16(v1310, v1348); + int16x8_t v1350 = vqrdmulhq_n_s16(v1349, 16404); + int16x8_t v1351 = vaddq_s16(v1239, v1350); + int16x8_t v1352 = vaddq_s16(v1241, v1147); + int16x8_t v1353_tmp = vqrdmulhq_n_s16(v1352, 13573); + int16x8_t v1353 = vaddq_s16(v1353_tmp, v1352); + int16x8_t v1354 = vaddq_s16(v1245, v1149); + int16x8_t v1355 = vaddq_s16(v1150, v1247); + int16x8_t v1356 = vaddq_s16(v1354, v1355); + int16x8_t v1357 = vaddq_s16(v1353, v1356); + int16x8_t v1358 = vaddq_s16(v1253, v1153); + int16x8_t v1359 = vaddq_s16(v1154, v1255); + int16x8_t v1360 = vaddq_s16(v1358, v1359); + int16x8_t v1361_tmp = vqrdmulhq_n_s16(v1360, 13573); + int16x8_t v1361 = vaddq_s16(v1361_tmp, v1360); + int16x8_t v1362 = vaddq_s16(v1261, v1157); + int16x8_t v1363 = vaddq_s16(v1158, v1263); + int16x8_t v1364 = vaddq_s16(v1362, v1363); + int16x8_t v1365 = vaddq_s16(v1364, v1360); + int16x8_t v1366 = vaddq_s16(v1361, v1365); + int16x8_t v1367 = vqrdmulhq_n_s16(v1366, 17734); + int16x8_t v1368 = vaddq_s16(v1357, v1367); + int16x8_t v1369 = vaddq_s16(v1272, v1164); + int16x8_t v1370 = vaddq_s16(v1165, v1274); + int16x8_t v1371 = vaddq_s16(v1369, v1370); + int16x8_t v1372_tmp = vqrdmulhq_n_s16(v1371, 13573); + int16x8_t v1372 = vaddq_s16(v1372_tmp, v1371); + int16x8_t v1373 = vaddq_s16(v1280, v1168); + int16x8_t v1374 = vaddq_s16(v1169, v1282); + int16x8_t v1375 = vaddq_s16(v1373, v1374); + int16x8_t v1376 = vaddq_s16(v1287, v1171); + int16x8_t v1377 = vaddq_s16(v1172, v1289); + int16x8_t v1378 = vaddq_s16(v1376, v1377); + int16x8_t v1379 = vaddq_s16(v1375, v1378); + int16x8_t v1380 = vaddq_s16(v1372, v1379); + int16x8_t v1381 = vaddq_s16(v1378, v1371); + int16x8_t v1382_tmp = vqrdmulhq_n_s16(v1381, 13573); + int16x8_t v1382 = vaddq_s16(v1382_tmp, v1381); + int16x8_t v1383 = vaddq_s16(v1298, v1178); + int16x8_t v1384 = vaddq_s16(v1179, v1300); + int16x8_t v1385 = vaddq_s16(v1383, v1384); + int16x8_t v1386 = vaddq_s16(v1385, v1375); + int16x8_t v1387 = vaddq_s16(v1386, v1381); + int16x8_t v1388 = vaddq_s16(v1382, v1387); + int16x8_t v1389 = vqrdmulhq_n_s16(v1388, 17734); + int16x8_t v1390 = vaddq_s16(v1380, v1389); + int16x8_t v1391 = vqrdmulhq_n_s16(v1390, 16705); + int16x8_t v1392 = vaddq_s16(v1368, v1391); + int16x8_t v1393 = vaddq_s16(v1275, v1188); + int16x8_t v1394 = vaddq_s16(v1189, v1240); + int16x8_t v1395 = vaddq_s16(v1393, v1394); + int16x8_t v1396_tmp = vqrdmulhq_n_s16(v1395, 13573); + int16x8_t v1396 = vaddq_s16(v1396_tmp, v1395); + int16x8_t v1397 = vaddq_s16(v1283, v1192); + int16x8_t v1398 = vaddq_s16(v1193, v1244); + int16x8_t v1399 = vaddq_s16(v1397, v1398); + int16x8_t v1400 = vaddq_s16(v1248, v1195); + int16x8_t v1401 = vaddq_s16(v1196, v1286); + int16x8_t v1402 = vaddq_s16(v1400, v1401); + int16x8_t v1403 = vaddq_s16(v1399, v1402); + int16x8_t v1404 = vaddq_s16(v1396, v1403); + int16x8_t v1405 = vaddq_s16(v1290, v1200); + int16x8_t v1406 = vaddq_s16(v1201, v1252); + int16x8_t v1407 = vaddq_s16(v1405, v1406); + int16x8_t v1408 = vaddq_s16(v1256, v1203); + int16x8_t v1409 = vaddq_s16(v1204, v1271); + int16x8_t v1410 = vaddq_s16(v1408, v1409); + int16x8_t v1411 = vaddq_s16(v1407, v1410); + int16x8_t v1412_tmp = vqrdmulhq_n_s16(v1411, 13573); + int16x8_t v1412 = vaddq_s16(v1412_tmp, v1411); + int16x8_t v1413 = vaddq_s16(v1301, v1208); + int16x8_t v1414 = vaddq_s16(v1209, v1260); + int16x8_t v1415 = vaddq_s16(v1413, v1414); + int16x8_t v1416 = vaddq_s16(v1264, v1211); + int16x8_t v1417 = vaddq_s16(v1212, v1279); + int16x8_t v1418 = vaddq_s16(v1416, v1417); + int16x8_t v1419 = vaddq_s16(v1415, v1418); + int16x8_t v1420 = vaddq_s16(v1419, v1411); + int16x8_t v1421 = vaddq_s16(v1412, v1420); + int16x8_t v1422 = vqrdmulhq_n_s16(v1421, 17734); + int16x8_t v1423 = vaddq_s16(v1404, v1422); + int16x8_t v1424 = vaddq_s16(v1410, v1395); + int16x8_t v1425_tmp = vqrdmulhq_n_s16(v1424, 13573); + int16x8_t v1425 = vaddq_s16(v1425_tmp, v1424); + int16x8_t v1426 = vaddq_s16(v1418, v1399); + int16x8_t v1427 = vaddq_s16(v1402, v1407); + int16x8_t v1428 = vaddq_s16(v1426, v1427); + int16x8_t v1429 = vaddq_s16(v1425, v1428); + int16x8_t v1430 = vaddq_s16(v1427, v1424); + int16x8_t v1431_tmp = vqrdmulhq_n_s16(v1430, 13573); + int16x8_t v1431 = vaddq_s16(v1431_tmp, v1430); + int16x8_t v1432 = vaddq_s16(v1337, v1227); + int16x8_t v1433 = vaddq_s16(v1228, v1297); + int16x8_t v1434 = vaddq_s16(v1432, v1433); + int16x8_t v1435 = vaddq_s16(v1434, v1415); + int16x8_t v1436 = vaddq_s16(v1435, v1426); + int16x8_t v1437 = vaddq_s16(v1436, v1430); + int16x8_t v1438 = vaddq_s16(v1431, v1437); + int16x8_t v1439 = vqrdmulhq_n_s16(v1438, 17734); + int16x8_t v1440 = vaddq_s16(v1429, v1439); + int16x8_t v1441 = vqrdmulhq_n_s16(v1440, 16705); + int16x8_t v1442 = vaddq_s16(v1423, v1441); + int16x8_t v1443 = vqrdmulhq_n_s16(v1442, 16463); + int16x8_t v1444 = vaddq_s16(v1392, v1443); + int16x8_t v1445 = vaddq_s16(v1394, v1352); + int16x8_t v1446_tmp = vqrdmulhq_n_s16(v1445, 13573); + int16x8_t v1446 = vaddq_s16(v1446_tmp, v1445); + int16x8_t v1447 = vaddq_s16(v1398, v1354); + int16x8_t v1448 = vaddq_s16(v1355, v1400); + int16x8_t v1449 = vaddq_s16(v1447, v1448); + int16x8_t v1450 = vaddq_s16(v1446, v1449); + int16x8_t v1451 = vaddq_s16(v1406, v1358); + int16x8_t v1452 = vaddq_s16(v1359, v1408); + int16x8_t v1453 = vaddq_s16(v1451, v1452); + int16x8_t v1454_tmp = vqrdmulhq_n_s16(v1453, 13573); + int16x8_t v1454 = vaddq_s16(v1454_tmp, v1453); + int16x8_t v1455 = vaddq_s16(v1414, v1362); + int16x8_t v1456 = vaddq_s16(v1363, v1416); + int16x8_t v1457 = vaddq_s16(v1455, v1456); + int16x8_t v1458 = vaddq_s16(v1457, v1453); + int16x8_t v1459 = vaddq_s16(v1454, v1458); + int16x8_t v1460 = vqrdmulhq_n_s16(v1459, 17734); + int16x8_t v1461 = vaddq_s16(v1450, v1460); + int16x8_t v1462 = vaddq_s16(v1409, v1369); + int16x8_t v1463 = vaddq_s16(v1370, v1393); + int16x8_t v1464 = vaddq_s16(v1462, v1463); + int16x8_t v1465_tmp = vqrdmulhq_n_s16(v1464, 13573); + int16x8_t v1465 = vaddq_s16(v1465_tmp, v1464); + int16x8_t v1466 = vaddq_s16(v1417, v1373); + int16x8_t v1467 = vaddq_s16(v1374, v1397); + int16x8_t v1468 = vaddq_s16(v1466, v1467); + int16x8_t v1469 = vaddq_s16(v1401, v1376); + int16x8_t v1470 = vaddq_s16(v1377, v1405); + int16x8_t v1471 = vaddq_s16(v1469, v1470); + int16x8_t v1472 = vaddq_s16(v1468, v1471); + int16x8_t v1473 = vaddq_s16(v1465, v1472); + int16x8_t v1474 = vaddq_s16(v1471, v1464); + int16x8_t v1475_tmp = vqrdmulhq_n_s16(v1474, 13573); + int16x8_t v1475 = vaddq_s16(v1475_tmp, v1474); + int16x8_t v1476 = vaddq_s16(v1433, v1383); + int16x8_t v1477 = vaddq_s16(v1384, v1413); + int16x8_t v1478 = vaddq_s16(v1476, v1477); + int16x8_t v1479 = vaddq_s16(v1478, v1468); + int16x8_t v1480 = vaddq_s16(v1479, v1474); + int16x8_t v1481 = vaddq_s16(v1475, v1480); + int16x8_t v1482 = vqrdmulhq_n_s16(v1481, 17734); + int16x8_t v1483 = vaddq_s16(v1473, v1482); + int16x8_t v1484 = vqrdmulhq_n_s16(v1483, 16705); + int16x8_t v1485 = vaddq_s16(v1461, v1484); + int16x8_t v1486 = vaddq_s16(v1463, v1445); + int16x8_t v1487_tmp = vqrdmulhq_n_s16(v1486, 13573); + int16x8_t v1487 = vaddq_s16(v1487_tmp, v1486); + int16x8_t v1488 = vaddq_s16(v1467, v1447); + int16x8_t v1489 = vaddq_s16(v1448, v1469); + int16x8_t v1490 = vaddq_s16(v1488, v1489); + int16x8_t v1491 = vaddq_s16(v1487, v1490); + int16x8_t v1492 = vaddq_s16(v1470, v1451); + int16x8_t v1493 = vaddq_s16(v1452, v1462); + int16x8_t v1494 = vaddq_s16(v1492, v1493); + int16x8_t v1495_tmp = vqrdmulhq_n_s16(v1494, 13573); + int16x8_t v1495 = vaddq_s16(v1495_tmp, v1494); + int16x8_t v1496 = vaddq_s16(v1477, v1455); + int16x8_t v1497 = vaddq_s16(v1456, v1466); + int16x8_t v1498 = vaddq_s16(v1496, v1497); + int16x8_t v1499 = vaddq_s16(v1498, v1494); + int16x8_t v1500 = vaddq_s16(v1495, v1499); + int16x8_t v1501 = vqrdmulhq_n_s16(v1500, 17734); + int16x8_t v1502 = vaddq_s16(v1491, v1501); + int16x8_t v1503 = vaddq_s16(v1493, v1486); + int16x8_t v1504_tmp = vqrdmulhq_n_s16(v1503, 13573); + int16x8_t v1504 = vaddq_s16(v1504_tmp, v1503); + int16x8_t v1505 = vaddq_s16(v1497, v1488); + int16x8_t v1506 = vaddq_s16(v1489, v1492); + int16x8_t v1507 = vaddq_s16(v1505, v1506); + int16x8_t v1508 = vaddq_s16(v1504, v1507); + int16x8_t v1509 = vaddq_s16(v1506, v1503); + int16x8_t v1510_tmp = vqrdmulhq_n_s16(v1509, 13573); + int16x8_t v1510 = vaddq_s16(v1510_tmp, v1509); + int16x8_t v1511 = vld1q_s16(in + in_stride * 255 + i); + int16x8_t v1512 = vaddq_s16(v1511, v1128); + int16x8_t v1513 = vaddq_s16(v1512, v1336); + int16x8_t v1514 = vaddq_s16(v1513, v1432); + int16x8_t v1515 = vaddq_s16(v1514, v1476); + int16x8_t v1516 = vaddq_s16(v1515, v1496); + int16x8_t v1517 = vaddq_s16(v1516, v1505); + int16x8_t v1518 = vaddq_s16(v1517, v1509); + int16x8_t v1519 = vaddq_s16(v1510, v1518); + int16x8_t v1520 = vqrdmulhq_n_s16(v1519, 17734); + int16x8_t v1521 = vaddq_s16(v1508, v1520); + int16x8_t v1522 = vqrdmulhq_n_s16(v1521, 16705); + int16x8_t v1523 = vaddq_s16(v1502, v1522); + int16x8_t v1524 = vqrdmulhq_n_s16(v1523, 16463); + int16x8_t v1525 = vaddq_s16(v1485, v1524); + int16x8_t v1526 = vqrdmulhq_n_s16(v1525, 16404); + int16x8_t v1527 = vaddq_s16(v1444, v1526); + int16x8_t v1528 = vqrdmulhq_n_s16(v1527, 16389); + int16x8_t v1529 = vaddq_s16(v1351, v1528); + int16x8_t v1530 = vqrdmulhq_n_s16(v1529, 16385); + int16x8_t v1531 = vaddq_s16(v1146, v1530); + int16x8_t v1532 = vqrdmulhq_n_s16(v1531, 16384); + int16x8_t v1533 = vaddq_s16(v701, v1532); + int16x8_t v1534 = vsubq_s16(v0, v1); + int16x8_t v1535 = vsubq_s16(v4, v6); + int16x8_t v1536_tmp = vqrdmulhq_n_s16(v1535, 10045); + int16x8_t v1536 = vaddq_s16(v1536_tmp, v1535); + int16x8_t v1537 = vaddq_s16(v1534, v1536); + int16x8_t v1538 = vsubq_s16(v11, v14); + int16x8_t v1539 = vsubq_s16(v17, v20); + int16x8_t v1540_tmp = vqrdmulhq_n_s16(v1539, 10045); + int16x8_t v1540 = vaddq_s16(v1540_tmp, v1539); + int16x8_t v1541 = vaddq_s16(v1538, v1540); + int16x8_t v1542 = vqrdmulhq_n_s16(v1541, 19705); + int16x8_t v1543 = vaddq_s16(v1537, v1542); + int16x8_t v1544 = vsubq_s16(v27, v30); + int16x8_t v1545 = vsubq_s16(v35, v39); + int16x8_t v1546_tmp = vqrdmulhq_n_s16(v1545, 10045); + int16x8_t v1546 = vaddq_s16(v1546_tmp, v1545); + int16x8_t v1547 = vaddq_s16(v1544, v1546); + int16x8_t v1548 = vsubq_s16(v44, v47); + int16x8_t v1549 = vsubq_s16(v50, v54); + int16x8_t v1550_tmp = vqrdmulhq_n_s16(v1549, 10045); + int16x8_t v1550 = vaddq_s16(v1550_tmp, v1549); + int16x8_t v1551 = vaddq_s16(v1548, v1550); + int16x8_t v1552 = vqrdmulhq_n_s16(v1551, 19705); + int16x8_t v1553 = vaddq_s16(v1547, v1552); + int16x8_t v1554 = vqrdmulhq_n_s16(v1553, 17121); + int16x8_t v1555 = vaddq_s16(v1543, v1554); + int16x8_t v1556 = vsubq_s16(v63, v66); + int16x8_t v1557 = vsubq_s16(v71, v75); + int16x8_t v1558_tmp = vqrdmulhq_n_s16(v1557, 10045); + int16x8_t v1558 = vaddq_s16(v1558_tmp, v1557); + int16x8_t v1559 = vaddq_s16(v1556, v1558); + int16x8_t v1560 = vsubq_s16(v82, v89); + int16x8_t v1561 = vsubq_s16(v92, v97); + int16x8_t v1562_tmp = vqrdmulhq_n_s16(v1561, 10045); + int16x8_t v1562 = vaddq_s16(v1562_tmp, v1561); + int16x8_t v1563 = vaddq_s16(v1560, v1562); + int16x8_t v1564 = vqrdmulhq_n_s16(v1563, 19705); + int16x8_t v1565 = vaddq_s16(v1559, v1564); + int16x8_t v1566 = vsubq_s16(v104, v107); + int16x8_t v1567 = vsubq_s16(v112, v116); + int16x8_t v1568_tmp = vqrdmulhq_n_s16(v1567, 10045); + int16x8_t v1568 = vaddq_s16(v1568_tmp, v1567); + int16x8_t v1569 = vaddq_s16(v1566, v1568); + int16x8_t v1570 = vsubq_s16(v121, v124); + int16x8_t v1571 = vsubq_s16(v127, v132); + int16x8_t v1572_tmp = vqrdmulhq_n_s16(v1571, 10045); + int16x8_t v1572 = vaddq_s16(v1572_tmp, v1571); + int16x8_t v1573 = vaddq_s16(v1570, v1572); + int16x8_t v1574 = vqrdmulhq_n_s16(v1573, 19705); + int16x8_t v1575 = vaddq_s16(v1569, v1574); + int16x8_t v1576 = vqrdmulhq_n_s16(v1575, 17121); + int16x8_t v1577 = vaddq_s16(v1565, v1576); + int16x8_t v1578 = vqrdmulhq_n_s16(v1577, 16563); + int16x8_t v1579 = vaddq_s16(v1555, v1578); + int16x8_t v1580 = vsubq_s16(v143, v146); + int16x8_t v1581 = vsubq_s16(v151, v155); + int16x8_t v1582_tmp = vqrdmulhq_n_s16(v1581, 10045); + int16x8_t v1582 = vaddq_s16(v1582_tmp, v1581); + int16x8_t v1583 = vaddq_s16(v1580, v1582); + int16x8_t v1584 = vsubq_s16(v162, v169); + int16x8_t v1585 = vsubq_s16(v172, v177); + int16x8_t v1586_tmp = vqrdmulhq_n_s16(v1585, 10045); + int16x8_t v1586 = vaddq_s16(v1586_tmp, v1585); + int16x8_t v1587 = vaddq_s16(v1584, v1586); + int16x8_t v1588 = vqrdmulhq_n_s16(v1587, 19705); + int16x8_t v1589 = vaddq_s16(v1583, v1588); + int16x8_t v1590 = vsubq_s16(v186, v193); + int16x8_t v1591 = vsubq_s16(v202, v210); + int16x8_t v1592_tmp = vqrdmulhq_n_s16(v1591, 10045); + int16x8_t v1592 = vaddq_s16(v1592_tmp, v1591); + int16x8_t v1593 = vaddq_s16(v1590, v1592); + int16x8_t v1594 = vsubq_s16(v215, v218); + int16x8_t v1595 = vsubq_s16(v221, v227); + int16x8_t v1596_tmp = vqrdmulhq_n_s16(v1595, 10045); + int16x8_t v1596 = vaddq_s16(v1596_tmp, v1595); + int16x8_t v1597 = vaddq_s16(v1594, v1596); + int16x8_t v1598 = vqrdmulhq_n_s16(v1597, 19705); + int16x8_t v1599 = vaddq_s16(v1593, v1598); + int16x8_t v1600 = vqrdmulhq_n_s16(v1599, 17121); + int16x8_t v1601 = vaddq_s16(v1589, v1600); + int16x8_t v1602 = vsubq_s16(v236, v239); + int16x8_t v1603 = vsubq_s16(v244, v248); + int16x8_t v1604_tmp = vqrdmulhq_n_s16(v1603, 10045); + int16x8_t v1604 = vaddq_s16(v1604_tmp, v1603); + int16x8_t v1605 = vaddq_s16(v1602, v1604); + int16x8_t v1606 = vsubq_s16(v255, v262); + int16x8_t v1607 = vsubq_s16(v265, v270); + int16x8_t v1608_tmp = vqrdmulhq_n_s16(v1607, 10045); + int16x8_t v1608 = vaddq_s16(v1608_tmp, v1607); + int16x8_t v1609 = vaddq_s16(v1606, v1608); + int16x8_t v1610 = vqrdmulhq_n_s16(v1609, 19705); + int16x8_t v1611 = vaddq_s16(v1605, v1610); + int16x8_t v1612 = vsubq_s16(v277, v280); + int16x8_t v1613 = vsubq_s16(v285, v289); + int16x8_t v1614_tmp = vqrdmulhq_n_s16(v1613, 10045); + int16x8_t v1614 = vaddq_s16(v1614_tmp, v1613); + int16x8_t v1615 = vaddq_s16(v1612, v1614); + int16x8_t v1616 = vsubq_s16(v294, v297); + int16x8_t v1617 = vsubq_s16(v300, v306); + int16x8_t v1618_tmp = vqrdmulhq_n_s16(v1617, 10045); + int16x8_t v1618 = vaddq_s16(v1618_tmp, v1617); + int16x8_t v1619 = vaddq_s16(v1616, v1618); + int16x8_t v1620 = vqrdmulhq_n_s16(v1619, 19705); + int16x8_t v1621 = vaddq_s16(v1615, v1620); + int16x8_t v1622 = vqrdmulhq_n_s16(v1621, 17121); + int16x8_t v1623 = vaddq_s16(v1611, v1622); + int16x8_t v1624 = vqrdmulhq_n_s16(v1623, 16563); + int16x8_t v1625 = vaddq_s16(v1601, v1624); + int16x8_t v1626 = vqrdmulhq_n_s16(v1625, 16429); + int16x8_t v1627 = vaddq_s16(v1579, v1626); + int16x8_t v1628 = vsubq_s16(v319, v322); + int16x8_t v1629 = vsubq_s16(v327, v331); + int16x8_t v1630_tmp = vqrdmulhq_n_s16(v1629, 10045); + int16x8_t v1630 = vaddq_s16(v1630_tmp, v1629); + int16x8_t v1631 = vaddq_s16(v1628, v1630); + int16x8_t v1632 = vsubq_s16(v338, v345); + int16x8_t v1633 = vsubq_s16(v348, v353); + int16x8_t v1634_tmp = vqrdmulhq_n_s16(v1633, 10045); + int16x8_t v1634 = vaddq_s16(v1634_tmp, v1633); + int16x8_t v1635 = vaddq_s16(v1632, v1634); + int16x8_t v1636 = vqrdmulhq_n_s16(v1635, 19705); + int16x8_t v1637 = vaddq_s16(v1631, v1636); + int16x8_t v1638 = vsubq_s16(v362, v369); + int16x8_t v1639 = vsubq_s16(v378, v386); + int16x8_t v1640_tmp = vqrdmulhq_n_s16(v1639, 10045); + int16x8_t v1640 = vaddq_s16(v1640_tmp, v1639); + int16x8_t v1641 = vaddq_s16(v1638, v1640); + int16x8_t v1642 = vsubq_s16(v391, v394); + int16x8_t v1643 = vsubq_s16(v397, v403); + int16x8_t v1644_tmp = vqrdmulhq_n_s16(v1643, 10045); + int16x8_t v1644 = vaddq_s16(v1644_tmp, v1643); + int16x8_t v1645 = vaddq_s16(v1642, v1644); + int16x8_t v1646 = vqrdmulhq_n_s16(v1645, 19705); + int16x8_t v1647 = vaddq_s16(v1641, v1646); + int16x8_t v1648 = vqrdmulhq_n_s16(v1647, 17121); + int16x8_t v1649 = vaddq_s16(v1637, v1648); + int16x8_t v1650 = vsubq_s16(v414, v421); + int16x8_t v1651 = vsubq_s16(v430, v438); + int16x8_t v1652_tmp = vqrdmulhq_n_s16(v1651, 10045); + int16x8_t v1652 = vaddq_s16(v1652_tmp, v1651); + int16x8_t v1653 = vaddq_s16(v1650, v1652); + int16x8_t v1654 = vsubq_s16(v449, v464); + int16x8_t v1655 = vsubq_s16(v467, v476); + int16x8_t v1656_tmp = vqrdmulhq_n_s16(v1655, 10045); + int16x8_t v1656 = vaddq_s16(v1656_tmp, v1655); + int16x8_t v1657 = vaddq_s16(v1654, v1656); + int16x8_t v1658 = vqrdmulhq_n_s16(v1657, 19705); + int16x8_t v1659 = vaddq_s16(v1653, v1658); + int16x8_t v1660 = vsubq_s16(v483, v486); + int16x8_t v1661 = vsubq_s16(v491, v495); + int16x8_t v1662_tmp = vqrdmulhq_n_s16(v1661, 10045); + int16x8_t v1662 = vaddq_s16(v1662_tmp, v1661); + int16x8_t v1663 = vaddq_s16(v1660, v1662); + int16x8_t v1664 = vsubq_s16(v500, v503); + int16x8_t v1665 = vsubq_s16(v506, v513); + int16x8_t v1666_tmp = vqrdmulhq_n_s16(v1665, 10045); + int16x8_t v1666 = vaddq_s16(v1666_tmp, v1665); + int16x8_t v1667 = vaddq_s16(v1664, v1666); + int16x8_t v1668 = vqrdmulhq_n_s16(v1667, 19705); + int16x8_t v1669 = vaddq_s16(v1663, v1668); + int16x8_t v1670 = vqrdmulhq_n_s16(v1669, 17121); + int16x8_t v1671 = vaddq_s16(v1659, v1670); + int16x8_t v1672 = vqrdmulhq_n_s16(v1671, 16563); + int16x8_t v1673 = vaddq_s16(v1649, v1672); + int16x8_t v1674 = vsubq_s16(v524, v527); + int16x8_t v1675 = vsubq_s16(v532, v536); + int16x8_t v1676_tmp = vqrdmulhq_n_s16(v1675, 10045); + int16x8_t v1676 = vaddq_s16(v1676_tmp, v1675); + int16x8_t v1677 = vaddq_s16(v1674, v1676); + int16x8_t v1678 = vsubq_s16(v543, v550); + int16x8_t v1679 = vsubq_s16(v553, v558); + int16x8_t v1680_tmp = vqrdmulhq_n_s16(v1679, 10045); + int16x8_t v1680 = vaddq_s16(v1680_tmp, v1679); + int16x8_t v1681 = vaddq_s16(v1678, v1680); + int16x8_t v1682 = vqrdmulhq_n_s16(v1681, 19705); + int16x8_t v1683 = vaddq_s16(v1677, v1682); + int16x8_t v1684 = vsubq_s16(v567, v574); + int16x8_t v1685 = vsubq_s16(v583, v591); + int16x8_t v1686_tmp = vqrdmulhq_n_s16(v1685, 10045); + int16x8_t v1686 = vaddq_s16(v1686_tmp, v1685); + int16x8_t v1687 = vaddq_s16(v1684, v1686); + int16x8_t v1688 = vsubq_s16(v596, v599); + int16x8_t v1689 = vsubq_s16(v602, v608); + int16x8_t v1690_tmp = vqrdmulhq_n_s16(v1689, 10045); + int16x8_t v1690 = vaddq_s16(v1690_tmp, v1689); + int16x8_t v1691 = vaddq_s16(v1688, v1690); + int16x8_t v1692 = vqrdmulhq_n_s16(v1691, 19705); + int16x8_t v1693 = vaddq_s16(v1687, v1692); + int16x8_t v1694 = vqrdmulhq_n_s16(v1693, 17121); + int16x8_t v1695 = vaddq_s16(v1683, v1694); + int16x8_t v1696 = vsubq_s16(v617, v620); + int16x8_t v1697 = vsubq_s16(v625, v629); + int16x8_t v1698_tmp = vqrdmulhq_n_s16(v1697, 10045); + int16x8_t v1698 = vaddq_s16(v1698_tmp, v1697); + int16x8_t v1699 = vaddq_s16(v1696, v1698); + int16x8_t v1700 = vsubq_s16(v636, v643); + int16x8_t v1701 = vsubq_s16(v646, v651); + int16x8_t v1702_tmp = vqrdmulhq_n_s16(v1701, 10045); + int16x8_t v1702 = vaddq_s16(v1702_tmp, v1701); + int16x8_t v1703 = vaddq_s16(v1700, v1702); + int16x8_t v1704 = vqrdmulhq_n_s16(v1703, 19705); + int16x8_t v1705 = vaddq_s16(v1699, v1704); + int16x8_t v1706 = vsubq_s16(v658, v661); + int16x8_t v1707 = vsubq_s16(v666, v670); + int16x8_t v1708_tmp = vqrdmulhq_n_s16(v1707, 10045); + int16x8_t v1708 = vaddq_s16(v1708_tmp, v1707); + int16x8_t v1709 = vaddq_s16(v1706, v1708); + int16x8_t v1710 = vsubq_s16(v675, v678); + int16x8_t v1711 = vsubq_s16(v681, v688); + int16x8_t v1712_tmp = vqrdmulhq_n_s16(v1711, 10045); + int16x8_t v1712 = vaddq_s16(v1712_tmp, v1711); + int16x8_t v1713 = vaddq_s16(v1710, v1712); + int16x8_t v1714 = vqrdmulhq_n_s16(v1713, 19705); + int16x8_t v1715 = vaddq_s16(v1709, v1714); + int16x8_t v1716 = vqrdmulhq_n_s16(v1715, 17121); + int16x8_t v1717 = vaddq_s16(v1705, v1716); + int16x8_t v1718 = vqrdmulhq_n_s16(v1717, 16563); + int16x8_t v1719 = vaddq_s16(v1695, v1718); + int16x8_t v1720 = vqrdmulhq_n_s16(v1719, 16429); + int16x8_t v1721 = vaddq_s16(v1673, v1720); + int16x8_t v1722 = vqrdmulhq_n_s16(v1721, 16395); + int16x8_t v1723 = vaddq_s16(v1627, v1722); + int16x8_t v1724 = vsubq_s16(v703, v706); + int16x8_t v1725 = vsubq_s16(v711, v715); + int16x8_t v1726_tmp = vqrdmulhq_n_s16(v1725, 10045); + int16x8_t v1726 = vaddq_s16(v1726_tmp, v1725); + int16x8_t v1727 = vaddq_s16(v1724, v1726); + int16x8_t v1728 = vsubq_s16(v722, v729); + int16x8_t v1729 = vsubq_s16(v732, v737); + int16x8_t v1730_tmp = vqrdmulhq_n_s16(v1729, 10045); + int16x8_t v1730 = vaddq_s16(v1730_tmp, v1729); + int16x8_t v1731 = vaddq_s16(v1728, v1730); + int16x8_t v1732 = vqrdmulhq_n_s16(v1731, 19705); + int16x8_t v1733 = vaddq_s16(v1727, v1732); + int16x8_t v1734 = vsubq_s16(v746, v753); + int16x8_t v1735 = vsubq_s16(v762, v770); + int16x8_t v1736_tmp = vqrdmulhq_n_s16(v1735, 10045); + int16x8_t v1736 = vaddq_s16(v1736_tmp, v1735); + int16x8_t v1737 = vaddq_s16(v1734, v1736); + int16x8_t v1738 = vsubq_s16(v775, v778); + int16x8_t v1739 = vsubq_s16(v781, v787); + int16x8_t v1740_tmp = vqrdmulhq_n_s16(v1739, 10045); + int16x8_t v1740 = vaddq_s16(v1740_tmp, v1739); + int16x8_t v1741 = vaddq_s16(v1738, v1740); + int16x8_t v1742 = vqrdmulhq_n_s16(v1741, 19705); + int16x8_t v1743 = vaddq_s16(v1737, v1742); + int16x8_t v1744 = vqrdmulhq_n_s16(v1743, 17121); + int16x8_t v1745 = vaddq_s16(v1733, v1744); + int16x8_t v1746 = vsubq_s16(v798, v805); + int16x8_t v1747 = vsubq_s16(v814, v822); + int16x8_t v1748_tmp = vqrdmulhq_n_s16(v1747, 10045); + int16x8_t v1748 = vaddq_s16(v1748_tmp, v1747); + int16x8_t v1749 = vaddq_s16(v1746, v1748); + int16x8_t v1750 = vsubq_s16(v833, v848); + int16x8_t v1751 = vsubq_s16(v851, v860); + int16x8_t v1752_tmp = vqrdmulhq_n_s16(v1751, 10045); + int16x8_t v1752 = vaddq_s16(v1752_tmp, v1751); + int16x8_t v1753 = vaddq_s16(v1750, v1752); + int16x8_t v1754 = vqrdmulhq_n_s16(v1753, 19705); + int16x8_t v1755 = vaddq_s16(v1749, v1754); + int16x8_t v1756 = vsubq_s16(v867, v870); + int16x8_t v1757 = vsubq_s16(v875, v879); + int16x8_t v1758_tmp = vqrdmulhq_n_s16(v1757, 10045); + int16x8_t v1758 = vaddq_s16(v1758_tmp, v1757); + int16x8_t v1759 = vaddq_s16(v1756, v1758); + int16x8_t v1760 = vsubq_s16(v884, v887); + int16x8_t v1761 = vsubq_s16(v890, v897); + int16x8_t v1762_tmp = vqrdmulhq_n_s16(v1761, 10045); + int16x8_t v1762 = vaddq_s16(v1762_tmp, v1761); + int16x8_t v1763 = vaddq_s16(v1760, v1762); + int16x8_t v1764 = vqrdmulhq_n_s16(v1763, 19705); + int16x8_t v1765 = vaddq_s16(v1759, v1764); + int16x8_t v1766 = vqrdmulhq_n_s16(v1765, 17121); + int16x8_t v1767 = vaddq_s16(v1755, v1766); + int16x8_t v1768 = vqrdmulhq_n_s16(v1767, 16563); + int16x8_t v1769 = vaddq_s16(v1745, v1768); + int16x8_t v1770 = vsubq_s16(v910, v917); + int16x8_t v1771 = vsubq_s16(v926, v934); + int16x8_t v1772_tmp = vqrdmulhq_n_s16(v1771, 10045); + int16x8_t v1772 = vaddq_s16(v1772_tmp, v1771); + int16x8_t v1773 = vaddq_s16(v1770, v1772); + int16x8_t v1774 = vsubq_s16(v945, v960); + int16x8_t v1775 = vsubq_s16(v963, v972); + int16x8_t v1776_tmp = vqrdmulhq_n_s16(v1775, 10045); + int16x8_t v1776 = vaddq_s16(v1776_tmp, v1775); + int16x8_t v1777 = vaddq_s16(v1774, v1776); + int16x8_t v1778 = vqrdmulhq_n_s16(v1777, 19705); + int16x8_t v1779 = vaddq_s16(v1773, v1778); + int16x8_t v1780 = vsubq_s16(v985, v1000); + int16x8_t v1781 = vsubq_s16(v1017, v1033); + int16x8_t v1782_tmp = vqrdmulhq_n_s16(v1781, 10045); + int16x8_t v1782 = vaddq_s16(v1782_tmp, v1781); + int16x8_t v1783 = vaddq_s16(v1780, v1782); + int16x8_t v1784 = vsubq_s16(v1038, v1041); + int16x8_t v1785 = vsubq_s16(v1044, v1054); + int16x8_t v1786_tmp = vqrdmulhq_n_s16(v1785, 10045); + int16x8_t v1786 = vaddq_s16(v1786_tmp, v1785); + int16x8_t v1787 = vaddq_s16(v1784, v1786); + int16x8_t v1788 = vqrdmulhq_n_s16(v1787, 19705); + int16x8_t v1789 = vaddq_s16(v1783, v1788); + int16x8_t v1790 = vqrdmulhq_n_s16(v1789, 17121); + int16x8_t v1791 = vaddq_s16(v1779, v1790); + int16x8_t v1792 = vsubq_s16(v1063, v1066); + int16x8_t v1793 = vsubq_s16(v1071, v1075); + int16x8_t v1794_tmp = vqrdmulhq_n_s16(v1793, 10045); + int16x8_t v1794 = vaddq_s16(v1794_tmp, v1793); + int16x8_t v1795 = vaddq_s16(v1792, v1794); + int16x8_t v1796 = vsubq_s16(v1082, v1089); + int16x8_t v1797 = vsubq_s16(v1092, v1097); + int16x8_t v1798_tmp = vqrdmulhq_n_s16(v1797, 10045); + int16x8_t v1798 = vaddq_s16(v1798_tmp, v1797); + int16x8_t v1799 = vaddq_s16(v1796, v1798); + int16x8_t v1800 = vqrdmulhq_n_s16(v1799, 19705); + int16x8_t v1801 = vaddq_s16(v1795, v1800); + int16x8_t v1802 = vsubq_s16(v1104, v1107); + int16x8_t v1803 = vsubq_s16(v1112, v1116); + int16x8_t v1804_tmp = vqrdmulhq_n_s16(v1803, 10045); + int16x8_t v1804 = vaddq_s16(v1804_tmp, v1803); + int16x8_t v1805 = vaddq_s16(v1802, v1804); + int16x8_t v1806 = vsubq_s16(v1121, v1124); + int16x8_t v1807 = vsubq_s16(v1127, v1135); + int16x8_t v1808_tmp = vqrdmulhq_n_s16(v1807, 10045); + int16x8_t v1808 = vaddq_s16(v1808_tmp, v1807); + int16x8_t v1809 = vaddq_s16(v1806, v1808); + int16x8_t v1810 = vqrdmulhq_n_s16(v1809, 19705); + int16x8_t v1811 = vaddq_s16(v1805, v1810); + int16x8_t v1812 = vqrdmulhq_n_s16(v1811, 17121); + int16x8_t v1813 = vaddq_s16(v1801, v1812); + int16x8_t v1814 = vqrdmulhq_n_s16(v1813, 16563); + int16x8_t v1815 = vaddq_s16(v1791, v1814); + int16x8_t v1816 = vqrdmulhq_n_s16(v1815, 16429); + int16x8_t v1817 = vaddq_s16(v1769, v1816); + int16x8_t v1818 = vsubq_s16(v1148, v1151); + int16x8_t v1819 = vsubq_s16(v1156, v1160); + int16x8_t v1820_tmp = vqrdmulhq_n_s16(v1819, 10045); + int16x8_t v1820 = vaddq_s16(v1820_tmp, v1819); + int16x8_t v1821 = vaddq_s16(v1818, v1820); + int16x8_t v1822 = vsubq_s16(v1167, v1174); + int16x8_t v1823 = vsubq_s16(v1177, v1182); + int16x8_t v1824_tmp = vqrdmulhq_n_s16(v1823, 10045); + int16x8_t v1824 = vaddq_s16(v1824_tmp, v1823); + int16x8_t v1825 = vaddq_s16(v1822, v1824); + int16x8_t v1826 = vqrdmulhq_n_s16(v1825, 19705); + int16x8_t v1827 = vaddq_s16(v1821, v1826); + int16x8_t v1828 = vsubq_s16(v1191, v1198); + int16x8_t v1829 = vsubq_s16(v1207, v1215); + int16x8_t v1830_tmp = vqrdmulhq_n_s16(v1829, 10045); + int16x8_t v1830 = vaddq_s16(v1830_tmp, v1829); + int16x8_t v1831 = vaddq_s16(v1828, v1830); + int16x8_t v1832 = vsubq_s16(v1220, v1223); + int16x8_t v1833 = vsubq_s16(v1226, v1232); + int16x8_t v1834_tmp = vqrdmulhq_n_s16(v1833, 10045); + int16x8_t v1834 = vaddq_s16(v1834_tmp, v1833); + int16x8_t v1835 = vaddq_s16(v1832, v1834); + int16x8_t v1836 = vqrdmulhq_n_s16(v1835, 19705); + int16x8_t v1837 = vaddq_s16(v1831, v1836); + int16x8_t v1838 = vqrdmulhq_n_s16(v1837, 17121); + int16x8_t v1839 = vaddq_s16(v1827, v1838); + int16x8_t v1840 = vsubq_s16(v1243, v1250); + int16x8_t v1841 = vsubq_s16(v1259, v1267); + int16x8_t v1842_tmp = vqrdmulhq_n_s16(v1841, 10045); + int16x8_t v1842 = vaddq_s16(v1842_tmp, v1841); + int16x8_t v1843 = vaddq_s16(v1840, v1842); + int16x8_t v1844 = vsubq_s16(v1278, v1293); + int16x8_t v1845 = vsubq_s16(v1296, v1305); + int16x8_t v1846_tmp = vqrdmulhq_n_s16(v1845, 10045); + int16x8_t v1846 = vaddq_s16(v1846_tmp, v1845); + int16x8_t v1847 = vaddq_s16(v1844, v1846); + int16x8_t v1848 = vqrdmulhq_n_s16(v1847, 19705); + int16x8_t v1849 = vaddq_s16(v1843, v1848); + int16x8_t v1850 = vsubq_s16(v1312, v1315); + int16x8_t v1851 = vsubq_s16(v1320, v1324); + int16x8_t v1852_tmp = vqrdmulhq_n_s16(v1851, 10045); + int16x8_t v1852 = vaddq_s16(v1852_tmp, v1851); + int16x8_t v1853 = vaddq_s16(v1850, v1852); + int16x8_t v1854 = vsubq_s16(v1329, v1332); + int16x8_t v1855 = vsubq_s16(v1335, v1342); + int16x8_t v1856_tmp = vqrdmulhq_n_s16(v1855, 10045); + int16x8_t v1856 = vaddq_s16(v1856_tmp, v1855); + int16x8_t v1857 = vaddq_s16(v1854, v1856); + int16x8_t v1858 = vqrdmulhq_n_s16(v1857, 19705); + int16x8_t v1859 = vaddq_s16(v1853, v1858); + int16x8_t v1860 = vqrdmulhq_n_s16(v1859, 17121); + int16x8_t v1861 = vaddq_s16(v1849, v1860); + int16x8_t v1862 = vqrdmulhq_n_s16(v1861, 16563); + int16x8_t v1863 = vaddq_s16(v1839, v1862); + int16x8_t v1864 = vsubq_s16(v1353, v1356); + int16x8_t v1865 = vsubq_s16(v1361, v1365); + int16x8_t v1866_tmp = vqrdmulhq_n_s16(v1865, 10045); + int16x8_t v1866 = vaddq_s16(v1866_tmp, v1865); + int16x8_t v1867 = vaddq_s16(v1864, v1866); + int16x8_t v1868 = vsubq_s16(v1372, v1379); + int16x8_t v1869 = vsubq_s16(v1382, v1387); + int16x8_t v1870_tmp = vqrdmulhq_n_s16(v1869, 10045); + int16x8_t v1870 = vaddq_s16(v1870_tmp, v1869); + int16x8_t v1871 = vaddq_s16(v1868, v1870); + int16x8_t v1872 = vqrdmulhq_n_s16(v1871, 19705); + int16x8_t v1873 = vaddq_s16(v1867, v1872); + int16x8_t v1874 = vsubq_s16(v1396, v1403); + int16x8_t v1875 = vsubq_s16(v1412, v1420); + int16x8_t v1876_tmp = vqrdmulhq_n_s16(v1875, 10045); + int16x8_t v1876 = vaddq_s16(v1876_tmp, v1875); + int16x8_t v1877 = vaddq_s16(v1874, v1876); + int16x8_t v1878 = vsubq_s16(v1425, v1428); + int16x8_t v1879 = vsubq_s16(v1431, v1437); + int16x8_t v1880_tmp = vqrdmulhq_n_s16(v1879, 10045); + int16x8_t v1880 = vaddq_s16(v1880_tmp, v1879); + int16x8_t v1881 = vaddq_s16(v1878, v1880); + int16x8_t v1882 = vqrdmulhq_n_s16(v1881, 19705); + int16x8_t v1883 = vaddq_s16(v1877, v1882); + int16x8_t v1884 = vqrdmulhq_n_s16(v1883, 17121); + int16x8_t v1885 = vaddq_s16(v1873, v1884); + int16x8_t v1886 = vsubq_s16(v1446, v1449); + int16x8_t v1887 = vsubq_s16(v1454, v1458); + int16x8_t v1888_tmp = vqrdmulhq_n_s16(v1887, 10045); + int16x8_t v1888 = vaddq_s16(v1888_tmp, v1887); + int16x8_t v1889 = vaddq_s16(v1886, v1888); + int16x8_t v1890 = vsubq_s16(v1465, v1472); + int16x8_t v1891 = vsubq_s16(v1475, v1480); + int16x8_t v1892_tmp = vqrdmulhq_n_s16(v1891, 10045); + int16x8_t v1892 = vaddq_s16(v1892_tmp, v1891); + int16x8_t v1893 = vaddq_s16(v1890, v1892); + int16x8_t v1894 = vqrdmulhq_n_s16(v1893, 19705); + int16x8_t v1895 = vaddq_s16(v1889, v1894); + int16x8_t v1896 = vsubq_s16(v1487, v1490); + int16x8_t v1897 = vsubq_s16(v1495, v1499); + int16x8_t v1898_tmp = vqrdmulhq_n_s16(v1897, 10045); + int16x8_t v1898 = vaddq_s16(v1898_tmp, v1897); + int16x8_t v1899 = vaddq_s16(v1896, v1898); + int16x8_t v1900 = vsubq_s16(v1504, v1507); + int16x8_t v1901 = vsubq_s16(v1510, v1518); + int16x8_t v1902_tmp = vqrdmulhq_n_s16(v1901, 10045); + int16x8_t v1902 = vaddq_s16(v1902_tmp, v1901); + int16x8_t v1903 = vaddq_s16(v1900, v1902); + int16x8_t v1904 = vqrdmulhq_n_s16(v1903, 19705); + int16x8_t v1905 = vaddq_s16(v1899, v1904); + int16x8_t v1906 = vqrdmulhq_n_s16(v1905, 17121); + int16x8_t v1907 = vaddq_s16(v1895, v1906); + int16x8_t v1908 = vqrdmulhq_n_s16(v1907, 16563); + int16x8_t v1909 = vaddq_s16(v1885, v1908); + int16x8_t v1910 = vqrdmulhq_n_s16(v1909, 16429); + int16x8_t v1911 = vaddq_s16(v1863, v1910); + int16x8_t v1912 = vqrdmulhq_n_s16(v1911, 16395); + int16x8_t v1913 = vaddq_s16(v1817, v1912); + int16x8_t v1914 = vqrdmulhq_n_s16(v1913, 16387); + int16x8_t v1915 = vaddq_s16(v1723, v1914); + int16x8_t v1916 = vsubq_s16(v1534, v1536); + int16x8_t v1917 = vsubq_s16(v1538, v1540); + int16x8_t v1918 = vqrdmulhq_n_s16(v1917, 29490); + int16x8_t v1919 = vaddq_s16(v1916, v1918); + int16x8_t v1920 = vsubq_s16(v1544, v1546); + int16x8_t v1921 = vsubq_s16(v1548, v1550); + int16x8_t v1922 = vqrdmulhq_n_s16(v1921, 29490); + int16x8_t v1923 = vaddq_s16(v1920, v1922); + int16x8_t v1924 = vqrdmulhq_n_s16(v1923, 18578); + int16x8_t v1925 = vaddq_s16(v1919, v1924); + int16x8_t v1926 = vsubq_s16(v1556, v1558); + int16x8_t v1927 = vsubq_s16(v1560, v1562); + int16x8_t v1928 = vqrdmulhq_n_s16(v1927, 29490); + int16x8_t v1929 = vaddq_s16(v1926, v1928); + int16x8_t v1930 = vsubq_s16(v1566, v1568); + int16x8_t v1931 = vsubq_s16(v1570, v1572); + int16x8_t v1932 = vqrdmulhq_n_s16(v1931, 29490); + int16x8_t v1933 = vaddq_s16(v1930, v1932); + int16x8_t v1934 = vqrdmulhq_n_s16(v1933, 18578); + int16x8_t v1935 = vaddq_s16(v1929, v1934); + int16x8_t v1936 = vqrdmulhq_n_s16(v1935, 16890); + int16x8_t v1937 = vaddq_s16(v1925, v1936); + int16x8_t v1938 = vsubq_s16(v1580, v1582); + int16x8_t v1939 = vsubq_s16(v1584, v1586); + int16x8_t v1940 = vqrdmulhq_n_s16(v1939, 29490); + int16x8_t v1941 = vaddq_s16(v1938, v1940); + int16x8_t v1942 = vsubq_s16(v1590, v1592); + int16x8_t v1943 = vsubq_s16(v1594, v1596); + int16x8_t v1944 = vqrdmulhq_n_s16(v1943, 29490); + int16x8_t v1945 = vaddq_s16(v1942, v1944); + int16x8_t v1946 = vqrdmulhq_n_s16(v1945, 18578); + int16x8_t v1947 = vaddq_s16(v1941, v1946); + int16x8_t v1948 = vsubq_s16(v1602, v1604); + int16x8_t v1949 = vsubq_s16(v1606, v1608); + int16x8_t v1950 = vqrdmulhq_n_s16(v1949, 29490); + int16x8_t v1951 = vaddq_s16(v1948, v1950); + int16x8_t v1952 = vsubq_s16(v1612, v1614); + int16x8_t v1953 = vsubq_s16(v1616, v1618); + int16x8_t v1954 = vqrdmulhq_n_s16(v1953, 29490); + int16x8_t v1955 = vaddq_s16(v1952, v1954); + int16x8_t v1956 = vqrdmulhq_n_s16(v1955, 18578); + int16x8_t v1957 = vaddq_s16(v1951, v1956); + int16x8_t v1958 = vqrdmulhq_n_s16(v1957, 16890); + int16x8_t v1959 = vaddq_s16(v1947, v1958); + int16x8_t v1960 = vqrdmulhq_n_s16(v1959, 16508); + int16x8_t v1961 = vaddq_s16(v1937, v1960); + int16x8_t v1962 = vsubq_s16(v1628, v1630); + int16x8_t v1963 = vsubq_s16(v1632, v1634); + int16x8_t v1964 = vqrdmulhq_n_s16(v1963, 29490); + int16x8_t v1965 = vaddq_s16(v1962, v1964); + int16x8_t v1966 = vsubq_s16(v1638, v1640); + int16x8_t v1967 = vsubq_s16(v1642, v1644); + int16x8_t v1968 = vqrdmulhq_n_s16(v1967, 29490); + int16x8_t v1969 = vaddq_s16(v1966, v1968); + int16x8_t v1970 = vqrdmulhq_n_s16(v1969, 18578); + int16x8_t v1971 = vaddq_s16(v1965, v1970); + int16x8_t v1972 = vsubq_s16(v1650, v1652); + int16x8_t v1973 = vsubq_s16(v1654, v1656); + int16x8_t v1974 = vqrdmulhq_n_s16(v1973, 29490); + int16x8_t v1975 = vaddq_s16(v1972, v1974); + int16x8_t v1976 = vsubq_s16(v1660, v1662); + int16x8_t v1977 = vsubq_s16(v1664, v1666); + int16x8_t v1978 = vqrdmulhq_n_s16(v1977, 29490); + int16x8_t v1979 = vaddq_s16(v1976, v1978); + int16x8_t v1980 = vqrdmulhq_n_s16(v1979, 18578); + int16x8_t v1981 = vaddq_s16(v1975, v1980); + int16x8_t v1982 = vqrdmulhq_n_s16(v1981, 16890); + int16x8_t v1983 = vaddq_s16(v1971, v1982); + int16x8_t v1984 = vsubq_s16(v1674, v1676); + int16x8_t v1985 = vsubq_s16(v1678, v1680); + int16x8_t v1986 = vqrdmulhq_n_s16(v1985, 29490); + int16x8_t v1987 = vaddq_s16(v1984, v1986); + int16x8_t v1988 = vsubq_s16(v1684, v1686); + int16x8_t v1989 = vsubq_s16(v1688, v1690); + int16x8_t v1990 = vqrdmulhq_n_s16(v1989, 29490); + int16x8_t v1991 = vaddq_s16(v1988, v1990); + int16x8_t v1992 = vqrdmulhq_n_s16(v1991, 18578); + int16x8_t v1993 = vaddq_s16(v1987, v1992); + int16x8_t v1994 = vsubq_s16(v1696, v1698); + int16x8_t v1995 = vsubq_s16(v1700, v1702); + int16x8_t v1996 = vqrdmulhq_n_s16(v1995, 29490); + int16x8_t v1997 = vaddq_s16(v1994, v1996); + int16x8_t v1998 = vsubq_s16(v1706, v1708); + int16x8_t v1999 = vsubq_s16(v1710, v1712); + int16x8_t v2000 = vqrdmulhq_n_s16(v1999, 29490); + int16x8_t v2001 = vaddq_s16(v1998, v2000); + int16x8_t v2002 = vqrdmulhq_n_s16(v2001, 18578); + int16x8_t v2003 = vaddq_s16(v1997, v2002); + int16x8_t v2004 = vqrdmulhq_n_s16(v2003, 16890); + int16x8_t v2005 = vaddq_s16(v1993, v2004); + int16x8_t v2006 = vqrdmulhq_n_s16(v2005, 16508); + int16x8_t v2007 = vaddq_s16(v1983, v2006); + int16x8_t v2008 = vqrdmulhq_n_s16(v2007, 16415); + int16x8_t v2009 = vaddq_s16(v1961, v2008); + int16x8_t v2010 = vsubq_s16(v1724, v1726); + int16x8_t v2011 = vsubq_s16(v1728, v1730); + int16x8_t v2012 = vqrdmulhq_n_s16(v2011, 29490); + int16x8_t v2013 = vaddq_s16(v2010, v2012); + int16x8_t v2014 = vsubq_s16(v1734, v1736); + int16x8_t v2015 = vsubq_s16(v1738, v1740); + int16x8_t v2016 = vqrdmulhq_n_s16(v2015, 29490); + int16x8_t v2017 = vaddq_s16(v2014, v2016); + int16x8_t v2018 = vqrdmulhq_n_s16(v2017, 18578); + int16x8_t v2019 = vaddq_s16(v2013, v2018); + int16x8_t v2020 = vsubq_s16(v1746, v1748); + int16x8_t v2021 = vsubq_s16(v1750, v1752); + int16x8_t v2022 = vqrdmulhq_n_s16(v2021, 29490); + int16x8_t v2023 = vaddq_s16(v2020, v2022); + int16x8_t v2024 = vsubq_s16(v1756, v1758); + int16x8_t v2025 = vsubq_s16(v1760, v1762); + int16x8_t v2026 = vqrdmulhq_n_s16(v2025, 29490); + int16x8_t v2027 = vaddq_s16(v2024, v2026); + int16x8_t v2028 = vqrdmulhq_n_s16(v2027, 18578); + int16x8_t v2029 = vaddq_s16(v2023, v2028); + int16x8_t v2030 = vqrdmulhq_n_s16(v2029, 16890); + int16x8_t v2031 = vaddq_s16(v2019, v2030); + int16x8_t v2032 = vsubq_s16(v1770, v1772); + int16x8_t v2033 = vsubq_s16(v1774, v1776); + int16x8_t v2034 = vqrdmulhq_n_s16(v2033, 29490); + int16x8_t v2035 = vaddq_s16(v2032, v2034); + int16x8_t v2036 = vsubq_s16(v1780, v1782); + int16x8_t v2037 = vsubq_s16(v1784, v1786); + int16x8_t v2038 = vqrdmulhq_n_s16(v2037, 29490); + int16x8_t v2039 = vaddq_s16(v2036, v2038); + int16x8_t v2040 = vqrdmulhq_n_s16(v2039, 18578); + int16x8_t v2041 = vaddq_s16(v2035, v2040); + int16x8_t v2042 = vsubq_s16(v1792, v1794); + int16x8_t v2043 = vsubq_s16(v1796, v1798); + int16x8_t v2044 = vqrdmulhq_n_s16(v2043, 29490); + int16x8_t v2045 = vaddq_s16(v2042, v2044); + int16x8_t v2046 = vsubq_s16(v1802, v1804); + int16x8_t v2047 = vsubq_s16(v1806, v1808); + int16x8_t v2048 = vqrdmulhq_n_s16(v2047, 29490); + int16x8_t v2049 = vaddq_s16(v2046, v2048); + int16x8_t v2050 = vqrdmulhq_n_s16(v2049, 18578); + int16x8_t v2051 = vaddq_s16(v2045, v2050); + int16x8_t v2052 = vqrdmulhq_n_s16(v2051, 16890); + int16x8_t v2053 = vaddq_s16(v2041, v2052); + int16x8_t v2054 = vqrdmulhq_n_s16(v2053, 16508); + int16x8_t v2055 = vaddq_s16(v2031, v2054); + int16x8_t v2056 = vsubq_s16(v1818, v1820); + int16x8_t v2057 = vsubq_s16(v1822, v1824); + int16x8_t v2058 = vqrdmulhq_n_s16(v2057, 29490); + int16x8_t v2059 = vaddq_s16(v2056, v2058); + int16x8_t v2060 = vsubq_s16(v1828, v1830); + int16x8_t v2061 = vsubq_s16(v1832, v1834); + int16x8_t v2062 = vqrdmulhq_n_s16(v2061, 29490); + int16x8_t v2063 = vaddq_s16(v2060, v2062); + int16x8_t v2064 = vqrdmulhq_n_s16(v2063, 18578); + int16x8_t v2065 = vaddq_s16(v2059, v2064); + int16x8_t v2066 = vsubq_s16(v1840, v1842); + int16x8_t v2067 = vsubq_s16(v1844, v1846); + int16x8_t v2068 = vqrdmulhq_n_s16(v2067, 29490); + int16x8_t v2069 = vaddq_s16(v2066, v2068); + int16x8_t v2070 = vsubq_s16(v1850, v1852); + int16x8_t v2071 = vqrdmulhq_n_s16(v2070, 18578); + int16x8_t v2072 = vsubq_s16(v1854, v1856); + int16x8_t v2073 = vqrdmulhq_n_s16(v2072, 16719); + int16x8_t v2074 = vaddq_s16(v2071, v2073); + int16x8_t v2075 = vaddq_s16(v2069, v2074); + int16x8_t v2076 = vqrdmulhq_n_s16(v2075, 16890); + int16x8_t v2077 = vaddq_s16(v2065, v2076); + int16x8_t v2078 = vsubq_s16(v1864, v1866); + int16x8_t v2079 = vsubq_s16(v1868, v1870); + int16x8_t v2080 = vqrdmulhq_n_s16(v2079, 29490); + int16x8_t v2081 = vaddq_s16(v2078, v2080); + int16x8_t v2082 = vsubq_s16(v1874, v1876); + int16x8_t v2083 = vsubq_s16(v1878, v1880); + int16x8_t v2084 = vqrdmulhq_n_s16(v2083, 29490); + int16x8_t v2085 = vaddq_s16(v2082, v2084); + int16x8_t v2086 = vqrdmulhq_n_s16(v2085, 18578); + int16x8_t v2087 = vaddq_s16(v2081, v2086); + int16x8_t v2088 = vsubq_s16(v1886, v1888); + int16x8_t v2089 = vsubq_s16(v1890, v1892); + int16x8_t v2090 = vqrdmulhq_n_s16(v2089, 29490); + int16x8_t v2091 = vaddq_s16(v2088, v2090); + int16x8_t v2092 = vsubq_s16(v1896, v1898); + int16x8_t v2093 = vsubq_s16(v1900, v1902); + int16x8_t v2094 = vqrdmulhq_n_s16(v2093, 29490); + int16x8_t v2095 = vaddq_s16(v2092, v2094); + int16x8_t v2096 = vqrdmulhq_n_s16(v2095, 18578); + int16x8_t v2097 = vaddq_s16(v2091, v2096); + int16x8_t v2098 = vqrdmulhq_n_s16(v2097, 16890); + int16x8_t v2099 = vaddq_s16(v2087, v2098); + int16x8_t v2100 = vqrdmulhq_n_s16(v2099, 16508); + int16x8_t v2101 = vaddq_s16(v2077, v2100); + int16x8_t v2102 = vqrdmulhq_n_s16(v2101, 16415); + int16x8_t v2103 = vaddq_s16(v2055, v2102); + int16x8_t v2104 = vqrdmulhq_n_s16(v2103, 16392); + int16x8_t v2105 = vaddq_s16(v2009, v2104); + int16x8_t v2106 = vsubq_s16(v2, v8); + int16x8_t v2107 = vsubq_s16(v15, v22); + int16x8_t v2108_tmp = vqrdmulhq_n_s16(v2107, 18446); + int16x8_t v2108 = vmlaq_n_s16(v2108_tmp, v2107, 2); + int16x8_t v2109 = vaddq_s16(v2106, v2108); + int16x8_t v2110 = vsubq_s16(v31, v41); + int16x8_t v2111 = vsubq_s16(v48, v56); + int16x8_t v2112_tmp = vqrdmulhq_n_s16(v2111, 18446); + int16x8_t v2112 = vmlaq_n_s16(v2112_tmp, v2111, 2); + int16x8_t v2113 = vaddq_s16(v2110, v2112); + int16x8_t v2114 = vqrdmulhq_n_s16(v2113, 21195); + int16x8_t v2115 = vaddq_s16(v2109, v2114); + int16x8_t v2116 = vsubq_s16(v67, v77); + int16x8_t v2117 = vsubq_s16(v90, v99); + int16x8_t v2118_tmp = vqrdmulhq_n_s16(v2117, 18446); + int16x8_t v2118 = vmlaq_n_s16(v2118_tmp, v2117, 2); + int16x8_t v2119 = vaddq_s16(v2116, v2118); + int16x8_t v2120 = vsubq_s16(v108, v118); + int16x8_t v2121 = vsubq_s16(v125, v134); + int16x8_t v2122_tmp = vqrdmulhq_n_s16(v2121, 18446); + int16x8_t v2122 = vmlaq_n_s16(v2122_tmp, v2121, 2); + int16x8_t v2123 = vaddq_s16(v2120, v2122); + int16x8_t v2124 = vqrdmulhq_n_s16(v2123, 21195); + int16x8_t v2125 = vaddq_s16(v2119, v2124); + int16x8_t v2126 = vqrdmulhq_n_s16(v2125, 17401); + int16x8_t v2127 = vaddq_s16(v2115, v2126); + int16x8_t v2128 = vsubq_s16(v147, v157); + int16x8_t v2129 = vsubq_s16(v170, v179); + int16x8_t v2130_tmp = vqrdmulhq_n_s16(v2129, 18446); + int16x8_t v2130 = vmlaq_n_s16(v2130_tmp, v2129, 2); + int16x8_t v2131 = vaddq_s16(v2128, v2130); + int16x8_t v2132 = vsubq_s16(v194, v212); + int16x8_t v2133 = vsubq_s16(v219, v229); + int16x8_t v2134_tmp = vqrdmulhq_n_s16(v2133, 18446); + int16x8_t v2134 = vmlaq_n_s16(v2134_tmp, v2133, 2); + int16x8_t v2135 = vaddq_s16(v2132, v2134); + int16x8_t v2136 = vqrdmulhq_n_s16(v2135, 21195); + int16x8_t v2137 = vaddq_s16(v2131, v2136); + int16x8_t v2138 = vsubq_s16(v240, v250); + int16x8_t v2139 = vsubq_s16(v263, v272); + int16x8_t v2140_tmp = vqrdmulhq_n_s16(v2139, 18446); + int16x8_t v2140 = vmlaq_n_s16(v2140_tmp, v2139, 2); + int16x8_t v2141 = vaddq_s16(v2138, v2140); + int16x8_t v2142 = vsubq_s16(v281, v291); + int16x8_t v2143 = vsubq_s16(v298, v308); + int16x8_t v2144_tmp = vqrdmulhq_n_s16(v2143, 18446); + int16x8_t v2144 = vmlaq_n_s16(v2144_tmp, v2143, 2); + int16x8_t v2145 = vaddq_s16(v2142, v2144); + int16x8_t v2146 = vqrdmulhq_n_s16(v2145, 21195); + int16x8_t v2147 = vaddq_s16(v2141, v2146); + int16x8_t v2148 = vqrdmulhq_n_s16(v2147, 17401); + int16x8_t v2149 = vaddq_s16(v2137, v2148); + int16x8_t v2150 = vqrdmulhq_n_s16(v2149, 16629); + int16x8_t v2151 = vaddq_s16(v2127, v2150); + int16x8_t v2152 = vsubq_s16(v323, v333); + int16x8_t v2153 = vsubq_s16(v346, v355); + int16x8_t v2154_tmp = vqrdmulhq_n_s16(v2153, 18446); + int16x8_t v2154 = vmlaq_n_s16(v2154_tmp, v2153, 2); + int16x8_t v2155 = vaddq_s16(v2152, v2154); + int16x8_t v2156 = vsubq_s16(v370, v388); + int16x8_t v2157 = vsubq_s16(v395, v405); + int16x8_t v2158_tmp = vqrdmulhq_n_s16(v2157, 18446); + int16x8_t v2158 = vmlaq_n_s16(v2158_tmp, v2157, 2); + int16x8_t v2159 = vaddq_s16(v2156, v2158); + int16x8_t v2160 = vqrdmulhq_n_s16(v2159, 21195); + int16x8_t v2161 = vaddq_s16(v2155, v2160); + int16x8_t v2162 = vsubq_s16(v422, v440); + int16x8_t v2163 = vsubq_s16(v465, v478); + int16x8_t v2164_tmp = vqrdmulhq_n_s16(v2163, 18446); + int16x8_t v2164 = vmlaq_n_s16(v2164_tmp, v2163, 2); + int16x8_t v2165 = vaddq_s16(v2162, v2164); + int16x8_t v2166 = vsubq_s16(v487, v497); + int16x8_t v2167 = vsubq_s16(v504, v515); + int16x8_t v2168_tmp = vqrdmulhq_n_s16(v2167, 18446); + int16x8_t v2168 = vmlaq_n_s16(v2168_tmp, v2167, 2); + int16x8_t v2169 = vaddq_s16(v2166, v2168); + int16x8_t v2170 = vqrdmulhq_n_s16(v2169, 21195); + int16x8_t v2171 = vaddq_s16(v2165, v2170); + int16x8_t v2172 = vqrdmulhq_n_s16(v2171, 17401); + int16x8_t v2173 = vaddq_s16(v2161, v2172); + int16x8_t v2174 = vsubq_s16(v528, v538); + int16x8_t v2175 = vsubq_s16(v551, v560); + int16x8_t v2176_tmp = vqrdmulhq_n_s16(v2175, 18446); + int16x8_t v2176 = vmlaq_n_s16(v2176_tmp, v2175, 2); + int16x8_t v2177 = vaddq_s16(v2174, v2176); + int16x8_t v2178 = vsubq_s16(v575, v593); + int16x8_t v2179 = vsubq_s16(v600, v610); + int16x8_t v2180_tmp = vqrdmulhq_n_s16(v2179, 18446); + int16x8_t v2180 = vmlaq_n_s16(v2180_tmp, v2179, 2); + int16x8_t v2181 = vaddq_s16(v2178, v2180); + int16x8_t v2182 = vqrdmulhq_n_s16(v2181, 21195); + int16x8_t v2183 = vaddq_s16(v2177, v2182); + int16x8_t v2184 = vsubq_s16(v621, v631); + int16x8_t v2185 = vsubq_s16(v644, v653); + int16x8_t v2186_tmp = vqrdmulhq_n_s16(v2185, 18446); + int16x8_t v2186 = vmlaq_n_s16(v2186_tmp, v2185, 2); + int16x8_t v2187 = vaddq_s16(v2184, v2186); + int16x8_t v2188 = vsubq_s16(v662, v672); + int16x8_t v2189 = vsubq_s16(v679, v690); + int16x8_t v2190_tmp = vqrdmulhq_n_s16(v2189, 18446); + int16x8_t v2190 = vmlaq_n_s16(v2190_tmp, v2189, 2); + int16x8_t v2191 = vaddq_s16(v2188, v2190); + int16x8_t v2192 = vqrdmulhq_n_s16(v2191, 21195); + int16x8_t v2193 = vaddq_s16(v2187, v2192); + int16x8_t v2194 = vqrdmulhq_n_s16(v2193, 17401); + int16x8_t v2195 = vaddq_s16(v2183, v2194); + int16x8_t v2196 = vqrdmulhq_n_s16(v2195, 16629); + int16x8_t v2197 = vaddq_s16(v2173, v2196); + int16x8_t v2198 = vqrdmulhq_n_s16(v2197, 16445); + int16x8_t v2199 = vaddq_s16(v2151, v2198); + int16x8_t v2200 = vsubq_s16(v707, v717); + int16x8_t v2201 = vsubq_s16(v730, v739); + int16x8_t v2202_tmp = vqrdmulhq_n_s16(v2201, 18446); + int16x8_t v2202 = vmlaq_n_s16(v2202_tmp, v2201, 2); + int16x8_t v2203 = vaddq_s16(v2200, v2202); + int16x8_t v2204 = vsubq_s16(v754, v772); + int16x8_t v2205 = vsubq_s16(v779, v789); + int16x8_t v2206_tmp = vqrdmulhq_n_s16(v2205, 18446); + int16x8_t v2206 = vmlaq_n_s16(v2206_tmp, v2205, 2); + int16x8_t v2207 = vaddq_s16(v2204, v2206); + int16x8_t v2208 = vqrdmulhq_n_s16(v2207, 21195); + int16x8_t v2209 = vaddq_s16(v2203, v2208); + int16x8_t v2210 = vsubq_s16(v806, v824); + int16x8_t v2211 = vsubq_s16(v849, v862); + int16x8_t v2212_tmp = vqrdmulhq_n_s16(v2211, 18446); + int16x8_t v2212 = vmlaq_n_s16(v2212_tmp, v2211, 2); + int16x8_t v2213 = vaddq_s16(v2210, v2212); + int16x8_t v2214 = vsubq_s16(v871, v881); + int16x8_t v2215 = vsubq_s16(v888, v899); + int16x8_t v2216_tmp = vqrdmulhq_n_s16(v2215, 18446); + int16x8_t v2216 = vmlaq_n_s16(v2216_tmp, v2215, 2); + int16x8_t v2217 = vaddq_s16(v2214, v2216); + int16x8_t v2218 = vqrdmulhq_n_s16(v2217, 21195); + int16x8_t v2219 = vaddq_s16(v2213, v2218); + int16x8_t v2220 = vqrdmulhq_n_s16(v2219, 17401); + int16x8_t v2221 = vaddq_s16(v2209, v2220); + int16x8_t v2222 = vsubq_s16(v918, v936); + int16x8_t v2223 = vsubq_s16(v961, v974); + int16x8_t v2224_tmp = vqrdmulhq_n_s16(v2223, 18446); + int16x8_t v2224 = vmlaq_n_s16(v2224_tmp, v2223, 2); + int16x8_t v2225 = vaddq_s16(v2222, v2224); + int16x8_t v2226 = vsubq_s16(v1001, v1035); + int16x8_t v2227 = vsubq_s16(v1042, v1056); + int16x8_t v2228_tmp = vqrdmulhq_n_s16(v2227, 18446); + int16x8_t v2228 = vmlaq_n_s16(v2228_tmp, v2227, 2); + int16x8_t v2229 = vaddq_s16(v2226, v2228); + int16x8_t v2230 = vqrdmulhq_n_s16(v2229, 21195); + int16x8_t v2231 = vaddq_s16(v2225, v2230); + int16x8_t v2232 = vsubq_s16(v1067, v1077); + int16x8_t v2233 = vsubq_s16(v1090, v1099); + int16x8_t v2234_tmp = vqrdmulhq_n_s16(v2233, 18446); + int16x8_t v2234 = vmlaq_n_s16(v2234_tmp, v2233, 2); + int16x8_t v2235 = vaddq_s16(v2232, v2234); + int16x8_t v2236 = vsubq_s16(v1108, v1118); + int16x8_t v2237 = vsubq_s16(v1125, v1137); + int16x8_t v2238_tmp = vqrdmulhq_n_s16(v2237, 18446); + int16x8_t v2238 = vmlaq_n_s16(v2238_tmp, v2237, 2); + int16x8_t v2239 = vaddq_s16(v2236, v2238); + int16x8_t v2240 = vqrdmulhq_n_s16(v2239, 21195); + int16x8_t v2241 = vaddq_s16(v2235, v2240); + int16x8_t v2242 = vqrdmulhq_n_s16(v2241, 17401); + int16x8_t v2243 = vaddq_s16(v2231, v2242); + int16x8_t v2244 = vqrdmulhq_n_s16(v2243, 16629); + int16x8_t v2245 = vaddq_s16(v2221, v2244); + int16x8_t v2246 = vsubq_s16(v1152, v1162); + int16x8_t v2247 = vsubq_s16(v1175, v1184); + int16x8_t v2248_tmp = vqrdmulhq_n_s16(v2247, 18446); + int16x8_t v2248 = vmlaq_n_s16(v2248_tmp, v2247, 2); + int16x8_t v2249 = vaddq_s16(v2246, v2248); + int16x8_t v2250 = vsubq_s16(v1199, v1217); + int16x8_t v2251 = vsubq_s16(v1224, v1234); + int16x8_t v2252_tmp = vqrdmulhq_n_s16(v2251, 18446); + int16x8_t v2252 = vmlaq_n_s16(v2252_tmp, v2251, 2); + int16x8_t v2253 = vaddq_s16(v2250, v2252); + int16x8_t v2254 = vqrdmulhq_n_s16(v2253, 21195); + int16x8_t v2255 = vaddq_s16(v2249, v2254); + int16x8_t v2256 = vsubq_s16(v1251, v1269); + int16x8_t v2257 = vsubq_s16(v1294, v1307); + int16x8_t v2258_tmp = vqrdmulhq_n_s16(v2257, 18446); + int16x8_t v2258 = vmlaq_n_s16(v2258_tmp, v2257, 2); + int16x8_t v2259 = vaddq_s16(v2256, v2258); + int16x8_t v2260 = vsubq_s16(v1316, v1326); + int16x8_t v2261 = vsubq_s16(v1333, v1344); + int16x8_t v2262_tmp = vqrdmulhq_n_s16(v2261, 18446); + int16x8_t v2262 = vmlaq_n_s16(v2262_tmp, v2261, 2); + int16x8_t v2263 = vaddq_s16(v2260, v2262); + int16x8_t v2264 = vqrdmulhq_n_s16(v2263, 21195); + int16x8_t v2265 = vaddq_s16(v2259, v2264); + int16x8_t v2266 = vqrdmulhq_n_s16(v2265, 17401); + int16x8_t v2267 = vaddq_s16(v2255, v2266); + int16x8_t v2268 = vsubq_s16(v1357, v1367); + int16x8_t v2269 = vsubq_s16(v1380, v1389); + int16x8_t v2270_tmp = vqrdmulhq_n_s16(v2269, 18446); + int16x8_t v2270 = vmlaq_n_s16(v2270_tmp, v2269, 2); + int16x8_t v2271 = vaddq_s16(v2268, v2270); + int16x8_t v2272 = vsubq_s16(v1404, v1422); + int16x8_t v2273 = vsubq_s16(v1429, v1439); + int16x8_t v2274_tmp = vqrdmulhq_n_s16(v2273, 18446); + int16x8_t v2274 = vmlaq_n_s16(v2274_tmp, v2273, 2); + int16x8_t v2275 = vaddq_s16(v2272, v2274); + int16x8_t v2276 = vqrdmulhq_n_s16(v2275, 21195); + int16x8_t v2277 = vaddq_s16(v2271, v2276); + int16x8_t v2278 = vsubq_s16(v1450, v1460); + int16x8_t v2279 = vsubq_s16(v1473, v1482); + int16x8_t v2280_tmp = vqrdmulhq_n_s16(v2279, 18446); + int16x8_t v2280 = vmlaq_n_s16(v2280_tmp, v2279, 2); + int16x8_t v2281 = vaddq_s16(v2278, v2280); + int16x8_t v2282 = vsubq_s16(v1491, v1501); + int16x8_t v2283 = vsubq_s16(v1508, v1520); + int16x8_t v2284_tmp = vqrdmulhq_n_s16(v2283, 18446); + int16x8_t v2284 = vmlaq_n_s16(v2284_tmp, v2283, 2); + int16x8_t v2285 = vaddq_s16(v2282, v2284); + int16x8_t v2286 = vqrdmulhq_n_s16(v2285, 21195); + int16x8_t v2287 = vaddq_s16(v2281, v2286); + int16x8_t v2288 = vqrdmulhq_n_s16(v2287, 17401); + int16x8_t v2289 = vaddq_s16(v2277, v2288); + int16x8_t v2290 = vqrdmulhq_n_s16(v2289, 16629); + int16x8_t v2291 = vaddq_s16(v2267, v2290); + int16x8_t v2292 = vqrdmulhq_n_s16(v2291, 16445); + int16x8_t v2293 = vaddq_s16(v2245, v2292); + int16x8_t v2294 = vqrdmulhq_n_s16(v2293, 16399); + int16x8_t v2295 = vaddq_s16(v2199, v2294); + int16x8_t v2296 = vsubq_s16(v2106, v2108); + int16x8_t v2297 = vsubq_s16(v2110, v2112); + int16x8_t v2298 = vqrdmulhq_n_s16(v2297, 25826); + int16x8_t v2299 = vaddq_s16(v2296, v2298); + int16x8_t v2300 = vsubq_s16(v2116, v2118); + int16x8_t v2301 = vsubq_s16(v2120, v2122); + int16x8_t v2302 = vqrdmulhq_n_s16(v2301, 25826); + int16x8_t v2303 = vaddq_s16(v2300, v2302); + int16x8_t v2304 = vqrdmulhq_n_s16(v2303, 18124); + int16x8_t v2305 = vaddq_s16(v2299, v2304); + int16x8_t v2306 = vsubq_s16(v2128, v2130); + int16x8_t v2307 = vsubq_s16(v2132, v2134); + int16x8_t v2308 = vqrdmulhq_n_s16(v2307, 25826); + int16x8_t v2309 = vaddq_s16(v2306, v2308); + int16x8_t v2310 = vsubq_s16(v2138, v2140); + int16x8_t v2311 = vsubq_s16(v2142, v2144); + int16x8_t v2312 = vqrdmulhq_n_s16(v2311, 25826); + int16x8_t v2313 = vaddq_s16(v2310, v2312); + int16x8_t v2314 = vqrdmulhq_n_s16(v2313, 18124); + int16x8_t v2315 = vaddq_s16(v2309, v2314); + int16x8_t v2316 = vqrdmulhq_n_s16(v2315, 16792); + int16x8_t v2317 = vaddq_s16(v2305, v2316); + int16x8_t v2318 = vsubq_s16(v2152, v2154); + int16x8_t v2319 = vsubq_s16(v2156, v2158); + int16x8_t v2320 = vqrdmulhq_n_s16(v2319, 25826); + int16x8_t v2321 = vaddq_s16(v2318, v2320); + int16x8_t v2322 = vsubq_s16(v2162, v2164); + int16x8_t v2323 = vsubq_s16(v2166, v2168); + int16x8_t v2324 = vqrdmulhq_n_s16(v2323, 25826); + int16x8_t v2325 = vaddq_s16(v2322, v2324); + int16x8_t v2326 = vqrdmulhq_n_s16(v2325, 18124); + int16x8_t v2327 = vaddq_s16(v2321, v2326); + int16x8_t v2328 = vsubq_s16(v2174, v2176); + int16x8_t v2329 = vsubq_s16(v2178, v2180); + int16x8_t v2330 = vqrdmulhq_n_s16(v2329, 25826); + int16x8_t v2331 = vaddq_s16(v2328, v2330); + int16x8_t v2332 = vsubq_s16(v2184, v2186); + int16x8_t v2333 = vsubq_s16(v2188, v2190); + int16x8_t v2334 = vqrdmulhq_n_s16(v2333, 25826); + int16x8_t v2335 = vaddq_s16(v2332, v2334); + int16x8_t v2336 = vqrdmulhq_n_s16(v2335, 18124); + int16x8_t v2337 = vaddq_s16(v2331, v2336); + int16x8_t v2338 = vqrdmulhq_n_s16(v2337, 16792); + int16x8_t v2339 = vaddq_s16(v2327, v2338); + int16x8_t v2340 = vqrdmulhq_n_s16(v2339, 16484); + int16x8_t v2341 = vaddq_s16(v2317, v2340); + int16x8_t v2342 = vsubq_s16(v2200, v2202); + int16x8_t v2343 = vsubq_s16(v2204, v2206); + int16x8_t v2344 = vqrdmulhq_n_s16(v2343, 25826); + int16x8_t v2345 = vaddq_s16(v2342, v2344); + int16x8_t v2346 = vsubq_s16(v2210, v2212); + int16x8_t v2347 = vsubq_s16(v2214, v2216); + int16x8_t v2348 = vqrdmulhq_n_s16(v2347, 25826); + int16x8_t v2349 = vaddq_s16(v2346, v2348); + int16x8_t v2350 = vqrdmulhq_n_s16(v2349, 18124); + int16x8_t v2351 = vaddq_s16(v2345, v2350); + int16x8_t v2352 = vsubq_s16(v2222, v2224); + int16x8_t v2353 = vsubq_s16(v2226, v2228); + int16x8_t v2354 = vqrdmulhq_n_s16(v2353, 25826); + int16x8_t v2355 = vaddq_s16(v2352, v2354); + int16x8_t v2356 = vsubq_s16(v2232, v2234); + int16x8_t v2357 = vsubq_s16(v2236, v2238); + int16x8_t v2358 = vqrdmulhq_n_s16(v2357, 25826); + int16x8_t v2359 = vaddq_s16(v2356, v2358); + int16x8_t v2360 = vqrdmulhq_n_s16(v2359, 18124); + int16x8_t v2361 = vaddq_s16(v2355, v2360); + int16x8_t v2362 = vqrdmulhq_n_s16(v2361, 16792); + int16x8_t v2363 = vaddq_s16(v2351, v2362); + int16x8_t v2364 = vsubq_s16(v2246, v2248); + int16x8_t v2365 = vsubq_s16(v2250, v2252); + int16x8_t v2366 = vqrdmulhq_n_s16(v2365, 25826); + int16x8_t v2367 = vaddq_s16(v2364, v2366); + int16x8_t v2368 = vsubq_s16(v2256, v2258); + int16x8_t v2369 = vsubq_s16(v2260, v2262); + int16x8_t v2370 = vqrdmulhq_n_s16(v2369, 25826); + int16x8_t v2371 = vaddq_s16(v2368, v2370); + int16x8_t v2372 = vqrdmulhq_n_s16(v2371, 18124); + int16x8_t v2373 = vaddq_s16(v2367, v2372); + int16x8_t v2374 = vsubq_s16(v2268, v2270); + int16x8_t v2375 = vsubq_s16(v2272, v2274); + int16x8_t v2376 = vqrdmulhq_n_s16(v2375, 25826); + int16x8_t v2377 = vaddq_s16(v2374, v2376); + int16x8_t v2378 = vsubq_s16(v2278, v2280); + int16x8_t v2379 = vsubq_s16(v2282, v2284); + int16x8_t v2380 = vqrdmulhq_n_s16(v2379, 25826); + int16x8_t v2381 = vaddq_s16(v2378, v2380); + int16x8_t v2382 = vqrdmulhq_n_s16(v2381, 18124); + int16x8_t v2383 = vaddq_s16(v2377, v2382); + int16x8_t v2384 = vqrdmulhq_n_s16(v2383, 16792); + int16x8_t v2385 = vaddq_s16(v2373, v2384); + int16x8_t v2386 = vqrdmulhq_n_s16(v2385, 16484); + int16x8_t v2387 = vaddq_s16(v2363, v2386); + int16x8_t v2388 = vqrdmulhq_n_s16(v2387, 16409); + int16x8_t v2389 = vaddq_s16(v2341, v2388); + int16x8_t v2390 = vsubq_s16(v1916, v1918); + int16x8_t v2391 = vsubq_s16(v1920, v1922); + int16x8_t v2392_tmp = vqrdmulhq_n_s16(v2391, 1988); + int16x8_t v2392 = vaddq_s16(v2392_tmp, v2391); + int16x8_t v2393 = vaddq_s16(v2390, v2392); + int16x8_t v2394 = vsubq_s16(v1926, v1928); + int16x8_t v2395 = vsubq_s16(v1930, v1932); + int16x8_t v2396_tmp = vqrdmulhq_n_s16(v2395, 1988); + int16x8_t v2396 = vaddq_s16(v2396_tmp, v2395); + int16x8_t v2397 = vaddq_s16(v2394, v2396); + int16x8_t v2398 = vqrdmulhq_n_s16(v2397, 19102); + int16x8_t v2399 = vaddq_s16(v2393, v2398); + int16x8_t v2400 = vsubq_s16(v1938, v1940); + int16x8_t v2401 = vsubq_s16(v1942, v1944); + int16x8_t v2402_tmp = vqrdmulhq_n_s16(v2401, 1988); + int16x8_t v2402 = vaddq_s16(v2402_tmp, v2401); + int16x8_t v2403 = vaddq_s16(v2400, v2402); + int16x8_t v2404 = vsubq_s16(v1948, v1950); + int16x8_t v2405 = vsubq_s16(v1952, v1954); + int16x8_t v2406_tmp = vqrdmulhq_n_s16(v2405, 1988); + int16x8_t v2406 = vaddq_s16(v2406_tmp, v2405); + int16x8_t v2407 = vaddq_s16(v2404, v2406); + int16x8_t v2408 = vqrdmulhq_n_s16(v2407, 19102); + int16x8_t v2409 = vaddq_s16(v2403, v2408); + int16x8_t v2410 = vqrdmulhq_n_s16(v2409, 17000); + int16x8_t v2411 = vaddq_s16(v2399, v2410); + int16x8_t v2412 = vsubq_s16(v1962, v1964); + int16x8_t v2413 = vsubq_s16(v1966, v1968); + int16x8_t v2414_tmp = vqrdmulhq_n_s16(v2413, 1988); + int16x8_t v2414 = vaddq_s16(v2414_tmp, v2413); + int16x8_t v2415 = vaddq_s16(v2412, v2414); + int16x8_t v2416 = vsubq_s16(v1972, v1974); + int16x8_t v2417 = vsubq_s16(v1976, v1978); + int16x8_t v2418_tmp = vqrdmulhq_n_s16(v2417, 1988); + int16x8_t v2418 = vaddq_s16(v2418_tmp, v2417); + int16x8_t v2419 = vaddq_s16(v2416, v2418); + int16x8_t v2420 = vqrdmulhq_n_s16(v2419, 19102); + int16x8_t v2421 = vaddq_s16(v2415, v2420); + int16x8_t v2422 = vsubq_s16(v1984, v1986); + int16x8_t v2423 = vsubq_s16(v1988, v1990); + int16x8_t v2424_tmp = vqrdmulhq_n_s16(v2423, 1988); + int16x8_t v2424 = vaddq_s16(v2424_tmp, v2423); + int16x8_t v2425 = vaddq_s16(v2422, v2424); + int16x8_t v2426 = vsubq_s16(v1994, v1996); + int16x8_t v2427 = vsubq_s16(v1998, v2000); + int16x8_t v2428_tmp = vqrdmulhq_n_s16(v2427, 1988); + int16x8_t v2428 = vaddq_s16(v2428_tmp, v2427); + int16x8_t v2429 = vaddq_s16(v2426, v2428); + int16x8_t v2430 = vqrdmulhq_n_s16(v2429, 19102); + int16x8_t v2431 = vaddq_s16(v2425, v2430); + int16x8_t v2432 = vqrdmulhq_n_s16(v2431, 17000); + int16x8_t v2433 = vaddq_s16(v2421, v2432); + int16x8_t v2434 = vqrdmulhq_n_s16(v2433, 16534); + int16x8_t v2435 = vaddq_s16(v2411, v2434); + int16x8_t v2436 = vsubq_s16(v2010, v2012); + int16x8_t v2437 = vsubq_s16(v2014, v2016); + int16x8_t v2438_tmp = vqrdmulhq_n_s16(v2437, 1988); + int16x8_t v2438 = vaddq_s16(v2438_tmp, v2437); + int16x8_t v2439 = vaddq_s16(v2436, v2438); + int16x8_t v2440 = vsubq_s16(v2020, v2022); + int16x8_t v2441 = vsubq_s16(v2024, v2026); + int16x8_t v2442_tmp = vqrdmulhq_n_s16(v2441, 1988); + int16x8_t v2442 = vaddq_s16(v2442_tmp, v2441); + int16x8_t v2443 = vaddq_s16(v2440, v2442); + int16x8_t v2444 = vqrdmulhq_n_s16(v2443, 19102); + int16x8_t v2445 = vaddq_s16(v2439, v2444); + int16x8_t v2446 = vsubq_s16(v2032, v2034); + int16x8_t v2447 = vsubq_s16(v2036, v2038); + int16x8_t v2448_tmp = vqrdmulhq_n_s16(v2447, 1988); + int16x8_t v2448 = vaddq_s16(v2448_tmp, v2447); + int16x8_t v2449 = vaddq_s16(v2446, v2448); + int16x8_t v2450 = vsubq_s16(v2042, v2044); + int16x8_t v2451 = vsubq_s16(v2046, v2048); + int16x8_t v2452_tmp = vqrdmulhq_n_s16(v2451, 1988); + int16x8_t v2452 = vaddq_s16(v2452_tmp, v2451); + int16x8_t v2453 = vaddq_s16(v2450, v2452); + int16x8_t v2454 = vqrdmulhq_n_s16(v2453, 19102); + int16x8_t v2455 = vaddq_s16(v2449, v2454); + int16x8_t v2456 = vqrdmulhq_n_s16(v2455, 17000); + int16x8_t v2457 = vaddq_s16(v2445, v2456); + int16x8_t v2458 = vsubq_s16(v2056, v2058); + int16x8_t v2459 = vsubq_s16(v2060, v2062); + int16x8_t v2460_tmp = vqrdmulhq_n_s16(v2459, 1988); + int16x8_t v2460 = vaddq_s16(v2460_tmp, v2459); + int16x8_t v2461 = vaddq_s16(v2458, v2460); + int16x8_t v2462 = vsubq_s16(v2066, v2068); + int16x8_t v2463 = vqrdmulhq_n_s16(v2072, 29490); + int16x8_t v2464 = vsubq_s16(v2070, v2463); + int16x8_t v2465_tmp = vqrdmulhq_n_s16(v2464, 1988); + int16x8_t v2465 = vaddq_s16(v2465_tmp, v2464); + int16x8_t v2466 = vaddq_s16(v2462, v2465); + int16x8_t v2467 = vqrdmulhq_n_s16(v2466, 19102); + int16x8_t v2468 = vaddq_s16(v2461, v2467); + int16x8_t v2469 = vsubq_s16(v2078, v2080); + int16x8_t v2470 = vsubq_s16(v2082, v2084); + int16x8_t v2471_tmp = vqrdmulhq_n_s16(v2470, 1988); + int16x8_t v2471 = vaddq_s16(v2471_tmp, v2470); + int16x8_t v2472 = vaddq_s16(v2469, v2471); + int16x8_t v2473 = vsubq_s16(v2088, v2090); + int16x8_t v2474 = vsubq_s16(v2092, v2094); + int16x8_t v2475_tmp = vqrdmulhq_n_s16(v2474, 1988); + int16x8_t v2475 = vaddq_s16(v2475_tmp, v2474); + int16x8_t v2476 = vaddq_s16(v2473, v2475); + int16x8_t v2477 = vqrdmulhq_n_s16(v2476, 19102); + int16x8_t v2478 = vaddq_s16(v2472, v2477); + int16x8_t v2479 = vqrdmulhq_n_s16(v2478, 17000); + int16x8_t v2480 = vaddq_s16(v2468, v2479); + int16x8_t v2481 = vqrdmulhq_n_s16(v2480, 16534); + int16x8_t v2482 = vaddq_s16(v2457, v2481); + int16x8_t v2483 = vqrdmulhq_n_s16(v2482, 16421); + int16x8_t v2484 = vaddq_s16(v2435, v2483); + int16x8_t v2485 = vsubq_s16(v1537, v1542); + int16x8_t v2486 = vsubq_s16(v1547, v1552); + int16x8_t v2487_tmp = vqrdmulhq_n_s16(v2486, 23673); + int16x8_t v2487 = vaddq_s16(v2487_tmp, v2486); + int16x8_t v2488 = vaddq_s16(v2485, v2487); + int16x8_t v2489 = vsubq_s16(v1559, v1564); + int16x8_t v2490 = vsubq_s16(v1569, v1574); + int16x8_t v2491_tmp = vqrdmulhq_n_s16(v2490, 23673); + int16x8_t v2491 = vaddq_s16(v2491_tmp, v2490); + int16x8_t v2492 = vaddq_s16(v2489, v2491); + int16x8_t v2493 = vqrdmulhq_n_s16(v2492, 20398); + int16x8_t v2494 = vaddq_s16(v2488, v2493); + int16x8_t v2495 = vsubq_s16(v1583, v1588); + int16x8_t v2496 = vsubq_s16(v1593, v1598); + int16x8_t v2497_tmp = vqrdmulhq_n_s16(v2496, 23673); + int16x8_t v2497 = vaddq_s16(v2497_tmp, v2496); + int16x8_t v2498 = vaddq_s16(v2495, v2497); + int16x8_t v2499 = vsubq_s16(v1605, v1610); + int16x8_t v2500 = vsubq_s16(v1615, v1620); + int16x8_t v2501_tmp = vqrdmulhq_n_s16(v2500, 23673); + int16x8_t v2501 = vaddq_s16(v2501_tmp, v2500); + int16x8_t v2502 = vaddq_s16(v2499, v2501); + int16x8_t v2503 = vqrdmulhq_n_s16(v2502, 20398); + int16x8_t v2504 = vaddq_s16(v2498, v2503); + int16x8_t v2505 = vqrdmulhq_n_s16(v2504, 17255); + int16x8_t v2506 = vaddq_s16(v2494, v2505); + int16x8_t v2507 = vsubq_s16(v1631, v1636); + int16x8_t v2508 = vsubq_s16(v1641, v1646); + int16x8_t v2509_tmp = vqrdmulhq_n_s16(v2508, 23673); + int16x8_t v2509 = vaddq_s16(v2509_tmp, v2508); + int16x8_t v2510 = vaddq_s16(v2507, v2509); + int16x8_t v2511 = vsubq_s16(v1653, v1658); + int16x8_t v2512 = vsubq_s16(v1663, v1668); + int16x8_t v2513_tmp = vqrdmulhq_n_s16(v2512, 23673); + int16x8_t v2513 = vaddq_s16(v2513_tmp, v2512); + int16x8_t v2514 = vaddq_s16(v2511, v2513); + int16x8_t v2515 = vqrdmulhq_n_s16(v2514, 20398); + int16x8_t v2516 = vaddq_s16(v2510, v2515); + int16x8_t v2517 = vsubq_s16(v1677, v1682); + int16x8_t v2518 = vsubq_s16(v1687, v1692); + int16x8_t v2519_tmp = vqrdmulhq_n_s16(v2518, 23673); + int16x8_t v2519 = vaddq_s16(v2519_tmp, v2518); + int16x8_t v2520 = vaddq_s16(v2517, v2519); + int16x8_t v2521 = vsubq_s16(v1699, v1704); + int16x8_t v2522 = vsubq_s16(v1709, v1714); + int16x8_t v2523_tmp = vqrdmulhq_n_s16(v2522, 23673); + int16x8_t v2523 = vaddq_s16(v2523_tmp, v2522); + int16x8_t v2524 = vaddq_s16(v2521, v2523); + int16x8_t v2525 = vqrdmulhq_n_s16(v2524, 20398); + int16x8_t v2526 = vaddq_s16(v2520, v2525); + int16x8_t v2527 = vqrdmulhq_n_s16(v2526, 17255); + int16x8_t v2528 = vaddq_s16(v2516, v2527); + int16x8_t v2529 = vqrdmulhq_n_s16(v2528, 16595); + int16x8_t v2530 = vaddq_s16(v2506, v2529); + int16x8_t v2531 = vsubq_s16(v1727, v1732); + int16x8_t v2532 = vsubq_s16(v1737, v1742); + int16x8_t v2533_tmp = vqrdmulhq_n_s16(v2532, 23673); + int16x8_t v2533 = vaddq_s16(v2533_tmp, v2532); + int16x8_t v2534 = vaddq_s16(v2531, v2533); + int16x8_t v2535 = vsubq_s16(v1749, v1754); + int16x8_t v2536 = vsubq_s16(v1759, v1764); + int16x8_t v2537_tmp = vqrdmulhq_n_s16(v2536, 23673); + int16x8_t v2537 = vaddq_s16(v2537_tmp, v2536); + int16x8_t v2538 = vaddq_s16(v2535, v2537); + int16x8_t v2539 = vqrdmulhq_n_s16(v2538, 20398); + int16x8_t v2540 = vaddq_s16(v2534, v2539); + int16x8_t v2541 = vsubq_s16(v1773, v1778); + int16x8_t v2542 = vsubq_s16(v1783, v1788); + int16x8_t v2543_tmp = vqrdmulhq_n_s16(v2542, 23673); + int16x8_t v2543 = vaddq_s16(v2543_tmp, v2542); + int16x8_t v2544 = vaddq_s16(v2541, v2543); + int16x8_t v2545 = vsubq_s16(v1795, v1800); + int16x8_t v2546 = vsubq_s16(v1805, v1810); + int16x8_t v2547_tmp = vqrdmulhq_n_s16(v2546, 23673); + int16x8_t v2547 = vaddq_s16(v2547_tmp, v2546); + int16x8_t v2548 = vaddq_s16(v2545, v2547); + int16x8_t v2549 = vqrdmulhq_n_s16(v2548, 20398); + int16x8_t v2550 = vaddq_s16(v2544, v2549); + int16x8_t v2551 = vqrdmulhq_n_s16(v2550, 17255); + int16x8_t v2552 = vaddq_s16(v2540, v2551); + int16x8_t v2553 = vsubq_s16(v1821, v1826); + int16x8_t v2554 = vsubq_s16(v1831, v1836); + int16x8_t v2555_tmp = vqrdmulhq_n_s16(v2554, 23673); + int16x8_t v2555 = vaddq_s16(v2555_tmp, v2554); + int16x8_t v2556 = vaddq_s16(v2553, v2555); + int16x8_t v2557 = vsubq_s16(v1843, v1848); + int16x8_t v2558 = vsubq_s16(v1853, v1858); + int16x8_t v2559_tmp = vqrdmulhq_n_s16(v2558, 23673); + int16x8_t v2559 = vaddq_s16(v2559_tmp, v2558); + int16x8_t v2560 = vaddq_s16(v2557, v2559); + int16x8_t v2561 = vqrdmulhq_n_s16(v2560, 20398); + int16x8_t v2562 = vaddq_s16(v2556, v2561); + int16x8_t v2563 = vsubq_s16(v1867, v1872); + int16x8_t v2564 = vsubq_s16(v1877, v1882); + int16x8_t v2565_tmp = vqrdmulhq_n_s16(v2564, 23673); + int16x8_t v2565 = vaddq_s16(v2565_tmp, v2564); + int16x8_t v2566 = vaddq_s16(v2563, v2565); + int16x8_t v2567 = vsubq_s16(v1889, v1894); + int16x8_t v2568 = vsubq_s16(v1899, v1904); + int16x8_t v2569_tmp = vqrdmulhq_n_s16(v2568, 23673); + int16x8_t v2569 = vaddq_s16(v2569_tmp, v2568); + int16x8_t v2570 = vaddq_s16(v2567, v2569); + int16x8_t v2571 = vqrdmulhq_n_s16(v2570, 20398); + int16x8_t v2572 = vaddq_s16(v2566, v2571); + int16x8_t v2573 = vqrdmulhq_n_s16(v2572, 17255); + int16x8_t v2574 = vaddq_s16(v2562, v2573); + int16x8_t v2575 = vqrdmulhq_n_s16(v2574, 16595); + int16x8_t v2576 = vaddq_s16(v2552, v2575); + int16x8_t v2577 = vqrdmulhq_n_s16(v2576, 16436); + int16x8_t v2578 = vaddq_s16(v2530, v2577); + int16x8_t v2579 = vsubq_s16(v9, v24); + int16x8_t v2580 = vsubq_s16(v42, v58); + int16x8_t v2581_tmp = vqrdmulhq_n_s16(v2580, 3314); + int16x8_t v2581 = vmlaq_n_s16(v2581_tmp, v2580, 5); + int16x8_t v2582 = vaddq_s16(v2579, v2581); + int16x8_t v2583 = vsubq_s16(v78, v101); + int16x8_t v2584 = vsubq_s16(v119, v136); + int16x8_t v2585_tmp = vqrdmulhq_n_s16(v2584, 3314); + int16x8_t v2585 = vmlaq_n_s16(v2585_tmp, v2584, 5); + int16x8_t v2586 = vaddq_s16(v2583, v2585); + int16x8_t v2587 = vqrdmulhq_n_s16(v2586, 22112); + int16x8_t v2588 = vaddq_s16(v2582, v2587); + int16x8_t v2589 = vsubq_s16(v158, v181); + int16x8_t v2590 = vsubq_s16(v213, v231); + int16x8_t v2591_tmp = vqrdmulhq_n_s16(v2590, 3314); + int16x8_t v2591 = vmlaq_n_s16(v2591_tmp, v2590, 5); + int16x8_t v2592 = vaddq_s16(v2589, v2591); + int16x8_t v2593 = vsubq_s16(v251, v274); + int16x8_t v2594 = vsubq_s16(v292, v310); + int16x8_t v2595_tmp = vqrdmulhq_n_s16(v2594, 3314); + int16x8_t v2595 = vmlaq_n_s16(v2595_tmp, v2594, 5); + int16x8_t v2596 = vaddq_s16(v2593, v2595); + int16x8_t v2597 = vqrdmulhq_n_s16(v2596, 22112); + int16x8_t v2598 = vaddq_s16(v2592, v2597); + int16x8_t v2599 = vqrdmulhq_n_s16(v2598, 17561); + int16x8_t v2600 = vaddq_s16(v2588, v2599); + int16x8_t v2601 = vsubq_s16(v334, v357); + int16x8_t v2602 = vsubq_s16(v389, v407); + int16x8_t v2603_tmp = vqrdmulhq_n_s16(v2602, 3314); + int16x8_t v2603 = vmlaq_n_s16(v2603_tmp, v2602, 5); + int16x8_t v2604 = vaddq_s16(v2601, v2603); + int16x8_t v2605 = vsubq_s16(v441, v480); + int16x8_t v2606 = vsubq_s16(v498, v517); + int16x8_t v2607_tmp = vqrdmulhq_n_s16(v2606, 3314); + int16x8_t v2607 = vmlaq_n_s16(v2607_tmp, v2606, 5); + int16x8_t v2608 = vaddq_s16(v2605, v2607); + int16x8_t v2609 = vqrdmulhq_n_s16(v2608, 22112); + int16x8_t v2610 = vaddq_s16(v2604, v2609); + int16x8_t v2611 = vsubq_s16(v539, v562); + int16x8_t v2612 = vsubq_s16(v594, v612); + int16x8_t v2613_tmp = vqrdmulhq_n_s16(v2612, 3314); + int16x8_t v2613 = vmlaq_n_s16(v2613_tmp, v2612, 5); + int16x8_t v2614 = vaddq_s16(v2611, v2613); + int16x8_t v2615 = vsubq_s16(v632, v655); + int16x8_t v2616 = vsubq_s16(v673, v692); + int16x8_t v2617_tmp = vqrdmulhq_n_s16(v2616, 3314); + int16x8_t v2617 = vmlaq_n_s16(v2617_tmp, v2616, 5); + int16x8_t v2618 = vaddq_s16(v2615, v2617); + int16x8_t v2619 = vqrdmulhq_n_s16(v2618, 22112); + int16x8_t v2620 = vaddq_s16(v2614, v2619); + int16x8_t v2621 = vqrdmulhq_n_s16(v2620, 17561); + int16x8_t v2622 = vaddq_s16(v2610, v2621); + int16x8_t v2623 = vqrdmulhq_n_s16(v2622, 16666); + int16x8_t v2624 = vaddq_s16(v2600, v2623); + int16x8_t v2625 = vsubq_s16(v718, v741); + int16x8_t v2626 = vsubq_s16(v773, v791); + int16x8_t v2627_tmp = vqrdmulhq_n_s16(v2626, 3314); + int16x8_t v2627 = vmlaq_n_s16(v2627_tmp, v2626, 5); + int16x8_t v2628 = vaddq_s16(v2625, v2627); + int16x8_t v2629 = vsubq_s16(v825, v864); + int16x8_t v2630 = vsubq_s16(v882, v901); + int16x8_t v2631_tmp = vqrdmulhq_n_s16(v2630, 3314); + int16x8_t v2631 = vmlaq_n_s16(v2631_tmp, v2630, 5); + int16x8_t v2632 = vaddq_s16(v2629, v2631); + int16x8_t v2633 = vqrdmulhq_n_s16(v2632, 22112); + int16x8_t v2634 = vaddq_s16(v2628, v2633); + int16x8_t v2635 = vsubq_s16(v937, v976); + int16x8_t v2636 = vsubq_s16(v1036, v1058); + int16x8_t v2637_tmp = vqrdmulhq_n_s16(v2636, 3314); + int16x8_t v2637 = vmlaq_n_s16(v2637_tmp, v2636, 5); + int16x8_t v2638 = vaddq_s16(v2635, v2637); + int16x8_t v2639 = vsubq_s16(v1078, v1101); + int16x8_t v2640 = vsubq_s16(v1119, v1139); + int16x8_t v2641_tmp = vqrdmulhq_n_s16(v2640, 3314); + int16x8_t v2641 = vmlaq_n_s16(v2641_tmp, v2640, 5); + int16x8_t v2642 = vaddq_s16(v2639, v2641); + int16x8_t v2643 = vqrdmulhq_n_s16(v2642, 22112); + int16x8_t v2644 = vaddq_s16(v2638, v2643); + int16x8_t v2645 = vqrdmulhq_n_s16(v2644, 17561); + int16x8_t v2646 = vaddq_s16(v2634, v2645); + int16x8_t v2647 = vsubq_s16(v1163, v1186); + int16x8_t v2648 = vsubq_s16(v1218, v1236); + int16x8_t v2649_tmp = vqrdmulhq_n_s16(v2648, 3314); + int16x8_t v2649 = vmlaq_n_s16(v2649_tmp, v2648, 5); + int16x8_t v2650 = vaddq_s16(v2647, v2649); + int16x8_t v2651 = vsubq_s16(v1270, v1309); + int16x8_t v2652 = vsubq_s16(v1327, v1346); + int16x8_t v2653_tmp = vqrdmulhq_n_s16(v2652, 3314); + int16x8_t v2653 = vmlaq_n_s16(v2653_tmp, v2652, 5); + int16x8_t v2654 = vaddq_s16(v2651, v2653); + int16x8_t v2655 = vqrdmulhq_n_s16(v2654, 22112); + int16x8_t v2656 = vaddq_s16(v2650, v2655); + int16x8_t v2657 = vsubq_s16(v1368, v1391); + int16x8_t v2658 = vsubq_s16(v1423, v1441); + int16x8_t v2659_tmp = vqrdmulhq_n_s16(v2658, 3314); + int16x8_t v2659 = vmlaq_n_s16(v2659_tmp, v2658, 5); + int16x8_t v2660 = vaddq_s16(v2657, v2659); + int16x8_t v2661 = vsubq_s16(v1461, v1484); + int16x8_t v2662 = vsubq_s16(v1502, v1522); + int16x8_t v2663_tmp = vqrdmulhq_n_s16(v2662, 3314); + int16x8_t v2663 = vmlaq_n_s16(v2663_tmp, v2662, 5); + int16x8_t v2664 = vaddq_s16(v2661, v2663); + int16x8_t v2665 = vqrdmulhq_n_s16(v2664, 22112); + int16x8_t v2666 = vaddq_s16(v2660, v2665); + int16x8_t v2667 = vqrdmulhq_n_s16(v2666, 17561); + int16x8_t v2668 = vaddq_s16(v2656, v2667); + int16x8_t v2669 = vqrdmulhq_n_s16(v2668, 16666); + int16x8_t v2670 = vaddq_s16(v2646, v2669); + int16x8_t v2671 = vqrdmulhq_n_s16(v2670, 16454); + int16x8_t v2672 = vaddq_s16(v2624, v2671); + int16x8_t v2673 = vsubq_s16(v2579, v2581); + int16x8_t v2674 = vsubq_s16(v2583, v2585); + int16x8_t v2675 = vqrdmulhq_n_s16(v2674, 24397); + int16x8_t v2676 = vaddq_s16(v2673, v2675); + int16x8_t v2677 = vsubq_s16(v2589, v2591); + int16x8_t v2678 = vsubq_s16(v2593, v2595); + int16x8_t v2679 = vqrdmulhq_n_s16(v2678, 24397); + int16x8_t v2680 = vaddq_s16(v2677, v2679); + int16x8_t v2681 = vqrdmulhq_n_s16(v2680, 17921); + int16x8_t v2682 = vaddq_s16(v2676, v2681); + int16x8_t v2683 = vsubq_s16(v2601, v2603); + int16x8_t v2684 = vsubq_s16(v2605, v2607); + int16x8_t v2685 = vqrdmulhq_n_s16(v2684, 24397); + int16x8_t v2686 = vaddq_s16(v2683, v2685); + int16x8_t v2687 = vsubq_s16(v2611, v2613); + int16x8_t v2688 = vsubq_s16(v2615, v2617); + int16x8_t v2689 = vqrdmulhq_n_s16(v2688, 24397); + int16x8_t v2690 = vaddq_s16(v2687, v2689); + int16x8_t v2691 = vqrdmulhq_n_s16(v2690, 17921); + int16x8_t v2692 = vaddq_s16(v2686, v2691); + int16x8_t v2693 = vqrdmulhq_n_s16(v2692, 16747); + int16x8_t v2694 = vaddq_s16(v2682, v2693); + int16x8_t v2695 = vsubq_s16(v2625, v2627); + int16x8_t v2696 = vsubq_s16(v2629, v2631); + int16x8_t v2697 = vqrdmulhq_n_s16(v2696, 24397); + int16x8_t v2698 = vaddq_s16(v2695, v2697); + int16x8_t v2699 = vsubq_s16(v2635, v2637); + int16x8_t v2700 = vsubq_s16(v2639, v2641); + int16x8_t v2701 = vqrdmulhq_n_s16(v2700, 24397); + int16x8_t v2702 = vaddq_s16(v2699, v2701); + int16x8_t v2703 = vqrdmulhq_n_s16(v2702, 17921); + int16x8_t v2704 = vaddq_s16(v2698, v2703); + int16x8_t v2705 = vsubq_s16(v2647, v2649); + int16x8_t v2706 = vsubq_s16(v2651, v2653); + int16x8_t v2707 = vqrdmulhq_n_s16(v2706, 24397); + int16x8_t v2708 = vaddq_s16(v2705, v2707); + int16x8_t v2709 = vsubq_s16(v2657, v2659); + int16x8_t v2710 = vsubq_s16(v2661, v2663); + int16x8_t v2711 = vqrdmulhq_n_s16(v2710, 24397); + int16x8_t v2712 = vaddq_s16(v2709, v2711); + int16x8_t v2713 = vqrdmulhq_n_s16(v2712, 17921); + int16x8_t v2714 = vaddq_s16(v2708, v2713); + int16x8_t v2715 = vqrdmulhq_n_s16(v2714, 16747); + int16x8_t v2716 = vaddq_s16(v2704, v2715); + int16x8_t v2717 = vqrdmulhq_n_s16(v2716, 16474); + int16x8_t v2718 = vaddq_s16(v2694, v2717); + int16x8_t v2719 = vsubq_s16(v2485, v2487); + int16x8_t v2720 = vsubq_s16(v2489, v2491); + int16x8_t v2721 = vqrdmulhq_n_s16(v2720, 27504); + int16x8_t v2722 = vaddq_s16(v2719, v2721); + int16x8_t v2723 = vsubq_s16(v2495, v2497); + int16x8_t v2724 = vsubq_s16(v2499, v2501); + int16x8_t v2725 = vqrdmulhq_n_s16(v2724, 27504); + int16x8_t v2726 = vaddq_s16(v2723, v2725); + int16x8_t v2727 = vqrdmulhq_n_s16(v2726, 18343); + int16x8_t v2728 = vaddq_s16(v2722, v2727); + int16x8_t v2729 = vsubq_s16(v2507, v2509); + int16x8_t v2730 = vsubq_s16(v2511, v2513); + int16x8_t v2731 = vqrdmulhq_n_s16(v2730, 27504); + int16x8_t v2732 = vaddq_s16(v2729, v2731); + int16x8_t v2733 = vsubq_s16(v2517, v2519); + int16x8_t v2734 = vsubq_s16(v2521, v2523); + int16x8_t v2735 = vqrdmulhq_n_s16(v2734, 27504); + int16x8_t v2736 = vaddq_s16(v2733, v2735); + int16x8_t v2737 = vqrdmulhq_n_s16(v2736, 18343); + int16x8_t v2738 = vaddq_s16(v2732, v2737); + int16x8_t v2739 = vqrdmulhq_n_s16(v2738, 16840); + int16x8_t v2740 = vaddq_s16(v2728, v2739); + int16x8_t v2741 = vsubq_s16(v2531, v2533); + int16x8_t v2742 = vsubq_s16(v2535, v2537); + int16x8_t v2743 = vqrdmulhq_n_s16(v2742, 27504); + int16x8_t v2744 = vaddq_s16(v2741, v2743); + int16x8_t v2745 = vsubq_s16(v2541, v2543); + int16x8_t v2746 = vsubq_s16(v2545, v2547); + int16x8_t v2747 = vqrdmulhq_n_s16(v2746, 27504); + int16x8_t v2748 = vaddq_s16(v2745, v2747); + int16x8_t v2749 = vqrdmulhq_n_s16(v2748, 18343); + int16x8_t v2750 = vaddq_s16(v2744, v2749); + int16x8_t v2751 = vsubq_s16(v2553, v2555); + int16x8_t v2752 = vsubq_s16(v2557, v2559); + int16x8_t v2753 = vqrdmulhq_n_s16(v2752, 27504); + int16x8_t v2754 = vaddq_s16(v2751, v2753); + int16x8_t v2755 = vsubq_s16(v2563, v2565); + int16x8_t v2756 = vsubq_s16(v2567, v2569); + int16x8_t v2757 = vqrdmulhq_n_s16(v2756, 27504); + int16x8_t v2758 = vaddq_s16(v2755, v2757); + int16x8_t v2759 = vqrdmulhq_n_s16(v2758, 18343); + int16x8_t v2760 = vaddq_s16(v2754, v2759); + int16x8_t v2761 = vqrdmulhq_n_s16(v2760, 16840); + int16x8_t v2762 = vaddq_s16(v2750, v2761); + int16x8_t v2763 = vqrdmulhq_n_s16(v2762, 16496); + int16x8_t v2764 = vaddq_s16(v2740, v2763); + int16x8_t v2765 = vsubq_s16(v2390, v2392); + int16x8_t v2766 = vsubq_s16(v2394, v2396); + int16x8_t v2767 = vqrdmulhq_n_s16(v2766, 31869); + int16x8_t v2768 = vaddq_s16(v2765, v2767); + int16x8_t v2769 = vsubq_s16(v2400, v2402); + int16x8_t v2770 = vsubq_s16(v2404, v2406); + int16x8_t v2771 = vqrdmulhq_n_s16(v2770, 31869); + int16x8_t v2772 = vaddq_s16(v2769, v2771); + int16x8_t v2773 = vqrdmulhq_n_s16(v2772, 18830); + int16x8_t v2774 = vaddq_s16(v2768, v2773); + int16x8_t v2775 = vsubq_s16(v2412, v2414); + int16x8_t v2776 = vsubq_s16(v2416, v2418); + int16x8_t v2777 = vqrdmulhq_n_s16(v2776, 31869); + int16x8_t v2778 = vaddq_s16(v2775, v2777); + int16x8_t v2779 = vsubq_s16(v2422, v2424); + int16x8_t v2780 = vsubq_s16(v2426, v2428); + int16x8_t v2781 = vqrdmulhq_n_s16(v2780, 31869); + int16x8_t v2782 = vaddq_s16(v2779, v2781); + int16x8_t v2783 = vqrdmulhq_n_s16(v2782, 18830); + int16x8_t v2784 = vaddq_s16(v2778, v2783); + int16x8_t v2785 = vqrdmulhq_n_s16(v2784, 16944); + int16x8_t v2786 = vaddq_s16(v2774, v2785); + int16x8_t v2787 = vsubq_s16(v2436, v2438); + int16x8_t v2788 = vsubq_s16(v2440, v2442); + int16x8_t v2789 = vqrdmulhq_n_s16(v2788, 31869); + int16x8_t v2790 = vaddq_s16(v2787, v2789); + int16x8_t v2791 = vsubq_s16(v2446, v2448); + int16x8_t v2792 = vsubq_s16(v2450, v2452); + int16x8_t v2793 = vqrdmulhq_n_s16(v2792, 31869); + int16x8_t v2794 = vaddq_s16(v2791, v2793); + int16x8_t v2795 = vqrdmulhq_n_s16(v2794, 18830); + int16x8_t v2796 = vaddq_s16(v2790, v2795); + int16x8_t v2797 = vsubq_s16(v2458, v2460); + int16x8_t v2798 = vsubq_s16(v2462, v2465); + int16x8_t v2799 = vqrdmulhq_n_s16(v2798, 31869); + int16x8_t v2800 = vaddq_s16(v2797, v2799); + int16x8_t v2801 = vsubq_s16(v2469, v2471); + int16x8_t v2802 = vsubq_s16(v2473, v2475); + int16x8_t v2803 = vqrdmulhq_n_s16(v2802, 31869); + int16x8_t v2804 = vaddq_s16(v2801, v2803); + int16x8_t v2805 = vqrdmulhq_n_s16(v2804, 18830); + int16x8_t v2806 = vaddq_s16(v2800, v2805); + int16x8_t v2807 = vqrdmulhq_n_s16(v2806, 16944); + int16x8_t v2808 = vaddq_s16(v2796, v2807); + int16x8_t v2809 = vqrdmulhq_n_s16(v2808, 16521); + int16x8_t v2810 = vaddq_s16(v2786, v2809); + int16x8_t v2811 = vsubq_s16(v2296, v2298); + int16x8_t v2812 = vsubq_s16(v2300, v2302); + int16x8_t v2813_tmp = vqrdmulhq_n_s16(v2812, 5552); + int16x8_t v2813 = vaddq_s16(v2813_tmp, v2812); + int16x8_t v2814 = vaddq_s16(v2811, v2813); + int16x8_t v2815 = vsubq_s16(v2306, v2308); + int16x8_t v2816 = vsubq_s16(v2310, v2312); + int16x8_t v2817_tmp = vqrdmulhq_n_s16(v2816, 5552); + int16x8_t v2817 = vaddq_s16(v2817_tmp, v2816); + int16x8_t v2818 = vaddq_s16(v2815, v2817); + int16x8_t v2819 = vqrdmulhq_n_s16(v2818, 19393); + int16x8_t v2820 = vaddq_s16(v2814, v2819); + int16x8_t v2821 = vsubq_s16(v2318, v2320); + int16x8_t v2822 = vsubq_s16(v2322, v2324); + int16x8_t v2823_tmp = vqrdmulhq_n_s16(v2822, 5552); + int16x8_t v2823 = vaddq_s16(v2823_tmp, v2822); + int16x8_t v2824 = vaddq_s16(v2821, v2823); + int16x8_t v2825 = vsubq_s16(v2328, v2330); + int16x8_t v2826 = vsubq_s16(v2332, v2334); + int16x8_t v2827_tmp = vqrdmulhq_n_s16(v2826, 5552); + int16x8_t v2827 = vaddq_s16(v2827_tmp, v2826); + int16x8_t v2828 = vaddq_s16(v2825, v2827); + int16x8_t v2829 = vqrdmulhq_n_s16(v2828, 19393); + int16x8_t v2830 = vaddq_s16(v2824, v2829); + int16x8_t v2831 = vqrdmulhq_n_s16(v2830, 17059); + int16x8_t v2832 = vaddq_s16(v2820, v2831); + int16x8_t v2833 = vsubq_s16(v2342, v2344); + int16x8_t v2834 = vsubq_s16(v2346, v2348); + int16x8_t v2835_tmp = vqrdmulhq_n_s16(v2834, 5552); + int16x8_t v2835 = vaddq_s16(v2835_tmp, v2834); + int16x8_t v2836 = vaddq_s16(v2833, v2835); + int16x8_t v2837 = vsubq_s16(v2352, v2354); + int16x8_t v2838 = vsubq_s16(v2356, v2358); + int16x8_t v2839_tmp = vqrdmulhq_n_s16(v2838, 5552); + int16x8_t v2839 = vaddq_s16(v2839_tmp, v2838); + int16x8_t v2840 = vaddq_s16(v2837, v2839); + int16x8_t v2841 = vqrdmulhq_n_s16(v2840, 19393); + int16x8_t v2842 = vaddq_s16(v2836, v2841); + int16x8_t v2843 = vsubq_s16(v2364, v2366); + int16x8_t v2844 = vsubq_s16(v2368, v2370); + int16x8_t v2845_tmp = vqrdmulhq_n_s16(v2844, 5552); + int16x8_t v2845 = vaddq_s16(v2845_tmp, v2844); + int16x8_t v2846 = vaddq_s16(v2843, v2845); + int16x8_t v2847 = vsubq_s16(v2374, v2376); + int16x8_t v2848 = vsubq_s16(v2378, v2380); + int16x8_t v2849_tmp = vqrdmulhq_n_s16(v2848, 5552); + int16x8_t v2849 = vaddq_s16(v2849_tmp, v2848); + int16x8_t v2850 = vaddq_s16(v2847, v2849); + int16x8_t v2851 = vqrdmulhq_n_s16(v2850, 19393); + int16x8_t v2852 = vaddq_s16(v2846, v2851); + int16x8_t v2853 = vqrdmulhq_n_s16(v2852, 17059); + int16x8_t v2854 = vaddq_s16(v2842, v2853); + int16x8_t v2855 = vqrdmulhq_n_s16(v2854, 16549); + int16x8_t v2856 = vaddq_s16(v2832, v2855); + int16x8_t v2857 = vsubq_s16(v2109, v2114); + int16x8_t v2858 = vsubq_s16(v2119, v2124); + int16x8_t v2859_tmp = vqrdmulhq_n_s16(v2858, 15865); + int16x8_t v2859 = vaddq_s16(v2859_tmp, v2858); + int16x8_t v2860 = vaddq_s16(v2857, v2859); + int16x8_t v2861 = vsubq_s16(v2131, v2136); + int16x8_t v2862 = vsubq_s16(v2141, v2146); + int16x8_t v2863_tmp = vqrdmulhq_n_s16(v2862, 15865); + int16x8_t v2863 = vaddq_s16(v2863_tmp, v2862); + int16x8_t v2864 = vaddq_s16(v2861, v2863); + int16x8_t v2865 = vqrdmulhq_n_s16(v2864, 20040); + int16x8_t v2866 = vaddq_s16(v2860, v2865); + int16x8_t v2867 = vsubq_s16(v2155, v2160); + int16x8_t v2868 = vsubq_s16(v2165, v2170); + int16x8_t v2869_tmp = vqrdmulhq_n_s16(v2868, 15865); + int16x8_t v2869 = vaddq_s16(v2869_tmp, v2868); + int16x8_t v2870 = vaddq_s16(v2867, v2869); + int16x8_t v2871 = vsubq_s16(v2177, v2182); + int16x8_t v2872 = vsubq_s16(v2187, v2192); + int16x8_t v2873_tmp = vqrdmulhq_n_s16(v2872, 15865); + int16x8_t v2873 = vaddq_s16(v2873_tmp, v2872); + int16x8_t v2874 = vaddq_s16(v2871, v2873); + int16x8_t v2875 = vqrdmulhq_n_s16(v2874, 20040); + int16x8_t v2876 = vaddq_s16(v2870, v2875); + int16x8_t v2877 = vqrdmulhq_n_s16(v2876, 17187); + int16x8_t v2878 = vaddq_s16(v2866, v2877); + int16x8_t v2879 = vsubq_s16(v2203, v2208); + int16x8_t v2880 = vsubq_s16(v2213, v2218); + int16x8_t v2881_tmp = vqrdmulhq_n_s16(v2880, 15865); + int16x8_t v2881 = vaddq_s16(v2881_tmp, v2880); + int16x8_t v2882 = vaddq_s16(v2879, v2881); + int16x8_t v2883 = vsubq_s16(v2225, v2230); + int16x8_t v2884 = vsubq_s16(v2235, v2240); + int16x8_t v2885_tmp = vqrdmulhq_n_s16(v2884, 15865); + int16x8_t v2885 = vaddq_s16(v2885_tmp, v2884); + int16x8_t v2886 = vaddq_s16(v2883, v2885); + int16x8_t v2887 = vqrdmulhq_n_s16(v2886, 20040); + int16x8_t v2888 = vaddq_s16(v2882, v2887); + int16x8_t v2889 = vsubq_s16(v2249, v2254); + int16x8_t v2890 = vsubq_s16(v2259, v2264); + int16x8_t v2891_tmp = vqrdmulhq_n_s16(v2890, 15865); + int16x8_t v2891 = vaddq_s16(v2891_tmp, v2890); + int16x8_t v2892 = vaddq_s16(v2889, v2891); + int16x8_t v2893 = vsubq_s16(v2271, v2276); + int16x8_t v2894 = vsubq_s16(v2281, v2286); + int16x8_t v2895_tmp = vqrdmulhq_n_s16(v2894, 15865); + int16x8_t v2895 = vaddq_s16(v2895_tmp, v2894); + int16x8_t v2896 = vaddq_s16(v2893, v2895); + int16x8_t v2897 = vqrdmulhq_n_s16(v2896, 20040); + int16x8_t v2898 = vaddq_s16(v2892, v2897); + int16x8_t v2899 = vqrdmulhq_n_s16(v2898, 17187); + int16x8_t v2900 = vaddq_s16(v2888, v2899); + int16x8_t v2901 = vqrdmulhq_n_s16(v2900, 16579); + int16x8_t v2902 = vaddq_s16(v2878, v2901); + int16x8_t v2903 = vsubq_s16(v1919, v1924); + int16x8_t v2904 = vsubq_s16(v1929, v1934); + int16x8_t v2905_tmp = vqrdmulhq_n_s16(v2904, 1893); + int16x8_t v2905 = vmlaq_n_s16(v2905_tmp, v2904, 2); + int16x8_t v2906 = vaddq_s16(v2903, v2905); + int16x8_t v2907 = vsubq_s16(v1941, v1946); + int16x8_t v2908 = vsubq_s16(v1951, v1956); + int16x8_t v2909_tmp = vqrdmulhq_n_s16(v2908, 1893); + int16x8_t v2909 = vmlaq_n_s16(v2909_tmp, v2908, 2); + int16x8_t v2910 = vaddq_s16(v2907, v2909); + int16x8_t v2911 = vqrdmulhq_n_s16(v2910, 20783); + int16x8_t v2912 = vaddq_s16(v2906, v2911); + int16x8_t v2913 = vsubq_s16(v1965, v1970); + int16x8_t v2914 = vsubq_s16(v1975, v1980); + int16x8_t v2915_tmp = vqrdmulhq_n_s16(v2914, 1893); + int16x8_t v2915 = vmlaq_n_s16(v2915_tmp, v2914, 2); + int16x8_t v2916 = vaddq_s16(v2913, v2915); + int16x8_t v2917 = vsubq_s16(v1987, v1992); + int16x8_t v2918 = vsubq_s16(v1997, v2002); + int16x8_t v2919_tmp = vqrdmulhq_n_s16(v2918, 1893); + int16x8_t v2919 = vmlaq_n_s16(v2919_tmp, v2918, 2); + int16x8_t v2920 = vaddq_s16(v2917, v2919); + int16x8_t v2921 = vqrdmulhq_n_s16(v2920, 20783); + int16x8_t v2922 = vaddq_s16(v2916, v2921); + int16x8_t v2923 = vqrdmulhq_n_s16(v2922, 17326); + int16x8_t v2924 = vaddq_s16(v2912, v2923); + int16x8_t v2925 = vsubq_s16(v2013, v2018); + int16x8_t v2926 = vsubq_s16(v2023, v2028); + int16x8_t v2927_tmp = vqrdmulhq_n_s16(v2926, 1893); + int16x8_t v2927 = vmlaq_n_s16(v2927_tmp, v2926, 2); + int16x8_t v2928 = vaddq_s16(v2925, v2927); + int16x8_t v2929 = vsubq_s16(v2035, v2040); + int16x8_t v2930 = vsubq_s16(v2045, v2050); + int16x8_t v2931_tmp = vqrdmulhq_n_s16(v2930, 1893); + int16x8_t v2931 = vmlaq_n_s16(v2931_tmp, v2930, 2); + int16x8_t v2932 = vaddq_s16(v2929, v2931); + int16x8_t v2933 = vqrdmulhq_n_s16(v2932, 20783); + int16x8_t v2934 = vaddq_s16(v2928, v2933); + int16x8_t v2935 = vsubq_s16(v2059, v2064); + int16x8_t v2936 = vsubq_s16(v2069, v2074); + int16x8_t v2937_tmp = vqrdmulhq_n_s16(v2936, 1893); + int16x8_t v2937 = vmlaq_n_s16(v2937_tmp, v2936, 2); + int16x8_t v2938 = vaddq_s16(v2935, v2937); + int16x8_t v2939 = vsubq_s16(v2081, v2086); + int16x8_t v2940 = vsubq_s16(v2091, v2096); + int16x8_t v2941_tmp = vqrdmulhq_n_s16(v2940, 1893); + int16x8_t v2941 = vmlaq_n_s16(v2941_tmp, v2940, 2); + int16x8_t v2942 = vaddq_s16(v2939, v2941); + int16x8_t v2943 = vqrdmulhq_n_s16(v2942, 20783); + int16x8_t v2944 = vaddq_s16(v2938, v2943); + int16x8_t v2945 = vqrdmulhq_n_s16(v2944, 17326); + int16x8_t v2946 = vaddq_s16(v2934, v2945); + int16x8_t v2947 = vqrdmulhq_n_s16(v2946, 16611); + int16x8_t v2948 = vaddq_s16(v2924, v2947); + int16x8_t v2949 = vsubq_s16(v1543, v1554); + int16x8_t v2950 = vsubq_s16(v1565, v1576); + int16x8_t v2951_tmp = vqrdmulhq_n_s16(v2950, 13357); + int16x8_t v2951 = vmlaq_n_s16(v2951_tmp, v2950, 3); + int16x8_t v2952 = vaddq_s16(v2949, v2951); + int16x8_t v2953 = vsubq_s16(v1589, v1600); + int16x8_t v2954 = vsubq_s16(v1611, v1622); + int16x8_t v2955_tmp = vqrdmulhq_n_s16(v2954, 13357); + int16x8_t v2955 = vmlaq_n_s16(v2955_tmp, v2954, 3); + int16x8_t v2956 = vaddq_s16(v2953, v2955); + int16x8_t v2957 = vqrdmulhq_n_s16(v2956, 21637); + int16x8_t v2958 = vaddq_s16(v2952, v2957); + int16x8_t v2959 = vsubq_s16(v1637, v1648); + int16x8_t v2960 = vsubq_s16(v1659, v1670); + int16x8_t v2961_tmp = vqrdmulhq_n_s16(v2960, 13357); + int16x8_t v2961 = vmlaq_n_s16(v2961_tmp, v2960, 3); + int16x8_t v2962 = vaddq_s16(v2959, v2961); + int16x8_t v2963 = vsubq_s16(v1683, v1694); + int16x8_t v2964 = vsubq_s16(v1705, v1716); + int16x8_t v2965_tmp = vqrdmulhq_n_s16(v2964, 13357); + int16x8_t v2965 = vmlaq_n_s16(v2965_tmp, v2964, 3); + int16x8_t v2966 = vaddq_s16(v2963, v2965); + int16x8_t v2967 = vqrdmulhq_n_s16(v2966, 21637); + int16x8_t v2968 = vaddq_s16(v2962, v2967); + int16x8_t v2969 = vqrdmulhq_n_s16(v2968, 17479); + int16x8_t v2970 = vaddq_s16(v2958, v2969); + int16x8_t v2971 = vsubq_s16(v1733, v1744); + int16x8_t v2972 = vsubq_s16(v1755, v1766); + int16x8_t v2973_tmp = vqrdmulhq_n_s16(v2972, 13357); + int16x8_t v2973 = vmlaq_n_s16(v2973_tmp, v2972, 3); + int16x8_t v2974 = vaddq_s16(v2971, v2973); + int16x8_t v2975 = vsubq_s16(v1779, v1790); + int16x8_t v2976 = vsubq_s16(v1801, v1812); + int16x8_t v2977_tmp = vqrdmulhq_n_s16(v2976, 13357); + int16x8_t v2977 = vmlaq_n_s16(v2977_tmp, v2976, 3); + int16x8_t v2978 = vaddq_s16(v2975, v2977); + int16x8_t v2979 = vqrdmulhq_n_s16(v2978, 21637); + int16x8_t v2980 = vaddq_s16(v2974, v2979); + int16x8_t v2981 = vsubq_s16(v1827, v1838); + int16x8_t v2982 = vsubq_s16(v1849, v1860); + int16x8_t v2983_tmp = vqrdmulhq_n_s16(v2982, 13357); + int16x8_t v2983 = vmlaq_n_s16(v2983_tmp, v2982, 3); + int16x8_t v2984 = vaddq_s16(v2981, v2983); + int16x8_t v2985 = vsubq_s16(v1873, v1884); + int16x8_t v2986 = vsubq_s16(v1895, v1906); + int16x8_t v2987_tmp = vqrdmulhq_n_s16(v2986, 13357); + int16x8_t v2987 = vmlaq_n_s16(v2987_tmp, v2986, 3); + int16x8_t v2988 = vaddq_s16(v2985, v2987); + int16x8_t v2989 = vqrdmulhq_n_s16(v2988, 21637); + int16x8_t v2990 = vaddq_s16(v2984, v2989); + int16x8_t v2991 = vqrdmulhq_n_s16(v2990, 17479); + int16x8_t v2992 = vaddq_s16(v2980, v2991); + int16x8_t v2993 = vqrdmulhq_n_s16(v2992, 16647); + int16x8_t v2994 = vaddq_s16(v2970, v2993); + int16x8_t v2995 = vsubq_s16(v25, v60); + int16x8_t v2996 = vsubq_s16(v102, v138); + int16x8_t v2997_tmp = vqrdmulhq_n_s16(v2996, 6226); + int16x8_t v2997 = vmlaq_n_s16(v2997_tmp, v2996, 10); + int16x8_t v2998 = vaddq_s16(v2995, v2997); + int16x8_t v2999 = vsubq_s16(v182, v233); + int16x8_t v3000 = vsubq_s16(v275, v312); + int16x8_t v3001_tmp = vqrdmulhq_n_s16(v3000, 6226); + int16x8_t v3001 = vmlaq_n_s16(v3001_tmp, v3000, 10); + int16x8_t v3002 = vaddq_s16(v2999, v3001); + int16x8_t v3003 = vqrdmulhq_n_s16(v3002, 22622); + int16x8_t v3004 = vaddq_s16(v2998, v3003); + int16x8_t v3005 = vsubq_s16(v358, v409); + int16x8_t v3006 = vsubq_s16(v481, v519); + int16x8_t v3007_tmp = vqrdmulhq_n_s16(v3006, 6226); + int16x8_t v3007 = vmlaq_n_s16(v3007_tmp, v3006, 10); + int16x8_t v3008 = vaddq_s16(v3005, v3007); + int16x8_t v3009 = vsubq_s16(v563, v614); + int16x8_t v3010 = vsubq_s16(v656, v694); + int16x8_t v3011_tmp = vqrdmulhq_n_s16(v3010, 6226); + int16x8_t v3011 = vmlaq_n_s16(v3011_tmp, v3010, 10); + int16x8_t v3012 = vaddq_s16(v3009, v3011); + int16x8_t v3013 = vqrdmulhq_n_s16(v3012, 22622); + int16x8_t v3014 = vaddq_s16(v3008, v3013); + int16x8_t v3015 = vqrdmulhq_n_s16(v3014, 17646); + int16x8_t v3016 = vaddq_s16(v3004, v3015); + int16x8_t v3017 = vsubq_s16(v742, v793); + int16x8_t v3018 = vsubq_s16(v865, v903); + int16x8_t v3019_tmp = vqrdmulhq_n_s16(v3018, 6226); + int16x8_t v3019 = vmlaq_n_s16(v3019_tmp, v3018, 10); + int16x8_t v3020 = vaddq_s16(v3017, v3019); + int16x8_t v3021 = vsubq_s16(v977, v1060); + int16x8_t v3022 = vsubq_s16(v1102, v1141); + int16x8_t v3023_tmp = vqrdmulhq_n_s16(v3022, 6226); + int16x8_t v3023 = vmlaq_n_s16(v3023_tmp, v3022, 10); + int16x8_t v3024 = vaddq_s16(v3021, v3023); + int16x8_t v3025 = vqrdmulhq_n_s16(v3024, 22622); + int16x8_t v3026 = vaddq_s16(v3020, v3025); + int16x8_t v3027 = vsubq_s16(v1187, v1238); + int16x8_t v3028 = vsubq_s16(v1310, v1348); + int16x8_t v3029_tmp = vqrdmulhq_n_s16(v3028, 6226); + int16x8_t v3029 = vmlaq_n_s16(v3029_tmp, v3028, 10); + int16x8_t v3030 = vaddq_s16(v3027, v3029); + int16x8_t v3031 = vsubq_s16(v1392, v1443); + int16x8_t v3032 = vsubq_s16(v1485, v1524); + int16x8_t v3033_tmp = vqrdmulhq_n_s16(v3032, 6226); + int16x8_t v3033 = vmlaq_n_s16(v3033_tmp, v3032, 10); + int16x8_t v3034 = vaddq_s16(v3031, v3033); + int16x8_t v3035 = vqrdmulhq_n_s16(v3034, 22622); + int16x8_t v3036 = vaddq_s16(v3030, v3035); + int16x8_t v3037 = vqrdmulhq_n_s16(v3036, 17646); + int16x8_t v3038 = vaddq_s16(v3026, v3037); + int16x8_t v3039 = vqrdmulhq_n_s16(v3038, 16685); + int16x8_t v3040 = vaddq_s16(v3016, v3039); + int16x8_t v3041 = vsubq_s16(v2995, v2997); + int16x8_t v3042 = vsubq_s16(v2999, v3001); + int16x8_t v3043 = vqrdmulhq_n_s16(v3042, 23761); + int16x8_t v3044 = vaddq_s16(v3041, v3043); + int16x8_t v3045 = vsubq_s16(v3005, v3007); + int16x8_t v3046 = vsubq_s16(v3009, v3011); + int16x8_t v3047 = vqrdmulhq_n_s16(v3046, 23761); + int16x8_t v3048 = vaddq_s16(v3045, v3047); + int16x8_t v3049 = vqrdmulhq_n_s16(v3048, 17826); + int16x8_t v3050 = vaddq_s16(v3044, v3049); + int16x8_t v3051 = vsubq_s16(v3017, v3019); + int16x8_t v3052 = vsubq_s16(v3021, v3023); + int16x8_t v3053 = vqrdmulhq_n_s16(v3052, 23761); + int16x8_t v3054 = vaddq_s16(v3051, v3053); + int16x8_t v3055 = vsubq_s16(v3027, v3029); + int16x8_t v3056 = vsubq_s16(v3031, v3033); + int16x8_t v3057 = vqrdmulhq_n_s16(v3056, 23761); + int16x8_t v3058 = vaddq_s16(v3055, v3057); + int16x8_t v3059 = vqrdmulhq_n_s16(v3058, 17826); + int16x8_t v3060 = vaddq_s16(v3054, v3059); + int16x8_t v3061 = vqrdmulhq_n_s16(v3060, 16726); + int16x8_t v3062 = vaddq_s16(v3050, v3061); + int16x8_t v3063 = vsubq_s16(v2949, v2951); + int16x8_t v3064 = vsubq_s16(v2953, v2955); + int16x8_t v3065 = vqrdmulhq_n_s16(v3064, 25084); + int16x8_t v3066 = vaddq_s16(v3063, v3065); + int16x8_t v3067 = vsubq_s16(v2959, v2961); + int16x8_t v3068 = vsubq_s16(v2963, v2965); + int16x8_t v3069 = vqrdmulhq_n_s16(v3068, 25084); + int16x8_t v3070 = vaddq_s16(v3067, v3069); + int16x8_t v3071 = vqrdmulhq_n_s16(v3070, 18021); + int16x8_t v3072 = vaddq_s16(v3066, v3071); + int16x8_t v3073 = vsubq_s16(v2971, v2973); + int16x8_t v3074 = vsubq_s16(v2975, v2977); + int16x8_t v3075 = vqrdmulhq_n_s16(v3074, 25084); + int16x8_t v3076 = vaddq_s16(v3073, v3075); + int16x8_t v3077 = vsubq_s16(v2981, v2983); + int16x8_t v3078 = vsubq_s16(v2985, v2987); + int16x8_t v3079 = vqrdmulhq_n_s16(v3078, 25084); + int16x8_t v3080 = vaddq_s16(v3077, v3079); + int16x8_t v3081 = vqrdmulhq_n_s16(v3080, 18021); + int16x8_t v3082 = vaddq_s16(v3076, v3081); + int16x8_t v3083 = vqrdmulhq_n_s16(v3082, 16769); + int16x8_t v3084 = vaddq_s16(v3072, v3083); + int16x8_t v3085 = vsubq_s16(v2903, v2905); + int16x8_t v3086 = vsubq_s16(v2907, v2909); + int16x8_t v3087 = vqrdmulhq_n_s16(v3086, 26631); + int16x8_t v3088 = vaddq_s16(v3085, v3087); + int16x8_t v3089 = vsubq_s16(v2913, v2915); + int16x8_t v3090 = vsubq_s16(v2917, v2919); + int16x8_t v3091 = vqrdmulhq_n_s16(v3090, 26631); + int16x8_t v3092 = vaddq_s16(v3089, v3091); + int16x8_t v3093 = vqrdmulhq_n_s16(v3092, 18231); + int16x8_t v3094 = vaddq_s16(v3088, v3093); + int16x8_t v3095 = vsubq_s16(v2925, v2927); + int16x8_t v3096 = vsubq_s16(v2929, v2931); + int16x8_t v3097 = vqrdmulhq_n_s16(v3096, 26631); + int16x8_t v3098 = vaddq_s16(v3095, v3097); + int16x8_t v3099 = vsubq_s16(v2935, v2937); + int16x8_t v3100 = vsubq_s16(v2939, v2941); + int16x8_t v3101 = vqrdmulhq_n_s16(v3100, 26631); + int16x8_t v3102 = vaddq_s16(v3099, v3101); + int16x8_t v3103 = vqrdmulhq_n_s16(v3102, 18231); + int16x8_t v3104 = vaddq_s16(v3098, v3103); + int16x8_t v3105 = vqrdmulhq_n_s16(v3104, 16815); + int16x8_t v3106 = vaddq_s16(v3094, v3105); + int16x8_t v3107 = vsubq_s16(v2857, v2859); + int16x8_t v3108 = vsubq_s16(v2861, v2863); + int16x8_t v3109 = vqrdmulhq_n_s16(v3108, 28454); + int16x8_t v3110 = vaddq_s16(v3107, v3109); + int16x8_t v3111 = vsubq_s16(v2867, v2869); + int16x8_t v3112 = vsubq_s16(v2871, v2873); + int16x8_t v3113 = vqrdmulhq_n_s16(v3112, 28454); + int16x8_t v3114 = vaddq_s16(v3111, v3113); + int16x8_t v3115 = vqrdmulhq_n_s16(v3114, 18458); + int16x8_t v3116 = vaddq_s16(v3110, v3115); + int16x8_t v3117 = vsubq_s16(v2879, v2881); + int16x8_t v3118 = vsubq_s16(v2883, v2885); + int16x8_t v3119 = vqrdmulhq_n_s16(v3118, 28454); + int16x8_t v3120 = vaddq_s16(v3117, v3119); + int16x8_t v3121 = vsubq_s16(v2889, v2891); + int16x8_t v3122 = vsubq_s16(v2893, v2895); + int16x8_t v3123 = vqrdmulhq_n_s16(v3122, 28454); + int16x8_t v3124 = vaddq_s16(v3121, v3123); + int16x8_t v3125 = vqrdmulhq_n_s16(v3124, 18458); + int16x8_t v3126 = vaddq_s16(v3120, v3125); + int16x8_t v3127 = vqrdmulhq_n_s16(v3126, 16865); + int16x8_t v3128 = vaddq_s16(v3116, v3127); + int16x8_t v3129 = vsubq_s16(v2811, v2813); + int16x8_t v3130 = vsubq_s16(v2815, v2817); + int16x8_t v3131 = vqrdmulhq_n_s16(v3130, 30624); + int16x8_t v3132 = vaddq_s16(v3129, v3131); + int16x8_t v3133 = vsubq_s16(v2821, v2823); + int16x8_t v3134 = vsubq_s16(v2825, v2827); + int16x8_t v3135 = vqrdmulhq_n_s16(v3134, 30624); + int16x8_t v3136 = vaddq_s16(v3133, v3135); + int16x8_t v3137 = vqrdmulhq_n_s16(v3136, 18702); + int16x8_t v3138 = vaddq_s16(v3132, v3137); + int16x8_t v3139 = vsubq_s16(v2833, v2835); + int16x8_t v3140 = vsubq_s16(v2837, v2839); + int16x8_t v3141 = vqrdmulhq_n_s16(v3140, 30624); + int16x8_t v3142 = vaddq_s16(v3139, v3141); + int16x8_t v3143 = vsubq_s16(v2843, v2845); + int16x8_t v3144 = vsubq_s16(v2847, v2849); + int16x8_t v3145 = vqrdmulhq_n_s16(v3144, 30624); + int16x8_t v3146 = vaddq_s16(v3143, v3145); + int16x8_t v3147 = vqrdmulhq_n_s16(v3146, 18702); + int16x8_t v3148 = vaddq_s16(v3142, v3147); + int16x8_t v3149 = vqrdmulhq_n_s16(v3148, 16916); + int16x8_t v3150 = vaddq_s16(v3138, v3149); + int16x8_t v3151 = vsubq_s16(v2765, v2767); + int16x8_t v3152 = vsubq_s16(v2769, v2771); + int16x8_t v3153_tmp = vqrdmulhq_n_s16(v3152, 472); + int16x8_t v3153 = vaddq_s16(v3153_tmp, v3152); + int16x8_t v3154 = vaddq_s16(v3151, v3153); + int16x8_t v3155 = vsubq_s16(v2775, v2777); + int16x8_t v3156 = vsubq_s16(v2779, v2781); + int16x8_t v3157_tmp = vqrdmulhq_n_s16(v3156, 472); + int16x8_t v3157 = vaddq_s16(v3157_tmp, v3156); + int16x8_t v3158 = vaddq_s16(v3155, v3157); + int16x8_t v3159 = vqrdmulhq_n_s16(v3158, 18964); + int16x8_t v3160 = vaddq_s16(v3154, v3159); + int16x8_t v3161 = vsubq_s16(v2787, v2789); + int16x8_t v3162 = vsubq_s16(v2791, v2793); + int16x8_t v3163_tmp = vqrdmulhq_n_s16(v3162, 472); + int16x8_t v3163 = vaddq_s16(v3163_tmp, v3162); + int16x8_t v3164 = vaddq_s16(v3161, v3163); + int16x8_t v3165 = vsubq_s16(v2797, v2799); + int16x8_t v3166 = vsubq_s16(v2801, v2803); + int16x8_t v3167_tmp = vqrdmulhq_n_s16(v3166, 472); + int16x8_t v3167 = vaddq_s16(v3167_tmp, v3166); + int16x8_t v3168 = vaddq_s16(v3165, v3167); + int16x8_t v3169 = vqrdmulhq_n_s16(v3168, 18964); + int16x8_t v3170 = vaddq_s16(v3164, v3169); + int16x8_t v3171 = vqrdmulhq_n_s16(v3170, 16971); + int16x8_t v3172 = vaddq_s16(v3160, v3171); + int16x8_t v3173 = vsubq_s16(v2719, v2721); + int16x8_t v3174 = vsubq_s16(v2723, v2725); + int16x8_t v3175_tmp = vqrdmulhq_n_s16(v3174, 3672); + int16x8_t v3175 = vaddq_s16(v3175_tmp, v3174); + int16x8_t v3176 = vaddq_s16(v3173, v3175); + int16x8_t v3177 = vsubq_s16(v2729, v2731); + int16x8_t v3178 = vsubq_s16(v2733, v2735); + int16x8_t v3179_tmp = vqrdmulhq_n_s16(v3178, 3672); + int16x8_t v3179 = vaddq_s16(v3179_tmp, v3178); + int16x8_t v3180 = vaddq_s16(v3177, v3179); + int16x8_t v3181 = vqrdmulhq_n_s16(v3180, 19245); + int16x8_t v3182 = vaddq_s16(v3176, v3181); + int16x8_t v3183 = vsubq_s16(v2741, v2743); + int16x8_t v3184 = vsubq_s16(v2745, v2747); + int16x8_t v3185_tmp = vqrdmulhq_n_s16(v3184, 3672); + int16x8_t v3185 = vaddq_s16(v3185_tmp, v3184); + int16x8_t v3186 = vaddq_s16(v3183, v3185); + int16x8_t v3187 = vsubq_s16(v2751, v2753); + int16x8_t v3188 = vsubq_s16(v2755, v2757); + int16x8_t v3189_tmp = vqrdmulhq_n_s16(v3188, 3672); + int16x8_t v3189 = vaddq_s16(v3189_tmp, v3188); + int16x8_t v3190 = vaddq_s16(v3187, v3189); + int16x8_t v3191 = vqrdmulhq_n_s16(v3190, 19245); + int16x8_t v3192 = vaddq_s16(v3186, v3191); + int16x8_t v3193 = vqrdmulhq_n_s16(v3192, 17029); + int16x8_t v3194 = vaddq_s16(v3182, v3193); + int16x8_t v3195 = vsubq_s16(v2673, v2675); + int16x8_t v3196 = vsubq_s16(v2677, v2679); + int16x8_t v3197_tmp = vqrdmulhq_n_s16(v3196, 7662); + int16x8_t v3197 = vaddq_s16(v3197_tmp, v3196); + int16x8_t v3198 = vaddq_s16(v3195, v3197); + int16x8_t v3199 = vsubq_s16(v2683, v2685); + int16x8_t v3200 = vsubq_s16(v2687, v2689); + int16x8_t v3201_tmp = vqrdmulhq_n_s16(v3200, 7662); + int16x8_t v3201 = vaddq_s16(v3201_tmp, v3200); + int16x8_t v3202 = vaddq_s16(v3199, v3201); + int16x8_t v3203 = vqrdmulhq_n_s16(v3202, 19546); + int16x8_t v3204 = vaddq_s16(v3198, v3203); + int16x8_t v3205 = vsubq_s16(v2695, v2697); + int16x8_t v3206 = vsubq_s16(v2699, v2701); + int16x8_t v3207_tmp = vqrdmulhq_n_s16(v3206, 7662); + int16x8_t v3207 = vaddq_s16(v3207_tmp, v3206); + int16x8_t v3208 = vaddq_s16(v3205, v3207); + int16x8_t v3209 = vsubq_s16(v2705, v2707); + int16x8_t v3210 = vsubq_s16(v2709, v2711); + int16x8_t v3211_tmp = vqrdmulhq_n_s16(v3210, 7662); + int16x8_t v3211 = vaddq_s16(v3211_tmp, v3210); + int16x8_t v3212 = vaddq_s16(v3209, v3211); + int16x8_t v3213 = vqrdmulhq_n_s16(v3212, 19546); + int16x8_t v3214 = vaddq_s16(v3208, v3213); + int16x8_t v3215 = vqrdmulhq_n_s16(v3214, 17090); + int16x8_t v3216 = vaddq_s16(v3204, v3215); + int16x8_t v3217 = vsubq_s16(v2582, v2587); + int16x8_t v3218 = vsubq_s16(v2592, v2597); + int16x8_t v3219_tmp = vqrdmulhq_n_s16(v3218, 12756); + int16x8_t v3219 = vaddq_s16(v3219_tmp, v3218); + int16x8_t v3220 = vaddq_s16(v3217, v3219); + int16x8_t v3221 = vsubq_s16(v2604, v2609); + int16x8_t v3222 = vsubq_s16(v2614, v2619); + int16x8_t v3223_tmp = vqrdmulhq_n_s16(v3222, 12756); + int16x8_t v3223 = vaddq_s16(v3223_tmp, v3222); + int16x8_t v3224 = vaddq_s16(v3221, v3223); + int16x8_t v3225 = vqrdmulhq_n_s16(v3224, 19869); + int16x8_t v3226 = vaddq_s16(v3220, v3225); + int16x8_t v3227 = vsubq_s16(v2628, v2633); + int16x8_t v3228 = vsubq_s16(v2638, v2643); + int16x8_t v3229_tmp = vqrdmulhq_n_s16(v3228, 12756); + int16x8_t v3229 = vaddq_s16(v3229_tmp, v3228); + int16x8_t v3230 = vaddq_s16(v3227, v3229); + int16x8_t v3231 = vsubq_s16(v2650, v2655); + int16x8_t v3232 = vsubq_s16(v2660, v2665); + int16x8_t v3233_tmp = vqrdmulhq_n_s16(v3232, 12756); + int16x8_t v3233 = vaddq_s16(v3233_tmp, v3232); + int16x8_t v3234 = vaddq_s16(v3231, v3233); + int16x8_t v3235 = vqrdmulhq_n_s16(v3234, 19869); + int16x8_t v3236 = vaddq_s16(v3230, v3235); + int16x8_t v3237 = vqrdmulhq_n_s16(v3236, 17153); + int16x8_t v3238 = vaddq_s16(v3226, v3237); + int16x8_t v3239 = vsubq_s16(v2488, v2493); + int16x8_t v3240 = vsubq_s16(v2498, v2503); + int16x8_t v3241_tmp = vqrdmulhq_n_s16(v3240, 19463); + int16x8_t v3241 = vaddq_s16(v3241_tmp, v3240); + int16x8_t v3242 = vaddq_s16(v3239, v3241); + int16x8_t v3243 = vsubq_s16(v2510, v2515); + int16x8_t v3244 = vsubq_s16(v2520, v2525); + int16x8_t v3245_tmp = vqrdmulhq_n_s16(v3244, 19463); + int16x8_t v3245 = vaddq_s16(v3245_tmp, v3244); + int16x8_t v3246 = vaddq_s16(v3243, v3245); + int16x8_t v3247 = vqrdmulhq_n_s16(v3246, 20216); + int16x8_t v3248 = vaddq_s16(v3242, v3247); + int16x8_t v3249 = vsubq_s16(v2534, v2539); + int16x8_t v3250 = vsubq_s16(v2544, v2549); + int16x8_t v3251_tmp = vqrdmulhq_n_s16(v3250, 19463); + int16x8_t v3251 = vaddq_s16(v3251_tmp, v3250); + int16x8_t v3252 = vaddq_s16(v3249, v3251); + int16x8_t v3253 = vsubq_s16(v2556, v2561); + int16x8_t v3254 = vsubq_s16(v2566, v2571); + int16x8_t v3255_tmp = vqrdmulhq_n_s16(v3254, 19463); + int16x8_t v3255 = vaddq_s16(v3255_tmp, v3254); + int16x8_t v3256 = vaddq_s16(v3253, v3255); + int16x8_t v3257 = vqrdmulhq_n_s16(v3256, 20216); + int16x8_t v3258 = vaddq_s16(v3252, v3257); + int16x8_t v3259 = vqrdmulhq_n_s16(v3258, 17220); + int16x8_t v3260 = vaddq_s16(v3248, v3259); + int16x8_t v3261 = vsubq_s16(v2393, v2398); + int16x8_t v3262 = vsubq_s16(v2403, v2408); + int16x8_t v3263_tmp = vqrdmulhq_n_s16(v3262, 28661); + int16x8_t v3263 = vaddq_s16(v3263_tmp, v3262); + int16x8_t v3264 = vaddq_s16(v3261, v3263); + int16x8_t v3265 = vsubq_s16(v2415, v2420); + int16x8_t v3266 = vsubq_s16(v2425, v2430); + int16x8_t v3267_tmp = vqrdmulhq_n_s16(v3266, 28661); + int16x8_t v3267 = vaddq_s16(v3267_tmp, v3266); + int16x8_t v3268 = vaddq_s16(v3265, v3267); + int16x8_t v3269 = vqrdmulhq_n_s16(v3268, 20587); + int16x8_t v3270 = vaddq_s16(v3264, v3269); + int16x8_t v3271 = vsubq_s16(v2439, v2444); + int16x8_t v3272 = vsubq_s16(v2449, v2454); + int16x8_t v3273_tmp = vqrdmulhq_n_s16(v3272, 28661); + int16x8_t v3273 = vaddq_s16(v3273_tmp, v3272); + int16x8_t v3274 = vaddq_s16(v3271, v3273); + int16x8_t v3275 = vsubq_s16(v2461, v2467); + int16x8_t v3276 = vsubq_s16(v2472, v2477); + int16x8_t v3277_tmp = vqrdmulhq_n_s16(v3276, 28661); + int16x8_t v3277 = vaddq_s16(v3277_tmp, v3276); + int16x8_t v3278 = vaddq_s16(v3275, v3277); + int16x8_t v3279 = vqrdmulhq_n_s16(v3278, 20587); + int16x8_t v3280 = vaddq_s16(v3274, v3279); + int16x8_t v3281 = vqrdmulhq_n_s16(v3280, 17290); + int16x8_t v3282 = vaddq_s16(v3270, v3281); + int16x8_t v3283 = vsubq_s16(v2299, v2304); + int16x8_t v3284 = vsubq_s16(v2309, v2314); + int16x8_t v3285_tmp = vqrdmulhq_n_s16(v3284, 9242); + int16x8_t v3285 = vmlaq_n_s16(v3285_tmp, v3284, 2); + int16x8_t v3286 = vaddq_s16(v3283, v3285); + int16x8_t v3287 = vsubq_s16(v2321, v2326); + int16x8_t v3288 = vsubq_s16(v2331, v2336); + int16x8_t v3289_tmp = vqrdmulhq_n_s16(v3288, 9242); + int16x8_t v3289 = vmlaq_n_s16(v3289_tmp, v3288, 2); + int16x8_t v3290 = vaddq_s16(v3287, v3289); + int16x8_t v3291 = vqrdmulhq_n_s16(v3290, 20985); + int16x8_t v3292 = vaddq_s16(v3286, v3291); + int16x8_t v3293 = vsubq_s16(v2345, v2350); + int16x8_t v3294 = vsubq_s16(v2355, v2360); + int16x8_t v3295_tmp = vqrdmulhq_n_s16(v3294, 9242); + int16x8_t v3295 = vmlaq_n_s16(v3295_tmp, v3294, 2); + int16x8_t v3296 = vaddq_s16(v3293, v3295); + int16x8_t v3297 = vsubq_s16(v2367, v2372); + int16x8_t v3298 = vsubq_s16(v2377, v2382); + int16x8_t v3299_tmp = vqrdmulhq_n_s16(v3298, 9242); + int16x8_t v3299 = vmlaq_n_s16(v3299_tmp, v3298, 2); + int16x8_t v3300 = vaddq_s16(v3297, v3299); + int16x8_t v3301 = vqrdmulhq_n_s16(v3300, 20985); + int16x8_t v3302 = vaddq_s16(v3296, v3301); + int16x8_t v3303 = vqrdmulhq_n_s16(v3302, 17363); + int16x8_t v3304 = vaddq_s16(v3292, v3303); + int16x8_t v3305 = vsubq_s16(v2115, v2126); + int16x8_t v3306 = vsubq_s16(v2137, v2148); + int16x8_t v3307_tmp = vqrdmulhq_n_s16(v3306, 30298); + int16x8_t v3307 = vmlaq_n_s16(v3307_tmp, v3306, 2); + int16x8_t v3308 = vaddq_s16(v3305, v3307); + int16x8_t v3309 = vsubq_s16(v2161, v2172); + int16x8_t v3310 = vsubq_s16(v2183, v2194); + int16x8_t v3311_tmp = vqrdmulhq_n_s16(v3310, 30298); + int16x8_t v3311 = vmlaq_n_s16(v3311_tmp, v3310, 2); + int16x8_t v3312 = vaddq_s16(v3309, v3311); + int16x8_t v3313 = vqrdmulhq_n_s16(v3312, 21412); + int16x8_t v3314 = vaddq_s16(v3308, v3313); + int16x8_t v3315 = vsubq_s16(v2209, v2220); + int16x8_t v3316 = vsubq_s16(v2231, v2242); + int16x8_t v3317_tmp = vqrdmulhq_n_s16(v3316, 30298); + int16x8_t v3317 = vmlaq_n_s16(v3317_tmp, v3316, 2); + int16x8_t v3318 = vaddq_s16(v3315, v3317); + int16x8_t v3319 = vsubq_s16(v2255, v2266); + int16x8_t v3320 = vsubq_s16(v2277, v2288); + int16x8_t v3321_tmp = vqrdmulhq_n_s16(v3320, 30298); + int16x8_t v3321 = vmlaq_n_s16(v3321_tmp, v3320, 2); + int16x8_t v3322 = vaddq_s16(v3319, v3321); + int16x8_t v3323 = vqrdmulhq_n_s16(v3322, 21412); + int16x8_t v3324 = vaddq_s16(v3318, v3323); + int16x8_t v3325 = vqrdmulhq_n_s16(v3324, 17440); + int16x8_t v3326 = vaddq_s16(v3314, v3325); + int16x8_t v3327 = vsubq_s16(v1925, v1936); + int16x8_t v3328 = vsubq_s16(v1947, v1958); + int16x8_t v3329_tmp = vqrdmulhq_n_s16(v3328, 2773); + int16x8_t v3329 = vmlaq_n_s16(v3329_tmp, v3328, 4); + int16x8_t v3330 = vaddq_s16(v3327, v3329); + int16x8_t v3331 = vsubq_s16(v1971, v1982); + int16x8_t v3332 = vsubq_s16(v1993, v2004); + int16x8_t v3333_tmp = vqrdmulhq_n_s16(v3332, 2773); + int16x8_t v3333 = vmlaq_n_s16(v3333_tmp, v3332, 4); + int16x8_t v3334 = vaddq_s16(v3331, v3333); + int16x8_t v3335 = vqrdmulhq_n_s16(v3334, 21871); + int16x8_t v3336 = vaddq_s16(v3330, v3335); + int16x8_t v3337 = vsubq_s16(v2019, v2030); + int16x8_t v3338 = vsubq_s16(v2041, v2052); + int16x8_t v3339_tmp = vqrdmulhq_n_s16(v3338, 2773); + int16x8_t v3339 = vmlaq_n_s16(v3339_tmp, v3338, 4); + int16x8_t v3340 = vaddq_s16(v3337, v3339); + int16x8_t v3341 = vsubq_s16(v2065, v2076); + int16x8_t v3342 = vsubq_s16(v2087, v2098); + int16x8_t v3343_tmp = vqrdmulhq_n_s16(v3342, 2773); + int16x8_t v3343 = vmlaq_n_s16(v3343_tmp, v3342, 4); + int16x8_t v3344 = vaddq_s16(v3341, v3343); + int16x8_t v3345 = vqrdmulhq_n_s16(v3344, 21871); + int16x8_t v3346 = vaddq_s16(v3340, v3345); + int16x8_t v3347 = vqrdmulhq_n_s16(v3346, 17520); + int16x8_t v3348 = vaddq_s16(v3336, v3347); + int16x8_t v3349 = vsubq_s16(v1555, v1578); + int16x8_t v3350 = vsubq_s16(v1601, v1624); + int16x8_t v3351_tmp = vqrdmulhq_n_s16(v3350, 26108); + int16x8_t v3351 = vmlaq_n_s16(v3351_tmp, v3350, 6); + int16x8_t v3352 = vaddq_s16(v3349, v3351); + int16x8_t v3353 = vsubq_s16(v1649, v1672); + int16x8_t v3354 = vsubq_s16(v1695, v1718); + int16x8_t v3355_tmp = vqrdmulhq_n_s16(v3354, 26108); + int16x8_t v3355 = vmlaq_n_s16(v3355_tmp, v3354, 6); + int16x8_t v3356 = vaddq_s16(v3353, v3355); + int16x8_t v3357 = vqrdmulhq_n_s16(v3356, 22363); + int16x8_t v3358 = vaddq_s16(v3352, v3357); + int16x8_t v3359 = vsubq_s16(v1745, v1768); + int16x8_t v3360 = vsubq_s16(v1791, v1814); + int16x8_t v3361_tmp = vqrdmulhq_n_s16(v3360, 26108); + int16x8_t v3361 = vmlaq_n_s16(v3361_tmp, v3360, 6); + int16x8_t v3362 = vaddq_s16(v3359, v3361); + int16x8_t v3363 = vsubq_s16(v1839, v1862); + int16x8_t v3364 = vsubq_s16(v1885, v1908); + int16x8_t v3365_tmp = vqrdmulhq_n_s16(v3364, 26108); + int16x8_t v3365 = vmlaq_n_s16(v3365_tmp, v3364, 6); + int16x8_t v3366 = vaddq_s16(v3363, v3365); + int16x8_t v3367 = vqrdmulhq_n_s16(v3366, 22363); + int16x8_t v3368 = vaddq_s16(v3362, v3367); + int16x8_t v3369 = vqrdmulhq_n_s16(v3368, 17603); + int16x8_t v3370 = vaddq_s16(v3358, v3369); + int16x8_t v3371 = vsubq_s16(v61, v140); + int16x8_t v3372 = vsubq_s16(v234, v314); + int16x8_t v3373_tmp = vqrdmulhq_n_s16(v3372, 12251); + int16x8_t v3373 = vmlaq_n_s16(v3373_tmp, v3372, 20); + int16x8_t v3374 = vaddq_s16(v3371, v3373); + int16x8_t v3375 = vsubq_s16(v410, v521); + int16x8_t v3376 = vsubq_s16(v615, v696); + int16x8_t v3377_tmp = vqrdmulhq_n_s16(v3376, 12251); + int16x8_t v3377 = vmlaq_n_s16(v3377_tmp, v3376, 20); + int16x8_t v3378 = vaddq_s16(v3375, v3377); + int16x8_t v3379 = vqrdmulhq_n_s16(v3378, 22891); + int16x8_t v3380 = vaddq_s16(v3374, v3379); + int16x8_t v3381 = vsubq_s16(v794, v905); + int16x8_t v3382 = vsubq_s16(v1061, v1143); + int16x8_t v3383_tmp = vqrdmulhq_n_s16(v3382, 12251); + int16x8_t v3383 = vmlaq_n_s16(v3383_tmp, v3382, 20); + int16x8_t v3384 = vaddq_s16(v3381, v3383); + int16x8_t v3385 = vsubq_s16(v1239, v1350); + int16x8_t v3386 = vsubq_s16(v1444, v1526); + int16x8_t v3387_tmp = vqrdmulhq_n_s16(v3386, 12251); + int16x8_t v3387 = vmlaq_n_s16(v3387_tmp, v3386, 20); + int16x8_t v3388 = vaddq_s16(v3385, v3387); + int16x8_t v3389 = vqrdmulhq_n_s16(v3388, 22891); + int16x8_t v3390 = vaddq_s16(v3384, v3389); + int16x8_t v3391 = vqrdmulhq_n_s16(v3390, 17689); + int16x8_t v3392 = vaddq_s16(v3380, v3391); + int16x8_t v3393 = vsubq_s16(v3371, v3373); + int16x8_t v3394 = vsubq_s16(v3375, v3377); + int16x8_t v3395 = vqrdmulhq_n_s16(v3394, 23460); + int16x8_t v3396 = vaddq_s16(v3393, v3395); + int16x8_t v3397 = vsubq_s16(v3381, v3383); + int16x8_t v3398 = vsubq_s16(v3385, v3387); + int16x8_t v3399 = vqrdmulhq_n_s16(v3398, 23460); + int16x8_t v3400 = vaddq_s16(v3397, v3399); + int16x8_t v3401 = vqrdmulhq_n_s16(v3400, 17779); + int16x8_t v3402 = vaddq_s16(v3396, v3401); + int16x8_t v3403 = vsubq_s16(v3349, v3351); + int16x8_t v3404 = vsubq_s16(v3353, v3355); + int16x8_t v3405 = vqrdmulhq_n_s16(v3404, 24073); + int16x8_t v3406 = vaddq_s16(v3403, v3405); + int16x8_t v3407 = vsubq_s16(v3359, v3361); + int16x8_t v3408 = vsubq_s16(v3363, v3365); + int16x8_t v3409 = vqrdmulhq_n_s16(v3408, 24073); + int16x8_t v3410 = vaddq_s16(v3407, v3409); + int16x8_t v3411 = vqrdmulhq_n_s16(v3410, 17873); + int16x8_t v3412 = vaddq_s16(v3406, v3411); + int16x8_t v3413 = vsubq_s16(v3327, v3329); + int16x8_t v3414 = vsubq_s16(v3331, v3333); + int16x8_t v3415 = vqrdmulhq_n_s16(v3414, 24734); + int16x8_t v3416 = vaddq_s16(v3413, v3415); + int16x8_t v3417 = vsubq_s16(v3337, v3339); + int16x8_t v3418 = vsubq_s16(v3341, v3343); + int16x8_t v3419 = vqrdmulhq_n_s16(v3418, 24734); + int16x8_t v3420 = vaddq_s16(v3417, v3419); + int16x8_t v3421 = vqrdmulhq_n_s16(v3420, 17971); + int16x8_t v3422 = vaddq_s16(v3416, v3421); + int16x8_t v3423 = vsubq_s16(v3305, v3307); + int16x8_t v3424 = vsubq_s16(v3309, v3311); + int16x8_t v3425 = vqrdmulhq_n_s16(v3424, 25448); + int16x8_t v3426 = vaddq_s16(v3423, v3425); + int16x8_t v3427 = vsubq_s16(v3315, v3317); + int16x8_t v3428 = vsubq_s16(v3319, v3321); + int16x8_t v3429 = vqrdmulhq_n_s16(v3428, 25448); + int16x8_t v3430 = vaddq_s16(v3427, v3429); + int16x8_t v3431 = vqrdmulhq_n_s16(v3430, 18072); + int16x8_t v3432 = vaddq_s16(v3426, v3431); + int16x8_t v3433 = vsubq_s16(v3283, v3285); + int16x8_t v3434 = vsubq_s16(v3287, v3289); + int16x8_t v3435 = vqrdmulhq_n_s16(v3434, 26220); + int16x8_t v3436 = vaddq_s16(v3433, v3435); + int16x8_t v3437 = vsubq_s16(v3293, v3295); + int16x8_t v3438 = vsubq_s16(v3297, v3299); + int16x8_t v3439 = vqrdmulhq_n_s16(v3438, 26220); + int16x8_t v3440 = vaddq_s16(v3437, v3439); + int16x8_t v3441 = vqrdmulhq_n_s16(v3440, 18177); + int16x8_t v3442 = vaddq_s16(v3436, v3441); + int16x8_t v3443 = vsubq_s16(v3261, v3263); + int16x8_t v3444 = vsubq_s16(v3265, v3267); + int16x8_t v3445 = vqrdmulhq_n_s16(v3444, 27058); + int16x8_t v3446 = vaddq_s16(v3443, v3445); + int16x8_t v3447 = vsubq_s16(v3271, v3273); + int16x8_t v3448 = vsubq_s16(v3275, v3277); + int16x8_t v3449 = vqrdmulhq_n_s16(v3448, 27058); + int16x8_t v3450 = vaddq_s16(v3447, v3449); + int16x8_t v3451 = vqrdmulhq_n_s16(v3450, 18286); + int16x8_t v3452 = vaddq_s16(v3446, v3451); + int16x8_t v3453 = vsubq_s16(v3239, v3241); + int16x8_t v3454 = vsubq_s16(v3243, v3245); + int16x8_t v3455 = vqrdmulhq_n_s16(v3454, 27969); + int16x8_t v3456 = vaddq_s16(v3453, v3455); + int16x8_t v3457 = vsubq_s16(v3249, v3251); + int16x8_t v3458 = vsubq_s16(v3253, v3255); + int16x8_t v3459 = vqrdmulhq_n_s16(v3458, 27969); + int16x8_t v3460 = vaddq_s16(v3457, v3459); + int16x8_t v3461 = vqrdmulhq_n_s16(v3460, 18400); + int16x8_t v3462 = vaddq_s16(v3456, v3461); + int16x8_t v3463 = vsubq_s16(v3217, v3219); + int16x8_t v3464 = vsubq_s16(v3221, v3223); + int16x8_t v3465 = vqrdmulhq_n_s16(v3464, 28961); + int16x8_t v3466 = vaddq_s16(v3463, v3465); + int16x8_t v3467 = vsubq_s16(v3227, v3229); + int16x8_t v3468 = vsubq_s16(v3231, v3233); + int16x8_t v3469 = vqrdmulhq_n_s16(v3468, 28961); + int16x8_t v3470 = vaddq_s16(v3467, v3469); + int16x8_t v3471 = vqrdmulhq_n_s16(v3470, 18517); + int16x8_t v3472 = vaddq_s16(v3466, v3471); + int16x8_t v3473 = vsubq_s16(v3195, v3197); + int16x8_t v3474 = vsubq_s16(v3199, v3201); + int16x8_t v3475 = vqrdmulhq_n_s16(v3474, 30044); + int16x8_t v3476 = vaddq_s16(v3473, v3475); + int16x8_t v3477 = vsubq_s16(v3205, v3207); + int16x8_t v3478 = vsubq_s16(v3209, v3211); + int16x8_t v3479 = vqrdmulhq_n_s16(v3478, 30044); + int16x8_t v3480 = vaddq_s16(v3477, v3479); + int16x8_t v3481 = vqrdmulhq_n_s16(v3480, 18639); + int16x8_t v3482 = vaddq_s16(v3476, v3481); + int16x8_t v3483 = vsubq_s16(v3173, v3175); + int16x8_t v3484 = vsubq_s16(v3177, v3179); + int16x8_t v3485 = vqrdmulhq_n_s16(v3484, 31232); + int16x8_t v3486 = vaddq_s16(v3483, v3485); + int16x8_t v3487 = vsubq_s16(v3183, v3185); + int16x8_t v3488 = vsubq_s16(v3187, v3189); + int16x8_t v3489 = vqrdmulhq_n_s16(v3488, 31232); + int16x8_t v3490 = vaddq_s16(v3487, v3489); + int16x8_t v3491 = vqrdmulhq_n_s16(v3490, 18765); + int16x8_t v3492 = vaddq_s16(v3486, v3491); + int16x8_t v3493 = vsubq_s16(v3151, v3153); + int16x8_t v3494 = vsubq_s16(v3155, v3157); + int16x8_t v3495 = vqrdmulhq_n_s16(v3494, 32538); + int16x8_t v3496 = vaddq_s16(v3493, v3495); + int16x8_t v3497 = vsubq_s16(v3161, v3163); + int16x8_t v3498 = vsubq_s16(v3165, v3167); + int16x8_t v3499 = vqrdmulhq_n_s16(v3498, 32538); + int16x8_t v3500 = vaddq_s16(v3497, v3499); + int16x8_t v3501 = vqrdmulhq_n_s16(v3500, 18896); + int16x8_t v3502 = vaddq_s16(v3496, v3501); + int16x8_t v3503 = vsubq_s16(v3129, v3131); + int16x8_t v3504 = vsubq_s16(v3133, v3135); + int16x8_t v3505_tmp = vqrdmulhq_n_s16(v3504, 1211); + int16x8_t v3505 = vaddq_s16(v3505_tmp, v3504); + int16x8_t v3506 = vaddq_s16(v3503, v3505); + int16x8_t v3507 = vsubq_s16(v3139, v3141); + int16x8_t v3508 = vsubq_s16(v3143, v3145); + int16x8_t v3509_tmp = vqrdmulhq_n_s16(v3508, 1211); + int16x8_t v3509 = vaddq_s16(v3509_tmp, v3508); + int16x8_t v3510 = vaddq_s16(v3507, v3509); + int16x8_t v3511 = vqrdmulhq_n_s16(v3510, 19032); + int16x8_t v3512 = vaddq_s16(v3506, v3511); + int16x8_t v3513 = vsubq_s16(v3107, v3109); + int16x8_t v3514 = vsubq_s16(v3111, v3113); + int16x8_t v3515_tmp = vqrdmulhq_n_s16(v3514, 2808); + int16x8_t v3515 = vaddq_s16(v3515_tmp, v3514); + int16x8_t v3516 = vaddq_s16(v3513, v3515); + int16x8_t v3517 = vsubq_s16(v3117, v3119); + int16x8_t v3518 = vsubq_s16(v3121, v3123); + int16x8_t v3519_tmp = vqrdmulhq_n_s16(v3518, 2808); + int16x8_t v3519 = vaddq_s16(v3519_tmp, v3518); + int16x8_t v3520 = vaddq_s16(v3517, v3519); + int16x8_t v3521 = vqrdmulhq_n_s16(v3520, 19172); + int16x8_t v3522 = vaddq_s16(v3516, v3521); + int16x8_t v3523 = vsubq_s16(v3085, v3087); + int16x8_t v3524 = vsubq_s16(v3089, v3091); + int16x8_t v3525_tmp = vqrdmulhq_n_s16(v3524, 4586); + int16x8_t v3525 = vaddq_s16(v3525_tmp, v3524); + int16x8_t v3526 = vaddq_s16(v3523, v3525); + int16x8_t v3527 = vsubq_s16(v3095, v3097); + int16x8_t v3528 = vsubq_s16(v3099, v3101); + int16x8_t v3529_tmp = vqrdmulhq_n_s16(v3528, 4586); + int16x8_t v3529 = vaddq_s16(v3529_tmp, v3528); + int16x8_t v3530 = vaddq_s16(v3527, v3529); + int16x8_t v3531 = vqrdmulhq_n_s16(v3530, 19318); + int16x8_t v3532 = vaddq_s16(v3526, v3531); + int16x8_t v3533 = vsubq_s16(v3063, v3065); + int16x8_t v3534 = vsubq_s16(v3067, v3069); + int16x8_t v3535_tmp = vqrdmulhq_n_s16(v3534, 6576); + int16x8_t v3535 = vaddq_s16(v3535_tmp, v3534); + int16x8_t v3536 = vaddq_s16(v3533, v3535); + int16x8_t v3537 = vsubq_s16(v3073, v3075); + int16x8_t v3538 = vsubq_s16(v3077, v3079); + int16x8_t v3539_tmp = vqrdmulhq_n_s16(v3538, 6576); + int16x8_t v3539 = vaddq_s16(v3539_tmp, v3538); + int16x8_t v3540 = vaddq_s16(v3537, v3539); + int16x8_t v3541 = vqrdmulhq_n_s16(v3540, 19469); + int16x8_t v3542 = vaddq_s16(v3536, v3541); + int16x8_t v3543 = vsubq_s16(v3041, v3043); + int16x8_t v3544 = vsubq_s16(v3045, v3047); + int16x8_t v3545_tmp = vqrdmulhq_n_s16(v3544, 8817); + int16x8_t v3545 = vaddq_s16(v3545_tmp, v3544); + int16x8_t v3546 = vaddq_s16(v3543, v3545); + int16x8_t v3547 = vsubq_s16(v3051, v3053); + int16x8_t v3548 = vsubq_s16(v3055, v3057); + int16x8_t v3549_tmp = vqrdmulhq_n_s16(v3548, 8817); + int16x8_t v3549 = vaddq_s16(v3549_tmp, v3548); + int16x8_t v3550 = vaddq_s16(v3547, v3549); + int16x8_t v3551 = vqrdmulhq_n_s16(v3550, 19625); + int16x8_t v3552 = vaddq_s16(v3546, v3551); + int16x8_t v3553 = vsubq_s16(v2998, v3003); + int16x8_t v3554 = vsubq_s16(v3008, v3013); + int16x8_t v3555_tmp = vqrdmulhq_n_s16(v3554, 11356); + int16x8_t v3555 = vaddq_s16(v3555_tmp, v3554); + int16x8_t v3556 = vaddq_s16(v3553, v3555); + int16x8_t v3557 = vsubq_s16(v3020, v3025); + int16x8_t v3558 = vsubq_s16(v3030, v3035); + int16x8_t v3559_tmp = vqrdmulhq_n_s16(v3558, 11356); + int16x8_t v3559 = vaddq_s16(v3559_tmp, v3558); + int16x8_t v3560 = vaddq_s16(v3557, v3559); + int16x8_t v3561 = vqrdmulhq_n_s16(v3560, 19786); + int16x8_t v3562 = vaddq_s16(v3556, v3561); + int16x8_t v3563 = vsubq_s16(v2952, v2957); + int16x8_t v3564 = vsubq_s16(v2962, v2967); + int16x8_t v3565_tmp = vqrdmulhq_n_s16(v3564, 14256); + int16x8_t v3565 = vaddq_s16(v3565_tmp, v3564); + int16x8_t v3566 = vaddq_s16(v3563, v3565); + int16x8_t v3567 = vsubq_s16(v2974, v2979); + int16x8_t v3568 = vsubq_s16(v2984, v2989); + int16x8_t v3569_tmp = vqrdmulhq_n_s16(v3568, 14256); + int16x8_t v3569 = vaddq_s16(v3569_tmp, v3568); + int16x8_t v3570 = vaddq_s16(v3567, v3569); + int16x8_t v3571 = vqrdmulhq_n_s16(v3570, 19954); + int16x8_t v3572 = vaddq_s16(v3566, v3571); + int16x8_t v3573 = vsubq_s16(v2906, v2911); + int16x8_t v3574 = vsubq_s16(v2916, v2921); + int16x8_t v3575_tmp = vqrdmulhq_n_s16(v3574, 17596); + int16x8_t v3575 = vaddq_s16(v3575_tmp, v3574); + int16x8_t v3576 = vaddq_s16(v3573, v3575); + int16x8_t v3577 = vsubq_s16(v2928, v2933); + int16x8_t v3578 = vsubq_s16(v2938, v2943); + int16x8_t v3579_tmp = vqrdmulhq_n_s16(v3578, 17596); + int16x8_t v3579 = vaddq_s16(v3579_tmp, v3578); + int16x8_t v3580 = vaddq_s16(v3577, v3579); + int16x8_t v3581 = vqrdmulhq_n_s16(v3580, 20127); + int16x8_t v3582 = vaddq_s16(v3576, v3581); + int16x8_t v3583 = vsubq_s16(v2860, v2865); + int16x8_t v3584 = vsubq_s16(v2870, v2875); + int16x8_t v3585_tmp = vqrdmulhq_n_s16(v3584, 21483); + int16x8_t v3585 = vaddq_s16(v3585_tmp, v3584); + int16x8_t v3586 = vaddq_s16(v3583, v3585); + int16x8_t v3587 = vsubq_s16(v2882, v2887); + int16x8_t v3588 = vsubq_s16(v2892, v2897); + int16x8_t v3589_tmp = vqrdmulhq_n_s16(v3588, 21483); + int16x8_t v3589 = vaddq_s16(v3589_tmp, v3588); + int16x8_t v3590 = vaddq_s16(v3587, v3589); + int16x8_t v3591 = vqrdmulhq_n_s16(v3590, 20306); + int16x8_t v3592 = vaddq_s16(v3586, v3591); + int16x8_t v3593 = vsubq_s16(v2814, v2819); + int16x8_t v3594 = vsubq_s16(v2824, v2829); + int16x8_t v3595_tmp = vqrdmulhq_n_s16(v3594, 26057); + int16x8_t v3595 = vaddq_s16(v3595_tmp, v3594); + int16x8_t v3596 = vaddq_s16(v3593, v3595); + int16x8_t v3597 = vsubq_s16(v2836, v2841); + int16x8_t v3598 = vsubq_s16(v2846, v2851); + int16x8_t v3599_tmp = vqrdmulhq_n_s16(v3598, 26057); + int16x8_t v3599 = vaddq_s16(v3599_tmp, v3598); + int16x8_t v3600 = vaddq_s16(v3597, v3599); + int16x8_t v3601 = vqrdmulhq_n_s16(v3600, 20492); + int16x8_t v3602 = vaddq_s16(v3596, v3601); + int16x8_t v3603 = vsubq_s16(v2768, v2773); + int16x8_t v3604 = vsubq_s16(v2778, v2783); + int16x8_t v3605_tmp = vqrdmulhq_n_s16(v3604, 31517); + int16x8_t v3605 = vaddq_s16(v3605_tmp, v3604); + int16x8_t v3606 = vaddq_s16(v3603, v3605); + int16x8_t v3607 = vsubq_s16(v2790, v2795); + int16x8_t v3608 = vsubq_s16(v2800, v2805); + int16x8_t v3609_tmp = vqrdmulhq_n_s16(v3608, 31517); + int16x8_t v3609 = vaddq_s16(v3609_tmp, v3608); + int16x8_t v3610 = vaddq_s16(v3607, v3609); + int16x8_t v3611 = vqrdmulhq_n_s16(v3610, 20684); + int16x8_t v3612 = vaddq_s16(v3606, v3611); + int16x8_t v3613 = vsubq_s16(v2722, v2727); + int16x8_t v3614 = vsubq_s16(v2732, v2737); + int16x8_t v3615_tmp = vqrdmulhq_n_s16(v3614, 5373); + int16x8_t v3615 = vmlaq_n_s16(v3615_tmp, v3614, 2); + int16x8_t v3616 = vaddq_s16(v3613, v3615); + int16x8_t v3617 = vsubq_s16(v2744, v2749); + int16x8_t v3618 = vsubq_s16(v2754, v2759); + int16x8_t v3619_tmp = vqrdmulhq_n_s16(v3618, 5373); + int16x8_t v3619 = vmlaq_n_s16(v3619_tmp, v3618, 2); + int16x8_t v3620 = vaddq_s16(v3617, v3619); + int16x8_t v3621 = vqrdmulhq_n_s16(v3620, 20883); + int16x8_t v3622 = vaddq_s16(v3616, v3621); + int16x8_t v3623 = vsubq_s16(v2676, v2681); + int16x8_t v3624 = vsubq_s16(v2686, v2691); + int16x8_t v3625_tmp = vqrdmulhq_n_s16(v3624, 13571); + int16x8_t v3625 = vmlaq_n_s16(v3625_tmp, v3624, 2); + int16x8_t v3626 = vaddq_s16(v3623, v3625); + int16x8_t v3627 = vsubq_s16(v2698, v2703); + int16x8_t v3628 = vsubq_s16(v2708, v2713); + int16x8_t v3629_tmp = vqrdmulhq_n_s16(v3628, 13571); + int16x8_t v3629 = vmlaq_n_s16(v3629_tmp, v3628, 2); + int16x8_t v3630 = vaddq_s16(v3627, v3629); + int16x8_t v3631 = vqrdmulhq_n_s16(v3630, 21089); + int16x8_t v3632 = vaddq_s16(v3626, v3631); + int16x8_t v3633 = vsubq_s16(v2588, v2599); + int16x8_t v3634 = vsubq_s16(v2610, v2621); + int16x8_t v3635_tmp = vqrdmulhq_n_s16(v3634, 23975); + int16x8_t v3635 = vmlaq_n_s16(v3635_tmp, v3634, 2); + int16x8_t v3636 = vaddq_s16(v3633, v3635); + int16x8_t v3637 = vsubq_s16(v2634, v2645); + int16x8_t v3638 = vsubq_s16(v2656, v2667); + int16x8_t v3639_tmp = vqrdmulhq_n_s16(v3638, 23975); + int16x8_t v3639 = vmlaq_n_s16(v3639_tmp, v3638, 2); + int16x8_t v3640 = vaddq_s16(v3637, v3639); + int16x8_t v3641 = vqrdmulhq_n_s16(v3640, 21303); + int16x8_t v3642 = vaddq_s16(v3636, v3641); + int16x8_t v3643 = vsubq_s16(v2494, v2505); + int16x8_t v3644 = vsubq_s16(v2516, v2527); + int16x8_t v3645_tmp = vqrdmulhq_n_s16(v3644, 4832); + int16x8_t v3645 = vmlaq_n_s16(v3645_tmp, v3644, 3); + int16x8_t v3646 = vaddq_s16(v3643, v3645); + int16x8_t v3647 = vsubq_s16(v2540, v2551); + int16x8_t v3648 = vsubq_s16(v2562, v2573); + int16x8_t v3649_tmp = vqrdmulhq_n_s16(v3648, 4832); + int16x8_t v3649 = vmlaq_n_s16(v3649_tmp, v3648, 3); + int16x8_t v3650 = vaddq_s16(v3647, v3649); + int16x8_t v3651 = vqrdmulhq_n_s16(v3650, 21524); + int16x8_t v3652 = vaddq_s16(v3646, v3651); + int16x8_t v3653 = vsubq_s16(v2399, v2410); + int16x8_t v3654 = vsubq_s16(v2421, v2432); + int16x8_t v3655_tmp = vqrdmulhq_n_s16(v3654, 23437); + int16x8_t v3655 = vmlaq_n_s16(v3655_tmp, v3654, 3); + int16x8_t v3656 = vaddq_s16(v3653, v3655); + int16x8_t v3657 = vsubq_s16(v2445, v2456); + int16x8_t v3658 = vsubq_s16(v2468, v2479); + int16x8_t v3659_tmp = vqrdmulhq_n_s16(v3658, 23437); + int16x8_t v3659 = vmlaq_n_s16(v3659_tmp, v3658, 3); + int16x8_t v3660 = vaddq_s16(v3657, v3659); + int16x8_t v3661 = vqrdmulhq_n_s16(v3660, 21753); + int16x8_t v3662 = vaddq_s16(v3656, v3661); + int16x8_t v3663 = vsubq_s16(v2305, v2316); + int16x8_t v3664 = vsubq_s16(v2327, v2338); + int16x8_t v3665_tmp = vqrdmulhq_n_s16(v3664, 17573); + int16x8_t v3665 = vmlaq_n_s16(v3665_tmp, v3664, 4); + int16x8_t v3666 = vaddq_s16(v3663, v3665); + int16x8_t v3667 = vsubq_s16(v2351, v2362); + int16x8_t v3668 = vsubq_s16(v2373, v2384); + int16x8_t v3669_tmp = vqrdmulhq_n_s16(v3668, 17573); + int16x8_t v3669 = vmlaq_n_s16(v3669_tmp, v3668, 4); + int16x8_t v3670 = vaddq_s16(v3667, v3669); + int16x8_t v3671 = vqrdmulhq_n_s16(v3670, 21990); + int16x8_t v3672 = vaddq_s16(v3666, v3671); + int16x8_t v3673 = vsubq_s16(v2127, v2150); + int16x8_t v3674 = vsubq_s16(v2173, v2196); + int16x8_t v3675_tmp = vqrdmulhq_n_s16(v3674, 27122); + int16x8_t v3675 = vmlaq_n_s16(v3675_tmp, v3674, 5); + int16x8_t v3676 = vaddq_s16(v3673, v3675); + int16x8_t v3677 = vsubq_s16(v2221, v2244); + int16x8_t v3678 = vsubq_s16(v2267, v2290); + int16x8_t v3679_tmp = vqrdmulhq_n_s16(v3678, 27122); + int16x8_t v3679 = vmlaq_n_s16(v3679_tmp, v3678, 5); + int16x8_t v3680 = vaddq_s16(v3677, v3679); + int16x8_t v3681 = vqrdmulhq_n_s16(v3680, 22236); + int16x8_t v3682 = vaddq_s16(v3676, v3681); + int16x8_t v3683 = vsubq_s16(v1937, v1960); + int16x8_t v3684 = vsubq_s16(v1983, v2006); + int16x8_t v3685_tmp = vqrdmulhq_n_s16(v3684, 5041); + int16x8_t v3685 = vmlaq_n_s16(v3685_tmp, v3684, 8); + int16x8_t v3686 = vaddq_s16(v3683, v3685); + int16x8_t v3687 = vsubq_s16(v2031, v2054); + int16x8_t v3688 = vsubq_s16(v2077, v2100); + int16x8_t v3689_tmp = vqrdmulhq_n_s16(v3688, 5041); + int16x8_t v3689 = vmlaq_n_s16(v3689_tmp, v3688, 8); + int16x8_t v3690 = vaddq_s16(v3687, v3689); + int16x8_t v3691 = vqrdmulhq_n_s16(v3690, 22491); + int16x8_t v3692 = vaddq_s16(v3686, v3691); + int16x8_t v3693 = vsubq_s16(v1579, v1626); + int16x8_t v3694 = vsubq_s16(v1673, v1720); + int16x8_t v3695_tmp = vqrdmulhq_n_s16(v3694, 19146); + int16x8_t v3695 = vmlaq_n_s16(v3695_tmp, v3694, 13); + int16x8_t v3696 = vaddq_s16(v3693, v3695); + int16x8_t v3697 = vsubq_s16(v1769, v1816); + int16x8_t v3698 = vsubq_s16(v1863, v1910); + int16x8_t v3699_tmp = vqrdmulhq_n_s16(v3698, 19146); + int16x8_t v3699 = vmlaq_n_s16(v3699_tmp, v3698, 13); + int16x8_t v3700 = vaddq_s16(v3697, v3699); + int16x8_t v3701 = vqrdmulhq_n_s16(v3700, 22755); + int16x8_t v3702 = vaddq_s16(v3696, v3701); + int16x8_t v3703 = vsubq_s16(v141, v316); + int16x8_t v3704 = vsubq_s16(v522, v698); + int16x8_t v3705_tmp = vqrdmulhq_n_s16(v3704, 24402); + int16x8_t v3705 = vmlaq_n_s16(v3705_tmp, v3704, 40); + int16x8_t v3706 = vaddq_s16(v3703, v3705); + int16x8_t v3707 = vsubq_s16(v906, v1145); + int16x8_t v3708 = vsubq_s16(v1351, v1528); + int16x8_t v3709_tmp = vqrdmulhq_n_s16(v3708, 24402); + int16x8_t v3709 = vmlaq_n_s16(v3709_tmp, v3708, 40); + int16x8_t v3710 = vaddq_s16(v3707, v3709); + int16x8_t v3711 = vqrdmulhq_n_s16(v3710, 23030); + int16x8_t v3712 = vaddq_s16(v3706, v3711); + int16x8_t v3713 = vsubq_s16(v3703, v3705); + int16x8_t v3714 = vsubq_s16(v3707, v3709); + int16x8_t v3715 = vqrdmulhq_n_s16(v3714, 23314); + int16x8_t v3716 = vaddq_s16(v3713, v3715); + int16x8_t v3717 = vsubq_s16(v3693, v3695); + int16x8_t v3718 = vsubq_s16(v3697, v3699); + int16x8_t v3719 = vqrdmulhq_n_s16(v3718, 23609); + int16x8_t v3720 = vaddq_s16(v3717, v3719); + int16x8_t v3721 = vsubq_s16(v3683, v3685); + int16x8_t v3722 = vsubq_s16(v3687, v3689); + int16x8_t v3723 = vqrdmulhq_n_s16(v3722, 23915); + int16x8_t v3724 = vaddq_s16(v3721, v3723); + int16x8_t v3725 = vsubq_s16(v3673, v3675); + int16x8_t v3726 = vsubq_s16(v3677, v3679); + int16x8_t v3727 = vqrdmulhq_n_s16(v3726, 24233); + int16x8_t v3728 = vaddq_s16(v3725, v3727); + int16x8_t v3729 = vsubq_s16(v3663, v3665); + int16x8_t v3730 = vsubq_s16(v3667, v3669); + int16x8_t v3731 = vqrdmulhq_n_s16(v3730, 24564); + int16x8_t v3732 = vaddq_s16(v3729, v3731); + int16x8_t v3733 = vsubq_s16(v3653, v3655); + int16x8_t v3734 = vsubq_s16(v3657, v3659); + int16x8_t v3735 = vqrdmulhq_n_s16(v3734, 24907); + int16x8_t v3736 = vaddq_s16(v3733, v3735); + int16x8_t v3737 = vsubq_s16(v3643, v3645); + int16x8_t v3738 = vsubq_s16(v3647, v3649); + int16x8_t v3739 = vqrdmulhq_n_s16(v3738, 25264); + int16x8_t v3740 = vaddq_s16(v3737, v3739); + int16x8_t v3741 = vsubq_s16(v3633, v3635); + int16x8_t v3742 = vsubq_s16(v3637, v3639); + int16x8_t v3743 = vqrdmulhq_n_s16(v3742, 25635); + int16x8_t v3744 = vaddq_s16(v3741, v3743); + int16x8_t v3745 = vsubq_s16(v3623, v3625); + int16x8_t v3746 = vsubq_s16(v3627, v3629); + int16x8_t v3747 = vqrdmulhq_n_s16(v3746, 26021); + int16x8_t v3748 = vaddq_s16(v3745, v3747); + int16x8_t v3749 = vsubq_s16(v3613, v3615); + int16x8_t v3750 = vsubq_s16(v3617, v3619); + int16x8_t v3751 = vqrdmulhq_n_s16(v3750, 26423); + int16x8_t v3752 = vaddq_s16(v3749, v3751); + int16x8_t v3753 = vsubq_s16(v3603, v3605); + int16x8_t v3754 = vsubq_s16(v3607, v3609); + int16x8_t v3755 = vqrdmulhq_n_s16(v3754, 26842); + int16x8_t v3756 = vaddq_s16(v3753, v3755); + int16x8_t v3757 = vsubq_s16(v3593, v3595); + int16x8_t v3758 = vsubq_s16(v3597, v3599); + int16x8_t v3759 = vqrdmulhq_n_s16(v3758, 27279); + int16x8_t v3760 = vaddq_s16(v3757, v3759); + int16x8_t v3761 = vsubq_s16(v3583, v3585); + int16x8_t v3762 = vsubq_s16(v3587, v3589); + int16x8_t v3763 = vqrdmulhq_n_s16(v3762, 27734); + int16x8_t v3764 = vaddq_s16(v3761, v3763); + int16x8_t v3765 = vsubq_s16(v3573, v3575); + int16x8_t v3766 = vsubq_s16(v3577, v3579); + int16x8_t v3767 = vqrdmulhq_n_s16(v3766, 28209); + int16x8_t v3768 = vaddq_s16(v3765, v3767); + int16x8_t v3769 = vsubq_s16(v3563, v3565); + int16x8_t v3770 = vsubq_s16(v3567, v3569); + int16x8_t v3771 = vqrdmulhq_n_s16(v3770, 28705); + int16x8_t v3772 = vaddq_s16(v3769, v3771); + int16x8_t v3773 = vsubq_s16(v3553, v3555); + int16x8_t v3774 = vsubq_s16(v3557, v3559); + int16x8_t v3775 = vqrdmulhq_n_s16(v3774, 29223); + int16x8_t v3776 = vaddq_s16(v3773, v3775); + int16x8_t v3777 = vsubq_s16(v3543, v3545); + int16x8_t v3778 = vsubq_s16(v3547, v3549); + int16x8_t v3779 = vqrdmulhq_n_s16(v3778, 29764); + int16x8_t v3780 = vaddq_s16(v3777, v3779); + int16x8_t v3781 = vsubq_s16(v3533, v3535); + int16x8_t v3782 = vsubq_s16(v3537, v3539); + int16x8_t v3783 = vqrdmulhq_n_s16(v3782, 30331); + int16x8_t v3784 = vaddq_s16(v3781, v3783); + int16x8_t v3785 = vsubq_s16(v3523, v3525); + int16x8_t v3786 = vsubq_s16(v3527, v3529); + int16x8_t v3787 = vqrdmulhq_n_s16(v3786, 30925); + int16x8_t v3788 = vaddq_s16(v3785, v3787); + int16x8_t v3789 = vsubq_s16(v3513, v3515); + int16x8_t v3790 = vsubq_s16(v3517, v3519); + int16x8_t v3791 = vqrdmulhq_n_s16(v3790, 31547); + int16x8_t v3792 = vaddq_s16(v3789, v3791); + int16x8_t v3793 = vsubq_s16(v3503, v3505); + int16x8_t v3794 = vsubq_s16(v3507, v3509); + int16x8_t v3795 = vqrdmulhq_n_s16(v3794, 32199); + int16x8_t v3796 = vaddq_s16(v3793, v3795); + int16x8_t v3797 = vsubq_s16(v3493, v3495); + int16x8_t v3798 = vsubq_s16(v3497, v3499); + int16x8_t v3799_tmp = vqrdmulhq_n_s16(v3798, 117); + int16x8_t v3799 = vaddq_s16(v3799_tmp, v3798); + int16x8_t v3800 = vaddq_s16(v3797, v3799); + int16x8_t v3801 = vsubq_s16(v3483, v3485); + int16x8_t v3802 = vsubq_s16(v3487, v3489); + int16x8_t v3803_tmp = vqrdmulhq_n_s16(v3802, 837); + int16x8_t v3803 = vaddq_s16(v3803_tmp, v3802); + int16x8_t v3804 = vaddq_s16(v3801, v3803); + int16x8_t v3805 = vsubq_s16(v3473, v3475); + int16x8_t v3806 = vsubq_s16(v3477, v3479); + int16x8_t v3807_tmp = vqrdmulhq_n_s16(v3806, 1594); + int16x8_t v3807 = vaddq_s16(v3807_tmp, v3806); + int16x8_t v3808 = vaddq_s16(v3805, v3807); + int16x8_t v3809 = vsubq_s16(v3463, v3465); + int16x8_t v3810 = vsubq_s16(v3467, v3469); + int16x8_t v3811_tmp = vqrdmulhq_n_s16(v3810, 2393); + int16x8_t v3811 = vaddq_s16(v3811_tmp, v3810); + int16x8_t v3812 = vaddq_s16(v3809, v3811); + int16x8_t v3813 = vsubq_s16(v3453, v3455); + int16x8_t v3814 = vsubq_s16(v3457, v3459); + int16x8_t v3815_tmp = vqrdmulhq_n_s16(v3814, 3234); + int16x8_t v3815 = vaddq_s16(v3815_tmp, v3814); + int16x8_t v3816 = vaddq_s16(v3813, v3815); + int16x8_t v3817 = vsubq_s16(v3443, v3445); + int16x8_t v3818 = vsubq_s16(v3447, v3449); + int16x8_t v3819_tmp = vqrdmulhq_n_s16(v3818, 4123); + int16x8_t v3819 = vaddq_s16(v3819_tmp, v3818); + int16x8_t v3820 = vaddq_s16(v3817, v3819); + int16x8_t v3821 = vsubq_s16(v3433, v3435); + int16x8_t v3822 = vsubq_s16(v3437, v3439); + int16x8_t v3823_tmp = vqrdmulhq_n_s16(v3822, 5062); + int16x8_t v3823 = vaddq_s16(v3823_tmp, v3822); + int16x8_t v3824 = vaddq_s16(v3821, v3823); + int16x8_t v3825 = vsubq_s16(v3423, v3425); + int16x8_t v3826 = vsubq_s16(v3427, v3429); + int16x8_t v3827_tmp = vqrdmulhq_n_s16(v3826, 6057); + int16x8_t v3827 = vaddq_s16(v3827_tmp, v3826); + int16x8_t v3828 = vaddq_s16(v3825, v3827); + int16x8_t v3829 = vsubq_s16(v3413, v3415); + int16x8_t v3830 = vsubq_s16(v3417, v3419); + int16x8_t v3831_tmp = vqrdmulhq_n_s16(v3830, 7111); + int16x8_t v3831 = vaddq_s16(v3831_tmp, v3830); + int16x8_t v3832 = vaddq_s16(v3829, v3831); + int16x8_t v3833 = vsubq_s16(v3403, v3405); + int16x8_t v3834 = vsubq_s16(v3407, v3409); + int16x8_t v3835_tmp = vqrdmulhq_n_s16(v3834, 8231); + int16x8_t v3835 = vaddq_s16(v3835_tmp, v3834); + int16x8_t v3836 = vaddq_s16(v3833, v3835); + int16x8_t v3837 = vsubq_s16(v3393, v3395); + int16x8_t v3838 = vsubq_s16(v3397, v3399); + int16x8_t v3839_tmp = vqrdmulhq_n_s16(v3838, 9421); + int16x8_t v3839 = vaddq_s16(v3839_tmp, v3838); + int16x8_t v3840 = vaddq_s16(v3837, v3839); + int16x8_t v3841 = vsubq_s16(v3374, v3379); + int16x8_t v3842 = vsubq_s16(v3384, v3389); + int16x8_t v3843_tmp = vqrdmulhq_n_s16(v3842, 10690); + int16x8_t v3843 = vaddq_s16(v3843_tmp, v3842); + int16x8_t v3844 = vaddq_s16(v3841, v3843); + int16x8_t v3845 = vsubq_s16(v3352, v3357); + int16x8_t v3846 = vsubq_s16(v3362, v3367); + int16x8_t v3847_tmp = vqrdmulhq_n_s16(v3846, 12044); + int16x8_t v3847 = vaddq_s16(v3847_tmp, v3846); + int16x8_t v3848 = vaddq_s16(v3845, v3847); + int16x8_t v3849 = vsubq_s16(v3330, v3335); + int16x8_t v3850 = vsubq_s16(v3340, v3345); + int16x8_t v3851_tmp = vqrdmulhq_n_s16(v3850, 13493); + int16x8_t v3851 = vaddq_s16(v3851_tmp, v3850); + int16x8_t v3852 = vaddq_s16(v3849, v3851); + int16x8_t v3853 = vsubq_s16(v3308, v3313); + int16x8_t v3854 = vsubq_s16(v3318, v3323); + int16x8_t v3855_tmp = vqrdmulhq_n_s16(v3854, 15046); + int16x8_t v3855 = vaddq_s16(v3855_tmp, v3854); + int16x8_t v3856 = vaddq_s16(v3853, v3855); + int16x8_t v3857 = vsubq_s16(v3286, v3291); + int16x8_t v3858 = vsubq_s16(v3296, v3301); + int16x8_t v3859_tmp = vqrdmulhq_n_s16(v3858, 16715); + int16x8_t v3859 = vaddq_s16(v3859_tmp, v3858); + int16x8_t v3860 = vaddq_s16(v3857, v3859); + int16x8_t v3861 = vsubq_s16(v3264, v3269); + int16x8_t v3862 = vsubq_s16(v3274, v3279); + int16x8_t v3863_tmp = vqrdmulhq_n_s16(v3862, 18512); + int16x8_t v3863 = vaddq_s16(v3863_tmp, v3862); + int16x8_t v3864 = vaddq_s16(v3861, v3863); + int16x8_t v3865 = vsubq_s16(v3242, v3247); + int16x8_t v3866 = vsubq_s16(v3252, v3257); + int16x8_t v3867_tmp = vqrdmulhq_n_s16(v3866, 20453); + int16x8_t v3867 = vaddq_s16(v3867_tmp, v3866); + int16x8_t v3868 = vaddq_s16(v3865, v3867); + int16x8_t v3869 = vsubq_s16(v3220, v3225); + int16x8_t v3870 = vsubq_s16(v3230, v3235); + int16x8_t v3871_tmp = vqrdmulhq_n_s16(v3870, 22555); + int16x8_t v3871 = vaddq_s16(v3871_tmp, v3870); + int16x8_t v3872 = vaddq_s16(v3869, v3871); + int16x8_t v3873 = vsubq_s16(v3198, v3203); + int16x8_t v3874 = vsubq_s16(v3208, v3213); + int16x8_t v3875_tmp = vqrdmulhq_n_s16(v3874, 24839); + int16x8_t v3875 = vaddq_s16(v3875_tmp, v3874); + int16x8_t v3876 = vaddq_s16(v3873, v3875); + int16x8_t v3877 = vsubq_s16(v3176, v3181); + int16x8_t v3878 = vsubq_s16(v3186, v3191); + int16x8_t v3879_tmp = vqrdmulhq_n_s16(v3878, 27330); + int16x8_t v3879 = vaddq_s16(v3879_tmp, v3878); + int16x8_t v3880 = vaddq_s16(v3877, v3879); + int16x8_t v3881 = vsubq_s16(v3154, v3159); + int16x8_t v3882 = vsubq_s16(v3164, v3169); + int16x8_t v3883_tmp = vqrdmulhq_n_s16(v3882, 30056); + int16x8_t v3883 = vaddq_s16(v3883_tmp, v3882); + int16x8_t v3884 = vaddq_s16(v3881, v3883); + int16x8_t v3885 = vsubq_s16(v3132, v3137); + int16x8_t v3886 = vsubq_s16(v3142, v3147); + int16x8_t v3887_tmp = vqrdmulhq_n_s16(v3886, 282); + int16x8_t v3887 = vmlaq_n_s16(v3887_tmp, v3886, 2); + int16x8_t v3888 = vaddq_s16(v3885, v3887); + int16x8_t v3889 = vsubq_s16(v3110, v3115); + int16x8_t v3890 = vsubq_s16(v3120, v3125); + int16x8_t v3891_tmp = vqrdmulhq_n_s16(v3890, 3588); + int16x8_t v3891 = vmlaq_n_s16(v3891_tmp, v3890, 2); + int16x8_t v3892 = vaddq_s16(v3889, v3891); + int16x8_t v3893 = vsubq_s16(v3088, v3093); + int16x8_t v3894 = vsubq_s16(v3098, v3103); + int16x8_t v3895_tmp = vqrdmulhq_n_s16(v3894, 7255); + int16x8_t v3895 = vmlaq_n_s16(v3895_tmp, v3894, 2); + int16x8_t v3896 = vaddq_s16(v3893, v3895); + int16x8_t v3897 = vsubq_s16(v3066, v3071); + int16x8_t v3898 = vsubq_s16(v3076, v3081); + int16x8_t v3899_tmp = vqrdmulhq_n_s16(v3898, 11344); + int16x8_t v3899 = vmlaq_n_s16(v3899_tmp, v3898, 2); + int16x8_t v3900 = vaddq_s16(v3897, v3899); + int16x8_t v3901 = vsubq_s16(v3044, v3049); + int16x8_t v3902 = vsubq_s16(v3054, v3059); + int16x8_t v3903_tmp = vqrdmulhq_n_s16(v3902, 15934); + int16x8_t v3903 = vmlaq_n_s16(v3903_tmp, v3902, 2); + int16x8_t v3904 = vaddq_s16(v3901, v3903); + int16x8_t v3905 = vsubq_s16(v3004, v3015); + int16x8_t v3906 = vsubq_s16(v3026, v3037); + int16x8_t v3907_tmp = vqrdmulhq_n_s16(v3906, 21120); + int16x8_t v3907 = vmlaq_n_s16(v3907_tmp, v3906, 2); + int16x8_t v3908 = vaddq_s16(v3905, v3907); + int16x8_t v3909 = vsubq_s16(v2958, v2969); + int16x8_t v3910 = vsubq_s16(v2980, v2991); + int16x8_t v3911_tmp = vqrdmulhq_n_s16(v3910, 27027); + int16x8_t v3911 = vmlaq_n_s16(v3911_tmp, v3910, 2); + int16x8_t v3912 = vaddq_s16(v3909, v3911); + int16x8_t v3913 = vsubq_s16(v2912, v2923); + int16x8_t v3914 = vsubq_s16(v2934, v2945); + int16x8_t v3915_tmp = vqrdmulhq_n_s16(v3914, 1045); + int16x8_t v3915 = vmlaq_n_s16(v3915_tmp, v3914, 3); + int16x8_t v3916 = vaddq_s16(v3913, v3915); + int16x8_t v3917 = vsubq_s16(v2866, v2877); + int16x8_t v3918 = vsubq_s16(v2888, v2899); + int16x8_t v3919_tmp = vqrdmulhq_n_s16(v3918, 8923); + int16x8_t v3919 = vmlaq_n_s16(v3919_tmp, v3918, 3); + int16x8_t v3920 = vaddq_s16(v3917, v3919); + int16x8_t v3921 = vsubq_s16(v2820, v2831); + int16x8_t v3922 = vsubq_s16(v2842, v2853); + int16x8_t v3923_tmp = vqrdmulhq_n_s16(v3922, 18177); + int16x8_t v3923 = vmlaq_n_s16(v3923_tmp, v3922, 3); + int16x8_t v3924 = vaddq_s16(v3921, v3923); + int16x8_t v3925 = vsubq_s16(v2774, v2785); + int16x8_t v3926 = vsubq_s16(v2796, v2807); + int16x8_t v3927_tmp = vqrdmulhq_n_s16(v3926, 29200); + int16x8_t v3927 = vmlaq_n_s16(v3927_tmp, v3926, 3); + int16x8_t v3928 = vaddq_s16(v3925, v3927); + int16x8_t v3929 = vsubq_s16(v2728, v2739); + int16x8_t v3930 = vsubq_s16(v2750, v2761); + int16x8_t v3931_tmp = vqrdmulhq_n_s16(v3930, 9782); + int16x8_t v3931 = vmlaq_n_s16(v3931_tmp, v3930, 4); + int16x8_t v3932 = vaddq_s16(v3929, v3931); + int16x8_t v3933 = vsubq_s16(v2682, v2693); + int16x8_t v3934 = vsubq_s16(v2704, v2715); + int16x8_t v3935_tmp = vqrdmulhq_n_s16(v3934, 26282); + int16x8_t v3935 = vmlaq_n_s16(v3935_tmp, v3934, 4); + int16x8_t v3936 = vaddq_s16(v3933, v3935); + int16x8_t v3937 = vsubq_s16(v2600, v2623); + int16x8_t v3938 = vsubq_s16(v2646, v2669); + int16x8_t v3939_tmp = vqrdmulhq_n_s16(v3938, 14423); + int16x8_t v3939 = vmlaq_n_s16(v3939_tmp, v3938, 5); + int16x8_t v3940 = vaddq_s16(v3937, v3939); + int16x8_t v3941 = vsubq_s16(v2506, v2529); + int16x8_t v3942 = vsubq_s16(v2552, v2575); + int16x8_t v3943_tmp = vqrdmulhq_n_s16(v3942, 9008); + int16x8_t v3943 = vmlaq_n_s16(v3943_tmp, v3942, 6); + int16x8_t v3944 = vaddq_s16(v3941, v3943); + int16x8_t v3945 = vsubq_s16(v2411, v2434); + int16x8_t v3946 = vsubq_s16(v2457, v2481); + int16x8_t v3947_tmp = vqrdmulhq_n_s16(v3946, 13552); + int16x8_t v3947 = vmlaq_n_s16(v3947_tmp, v3946, 7); + int16x8_t v3948 = vaddq_s16(v3945, v3947); + int16x8_t v3949 = vsubq_s16(v2317, v2340); + int16x8_t v3950 = vsubq_s16(v2363, v2386); + int16x8_t v3951_tmp = vqrdmulhq_n_s16(v3950, 1925); + int16x8_t v3951 = vmlaq_n_s16(v3951_tmp, v3950, 9); + int16x8_t v3952 = vaddq_s16(v3949, v3951); + int16x8_t v3953 = vsubq_s16(v2151, v2198); + int16x8_t v3954 = vsubq_s16(v2245, v2292); + int16x8_t v3955_tmp = vqrdmulhq_n_s16(v3954, 21123); + int16x8_t v3955 = vmlaq_n_s16(v3955_tmp, v3954, 11); + int16x8_t v3956 = vaddq_s16(v3953, v3955); + int16x8_t v3957 = vsubq_s16(v1961, v2008); + int16x8_t v3958 = vsubq_s16(v2055, v2102); + int16x8_t v3959_tmp = vqrdmulhq_n_s16(v3958, 9831); + int16x8_t v3959 = vmlaq_n_s16(v3959_tmp, v3958, 16); + int16x8_t v3960 = vaddq_s16(v3957, v3959); + int16x8_t v3961 = vsubq_s16(v1627, v1722); + int16x8_t v3962 = vsubq_s16(v1817, v1912); + int16x8_t v3963_tmp = vqrdmulhq_n_s16(v3962, 5373); + int16x8_t v3963 = vmlaq_n_s16(v3963_tmp, v3962, 27); + int16x8_t v3964 = vaddq_s16(v3961, v3963); + int16x8_t v3965 = vsubq_s16(v317, v700); + int16x8_t v3966 = vsubq_s16(v1146, v1530); + int16x8_t v3967_tmp = vqrdmulhq_n_s16(v3966, 15986); + int16x8_t v3967 = vmlaq_n_s16(v3967_tmp, v3966, 81); + int16x8_t v3968 = vaddq_s16(v3965, v3967); + int16x8_t v3969 = vsubq_s16(v3965, v3967); + int16x8_t v3970 = vsubq_s16(v3961, v3963); + int16x8_t v3971 = vsubq_s16(v3957, v3959); + int16x8_t v3972 = vsubq_s16(v3953, v3955); + int16x8_t v3973 = vsubq_s16(v3949, v3951); + int16x8_t v3974 = vsubq_s16(v3945, v3947); + int16x8_t v3975 = vsubq_s16(v3941, v3943); + int16x8_t v3976 = vsubq_s16(v3937, v3939); + int16x8_t v3977 = vsubq_s16(v3933, v3935); + int16x8_t v3978 = vsubq_s16(v3929, v3931); + int16x8_t v3979 = vsubq_s16(v3925, v3927); + int16x8_t v3980 = vsubq_s16(v3921, v3923); + int16x8_t v3981 = vsubq_s16(v3917, v3919); + int16x8_t v3982 = vsubq_s16(v3913, v3915); + int16x8_t v3983 = vsubq_s16(v3909, v3911); + int16x8_t v3984 = vsubq_s16(v3905, v3907); + int16x8_t v3985 = vsubq_s16(v3901, v3903); + int16x8_t v3986 = vsubq_s16(v3897, v3899); + int16x8_t v3987 = vsubq_s16(v3893, v3895); + int16x8_t v3988 = vsubq_s16(v3889, v3891); + int16x8_t v3989 = vsubq_s16(v3885, v3887); + int16x8_t v3990 = vsubq_s16(v3881, v3883); + int16x8_t v3991 = vsubq_s16(v3877, v3879); + int16x8_t v3992 = vsubq_s16(v3873, v3875); + int16x8_t v3993 = vsubq_s16(v3869, v3871); + int16x8_t v3994 = vsubq_s16(v3865, v3867); + int16x8_t v3995 = vsubq_s16(v3861, v3863); + int16x8_t v3996 = vsubq_s16(v3857, v3859); + int16x8_t v3997 = vsubq_s16(v3853, v3855); + int16x8_t v3998 = vsubq_s16(v3849, v3851); + int16x8_t v3999 = vsubq_s16(v3845, v3847); + int16x8_t v4000 = vsubq_s16(v3841, v3843); + int16x8_t v4001 = vsubq_s16(v3837, v3839); + int16x8_t v4002 = vsubq_s16(v3833, v3835); + int16x8_t v4003 = vsubq_s16(v3829, v3831); + int16x8_t v4004 = vsubq_s16(v3825, v3827); + int16x8_t v4005 = vsubq_s16(v3821, v3823); + int16x8_t v4006 = vsubq_s16(v3817, v3819); + int16x8_t v4007 = vsubq_s16(v3813, v3815); + int16x8_t v4008 = vsubq_s16(v3809, v3811); + int16x8_t v4009 = vsubq_s16(v3805, v3807); + int16x8_t v4010 = vsubq_s16(v3801, v3803); + int16x8_t v4011 = vsubq_s16(v3797, v3799); + int16x8_t v4012 = vsubq_s16(v3793, v3795); + int16x8_t v4013 = vsubq_s16(v3789, v3791); + int16x8_t v4014 = vsubq_s16(v3785, v3787); + int16x8_t v4015 = vsubq_s16(v3781, v3783); + int16x8_t v4016 = vsubq_s16(v3777, v3779); + int16x8_t v4017 = vsubq_s16(v3773, v3775); + int16x8_t v4018 = vsubq_s16(v3769, v3771); + int16x8_t v4019 = vsubq_s16(v3765, v3767); + int16x8_t v4020 = vsubq_s16(v3761, v3763); + int16x8_t v4021 = vsubq_s16(v3757, v3759); + int16x8_t v4022 = vsubq_s16(v3753, v3755); + int16x8_t v4023 = vsubq_s16(v3749, v3751); + int16x8_t v4024 = vsubq_s16(v3745, v3747); + int16x8_t v4025 = vsubq_s16(v3741, v3743); + int16x8_t v4026 = vsubq_s16(v3737, v3739); + int16x8_t v4027 = vsubq_s16(v3733, v3735); + int16x8_t v4028 = vsubq_s16(v3729, v3731); + int16x8_t v4029 = vsubq_s16(v3725, v3727); + int16x8_t v4030 = vsubq_s16(v3721, v3723); + int16x8_t v4031 = vsubq_s16(v3717, v3719); + int16x8_t v4032 = vsubq_s16(v3713, v3715); + int16x8_t v4033 = vsubq_s16(v3706, v3711); + int16x8_t v4034 = vsubq_s16(v3696, v3701); + int16x8_t v4035 = vsubq_s16(v3686, v3691); + int16x8_t v4036 = vsubq_s16(v3676, v3681); + int16x8_t v4037 = vsubq_s16(v3666, v3671); + int16x8_t v4038 = vsubq_s16(v3656, v3661); + int16x8_t v4039 = vsubq_s16(v3646, v3651); + int16x8_t v4040 = vsubq_s16(v3636, v3641); + int16x8_t v4041 = vsubq_s16(v3626, v3631); + int16x8_t v4042 = vsubq_s16(v3616, v3621); + int16x8_t v4043 = vsubq_s16(v3606, v3611); + int16x8_t v4044 = vsubq_s16(v3596, v3601); + int16x8_t v4045 = vsubq_s16(v3586, v3591); + int16x8_t v4046 = vsubq_s16(v3576, v3581); + int16x8_t v4047 = vsubq_s16(v3566, v3571); + int16x8_t v4048 = vsubq_s16(v3556, v3561); + int16x8_t v4049 = vsubq_s16(v3546, v3551); + int16x8_t v4050 = vsubq_s16(v3536, v3541); + int16x8_t v4051 = vsubq_s16(v3526, v3531); + int16x8_t v4052 = vsubq_s16(v3516, v3521); + int16x8_t v4053 = vsubq_s16(v3506, v3511); + int16x8_t v4054 = vsubq_s16(v3496, v3501); + int16x8_t v4055 = vsubq_s16(v3486, v3491); + int16x8_t v4056 = vsubq_s16(v3476, v3481); + int16x8_t v4057 = vsubq_s16(v3466, v3471); + int16x8_t v4058 = vsubq_s16(v3456, v3461); + int16x8_t v4059 = vsubq_s16(v3446, v3451); + int16x8_t v4060 = vsubq_s16(v3436, v3441); + int16x8_t v4061 = vsubq_s16(v3426, v3431); + int16x8_t v4062 = vsubq_s16(v3416, v3421); + int16x8_t v4063 = vsubq_s16(v3406, v3411); + int16x8_t v4064 = vsubq_s16(v3396, v3401); + int16x8_t v4065 = vsubq_s16(v3380, v3391); + int16x8_t v4066 = vsubq_s16(v3358, v3369); + int16x8_t v4067 = vsubq_s16(v3336, v3347); + int16x8_t v4068 = vsubq_s16(v3314, v3325); + int16x8_t v4069 = vsubq_s16(v3292, v3303); + int16x8_t v4070 = vsubq_s16(v3270, v3281); + int16x8_t v4071 = vsubq_s16(v3248, v3259); + int16x8_t v4072 = vsubq_s16(v3226, v3237); + int16x8_t v4073 = vsubq_s16(v3204, v3215); + int16x8_t v4074 = vsubq_s16(v3182, v3193); + int16x8_t v4075 = vsubq_s16(v3160, v3171); + int16x8_t v4076 = vsubq_s16(v3138, v3149); + int16x8_t v4077 = vsubq_s16(v3116, v3127); + int16x8_t v4078 = vsubq_s16(v3094, v3105); + int16x8_t v4079 = vsubq_s16(v3072, v3083); + int16x8_t v4080 = vsubq_s16(v3050, v3061); + int16x8_t v4081 = vsubq_s16(v3016, v3039); + int16x8_t v4082 = vsubq_s16(v2970, v2993); + int16x8_t v4083 = vsubq_s16(v2924, v2947); + int16x8_t v4084 = vsubq_s16(v2878, v2901); + int16x8_t v4085 = vsubq_s16(v2832, v2855); + int16x8_t v4086 = vsubq_s16(v2786, v2809); + int16x8_t v4087 = vsubq_s16(v2740, v2763); + int16x8_t v4088 = vsubq_s16(v2694, v2717); + int16x8_t v4089 = vsubq_s16(v2624, v2671); + int16x8_t v4090 = vsubq_s16(v2530, v2577); + int16x8_t v4091 = vsubq_s16(v2435, v2483); + int16x8_t v4092 = vsubq_s16(v2341, v2388); + int16x8_t v4093 = vsubq_s16(v2199, v2294); + int16x8_t v4094 = vsubq_s16(v2009, v2104); + int16x8_t v4095 = vsubq_s16(v1723, v1914); + int16x8_t v4096 = vsubq_s16(v701, v1532); + vst1q_s16(out + out_stride * 0 + i, v1533); + vst1q_s16(out + out_stride * 1 + i, v1915); + vst1q_s16(out + out_stride * 2 + i, v2105); + vst1q_s16(out + out_stride * 3 + i, v2295); + vst1q_s16(out + out_stride * 4 + i, v2389); + vst1q_s16(out + out_stride * 5 + i, v2484); + vst1q_s16(out + out_stride * 6 + i, v2578); + vst1q_s16(out + out_stride * 7 + i, v2672); + vst1q_s16(out + out_stride * 8 + i, v2718); + vst1q_s16(out + out_stride * 9 + i, v2764); + vst1q_s16(out + out_stride * 10 + i, v2810); + vst1q_s16(out + out_stride * 11 + i, v2856); + vst1q_s16(out + out_stride * 12 + i, v2902); + vst1q_s16(out + out_stride * 13 + i, v2948); + vst1q_s16(out + out_stride * 14 + i, v2994); + vst1q_s16(out + out_stride * 15 + i, v3040); + vst1q_s16(out + out_stride * 16 + i, v3062); + vst1q_s16(out + out_stride * 17 + i, v3084); + vst1q_s16(out + out_stride * 18 + i, v3106); + vst1q_s16(out + out_stride * 19 + i, v3128); + vst1q_s16(out + out_stride * 20 + i, v3150); + vst1q_s16(out + out_stride * 21 + i, v3172); + vst1q_s16(out + out_stride * 22 + i, v3194); + vst1q_s16(out + out_stride * 23 + i, v3216); + vst1q_s16(out + out_stride * 24 + i, v3238); + vst1q_s16(out + out_stride * 25 + i, v3260); + vst1q_s16(out + out_stride * 26 + i, v3282); + vst1q_s16(out + out_stride * 27 + i, v3304); + vst1q_s16(out + out_stride * 28 + i, v3326); + vst1q_s16(out + out_stride * 29 + i, v3348); + vst1q_s16(out + out_stride * 30 + i, v3370); + vst1q_s16(out + out_stride * 31 + i, v3392); + vst1q_s16(out + out_stride * 32 + i, v3402); + vst1q_s16(out + out_stride * 33 + i, v3412); + vst1q_s16(out + out_stride * 34 + i, v3422); + vst1q_s16(out + out_stride * 35 + i, v3432); + vst1q_s16(out + out_stride * 36 + i, v3442); + vst1q_s16(out + out_stride * 37 + i, v3452); + vst1q_s16(out + out_stride * 38 + i, v3462); + vst1q_s16(out + out_stride * 39 + i, v3472); + vst1q_s16(out + out_stride * 40 + i, v3482); + vst1q_s16(out + out_stride * 41 + i, v3492); + vst1q_s16(out + out_stride * 42 + i, v3502); + vst1q_s16(out + out_stride * 43 + i, v3512); + vst1q_s16(out + out_stride * 44 + i, v3522); + vst1q_s16(out + out_stride * 45 + i, v3532); + vst1q_s16(out + out_stride * 46 + i, v3542); + vst1q_s16(out + out_stride * 47 + i, v3552); + vst1q_s16(out + out_stride * 48 + i, v3562); + vst1q_s16(out + out_stride * 49 + i, v3572); + vst1q_s16(out + out_stride * 50 + i, v3582); + vst1q_s16(out + out_stride * 51 + i, v3592); + vst1q_s16(out + out_stride * 52 + i, v3602); + vst1q_s16(out + out_stride * 53 + i, v3612); + vst1q_s16(out + out_stride * 54 + i, v3622); + vst1q_s16(out + out_stride * 55 + i, v3632); + vst1q_s16(out + out_stride * 56 + i, v3642); + vst1q_s16(out + out_stride * 57 + i, v3652); + vst1q_s16(out + out_stride * 58 + i, v3662); + vst1q_s16(out + out_stride * 59 + i, v3672); + vst1q_s16(out + out_stride * 60 + i, v3682); + vst1q_s16(out + out_stride * 61 + i, v3692); + vst1q_s16(out + out_stride * 62 + i, v3702); + vst1q_s16(out + out_stride * 63 + i, v3712); + vst1q_s16(out + out_stride * 64 + i, v3716); + vst1q_s16(out + out_stride * 65 + i, v3720); + vst1q_s16(out + out_stride * 66 + i, v3724); + vst1q_s16(out + out_stride * 67 + i, v3728); + vst1q_s16(out + out_stride * 68 + i, v3732); + vst1q_s16(out + out_stride * 69 + i, v3736); + vst1q_s16(out + out_stride * 70 + i, v3740); + vst1q_s16(out + out_stride * 71 + i, v3744); + vst1q_s16(out + out_stride * 72 + i, v3748); + vst1q_s16(out + out_stride * 73 + i, v3752); + vst1q_s16(out + out_stride * 74 + i, v3756); + vst1q_s16(out + out_stride * 75 + i, v3760); + vst1q_s16(out + out_stride * 76 + i, v3764); + vst1q_s16(out + out_stride * 77 + i, v3768); + vst1q_s16(out + out_stride * 78 + i, v3772); + vst1q_s16(out + out_stride * 79 + i, v3776); + vst1q_s16(out + out_stride * 80 + i, v3780); + vst1q_s16(out + out_stride * 81 + i, v3784); + vst1q_s16(out + out_stride * 82 + i, v3788); + vst1q_s16(out + out_stride * 83 + i, v3792); + vst1q_s16(out + out_stride * 84 + i, v3796); + vst1q_s16(out + out_stride * 85 + i, v3800); + vst1q_s16(out + out_stride * 86 + i, v3804); + vst1q_s16(out + out_stride * 87 + i, v3808); + vst1q_s16(out + out_stride * 88 + i, v3812); + vst1q_s16(out + out_stride * 89 + i, v3816); + vst1q_s16(out + out_stride * 90 + i, v3820); + vst1q_s16(out + out_stride * 91 + i, v3824); + vst1q_s16(out + out_stride * 92 + i, v3828); + vst1q_s16(out + out_stride * 93 + i, v3832); + vst1q_s16(out + out_stride * 94 + i, v3836); + vst1q_s16(out + out_stride * 95 + i, v3840); + vst1q_s16(out + out_stride * 96 + i, v3844); + vst1q_s16(out + out_stride * 97 + i, v3848); + vst1q_s16(out + out_stride * 98 + i, v3852); + vst1q_s16(out + out_stride * 99 + i, v3856); + vst1q_s16(out + out_stride * 100 + i, v3860); + vst1q_s16(out + out_stride * 101 + i, v3864); + vst1q_s16(out + out_stride * 102 + i, v3868); + vst1q_s16(out + out_stride * 103 + i, v3872); + vst1q_s16(out + out_stride * 104 + i, v3876); + vst1q_s16(out + out_stride * 105 + i, v3880); + vst1q_s16(out + out_stride * 106 + i, v3884); + vst1q_s16(out + out_stride * 107 + i, v3888); + vst1q_s16(out + out_stride * 108 + i, v3892); + vst1q_s16(out + out_stride * 109 + i, v3896); + vst1q_s16(out + out_stride * 110 + i, v3900); + vst1q_s16(out + out_stride * 111 + i, v3904); + vst1q_s16(out + out_stride * 112 + i, v3908); + vst1q_s16(out + out_stride * 113 + i, v3912); + vst1q_s16(out + out_stride * 114 + i, v3916); + vst1q_s16(out + out_stride * 115 + i, v3920); + vst1q_s16(out + out_stride * 116 + i, v3924); + vst1q_s16(out + out_stride * 117 + i, v3928); + vst1q_s16(out + out_stride * 118 + i, v3932); + vst1q_s16(out + out_stride * 119 + i, v3936); + vst1q_s16(out + out_stride * 120 + i, v3940); + vst1q_s16(out + out_stride * 121 + i, v3944); + vst1q_s16(out + out_stride * 122 + i, v3948); + vst1q_s16(out + out_stride * 123 + i, v3952); + vst1q_s16(out + out_stride * 124 + i, v3956); + vst1q_s16(out + out_stride * 125 + i, v3960); + vst1q_s16(out + out_stride * 126 + i, v3964); + vst1q_s16(out + out_stride * 127 + i, v3968); + vst1q_s16(out + out_stride * 128 + i, v3969); + vst1q_s16(out + out_stride * 129 + i, v3970); + vst1q_s16(out + out_stride * 130 + i, v3971); + vst1q_s16(out + out_stride * 131 + i, v3972); + vst1q_s16(out + out_stride * 132 + i, v3973); + vst1q_s16(out + out_stride * 133 + i, v3974); + vst1q_s16(out + out_stride * 134 + i, v3975); + vst1q_s16(out + out_stride * 135 + i, v3976); + vst1q_s16(out + out_stride * 136 + i, v3977); + vst1q_s16(out + out_stride * 137 + i, v3978); + vst1q_s16(out + out_stride * 138 + i, v3979); + vst1q_s16(out + out_stride * 139 + i, v3980); + vst1q_s16(out + out_stride * 140 + i, v3981); + vst1q_s16(out + out_stride * 141 + i, v3982); + vst1q_s16(out + out_stride * 142 + i, v3983); + vst1q_s16(out + out_stride * 143 + i, v3984); + vst1q_s16(out + out_stride * 144 + i, v3985); + vst1q_s16(out + out_stride * 145 + i, v3986); + vst1q_s16(out + out_stride * 146 + i, v3987); + vst1q_s16(out + out_stride * 147 + i, v3988); + vst1q_s16(out + out_stride * 148 + i, v3989); + vst1q_s16(out + out_stride * 149 + i, v3990); + vst1q_s16(out + out_stride * 150 + i, v3991); + vst1q_s16(out + out_stride * 151 + i, v3992); + vst1q_s16(out + out_stride * 152 + i, v3993); + vst1q_s16(out + out_stride * 153 + i, v3994); + vst1q_s16(out + out_stride * 154 + i, v3995); + vst1q_s16(out + out_stride * 155 + i, v3996); + vst1q_s16(out + out_stride * 156 + i, v3997); + vst1q_s16(out + out_stride * 157 + i, v3998); + vst1q_s16(out + out_stride * 158 + i, v3999); + vst1q_s16(out + out_stride * 159 + i, v4000); + vst1q_s16(out + out_stride * 160 + i, v4001); + vst1q_s16(out + out_stride * 161 + i, v4002); + vst1q_s16(out + out_stride * 162 + i, v4003); + vst1q_s16(out + out_stride * 163 + i, v4004); + vst1q_s16(out + out_stride * 164 + i, v4005); + vst1q_s16(out + out_stride * 165 + i, v4006); + vst1q_s16(out + out_stride * 166 + i, v4007); + vst1q_s16(out + out_stride * 167 + i, v4008); + vst1q_s16(out + out_stride * 168 + i, v4009); + vst1q_s16(out + out_stride * 169 + i, v4010); + vst1q_s16(out + out_stride * 170 + i, v4011); + vst1q_s16(out + out_stride * 171 + i, v4012); + vst1q_s16(out + out_stride * 172 + i, v4013); + vst1q_s16(out + out_stride * 173 + i, v4014); + vst1q_s16(out + out_stride * 174 + i, v4015); + vst1q_s16(out + out_stride * 175 + i, v4016); + vst1q_s16(out + out_stride * 176 + i, v4017); + vst1q_s16(out + out_stride * 177 + i, v4018); + vst1q_s16(out + out_stride * 178 + i, v4019); + vst1q_s16(out + out_stride * 179 + i, v4020); + vst1q_s16(out + out_stride * 180 + i, v4021); + vst1q_s16(out + out_stride * 181 + i, v4022); + vst1q_s16(out + out_stride * 182 + i, v4023); + vst1q_s16(out + out_stride * 183 + i, v4024); + vst1q_s16(out + out_stride * 184 + i, v4025); + vst1q_s16(out + out_stride * 185 + i, v4026); + vst1q_s16(out + out_stride * 186 + i, v4027); + vst1q_s16(out + out_stride * 187 + i, v4028); + vst1q_s16(out + out_stride * 188 + i, v4029); + vst1q_s16(out + out_stride * 189 + i, v4030); + vst1q_s16(out + out_stride * 190 + i, v4031); + vst1q_s16(out + out_stride * 191 + i, v4032); + vst1q_s16(out + out_stride * 192 + i, v4033); + vst1q_s16(out + out_stride * 193 + i, v4034); + vst1q_s16(out + out_stride * 194 + i, v4035); + vst1q_s16(out + out_stride * 195 + i, v4036); + vst1q_s16(out + out_stride * 196 + i, v4037); + vst1q_s16(out + out_stride * 197 + i, v4038); + vst1q_s16(out + out_stride * 198 + i, v4039); + vst1q_s16(out + out_stride * 199 + i, v4040); + vst1q_s16(out + out_stride * 200 + i, v4041); + vst1q_s16(out + out_stride * 201 + i, v4042); + vst1q_s16(out + out_stride * 202 + i, v4043); + vst1q_s16(out + out_stride * 203 + i, v4044); + vst1q_s16(out + out_stride * 204 + i, v4045); + vst1q_s16(out + out_stride * 205 + i, v4046); + vst1q_s16(out + out_stride * 206 + i, v4047); + vst1q_s16(out + out_stride * 207 + i, v4048); + vst1q_s16(out + out_stride * 208 + i, v4049); + vst1q_s16(out + out_stride * 209 + i, v4050); + vst1q_s16(out + out_stride * 210 + i, v4051); + vst1q_s16(out + out_stride * 211 + i, v4052); + vst1q_s16(out + out_stride * 212 + i, v4053); + vst1q_s16(out + out_stride * 213 + i, v4054); + vst1q_s16(out + out_stride * 214 + i, v4055); + vst1q_s16(out + out_stride * 215 + i, v4056); + vst1q_s16(out + out_stride * 216 + i, v4057); + vst1q_s16(out + out_stride * 217 + i, v4058); + vst1q_s16(out + out_stride * 218 + i, v4059); + vst1q_s16(out + out_stride * 219 + i, v4060); + vst1q_s16(out + out_stride * 220 + i, v4061); + vst1q_s16(out + out_stride * 221 + i, v4062); + vst1q_s16(out + out_stride * 222 + i, v4063); + vst1q_s16(out + out_stride * 223 + i, v4064); + vst1q_s16(out + out_stride * 224 + i, v4065); + vst1q_s16(out + out_stride * 225 + i, v4066); + vst1q_s16(out + out_stride * 226 + i, v4067); + vst1q_s16(out + out_stride * 227 + i, v4068); + vst1q_s16(out + out_stride * 228 + i, v4069); + vst1q_s16(out + out_stride * 229 + i, v4070); + vst1q_s16(out + out_stride * 230 + i, v4071); + vst1q_s16(out + out_stride * 231 + i, v4072); + vst1q_s16(out + out_stride * 232 + i, v4073); + vst1q_s16(out + out_stride * 233 + i, v4074); + vst1q_s16(out + out_stride * 234 + i, v4075); + vst1q_s16(out + out_stride * 235 + i, v4076); + vst1q_s16(out + out_stride * 236 + i, v4077); + vst1q_s16(out + out_stride * 237 + i, v4078); + vst1q_s16(out + out_stride * 238 + i, v4079); + vst1q_s16(out + out_stride * 239 + i, v4080); + vst1q_s16(out + out_stride * 240 + i, v4081); + vst1q_s16(out + out_stride * 241 + i, v4082); + vst1q_s16(out + out_stride * 242 + i, v4083); + vst1q_s16(out + out_stride * 243 + i, v4084); + vst1q_s16(out + out_stride * 244 + i, v4085); + vst1q_s16(out + out_stride * 245 + i, v4086); + vst1q_s16(out + out_stride * 246 + i, v4087); + vst1q_s16(out + out_stride * 247 + i, v4088); + vst1q_s16(out + out_stride * 248 + i, v4089); + vst1q_s16(out + out_stride * 249 + i, v4090); + vst1q_s16(out + out_stride * 250 + i, v4091); + vst1q_s16(out + out_stride * 251 + i, v4092); + vst1q_s16(out + out_stride * 252 + i, v4093); + vst1q_s16(out + out_stride * 253 + i, v4094); + vst1q_s16(out + out_stride * 254 + i, v4095); + vst1q_s16(out + out_stride * 255 + i, v4096); + } +} diff --git a/media/libjxl/src/lib/jxl/fast_dct32-inl.h b/media/libjxl/src/lib/jxl/fast_dct32-inl.h new file mode 100644 index 000000000..0f3b31cfe --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct32-inl.h @@ -0,0 +1,419 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<32>) { return 1; } + +void FastIDCT(FastDCTTag<32>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v17 = vaddq_s16(v16, v12); + int16x8_t v18 = vaddq_s16(v13, v10); + int16x8_t v19 = vaddq_s16(v17, v18); + int16x8_t v20 = vqrdmulhq_n_s16(v19, 17734); + int16x8_t v21 = vqrdmulhq_n_s16(v18, 25080); + int16x8_t v22 = vaddq_s16(v20, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35 = vqrdmulhq_n_s16(v34, 25080); + int16x8_t v36 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vqrdmulhq_n_s16(v39, 17734); + int16x8_t v41 = vaddq_s16(v35, v40); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v29, v32); + int16x8_t v46 = vaddq_s16(v37, v28); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vqrdmulhq_n_s16(v48, 16705); + int16x8_t v50 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v51 = vaddq_s16(v50, v36); + int16x8_t v52 = vaddq_s16(v51, v46); + int16x8_t v53 = vqrdmulhq_n_s16(v52, 17734); + int16x8_t v54 = vaddq_s16(v45, v43); + int16x8_t v55_tmp = vqrdmulhq_n_s16(v54, 10045); + int16x8_t v55 = vaddq_s16(v55_tmp, v54); + int16x8_t v56 = vaddq_s16(v53, v55); + int16x8_t v57 = vqrdmulhq_n_s16(v56, 16705); + int16x8_t v58 = vaddq_s16(v49, v57); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v63 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v64 = vaddq_s16(v62, v63); + int16x8_t v65 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v66 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v67 = vaddq_s16(v65, v66); + int16x8_t v68 = vaddq_s16(v64, v67); + int16x8_t v69_tmp = vqrdmulhq_n_s16(v68, 10045); + int16x8_t v69 = vaddq_s16(v69_tmp, v68); + int16x8_t v70 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v71 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v72 = vaddq_s16(v70, v71); + int16x8_t v73 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v74 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v75 = vaddq_s16(v73, v74); + int16x8_t v76 = vaddq_s16(v72, v75); + int16x8_t v77 = vqrdmulhq_n_s16(v76, 17734); + int16x8_t v78 = vaddq_s16(v69, v77); + int16x8_t v79 = vqrdmulhq_n_s16(v78, 16705); + int16x8_t v80_tmp = vqrdmulhq_n_s16(v67, 13573); + int16x8_t v80 = vaddq_s16(v80_tmp, v67); + int16x8_t v81 = vaddq_s16(v64, v72); + int16x8_t v82 = vaddq_s16(v80, v81); + int16x8_t v83 = vqrdmulhq_n_s16(v82, 16705); + int16x8_t v84 = vaddq_s16(v79, v83); + int16x8_t v85 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v86_tmp = vqrdmulhq_n_s16(v85, 13573); + int16x8_t v86 = vaddq_s16(v86_tmp, v85); + int16x8_t v87 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v88 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v89 = vaddq_s16(v87, v88); + int16x8_t v90 = vaddq_s16(v86, v89); + int16x8_t v91 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v92 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v93 = vaddq_s16(v91, v92); + int16x8_t v94 = vqrdmulhq_n_s16(v93, 25080); + int16x8_t v95 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v96 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v97 = vaddq_s16(v95, v96); + int16x8_t v98 = vaddq_s16(v97, v93); + int16x8_t v99 = vqrdmulhq_n_s16(v98, 17734); + int16x8_t v100 = vaddq_s16(v94, v99); + int16x8_t v101 = vaddq_s16(v90, v100); + int16x8_t v102 = vaddq_s16(v84, v101); + int16x8_t v103 = vaddq_s16(v92, v65); + int16x8_t v104 = vaddq_s16(v66, v85); + int16x8_t v105 = vaddq_s16(v103, v104); + int16x8_t v106_tmp = vqrdmulhq_n_s16(v105, 13573); + int16x8_t v106 = vaddq_s16(v106_tmp, v105); + int16x8_t v107 = vaddq_s16(v96, v70); + int16x8_t v108 = vaddq_s16(v71, v87); + int16x8_t v109 = vaddq_s16(v107, v108); + int16x8_t v110 = vaddq_s16(v63, v91); + int16x8_t v111 = vaddq_s16(v88, v62); + int16x8_t v112 = vaddq_s16(v110, v111); + int16x8_t v113 = vaddq_s16(v109, v112); + int16x8_t v114 = vaddq_s16(v106, v113); + int16x8_t v115 = vqrdmulhq_n_s16(v114, 16705); + int16x8_t v116 = vaddq_s16(v112, v105); + int16x8_t v117 = vqrdmulhq_n_s16(v116, 25080); + int16x8_t v118 = vqrdmulhq_n_s16(v116, 17734); + int16x8_t v119 = vaddq_s16(v74, v95); + int16x8_t v120 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v121 = vaddq_s16(v120, v73); + int16x8_t v122 = vaddq_s16(v119, v121); + int16x8_t v123 = vaddq_s16(v122, v109); + int16x8_t v124 = vqrdmulhq_n_s16(v123, 17734); + int16x8_t v125 = vaddq_s16(v118, v124); + int16x8_t v126 = vaddq_s16(v117, v125); + int16x8_t v127 = vqrdmulhq_n_s16(v126, 16705); + int16x8_t v128 = vaddq_s16(v115, v127); + int16x8_t v129 = vqrdmulhq_n_s16(v128, 16463); + int16x8_t v130_tmp = vqrdmulhq_n_s16(v104, 13573); + int16x8_t v130 = vaddq_s16(v130_tmp, v104); + int16x8_t v131 = vaddq_s16(v108, v111); + int16x8_t v132 = vaddq_s16(v130, v131); + int16x8_t v133 = vaddq_s16(v119, v107); + int16x8_t v134 = vqrdmulhq_n_s16(v133, 17734); + int16x8_t v135 = vaddq_s16(v110, v103); + int16x8_t v136_tmp = vqrdmulhq_n_s16(v135, 10045); + int16x8_t v136 = vaddq_s16(v136_tmp, v135); + int16x8_t v137 = vaddq_s16(v134, v136); + int16x8_t v138 = vaddq_s16(v132, v137); + int16x8_t v139 = vqrdmulhq_n_s16(v138, 16463); + int16x8_t v140 = vaddq_s16(v129, v139); + int16x8_t v141 = vaddq_s16(v102, v140); + int16x8_t v142 = vqrdmulhq_n_s16(v141, 16404); + int16x8_t v143 = vaddq_s16(v61, v142); + int16x8_t v144 = vsubq_s16(v0, v1); + int16x8_t v145 = vsubq_s16(v4, v6); + int16x8_t v146_tmp = vqrdmulhq_n_s16(v145, 10045); + int16x8_t v146 = vaddq_s16(v146_tmp, v145); + int16x8_t v147 = vaddq_s16(v144, v146); + int16x8_t v148 = vsubq_s16(v11, v14); + int16x8_t v149 = vqrdmulhq_n_s16(v18, 17734); + int16x8_t v150_tmp = vqrdmulhq_n_s16(v17, 10045); + int16x8_t v150 = vaddq_s16(v150_tmp, v17); + int16x8_t v151 = vsubq_s16(v149, v150); + int16x8_t v152 = vaddq_s16(v148, v151); + int16x8_t v153 = vqrdmulhq_n_s16(v152, 19705); + int16x8_t v154 = vaddq_s16(v147, v153); + int16x8_t v155 = vsubq_s16(v27, v30); + int16x8_t v156 = vqrdmulhq_n_s16(v34, 17734); + int16x8_t v157_tmp = vqrdmulhq_n_s16(v38, 10045); + int16x8_t v157 = vaddq_s16(v157_tmp, v38); + int16x8_t v158 = vsubq_s16(v156, v157); + int16x8_t v159 = vaddq_s16(v155, v158); + int16x8_t v160 = vqrdmulhq_n_s16(v54, 13573); + int16x8_t v161 = vsubq_s16(v160, v52); + int16x8_t v162 = vqrdmulhq_n_s16(v161, 25746); + int16x8_t v163 = vsubq_s16(v44, v47); + int16x8_t v164 = vqrdmulhq_n_s16(v163, 19705); + int16x8_t v165 = vaddq_s16(v162, v164); + int16x8_t v166 = vaddq_s16(v159, v165); + int16x8_t v167 = vqrdmulhq_n_s16(v166, 17121); + int16x8_t v168 = vaddq_s16(v154, v167); + int16x8_t v169 = vsubq_s16(v86, v89); + int16x8_t v170 = vqrdmulhq_n_s16(v93, 17734); + int16x8_t v171_tmp = vqrdmulhq_n_s16(v97, 10045); + int16x8_t v171 = vaddq_s16(v171_tmp, v97); + int16x8_t v172 = vsubq_s16(v170, v171); + int16x8_t v173 = vaddq_s16(v169, v172); + int16x8_t v174 = vsubq_s16(v80, v81); + int16x8_t v175 = vqrdmulhq_n_s16(v174, 19705); + int16x8_t v176 = vqrdmulhq_n_s16(v68, 13573); + int16x8_t v177 = vsubq_s16(v176, v76); + int16x8_t v178 = vqrdmulhq_n_s16(v177, 25746); + int16x8_t v179 = vaddq_s16(v175, v178); + int16x8_t v180 = vaddq_s16(v173, v179); + int16x8_t v181 = vsubq_s16(v130, v131); + int16x8_t v182 = vqrdmulhq_n_s16(v135, 13573); + int16x8_t v183 = vsubq_s16(v182, v133); + int16x8_t v184_tmp = vqrdmulhq_n_s16(v183, 10045); + int16x8_t v184 = vaddq_s16(v184_tmp, v183); + int16x8_t v185 = vaddq_s16(v181, v184); + int16x8_t v186 = vqrdmulhq_n_s16(v185, 17121); + int16x8_t v187 = vqrdmulhq_n_s16(v105, 27867); + int16x8_t v188 = vqrdmulhq_n_s16(v113, 19705); + int16x8_t v189 = vsubq_s16(v187, v188); + int16x8_t v190 = vqrdmulhq_n_s16(v116, 13573); + int16x8_t v191 = vsubq_s16(v190, v123); + int16x8_t v192 = vqrdmulhq_n_s16(v191, 25746); + int16x8_t v193 = vaddq_s16(v189, v192); + int16x8_t v194 = vqrdmulhq_n_s16(v193, 17121); + int16x8_t v195 = vaddq_s16(v186, v194); + int16x8_t v196 = vaddq_s16(v180, v195); + int16x8_t v197 = vqrdmulhq_n_s16(v196, 16563); + int16x8_t v198 = vaddq_s16(v168, v197); + int16x8_t v199 = vsubq_s16(v144, v146); + int16x8_t v200 = vsubq_s16(v148, v151); + int16x8_t v201 = vqrdmulhq_n_s16(v200, 29490); + int16x8_t v202 = vaddq_s16(v199, v201); + int16x8_t v203 = vsubq_s16(v155, v158); + int16x8_t v204 = vqrdmulhq_n_s16(v163, 29490); + int16x8_t v205_tmp = vqrdmulhq_n_s16(v161, 5763); + int16x8_t v205 = vaddq_s16(v205_tmp, v161); + int16x8_t v206 = vsubq_s16(v204, v205); + int16x8_t v207 = vaddq_s16(v203, v206); + int16x8_t v208 = vqrdmulhq_n_s16(v207, 18578); + int16x8_t v209 = vaddq_s16(v202, v208); + int16x8_t v210 = vsubq_s16(v169, v172); + int16x8_t v211 = vqrdmulhq_n_s16(v174, 29490); + int16x8_t v212_tmp = vqrdmulhq_n_s16(v177, 5763); + int16x8_t v212 = vaddq_s16(v212_tmp, v177); + int16x8_t v213 = vsubq_s16(v211, v212); + int16x8_t v214 = vaddq_s16(v210, v213); + int16x8_t v215 = vsubq_s16(v181, v184); + int16x8_t v216 = vqrdmulhq_n_s16(v215, 18578); + int16x8_t v217 = vqrdmulhq_n_s16(v189, 27803); + int16x8_t v218 = vqrdmulhq_n_s16(v191, 21845); + int16x8_t v219 = vsubq_s16(v217, v218); + int16x8_t v220 = vaddq_s16(v216, v219); + int16x8_t v221 = vaddq_s16(v214, v220); + int16x8_t v222 = vqrdmulhq_n_s16(v221, 16890); + int16x8_t v223 = vaddq_s16(v209, v222); + int16x8_t v224 = vsubq_s16(v2, v8); + int16x8_t v225 = vsubq_s16(v15, v22); + int16x8_t v226_tmp = vqrdmulhq_n_s16(v225, 18446); + int16x8_t v226 = vmlaq_n_s16(v226_tmp, v225, 2); + int16x8_t v227 = vaddq_s16(v224, v226); + int16x8_t v228 = vsubq_s16(v31, v41); + int16x8_t v229 = vsubq_s16(v48, v56); + int16x8_t v230_tmp = vqrdmulhq_n_s16(v229, 18446); + int16x8_t v230 = vmlaq_n_s16(v230_tmp, v229, 2); + int16x8_t v231 = vaddq_s16(v228, v230); + int16x8_t v232 = vqrdmulhq_n_s16(v231, 21195); + int16x8_t v233 = vaddq_s16(v227, v232); + int16x8_t v234 = vsubq_s16(v82, v78); + int16x8_t v235_tmp = vqrdmulhq_n_s16(v234, 18446); + int16x8_t v235 = vmlaq_n_s16(v235_tmp, v234, 2); + int16x8_t v236 = vsubq_s16(v90, v100); + int16x8_t v237 = vaddq_s16(v235, v236); + int16x8_t v238 = vsubq_s16(v132, v137); + int16x8_t v239 = vsubq_s16(v114, v126); + int16x8_t v240_tmp = vqrdmulhq_n_s16(v239, 18446); + int16x8_t v240 = vmlaq_n_s16(v240_tmp, v239, 2); + int16x8_t v241 = vaddq_s16(v238, v240); + int16x8_t v242 = vqrdmulhq_n_s16(v241, 21195); + int16x8_t v243 = vaddq_s16(v237, v242); + int16x8_t v244 = vqrdmulhq_n_s16(v243, 17401); + int16x8_t v245 = vaddq_s16(v233, v244); + int16x8_t v246 = vsubq_s16(v228, v230); + int16x8_t v247 = vqrdmulhq_n_s16(v246, 25826); + int16x8_t v248 = vsubq_s16(v224, v226); + int16x8_t v249 = vaddq_s16(v247, v248); + int16x8_t v250 = vsubq_s16(v238, v240); + int16x8_t v251 = vqrdmulhq_n_s16(v250, 25826); + int16x8_t v252 = vsubq_s16(v236, v235); + int16x8_t v253 = vaddq_s16(v251, v252); + int16x8_t v254 = vqrdmulhq_n_s16(v253, 18124); + int16x8_t v255 = vaddq_s16(v249, v254); + int16x8_t v256 = vsubq_s16(v199, v201); + int16x8_t v257 = vsubq_s16(v203, v206); + int16x8_t v258_tmp = vqrdmulhq_n_s16(v257, 1988); + int16x8_t v258 = vaddq_s16(v258_tmp, v257); + int16x8_t v259 = vaddq_s16(v256, v258); + int16x8_t v260 = vsubq_s16(v210, v213); + int16x8_t v261_tmp = vqrdmulhq_n_s16(v219, 25030); + int16x8_t v261 = vaddq_s16(v261_tmp, v219); + int16x8_t v262 = vsubq_s16(v215, v261); + int16x8_t v263_tmp = vqrdmulhq_n_s16(v262, 1988); + int16x8_t v263 = vaddq_s16(v263_tmp, v262); + int16x8_t v264 = vaddq_s16(v260, v263); + int16x8_t v265 = vqrdmulhq_n_s16(v264, 19102); + int16x8_t v266 = vaddq_s16(v259, v265); + int16x8_t v267 = vsubq_s16(v147, v153); + int16x8_t v268 = vsubq_s16(v159, v165); + int16x8_t v269_tmp = vqrdmulhq_n_s16(v268, 23673); + int16x8_t v269 = vaddq_s16(v269_tmp, v268); + int16x8_t v270 = vaddq_s16(v267, v269); + int16x8_t v271 = vsubq_s16(v173, v179); + int16x8_t v272 = vsubq_s16(v185, v193); + int16x8_t v273_tmp = vqrdmulhq_n_s16(v272, 23673); + int16x8_t v273 = vaddq_s16(v273_tmp, v272); + int16x8_t v274 = vaddq_s16(v271, v273); + int16x8_t v275 = vqrdmulhq_n_s16(v274, 20398); + int16x8_t v276 = vaddq_s16(v270, v275); + int16x8_t v277 = vsubq_s16(v9, v24); + int16x8_t v278 = vsubq_s16(v42, v58); + int16x8_t v279_tmp = vqrdmulhq_n_s16(v278, 3314); + int16x8_t v279 = vmlaq_n_s16(v279_tmp, v278, 5); + int16x8_t v280 = vaddq_s16(v277, v279); + int16x8_t v281 = vsubq_s16(v138, v128); + int16x8_t v282_tmp = vqrdmulhq_n_s16(v281, 3314); + int16x8_t v282 = vmlaq_n_s16(v282_tmp, v281, 5); + int16x8_t v283 = vsubq_s16(v101, v84); + int16x8_t v284 = vaddq_s16(v282, v283); + int16x8_t v285 = vqrdmulhq_n_s16(v284, 22112); + int16x8_t v286 = vaddq_s16(v280, v285); + int16x8_t v287 = vsubq_s16(v277, v279); + int16x8_t v288 = vsubq_s16(v283, v282); + int16x8_t v289 = vqrdmulhq_n_s16(v288, 24397); + int16x8_t v290 = vaddq_s16(v287, v289); + int16x8_t v291 = vsubq_s16(v267, v269); + int16x8_t v292 = vsubq_s16(v271, v273); + int16x8_t v293 = vqrdmulhq_n_s16(v292, 27504); + int16x8_t v294 = vaddq_s16(v291, v293); + int16x8_t v295 = vsubq_s16(v260, v263); + int16x8_t v296 = vqrdmulhq_n_s16(v295, 31869); + int16x8_t v297 = vsubq_s16(v256, v258); + int16x8_t v298 = vaddq_s16(v296, v297); + int16x8_t v299 = vsubq_s16(v248, v247); + int16x8_t v300 = vsubq_s16(v252, v251); + int16x8_t v301_tmp = vqrdmulhq_n_s16(v300, 5552); + int16x8_t v301 = vaddq_s16(v301_tmp, v300); + int16x8_t v302 = vaddq_s16(v299, v301); + int16x8_t v303 = vsubq_s16(v227, v232); + int16x8_t v304 = vsubq_s16(v237, v242); + int16x8_t v305_tmp = vqrdmulhq_n_s16(v304, 15865); + int16x8_t v305 = vaddq_s16(v305_tmp, v304); + int16x8_t v306 = vaddq_s16(v303, v305); + int16x8_t v307 = vsubq_s16(v202, v208); + int16x8_t v308 = vsubq_s16(v214, v220); + int16x8_t v309_tmp = vqrdmulhq_n_s16(v308, 1893); + int16x8_t v309 = vmlaq_n_s16(v309_tmp, v308, 2); + int16x8_t v310 = vaddq_s16(v307, v309); + int16x8_t v311 = vsubq_s16(v154, v167); + int16x8_t v312 = vsubq_s16(v180, v195); + int16x8_t v313_tmp = vqrdmulhq_n_s16(v312, 13357); + int16x8_t v313 = vmlaq_n_s16(v313_tmp, v312, 3); + int16x8_t v314 = vaddq_s16(v311, v313); + int16x8_t v315 = vsubq_s16(v102, v140); + int16x8_t v316_tmp = vqrdmulhq_n_s16(v315, 6226); + int16x8_t v316 = vmlaq_n_s16(v316_tmp, v315, 10); + int16x8_t v317 = vsubq_s16(v25, v60); + int16x8_t v318 = vaddq_s16(v316, v317); + int16x8_t v319 = vsubq_s16(v317, v316); + int16x8_t v320 = vsubq_s16(v311, v313); + int16x8_t v321 = vsubq_s16(v307, v309); + int16x8_t v322 = vsubq_s16(v303, v305); + int16x8_t v323 = vsubq_s16(v299, v301); + int16x8_t v324 = vsubq_s16(v297, v296); + int16x8_t v325 = vsubq_s16(v291, v293); + int16x8_t v326 = vsubq_s16(v287, v289); + int16x8_t v327 = vsubq_s16(v280, v285); + int16x8_t v328 = vsubq_s16(v270, v275); + int16x8_t v329 = vsubq_s16(v259, v265); + int16x8_t v330 = vsubq_s16(v249, v254); + int16x8_t v331 = vsubq_s16(v233, v244); + int16x8_t v332 = vsubq_s16(v209, v222); + int16x8_t v333 = vsubq_s16(v168, v197); + int16x8_t v334 = vsubq_s16(v61, v142); + vst1q_s16(out + out_stride * 0 + i, v143); + vst1q_s16(out + out_stride * 1 + i, v198); + vst1q_s16(out + out_stride * 2 + i, v223); + vst1q_s16(out + out_stride * 3 + i, v245); + vst1q_s16(out + out_stride * 4 + i, v255); + vst1q_s16(out + out_stride * 5 + i, v266); + vst1q_s16(out + out_stride * 6 + i, v276); + vst1q_s16(out + out_stride * 7 + i, v286); + vst1q_s16(out + out_stride * 8 + i, v290); + vst1q_s16(out + out_stride * 9 + i, v294); + vst1q_s16(out + out_stride * 10 + i, v298); + vst1q_s16(out + out_stride * 11 + i, v302); + vst1q_s16(out + out_stride * 12 + i, v306); + vst1q_s16(out + out_stride * 13 + i, v310); + vst1q_s16(out + out_stride * 14 + i, v314); + vst1q_s16(out + out_stride * 15 + i, v318); + vst1q_s16(out + out_stride * 16 + i, v319); + vst1q_s16(out + out_stride * 17 + i, v320); + vst1q_s16(out + out_stride * 18 + i, v321); + vst1q_s16(out + out_stride * 19 + i, v322); + vst1q_s16(out + out_stride * 20 + i, v323); + vst1q_s16(out + out_stride * 21 + i, v324); + vst1q_s16(out + out_stride * 22 + i, v325); + vst1q_s16(out + out_stride * 23 + i, v326); + vst1q_s16(out + out_stride * 24 + i, v327); + vst1q_s16(out + out_stride * 25 + i, v328); + vst1q_s16(out + out_stride * 26 + i, v329); + vst1q_s16(out + out_stride * 27 + i, v330); + vst1q_s16(out + out_stride * 28 + i, v331); + vst1q_s16(out + out_stride * 29 + i, v332); + vst1q_s16(out + out_stride * 30 + i, v333); + vst1q_s16(out + out_stride * 31 + i, v334); + } +} diff --git a/media/libjxl/src/lib/jxl/fast_dct64-inl.h b/media/libjxl/src/lib/jxl/fast_dct64-inl.h new file mode 100644 index 000000000..400da1a9d --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct64-inl.h @@ -0,0 +1,985 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<64>) { return 1; } + +void FastIDCT(FastDCTTag<64>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 32 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 48 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 40 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vld1q_s16(in + in_stride * 56 + i); + int16x8_t v17 = vaddq_s16(v16, v12); + int16x8_t v18 = vaddq_s16(v13, v10); + int16x8_t v19 = vaddq_s16(v17, v18); + int16x8_t v20 = vqrdmulhq_n_s16(v19, 17734); + int16x8_t v21 = vqrdmulhq_n_s16(v18, 25080); + int16x8_t v22 = vaddq_s16(v20, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 36 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35 = vqrdmulhq_n_s16(v34, 25080); + int16x8_t v36 = vld1q_s16(in + in_stride * 52 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 44 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vqrdmulhq_n_s16(v39, 17734); + int16x8_t v41 = vaddq_s16(v35, v40); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v37, v28); + int16x8_t v46 = vaddq_s16(v29, v32); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vqrdmulhq_n_s16(v48, 16705); + int16x8_t v50 = vaddq_s16(v46, v43); + int16x8_t v51_tmp = vqrdmulhq_n_s16(v50, 10045); + int16x8_t v51 = vaddq_s16(v51_tmp, v50); + int16x8_t v52 = vld1q_s16(in + in_stride * 60 + i); + int16x8_t v53 = vaddq_s16(v52, v36); + int16x8_t v54 = vaddq_s16(v53, v45); + int16x8_t v55 = vqrdmulhq_n_s16(v54, 17734); + int16x8_t v56 = vaddq_s16(v51, v55); + int16x8_t v57 = vqrdmulhq_n_s16(v56, 16705); + int16x8_t v58 = vaddq_s16(v49, v57); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v63_tmp = vqrdmulhq_n_s16(v62, 13573); + int16x8_t v63 = vaddq_s16(v63_tmp, v62); + int16x8_t v64 = vld1q_s16(in + in_stride * 34 + i); + int16x8_t v65 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v66 = vaddq_s16(v64, v65); + int16x8_t v67 = vaddq_s16(v63, v66); + int16x8_t v68 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v69 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v70 = vaddq_s16(v68, v69); + int16x8_t v71 = vqrdmulhq_n_s16(v70, 25080); + int16x8_t v72 = vld1q_s16(in + in_stride * 50 + i); + int16x8_t v73 = vld1q_s16(in + in_stride * 46 + i); + int16x8_t v74 = vaddq_s16(v72, v73); + int16x8_t v75 = vaddq_s16(v74, v70); + int16x8_t v76 = vqrdmulhq_n_s16(v75, 17734); + int16x8_t v77 = vaddq_s16(v71, v76); + int16x8_t v78 = vaddq_s16(v67, v77); + int16x8_t v79 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v80 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v81 = vaddq_s16(v79, v80); + int16x8_t v82_tmp = vqrdmulhq_n_s16(v81, 13573); + int16x8_t v82 = vaddq_s16(v82_tmp, v81); + int16x8_t v83 = vld1q_s16(in + in_stride * 42 + i); + int16x8_t v84 = vld1q_s16(in + in_stride * 38 + i); + int16x8_t v85 = vaddq_s16(v83, v84); + int16x8_t v86 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v87 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v88 = vaddq_s16(v86, v87); + int16x8_t v89 = vaddq_s16(v85, v88); + int16x8_t v90 = vaddq_s16(v82, v89); + int16x8_t v91 = vqrdmulhq_n_s16(v90, 16705); + int16x8_t v92 = vaddq_s16(v88, v81); + int16x8_t v93_tmp = vqrdmulhq_n_s16(v92, 10045); + int16x8_t v93 = vaddq_s16(v93_tmp, v92); + int16x8_t v94 = vld1q_s16(in + in_stride * 58 + i); + int16x8_t v95 = vld1q_s16(in + in_stride * 54 + i); + int16x8_t v96 = vaddq_s16(v94, v95); + int16x8_t v97 = vaddq_s16(v96, v85); + int16x8_t v98 = vqrdmulhq_n_s16(v97, 17734); + int16x8_t v99 = vaddq_s16(v93, v98); + int16x8_t v100 = vqrdmulhq_n_s16(v99, 16705); + int16x8_t v101 = vaddq_s16(v91, v100); + int16x8_t v102 = vaddq_s16(v78, v101); + int16x8_t v103 = vaddq_s16(v69, v79); + int16x8_t v104 = vaddq_s16(v80, v62); + int16x8_t v105 = vaddq_s16(v103, v104); + int16x8_t v106_tmp = vqrdmulhq_n_s16(v105, 13573); + int16x8_t v106 = vaddq_s16(v106_tmp, v105); + int16x8_t v107 = vaddq_s16(v73, v83); + int16x8_t v108 = vaddq_s16(v84, v64); + int16x8_t v109 = vaddq_s16(v107, v108); + int16x8_t v110 = vaddq_s16(v65, v86); + int16x8_t v111 = vaddq_s16(v87, v68); + int16x8_t v112 = vaddq_s16(v110, v111); + int16x8_t v113 = vaddq_s16(v109, v112); + int16x8_t v114 = vaddq_s16(v106, v113); + int16x8_t v115 = vqrdmulhq_n_s16(v114, 16705); + int16x8_t v116 = vaddq_s16(v112, v105); + int16x8_t v117 = vqrdmulhq_n_s16(v116, 25080); + int16x8_t v118 = vqrdmulhq_n_s16(v116, 17734); + int16x8_t v119 = vld1q_s16(in + in_stride * 62 + i); + int16x8_t v120 = vaddq_s16(v119, v94); + int16x8_t v121 = vaddq_s16(v95, v72); + int16x8_t v122 = vaddq_s16(v120, v121); + int16x8_t v123 = vaddq_s16(v122, v109); + int16x8_t v124 = vqrdmulhq_n_s16(v123, 17734); + int16x8_t v125 = vaddq_s16(v118, v124); + int16x8_t v126 = vaddq_s16(v117, v125); + int16x8_t v127 = vqrdmulhq_n_s16(v126, 16705); + int16x8_t v128 = vaddq_s16(v115, v127); + int16x8_t v129 = vqrdmulhq_n_s16(v128, 16463); + int16x8_t v130_tmp = vqrdmulhq_n_s16(v104, 13573); + int16x8_t v130 = vaddq_s16(v130_tmp, v104); + int16x8_t v131 = vaddq_s16(v108, v110); + int16x8_t v132 = vaddq_s16(v130, v131); + int16x8_t v133 = vaddq_s16(v111, v103); + int16x8_t v134_tmp = vqrdmulhq_n_s16(v133, 10045); + int16x8_t v134 = vaddq_s16(v134_tmp, v133); + int16x8_t v135 = vaddq_s16(v121, v107); + int16x8_t v136 = vqrdmulhq_n_s16(v135, 17734); + int16x8_t v137 = vaddq_s16(v134, v136); + int16x8_t v138 = vaddq_s16(v132, v137); + int16x8_t v139 = vqrdmulhq_n_s16(v138, 16463); + int16x8_t v140 = vaddq_s16(v129, v139); + int16x8_t v141 = vaddq_s16(v102, v140); + int16x8_t v142 = vqrdmulhq_n_s16(v141, 16404); + int16x8_t v143 = vaddq_s16(v61, v142); + int16x8_t v144 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v145_tmp = vqrdmulhq_n_s16(v144, 13573); + int16x8_t v145 = vaddq_s16(v145_tmp, v144); + int16x8_t v146 = vld1q_s16(in + in_stride * 33 + i); + int16x8_t v147 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v148 = vaddq_s16(v146, v147); + int16x8_t v149 = vaddq_s16(v145, v148); + int16x8_t v150 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v151 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v152 = vaddq_s16(v150, v151); + int16x8_t v153 = vqrdmulhq_n_s16(v152, 25080); + int16x8_t v154 = vld1q_s16(in + in_stride * 49 + i); + int16x8_t v155 = vld1q_s16(in + in_stride * 47 + i); + int16x8_t v156 = vaddq_s16(v154, v155); + int16x8_t v157 = vaddq_s16(v156, v152); + int16x8_t v158 = vqrdmulhq_n_s16(v157, 17734); + int16x8_t v159 = vaddq_s16(v153, v158); + int16x8_t v160 = vaddq_s16(v149, v159); + int16x8_t v161 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v162 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v163 = vaddq_s16(v161, v162); + int16x8_t v164_tmp = vqrdmulhq_n_s16(v163, 13573); + int16x8_t v164 = vaddq_s16(v164_tmp, v163); + int16x8_t v165 = vld1q_s16(in + in_stride * 41 + i); + int16x8_t v166 = vld1q_s16(in + in_stride * 39 + i); + int16x8_t v167 = vaddq_s16(v165, v166); + int16x8_t v168 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v169 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v170 = vaddq_s16(v168, v169); + int16x8_t v171 = vaddq_s16(v167, v170); + int16x8_t v172 = vaddq_s16(v164, v171); + int16x8_t v173 = vqrdmulhq_n_s16(v172, 16705); + int16x8_t v174 = vaddq_s16(v170, v163); + int16x8_t v175_tmp = vqrdmulhq_n_s16(v174, 10045); + int16x8_t v175 = vaddq_s16(v175_tmp, v174); + int16x8_t v176 = vld1q_s16(in + in_stride * 57 + i); + int16x8_t v177 = vld1q_s16(in + in_stride * 55 + i); + int16x8_t v178 = vaddq_s16(v176, v177); + int16x8_t v179 = vaddq_s16(v178, v167); + int16x8_t v180 = vqrdmulhq_n_s16(v179, 17734); + int16x8_t v181 = vaddq_s16(v175, v180); + int16x8_t v182 = vqrdmulhq_n_s16(v181, 16705); + int16x8_t v183 = vaddq_s16(v173, v182); + int16x8_t v184 = vaddq_s16(v160, v183); + int16x8_t v185 = vld1q_s16(in + in_stride * 37 + i); + int16x8_t v186 = vld1q_s16(in + in_stride * 35 + i); + int16x8_t v187 = vaddq_s16(v185, v186); + int16x8_t v188 = vld1q_s16(in + in_stride * 45 + i); + int16x8_t v189 = vld1q_s16(in + in_stride * 43 + i); + int16x8_t v190 = vaddq_s16(v188, v189); + int16x8_t v191 = vaddq_s16(v187, v190); + int16x8_t v192 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v193 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v194 = vaddq_s16(v192, v193); + int16x8_t v195 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v196 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v197 = vaddq_s16(v195, v196); + int16x8_t v198 = vaddq_s16(v194, v197); + int16x8_t v199 = vaddq_s16(v191, v198); + int16x8_t v200 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v201 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v202 = vaddq_s16(v200, v201); + int16x8_t v203 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v204 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v205 = vaddq_s16(v203, v204); + int16x8_t v206 = vaddq_s16(v202, v205); + int16x8_t v207_tmp = vqrdmulhq_n_s16(v206, 13573); + int16x8_t v207 = vaddq_s16(v207_tmp, v206); + int16x8_t v208 = vaddq_s16(v199, v207); + int16x8_t v209 = vqrdmulhq_n_s16(v208, 16705); + int16x8_t v210 = vaddq_s16(v198, v206); + int16x8_t v211 = vqrdmulhq_n_s16(v210, 25080); + int16x8_t v212 = vqrdmulhq_n_s16(v210, 17734); + int16x8_t v213 = vld1q_s16(in + in_stride * 53 + i); + int16x8_t v214 = vld1q_s16(in + in_stride * 51 + i); + int16x8_t v215 = vaddq_s16(v213, v214); + int16x8_t v216 = vld1q_s16(in + in_stride * 61 + i); + int16x8_t v217 = vld1q_s16(in + in_stride * 59 + i); + int16x8_t v218 = vaddq_s16(v216, v217); + int16x8_t v219 = vaddq_s16(v215, v218); + int16x8_t v220 = vaddq_s16(v219, v191); + int16x8_t v221 = vqrdmulhq_n_s16(v220, 17734); + int16x8_t v222 = vaddq_s16(v212, v221); + int16x8_t v223 = vaddq_s16(v211, v222); + int16x8_t v224 = vqrdmulhq_n_s16(v223, 16705); + int16x8_t v225 = vaddq_s16(v209, v224); + int16x8_t v226 = vqrdmulhq_n_s16(v225, 16463); + int16x8_t v227_tmp = vqrdmulhq_n_s16(v202, 13573); + int16x8_t v227 = vaddq_s16(v227_tmp, v202); + int16x8_t v228 = vaddq_s16(v187, v194); + int16x8_t v229 = vaddq_s16(v227, v228); + int16x8_t v230 = vaddq_s16(v215, v190); + int16x8_t v231 = vqrdmulhq_n_s16(v230, 17734); + int16x8_t v232 = vaddq_s16(v197, v205); + int16x8_t v233_tmp = vqrdmulhq_n_s16(v232, 10045); + int16x8_t v233 = vaddq_s16(v233_tmp, v232); + int16x8_t v234 = vaddq_s16(v231, v233); + int16x8_t v235 = vaddq_s16(v229, v234); + int16x8_t v236 = vqrdmulhq_n_s16(v235, 16463); + int16x8_t v237 = vaddq_s16(v226, v236); + int16x8_t v238 = vaddq_s16(v184, v237); + int16x8_t v239 = vaddq_s16(v201, v144); + int16x8_t v240_tmp = vqrdmulhq_n_s16(v239, 13573); + int16x8_t v240 = vaddq_s16(v240_tmp, v239); + int16x8_t v241 = vaddq_s16(v186, v146); + int16x8_t v242 = vaddq_s16(v147, v192); + int16x8_t v243 = vaddq_s16(v241, v242); + int16x8_t v244 = vaddq_s16(v240, v243); + int16x8_t v245 = vaddq_s16(v196, v150); + int16x8_t v246 = vaddq_s16(v151, v203); + int16x8_t v247 = vaddq_s16(v245, v246); + int16x8_t v248_tmp = vqrdmulhq_n_s16(v247, 10045); + int16x8_t v248 = vaddq_s16(v248_tmp, v247); + int16x8_t v249 = vaddq_s16(v155, v188); + int16x8_t v250 = vaddq_s16(v214, v154); + int16x8_t v251 = vaddq_s16(v249, v250); + int16x8_t v252 = vqrdmulhq_n_s16(v251, 17734); + int16x8_t v253 = vaddq_s16(v248, v252); + int16x8_t v254 = vaddq_s16(v244, v253); + int16x8_t v255 = vaddq_s16(v204, v161); + int16x8_t v256 = vaddq_s16(v162, v200); + int16x8_t v257 = vaddq_s16(v255, v256); + int16x8_t v258_tmp = vqrdmulhq_n_s16(v257, 13573); + int16x8_t v258 = vaddq_s16(v258_tmp, v257); + int16x8_t v259 = vaddq_s16(v189, v165); + int16x8_t v260 = vaddq_s16(v166, v185); + int16x8_t v261 = vaddq_s16(v259, v260); + int16x8_t v262 = vaddq_s16(v169, v195); + int16x8_t v263 = vaddq_s16(v193, v168); + int16x8_t v264 = vaddq_s16(v262, v263); + int16x8_t v265 = vaddq_s16(v261, v264); + int16x8_t v266 = vaddq_s16(v258, v265); + int16x8_t v267 = vqrdmulhq_n_s16(v266, 16705); + int16x8_t v268 = vaddq_s16(v264, v257); + int16x8_t v269 = vqrdmulhq_n_s16(v268, 25080); + int16x8_t v270 = vaddq_s16(v217, v176); + int16x8_t v271 = vaddq_s16(v177, v213); + int16x8_t v272 = vaddq_s16(v270, v271); + int16x8_t v273 = vaddq_s16(v272, v261); + int16x8_t v274 = vqrdmulhq_n_s16(v273, 17734); + int16x8_t v275 = vqrdmulhq_n_s16(v268, 17734); + int16x8_t v276 = vaddq_s16(v274, v275); + int16x8_t v277 = vaddq_s16(v269, v276); + int16x8_t v278 = vqrdmulhq_n_s16(v277, 16705); + int16x8_t v279 = vaddq_s16(v267, v278); + int16x8_t v280 = vaddq_s16(v254, v279); + int16x8_t v281 = vqrdmulhq_n_s16(v280, 16404); + int16x8_t v282 = vaddq_s16(v256, v239); + int16x8_t v283_tmp = vqrdmulhq_n_s16(v282, 13573); + int16x8_t v283 = vaddq_s16(v283_tmp, v282); + int16x8_t v284 = vaddq_s16(v260, v241); + int16x8_t v285 = vaddq_s16(v242, v263); + int16x8_t v286 = vaddq_s16(v284, v285); + int16x8_t v287 = vaddq_s16(v283, v286); + int16x8_t v288 = vaddq_s16(v262, v245); + int16x8_t v289 = vaddq_s16(v246, v255); + int16x8_t v290 = vaddq_s16(v288, v289); + int16x8_t v291 = vqrdmulhq_n_s16(v290, 25080); + int16x8_t v292 = vqrdmulhq_n_s16(v290, 17734); + int16x8_t v293 = vaddq_s16(v271, v250); + int16x8_t v294 = vaddq_s16(v249, v259); + int16x8_t v295 = vaddq_s16(v293, v294); + int16x8_t v296 = vqrdmulhq_n_s16(v295, 17734); + int16x8_t v297 = vaddq_s16(v292, v296); + int16x8_t v298 = vaddq_s16(v291, v297); + int16x8_t v299 = vaddq_s16(v287, v298); + int16x8_t v300 = vqrdmulhq_n_s16(v299, 16463); + int16x8_t v301 = vaddq_s16(v289, v282); + int16x8_t v302 = vqrdmulhq_n_s16(v301, 23624); + int16x8_t v303 = vaddq_s16(v294, v284); + int16x8_t v304 = vqrdmulhq_n_s16(v303, 19705); + int16x8_t v305 = vaddq_s16(v285, v288); + int16x8_t v306 = vqrdmulhq_n_s16(v305, 19705); + int16x8_t v307 = vaddq_s16(v304, v306); + int16x8_t v308 = vqrdmulhq_n_s16(v307, 27779); + int16x8_t v309 = vaddq_s16(v302, v308); + int16x8_t v310 = vaddq_s16(v305, v301); + int16x8_t v311 = vqrdmulhq_n_s16(v310, 25080); + int16x8_t v312 = vqrdmulhq_n_s16(v310, 17734); + int16x8_t v313 = vld1q_s16(in + in_stride * 63 + i); + int16x8_t v314 = vaddq_s16(v313, v216); + int16x8_t v315 = vaddq_s16(v314, v270); + int16x8_t v316 = vaddq_s16(v315, v293); + int16x8_t v317 = vqrdmulhq_n_s16(v316, 25746); + int16x8_t v318 = vqrdmulhq_n_s16(v303, 25746); + int16x8_t v319 = vaddq_s16(v317, v318); + int16x8_t v320 = vqrdmulhq_n_s16(v319, 22571); + int16x8_t v321 = vaddq_s16(v312, v320); + int16x8_t v322 = vaddq_s16(v311, v321); + int16x8_t v323 = vqrdmulhq_n_s16(v322, 16705); + int16x8_t v324 = vaddq_s16(v309, v323); + int16x8_t v325 = vqrdmulhq_n_s16(v324, 16463); + int16x8_t v326 = vaddq_s16(v300, v325); + int16x8_t v327 = vqrdmulhq_n_s16(v326, 16404); + int16x8_t v328 = vaddq_s16(v281, v327); + int16x8_t v329 = vaddq_s16(v238, v328); + int16x8_t v330 = vqrdmulhq_n_s16(v329, 16389); + int16x8_t v331 = vaddq_s16(v143, v330); + int16x8_t v332 = vsubq_s16(v82, v89); + int16x8_t v333 = vqrdmulhq_n_s16(v332, 19705); + int16x8_t v334 = vqrdmulhq_n_s16(v92, 13573); + int16x8_t v335 = vsubq_s16(v334, v97); + int16x8_t v336 = vqrdmulhq_n_s16(v335, 25746); + int16x8_t v337 = vaddq_s16(v333, v336); + int16x8_t v338 = vsubq_s16(v63, v66); + int16x8_t v339 = vqrdmulhq_n_s16(v70, 17734); + int16x8_t v340_tmp = vqrdmulhq_n_s16(v74, 10045); + int16x8_t v340 = vaddq_s16(v340_tmp, v74); + int16x8_t v341 = vsubq_s16(v339, v340); + int16x8_t v342 = vaddq_s16(v338, v341); + int16x8_t v343 = vaddq_s16(v337, v342); + int16x8_t v344 = vsubq_s16(v130, v131); + int16x8_t v345 = vqrdmulhq_n_s16(v133, 13573); + int16x8_t v346 = vsubq_s16(v345, v135); + int16x8_t v347_tmp = vqrdmulhq_n_s16(v346, 10045); + int16x8_t v347 = vaddq_s16(v347_tmp, v346); + int16x8_t v348 = vaddq_s16(v344, v347); + int16x8_t v349 = vqrdmulhq_n_s16(v348, 17121); + int16x8_t v350 = vqrdmulhq_n_s16(v105, 27867); + int16x8_t v351 = vqrdmulhq_n_s16(v113, 19705); + int16x8_t v352 = vsubq_s16(v350, v351); + int16x8_t v353 = vqrdmulhq_n_s16(v116, 13573); + int16x8_t v354 = vsubq_s16(v353, v123); + int16x8_t v355 = vqrdmulhq_n_s16(v354, 25746); + int16x8_t v356 = vaddq_s16(v352, v355); + int16x8_t v357 = vqrdmulhq_n_s16(v356, 17121); + int16x8_t v358 = vaddq_s16(v349, v357); + int16x8_t v359 = vaddq_s16(v343, v358); + int16x8_t v360 = vqrdmulhq_n_s16(v359, 16563); + int16x8_t v361 = vsubq_s16(v27, v30); + int16x8_t v362 = vqrdmulhq_n_s16(v34, 17734); + int16x8_t v363_tmp = vqrdmulhq_n_s16(v38, 10045); + int16x8_t v363 = vaddq_s16(v363_tmp, v38); + int16x8_t v364 = vsubq_s16(v362, v363); + int16x8_t v365 = vaddq_s16(v361, v364); + int16x8_t v366 = vsubq_s16(v44, v47); + int16x8_t v367 = vqrdmulhq_n_s16(v366, 19705); + int16x8_t v368 = vqrdmulhq_n_s16(v50, 13573); + int16x8_t v369 = vsubq_s16(v368, v54); + int16x8_t v370 = vqrdmulhq_n_s16(v369, 25746); + int16x8_t v371 = vaddq_s16(v367, v370); + int16x8_t v372 = vaddq_s16(v365, v371); + int16x8_t v373 = vqrdmulhq_n_s16(v372, 17121); + int16x8_t v374 = vsubq_s16(v0, v1); + int16x8_t v375 = vsubq_s16(v4, v6); + int16x8_t v376_tmp = vqrdmulhq_n_s16(v375, 10045); + int16x8_t v376 = vaddq_s16(v376_tmp, v375); + int16x8_t v377 = vaddq_s16(v374, v376); + int16x8_t v378 = vsubq_s16(v11, v14); + int16x8_t v379 = vqrdmulhq_n_s16(v18, 17734); + int16x8_t v380_tmp = vqrdmulhq_n_s16(v17, 10045); + int16x8_t v380 = vaddq_s16(v380_tmp, v17); + int16x8_t v381 = vsubq_s16(v379, v380); + int16x8_t v382 = vaddq_s16(v378, v381); + int16x8_t v383 = vqrdmulhq_n_s16(v382, 19705); + int16x8_t v384 = vaddq_s16(v377, v383); + int16x8_t v385 = vaddq_s16(v373, v384); + int16x8_t v386 = vaddq_s16(v360, v385); + int16x8_t v387 = vsubq_s16(v145, v148); + int16x8_t v388 = vqrdmulhq_n_s16(v152, 17734); + int16x8_t v389_tmp = vqrdmulhq_n_s16(v156, 10045); + int16x8_t v389 = vaddq_s16(v389_tmp, v156); + int16x8_t v390 = vsubq_s16(v388, v389); + int16x8_t v391 = vaddq_s16(v387, v390); + int16x8_t v392 = vsubq_s16(v164, v171); + int16x8_t v393 = vqrdmulhq_n_s16(v392, 19705); + int16x8_t v394 = vqrdmulhq_n_s16(v174, 13573); + int16x8_t v395 = vsubq_s16(v394, v179); + int16x8_t v396 = vqrdmulhq_n_s16(v395, 25746); + int16x8_t v397 = vaddq_s16(v393, v396); + int16x8_t v398 = vaddq_s16(v391, v397); + int16x8_t v399 = vsubq_s16(v227, v228); + int16x8_t v400 = vqrdmulhq_n_s16(v232, 13573); + int16x8_t v401 = vsubq_s16(v400, v230); + int16x8_t v402_tmp = vqrdmulhq_n_s16(v401, 10045); + int16x8_t v402 = vaddq_s16(v402_tmp, v401); + int16x8_t v403 = vaddq_s16(v399, v402); + int16x8_t v404 = vqrdmulhq_n_s16(v403, 17121); + int16x8_t v405 = vqrdmulhq_n_s16(v206, 27867); + int16x8_t v406 = vqrdmulhq_n_s16(v199, 19705); + int16x8_t v407 = vsubq_s16(v405, v406); + int16x8_t v408 = vqrdmulhq_n_s16(v210, 13573); + int16x8_t v409 = vsubq_s16(v408, v220); + int16x8_t v410 = vqrdmulhq_n_s16(v409, 25746); + int16x8_t v411 = vaddq_s16(v407, v410); + int16x8_t v412 = vqrdmulhq_n_s16(v411, 17121); + int16x8_t v413 = vaddq_s16(v404, v412); + int16x8_t v414 = vaddq_s16(v398, v413); + int16x8_t v415 = vsubq_s16(v240, v243); + int16x8_t v416 = vqrdmulhq_n_s16(v247, 13573); + int16x8_t v417 = vsubq_s16(v416, v251); + int16x8_t v418_tmp = vqrdmulhq_n_s16(v417, 10045); + int16x8_t v418 = vaddq_s16(v418_tmp, v417); + int16x8_t v419 = vaddq_s16(v415, v418); + int16x8_t v420 = vqrdmulhq_n_s16(v257, 27867); + int16x8_t v421 = vqrdmulhq_n_s16(v265, 19705); + int16x8_t v422 = vsubq_s16(v420, v421); + int16x8_t v423 = vqrdmulhq_n_s16(v268, 13573); + int16x8_t v424 = vsubq_s16(v423, v273); + int16x8_t v425 = vqrdmulhq_n_s16(v424, 25746); + int16x8_t v426 = vaddq_s16(v422, v425); + int16x8_t v427 = vaddq_s16(v419, v426); + int16x8_t v428 = vqrdmulhq_n_s16(v427, 16563); + int16x8_t v429 = vqrdmulhq_n_s16(v301, 27867); + int16x8_t v430 = vsubq_s16(v429, v307); + int16x8_t v431 = vqrdmulhq_n_s16(v310, 10664); + int16x8_t v432 = vsubq_s16(v431, v319); + int16x8_t v433 = vaddq_s16(v430, v432); + int16x8_t v434 = vqrdmulhq_n_s16(v433, 17121); + int16x8_t v435 = vsubq_s16(v283, v286); + int16x8_t v436 = vqrdmulhq_n_s16(v290, 13573); + int16x8_t v437 = vsubq_s16(v436, v295); + int16x8_t v438_tmp = vqrdmulhq_n_s16(v437, 10045); + int16x8_t v438 = vaddq_s16(v438_tmp, v437); + int16x8_t v439 = vaddq_s16(v435, v438); + int16x8_t v440 = vqrdmulhq_n_s16(v439, 17121); + int16x8_t v441 = vaddq_s16(v434, v440); + int16x8_t v442 = vqrdmulhq_n_s16(v441, 16563); + int16x8_t v443 = vaddq_s16(v428, v442); + int16x8_t v444 = vaddq_s16(v414, v443); + int16x8_t v445 = vqrdmulhq_n_s16(v444, 16429); + int16x8_t v446 = vaddq_s16(v386, v445); + int16x8_t v447 = vsubq_s16(v374, v376); + int16x8_t v448 = vsubq_s16(v378, v381); + int16x8_t v449 = vqrdmulhq_n_s16(v448, 29490); + int16x8_t v450 = vaddq_s16(v447, v449); + int16x8_t v451 = vsubq_s16(v361, v364); + int16x8_t v452 = vqrdmulhq_n_s16(v366, 29490); + int16x8_t v453_tmp = vqrdmulhq_n_s16(v369, 5763); + int16x8_t v453 = vaddq_s16(v453_tmp, v369); + int16x8_t v454 = vsubq_s16(v452, v453); + int16x8_t v455 = vaddq_s16(v451, v454); + int16x8_t v456 = vqrdmulhq_n_s16(v455, 18578); + int16x8_t v457 = vaddq_s16(v450, v456); + int16x8_t v458 = vsubq_s16(v338, v341); + int16x8_t v459 = vqrdmulhq_n_s16(v332, 29490); + int16x8_t v460_tmp = vqrdmulhq_n_s16(v335, 5763); + int16x8_t v460 = vaddq_s16(v460_tmp, v335); + int16x8_t v461 = vsubq_s16(v459, v460); + int16x8_t v462 = vaddq_s16(v458, v461); + int16x8_t v463 = vqrdmulhq_n_s16(v352, 27803); + int16x8_t v464 = vqrdmulhq_n_s16(v354, 21845); + int16x8_t v465 = vsubq_s16(v463, v464); + int16x8_t v466 = vsubq_s16(v344, v347); + int16x8_t v467 = vqrdmulhq_n_s16(v466, 18578); + int16x8_t v468 = vaddq_s16(v465, v467); + int16x8_t v469 = vaddq_s16(v462, v468); + int16x8_t v470 = vqrdmulhq_n_s16(v469, 16890); + int16x8_t v471 = vaddq_s16(v457, v470); + int16x8_t v472 = vsubq_s16(v415, v418); + int16x8_t v473_tmp = vqrdmulhq_n_s16(v422, 16273); + int16x8_t v473 = vaddq_s16(v473_tmp, v422); + int16x8_t v474_tmp = vqrdmulhq_n_s16(v424, 5763); + int16x8_t v474 = vaddq_s16(v474_tmp, v424); + int16x8_t v475 = vsubq_s16(v473, v474); + int16x8_t v476 = vaddq_s16(v472, v475); + int16x8_t v477 = vqrdmulhq_n_s16(v476, 16890); + int16x8_t v478 = vqrdmulhq_n_s16(v435, 20261); + int16x8_t v479 = vqrdmulhq_n_s16(v437, 26472); + int16x8_t v480 = vsubq_s16(v478, v479); + int16x8_t v481 = vqrdmulhq_n_s16(v480, 30046); + int16x8_t v482 = vqrdmulhq_n_s16(v430, 30322); + int16x8_t v483 = vqrdmulhq_n_s16(v432, 30322); + int16x8_t v484 = vsubq_s16(v482, v483); + int16x8_t v485 = vqrdmulhq_n_s16(v484, 30046); + int16x8_t v486 = vaddq_s16(v481, v485); + int16x8_t v487 = vqrdmulhq_n_s16(v486, 16890); + int16x8_t v488 = vaddq_s16(v477, v487); + int16x8_t v489 = vsubq_s16(v387, v390); + int16x8_t v490 = vqrdmulhq_n_s16(v392, 29490); + int16x8_t v491_tmp = vqrdmulhq_n_s16(v395, 5763); + int16x8_t v491 = vaddq_s16(v491_tmp, v395); + int16x8_t v492 = vsubq_s16(v490, v491); + int16x8_t v493 = vaddq_s16(v489, v492); + int16x8_t v494 = vsubq_s16(v399, v402); + int16x8_t v495 = vqrdmulhq_n_s16(v494, 18578); + int16x8_t v496 = vqrdmulhq_n_s16(v407, 27803); + int16x8_t v497 = vqrdmulhq_n_s16(v409, 21845); + int16x8_t v498 = vsubq_s16(v496, v497); + int16x8_t v499 = vaddq_s16(v495, v498); + int16x8_t v500 = vaddq_s16(v493, v499); + int16x8_t v501 = vaddq_s16(v488, v500); + int16x8_t v502 = vqrdmulhq_n_s16(v501, 16508); + int16x8_t v503 = vaddq_s16(v471, v502); + int16x8_t v504 = vsubq_s16(v2, v8); + int16x8_t v505 = vsubq_s16(v15, v22); + int16x8_t v506_tmp = vqrdmulhq_n_s16(v505, 18446); + int16x8_t v506 = vmlaq_n_s16(v506_tmp, v505, 2); + int16x8_t v507 = vaddq_s16(v504, v506); + int16x8_t v508 = vsubq_s16(v31, v41); + int16x8_t v509 = vsubq_s16(v48, v56); + int16x8_t v510_tmp = vqrdmulhq_n_s16(v509, 18446); + int16x8_t v510 = vmlaq_n_s16(v510_tmp, v509, 2); + int16x8_t v511 = vaddq_s16(v508, v510); + int16x8_t v512 = vqrdmulhq_n_s16(v511, 21195); + int16x8_t v513 = vaddq_s16(v507, v512); + int16x8_t v514 = vsubq_s16(v67, v77); + int16x8_t v515 = vsubq_s16(v90, v99); + int16x8_t v516_tmp = vqrdmulhq_n_s16(v515, 18446); + int16x8_t v516 = vmlaq_n_s16(v516_tmp, v515, 2); + int16x8_t v517 = vaddq_s16(v514, v516); + int16x8_t v518 = vsubq_s16(v114, v126); + int16x8_t v519_tmp = vqrdmulhq_n_s16(v518, 18446); + int16x8_t v519 = vmlaq_n_s16(v519_tmp, v518, 2); + int16x8_t v520 = vsubq_s16(v132, v137); + int16x8_t v521 = vaddq_s16(v519, v520); + int16x8_t v522 = vqrdmulhq_n_s16(v521, 21195); + int16x8_t v523 = vaddq_s16(v517, v522); + int16x8_t v524 = vqrdmulhq_n_s16(v523, 17401); + int16x8_t v525 = vaddq_s16(v513, v524); + int16x8_t v526 = vsubq_s16(v172, v181); + int16x8_t v527_tmp = vqrdmulhq_n_s16(v526, 18446); + int16x8_t v527 = vmlaq_n_s16(v527_tmp, v526, 2); + int16x8_t v528 = vsubq_s16(v149, v159); + int16x8_t v529 = vaddq_s16(v527, v528); + int16x8_t v530 = vsubq_s16(v229, v234); + int16x8_t v531 = vsubq_s16(v208, v223); + int16x8_t v532_tmp = vqrdmulhq_n_s16(v531, 18446); + int16x8_t v532 = vmlaq_n_s16(v532_tmp, v531, 2); + int16x8_t v533 = vaddq_s16(v530, v532); + int16x8_t v534 = vqrdmulhq_n_s16(v533, 21195); + int16x8_t v535 = vaddq_s16(v529, v534); + int16x8_t v536 = vsubq_s16(v244, v253); + int16x8_t v537 = vsubq_s16(v266, v277); + int16x8_t v538_tmp = vqrdmulhq_n_s16(v537, 18446); + int16x8_t v538 = vmlaq_n_s16(v538_tmp, v537, 2); + int16x8_t v539 = vaddq_s16(v536, v538); + int16x8_t v540 = vqrdmulhq_n_s16(v539, 17401); + int16x8_t v541 = vqrdmulhq_n_s16(v287, 25826); + int16x8_t v542 = vqrdmulhq_n_s16(v298, 25826); + int16x8_t v543 = vsubq_s16(v541, v542); + int16x8_t v544 = vqrdmulhq_n_s16(v543, 14281); + int16x8_t v545_tmp = vqrdmulhq_n_s16(v309, 31509); + int16x8_t v545 = vaddq_s16(v545_tmp, v309); + int16x8_t v546 = vsubq_s16(v545, v322); + int16x8_t v547 = vqrdmulhq_n_s16(v546, 28847); + int16x8_t v548 = vaddq_s16(v544, v547); + int16x8_t v549 = vaddq_s16(v540, v548); + int16x8_t v550 = vaddq_s16(v535, v549); + int16x8_t v551 = vqrdmulhq_n_s16(v550, 16629); + int16x8_t v552 = vaddq_s16(v525, v551); + int16x8_t v553 = vsubq_s16(v504, v506); + int16x8_t v554 = vsubq_s16(v508, v510); + int16x8_t v555 = vqrdmulhq_n_s16(v554, 25826); + int16x8_t v556 = vaddq_s16(v553, v555); + int16x8_t v557 = vsubq_s16(v514, v516); + int16x8_t v558 = vsubq_s16(v520, v519); + int16x8_t v559 = vqrdmulhq_n_s16(v558, 25826); + int16x8_t v560 = vaddq_s16(v557, v559); + int16x8_t v561 = vqrdmulhq_n_s16(v560, 18124); + int16x8_t v562 = vaddq_s16(v556, v561); + int16x8_t v563 = vsubq_s16(v528, v527); + int16x8_t v564 = vsubq_s16(v530, v532); + int16x8_t v565 = vqrdmulhq_n_s16(v564, 25826); + int16x8_t v566 = vaddq_s16(v563, v565); + int16x8_t v567 = vsubq_s16(v536, v538); + int16x8_t v568 = vqrdmulhq_n_s16(v567, 18124); + int16x8_t v569_tmp = vqrdmulhq_n_s16(v546, 654); + int16x8_t v569 = vmlaq_n_s16(v569_tmp, v546, 2); + int16x8_t v570 = vsubq_s16(v543, v569); + int16x8_t v571 = vqrdmulhq_n_s16(v570, 18124); + int16x8_t v572 = vaddq_s16(v568, v571); + int16x8_t v573 = vaddq_s16(v566, v572); + int16x8_t v574 = vqrdmulhq_n_s16(v573, 16792); + int16x8_t v575 = vaddq_s16(v562, v574); + int16x8_t v576 = vsubq_s16(v458, v461); + int16x8_t v577_tmp = vqrdmulhq_n_s16(v465, 25030); + int16x8_t v577 = vaddq_s16(v577_tmp, v465); + int16x8_t v578 = vsubq_s16(v466, v577); + int16x8_t v579_tmp = vqrdmulhq_n_s16(v578, 1988); + int16x8_t v579 = vaddq_s16(v579_tmp, v578); + int16x8_t v580 = vaddq_s16(v576, v579); + int16x8_t v581 = vqrdmulhq_n_s16(v580, 19102); + int16x8_t v582 = vsubq_s16(v447, v449); + int16x8_t v583 = vsubq_s16(v451, v454); + int16x8_t v584_tmp = vqrdmulhq_n_s16(v583, 1988); + int16x8_t v584 = vaddq_s16(v584_tmp, v583); + int16x8_t v585 = vaddq_s16(v582, v584); + int16x8_t v586 = vaddq_s16(v581, v585); + int16x8_t v587 = vsubq_s16(v489, v492); + int16x8_t v588_tmp = vqrdmulhq_n_s16(v498, 25030); + int16x8_t v588 = vaddq_s16(v588_tmp, v498); + int16x8_t v589 = vsubq_s16(v494, v588); + int16x8_t v590_tmp = vqrdmulhq_n_s16(v589, 1988); + int16x8_t v590 = vaddq_s16(v590_tmp, v589); + int16x8_t v591 = vaddq_s16(v587, v590); + int16x8_t v592 = vsubq_s16(v472, v475); + int16x8_t v593 = vqrdmulhq_n_s16(v592, 19102); + int16x8_t v594 = vsubq_s16(v480, v484); + int16x8_t v595 = vaddq_s16(v593, v594); + int16x8_t v596 = vaddq_s16(v591, v595); + int16x8_t v597 = vqrdmulhq_n_s16(v596, 17000); + int16x8_t v598 = vaddq_s16(v586, v597); + int16x8_t v599 = vsubq_s16(v365, v371); + int16x8_t v600_tmp = vqrdmulhq_n_s16(v599, 23673); + int16x8_t v600 = vaddq_s16(v600_tmp, v599); + int16x8_t v601 = vsubq_s16(v377, v383); + int16x8_t v602 = vaddq_s16(v600, v601); + int16x8_t v603 = vsubq_s16(v348, v356); + int16x8_t v604_tmp = vqrdmulhq_n_s16(v603, 23673); + int16x8_t v604 = vaddq_s16(v604_tmp, v603); + int16x8_t v605 = vsubq_s16(v342, v337); + int16x8_t v606 = vaddq_s16(v604, v605); + int16x8_t v607 = vqrdmulhq_n_s16(v606, 20398); + int16x8_t v608 = vaddq_s16(v602, v607); + int16x8_t v609 = vsubq_s16(v391, v397); + int16x8_t v610 = vsubq_s16(v403, v411); + int16x8_t v611_tmp = vqrdmulhq_n_s16(v610, 23673); + int16x8_t v611 = vaddq_s16(v611_tmp, v610); + int16x8_t v612 = vaddq_s16(v609, v611); + int16x8_t v613 = vsubq_s16(v419, v426); + int16x8_t v614 = vqrdmulhq_n_s16(v613, 20398); + int16x8_t v615 = vsubq_s16(v439, v433); + int16x8_t v616_tmp = vqrdmulhq_n_s16(v615, 2367); + int16x8_t v616 = vaddq_s16(v616_tmp, v615); + int16x8_t v617 = vaddq_s16(v614, v616); + int16x8_t v618 = vaddq_s16(v612, v617); + int16x8_t v619 = vqrdmulhq_n_s16(v618, 17255); + int16x8_t v620 = vaddq_s16(v608, v619); + int16x8_t v621 = vsubq_s16(v160, v183); + int16x8_t v622 = vsubq_s16(v235, v225); + int16x8_t v623_tmp = vqrdmulhq_n_s16(v622, 3314); + int16x8_t v623 = vmlaq_n_s16(v623_tmp, v622, 5); + int16x8_t v624 = vaddq_s16(v621, v623); + int16x8_t v625 = vsubq_s16(v254, v279); + int16x8_t v626 = vsubq_s16(v299, v324); + int16x8_t v627_tmp = vqrdmulhq_n_s16(v626, 3314); + int16x8_t v627 = vmlaq_n_s16(v627_tmp, v626, 5); + int16x8_t v628 = vaddq_s16(v625, v627); + int16x8_t v629 = vqrdmulhq_n_s16(v628, 22112); + int16x8_t v630 = vaddq_s16(v624, v629); + int16x8_t v631 = vqrdmulhq_n_s16(v630, 17561); + int16x8_t v632 = vsubq_s16(v9, v24); + int16x8_t v633 = vsubq_s16(v42, v58); + int16x8_t v634_tmp = vqrdmulhq_n_s16(v633, 3314); + int16x8_t v634 = vmlaq_n_s16(v634_tmp, v633, 5); + int16x8_t v635 = vaddq_s16(v632, v634); + int16x8_t v636 = vsubq_s16(v78, v101); + int16x8_t v637 = vsubq_s16(v138, v128); + int16x8_t v638_tmp = vqrdmulhq_n_s16(v637, 3314); + int16x8_t v638 = vmlaq_n_s16(v638_tmp, v637, 5); + int16x8_t v639 = vaddq_s16(v636, v638); + int16x8_t v640 = vqrdmulhq_n_s16(v639, 22112); + int16x8_t v641 = vaddq_s16(v635, v640); + int16x8_t v642 = vaddq_s16(v631, v641); + int16x8_t v643 = vsubq_s16(v632, v634); + int16x8_t v644 = vsubq_s16(v636, v638); + int16x8_t v645 = vqrdmulhq_n_s16(v644, 24397); + int16x8_t v646 = vaddq_s16(v643, v645); + int16x8_t v647 = vsubq_s16(v621, v623); + int16x8_t v648 = vsubq_s16(v625, v627); + int16x8_t v649 = vqrdmulhq_n_s16(v648, 24397); + int16x8_t v650 = vaddq_s16(v647, v649); + int16x8_t v651 = vqrdmulhq_n_s16(v650, 17921); + int16x8_t v652 = vaddq_s16(v646, v651); + int16x8_t v653 = vsubq_s16(v601, v600); + int16x8_t v654 = vsubq_s16(v605, v604); + int16x8_t v655 = vqrdmulhq_n_s16(v654, 27504); + int16x8_t v656 = vaddq_s16(v653, v655); + int16x8_t v657 = vsubq_s16(v609, v611); + int16x8_t v658 = vqrdmulhq_n_s16(v613, 27504); + int16x8_t v659_tmp = vqrdmulhq_n_s16(v615, 14606); + int16x8_t v659 = vaddq_s16(v659_tmp, v615); + int16x8_t v660 = vsubq_s16(v658, v659); + int16x8_t v661 = vaddq_s16(v657, v660); + int16x8_t v662 = vqrdmulhq_n_s16(v661, 18343); + int16x8_t v663 = vaddq_s16(v656, v662); + int16x8_t v664 = vsubq_s16(v582, v584); + int16x8_t v665 = vsubq_s16(v576, v579); + int16x8_t v666 = vqrdmulhq_n_s16(v665, 31869); + int16x8_t v667 = vaddq_s16(v664, v666); + int16x8_t v668 = vsubq_s16(v587, v590); + int16x8_t v669_tmp = vqrdmulhq_n_s16(v594, 23444); + int16x8_t v669 = vaddq_s16(v669_tmp, v594); + int16x8_t v670 = vsubq_s16(v592, v669); + int16x8_t v671 = vqrdmulhq_n_s16(v670, 31869); + int16x8_t v672 = vaddq_s16(v668, v671); + int16x8_t v673 = vqrdmulhq_n_s16(v672, 18830); + int16x8_t v674 = vaddq_s16(v667, v673); + int16x8_t v675 = vsubq_s16(v553, v555); + int16x8_t v676 = vsubq_s16(v557, v559); + int16x8_t v677_tmp = vqrdmulhq_n_s16(v676, 5552); + int16x8_t v677 = vaddq_s16(v677_tmp, v676); + int16x8_t v678 = vaddq_s16(v675, v677); + int16x8_t v679 = vsubq_s16(v563, v565); + int16x8_t v680 = vsubq_s16(v567, v570); + int16x8_t v681_tmp = vqrdmulhq_n_s16(v680, 5552); + int16x8_t v681 = vaddq_s16(v681_tmp, v680); + int16x8_t v682 = vaddq_s16(v679, v681); + int16x8_t v683 = vqrdmulhq_n_s16(v682, 19393); + int16x8_t v684 = vaddq_s16(v678, v683); + int16x8_t v685 = vsubq_s16(v507, v512); + int16x8_t v686 = vsubq_s16(v517, v522); + int16x8_t v687_tmp = vqrdmulhq_n_s16(v686, 15865); + int16x8_t v687 = vaddq_s16(v687_tmp, v686); + int16x8_t v688 = vaddq_s16(v685, v687); + int16x8_t v689 = vsubq_s16(v529, v534); + int16x8_t v690_tmp = vqrdmulhq_n_s16(v548, 28937); + int16x8_t v690 = vaddq_s16(v690_tmp, v548); + int16x8_t v691 = vsubq_s16(v539, v690); + int16x8_t v692_tmp = vqrdmulhq_n_s16(v691, 15865); + int16x8_t v692 = vaddq_s16(v692_tmp, v691); + int16x8_t v693 = vaddq_s16(v689, v692); + int16x8_t v694 = vqrdmulhq_n_s16(v693, 20040); + int16x8_t v695 = vaddq_s16(v688, v694); + int16x8_t v696 = vsubq_s16(v476, v486); + int16x8_t v697_tmp = vqrdmulhq_n_s16(v696, 1893); + int16x8_t v697 = vmlaq_n_s16(v697_tmp, v696, 2); + int16x8_t v698 = vsubq_s16(v493, v499); + int16x8_t v699 = vaddq_s16(v697, v698); + int16x8_t v700 = vqrdmulhq_n_s16(v699, 20783); + int16x8_t v701 = vsubq_s16(v450, v456); + int16x8_t v702 = vsubq_s16(v462, v468); + int16x8_t v703_tmp = vqrdmulhq_n_s16(v702, 1893); + int16x8_t v703 = vmlaq_n_s16(v703_tmp, v702, 2); + int16x8_t v704 = vaddq_s16(v701, v703); + int16x8_t v705 = vaddq_s16(v700, v704); + int16x8_t v706 = vsubq_s16(v384, v373); + int16x8_t v707 = vsubq_s16(v343, v358); + int16x8_t v708_tmp = vqrdmulhq_n_s16(v707, 13357); + int16x8_t v708 = vmlaq_n_s16(v708_tmp, v707, 3); + int16x8_t v709 = vaddq_s16(v706, v708); + int16x8_t v710 = vsubq_s16(v398, v413); + int16x8_t v711 = vsubq_s16(v427, v441); + int16x8_t v712_tmp = vqrdmulhq_n_s16(v711, 13357); + int16x8_t v712 = vmlaq_n_s16(v712_tmp, v711, 3); + int16x8_t v713 = vaddq_s16(v710, v712); + int16x8_t v714 = vqrdmulhq_n_s16(v713, 21637); + int16x8_t v715 = vaddq_s16(v709, v714); + int16x8_t v716 = vsubq_s16(v25, v60); + int16x8_t v717 = vsubq_s16(v102, v140); + int16x8_t v718_tmp = vqrdmulhq_n_s16(v717, 6226); + int16x8_t v718 = vmlaq_n_s16(v718_tmp, v717, 10); + int16x8_t v719 = vaddq_s16(v716, v718); + int16x8_t v720 = vsubq_s16(v280, v326); + int16x8_t v721_tmp = vqrdmulhq_n_s16(v720, 6226); + int16x8_t v721 = vmlaq_n_s16(v721_tmp, v720, 10); + int16x8_t v722 = vsubq_s16(v184, v237); + int16x8_t v723 = vaddq_s16(v721, v722); + int16x8_t v724 = vqrdmulhq_n_s16(v723, 22622); + int16x8_t v725 = vaddq_s16(v719, v724); + int16x8_t v726 = vsubq_s16(v716, v718); + int16x8_t v727 = vsubq_s16(v722, v721); + int16x8_t v728 = vqrdmulhq_n_s16(v727, 23761); + int16x8_t v729 = vaddq_s16(v726, v728); + int16x8_t v730 = vsubq_s16(v706, v708); + int16x8_t v731 = vsubq_s16(v710, v712); + int16x8_t v732 = vqrdmulhq_n_s16(v731, 25084); + int16x8_t v733 = vaddq_s16(v730, v732); + int16x8_t v734 = vsubq_s16(v701, v703); + int16x8_t v735 = vsubq_s16(v698, v697); + int16x8_t v736 = vqrdmulhq_n_s16(v735, 26631); + int16x8_t v737 = vaddq_s16(v734, v736); + int16x8_t v738 = vsubq_s16(v685, v687); + int16x8_t v739 = vsubq_s16(v689, v692); + int16x8_t v740 = vqrdmulhq_n_s16(v739, 28454); + int16x8_t v741 = vaddq_s16(v738, v740); + int16x8_t v742 = vsubq_s16(v675, v677); + int16x8_t v743 = vsubq_s16(v679, v681); + int16x8_t v744 = vqrdmulhq_n_s16(v743, 30624); + int16x8_t v745 = vaddq_s16(v742, v744); + int16x8_t v746 = vsubq_s16(v664, v666); + int16x8_t v747 = vsubq_s16(v668, v671); + int16x8_t v748_tmp = vqrdmulhq_n_s16(v747, 472); + int16x8_t v748 = vaddq_s16(v748_tmp, v747); + int16x8_t v749 = vaddq_s16(v746, v748); + int16x8_t v750 = vsubq_s16(v653, v655); + int16x8_t v751 = vsubq_s16(v657, v660); + int16x8_t v752_tmp = vqrdmulhq_n_s16(v751, 3672); + int16x8_t v752 = vaddq_s16(v752_tmp, v751); + int16x8_t v753 = vaddq_s16(v750, v752); + int16x8_t v754 = vsubq_s16(v643, v645); + int16x8_t v755 = vsubq_s16(v647, v649); + int16x8_t v756_tmp = vqrdmulhq_n_s16(v755, 7662); + int16x8_t v756 = vaddq_s16(v756_tmp, v755); + int16x8_t v757 = vaddq_s16(v754, v756); + int16x8_t v758 = vsubq_s16(v635, v640); + int16x8_t v759 = vsubq_s16(v624, v629); + int16x8_t v760_tmp = vqrdmulhq_n_s16(v759, 12756); + int16x8_t v760 = vaddq_s16(v760_tmp, v759); + int16x8_t v761 = vaddq_s16(v758, v760); + int16x8_t v762 = vsubq_s16(v602, v607); + int16x8_t v763 = vsubq_s16(v612, v617); + int16x8_t v764_tmp = vqrdmulhq_n_s16(v763, 19463); + int16x8_t v764 = vaddq_s16(v764_tmp, v763); + int16x8_t v765 = vaddq_s16(v762, v764); + int16x8_t v766 = vsubq_s16(v585, v581); + int16x8_t v767 = vsubq_s16(v591, v595); + int16x8_t v768_tmp = vqrdmulhq_n_s16(v767, 28661); + int16x8_t v768 = vaddq_s16(v768_tmp, v767); + int16x8_t v769 = vaddq_s16(v766, v768); + int16x8_t v770 = vsubq_s16(v556, v561); + int16x8_t v771 = vsubq_s16(v566, v572); + int16x8_t v772_tmp = vqrdmulhq_n_s16(v771, 9242); + int16x8_t v772 = vmlaq_n_s16(v772_tmp, v771, 2); + int16x8_t v773 = vaddq_s16(v770, v772); + int16x8_t v774 = vsubq_s16(v513, v524); + int16x8_t v775 = vsubq_s16(v535, v549); + int16x8_t v776_tmp = vqrdmulhq_n_s16(v775, 30298); + int16x8_t v776 = vmlaq_n_s16(v776_tmp, v775, 2); + int16x8_t v777 = vaddq_s16(v774, v776); + int16x8_t v778 = vsubq_s16(v457, v470); + int16x8_t v779 = vsubq_s16(v500, v488); + int16x8_t v780_tmp = vqrdmulhq_n_s16(v779, 2773); + int16x8_t v780 = vmlaq_n_s16(v780_tmp, v779, 4); + int16x8_t v781 = vaddq_s16(v778, v780); + int16x8_t v782 = vsubq_s16(v385, v360); + int16x8_t v783 = vsubq_s16(v414, v443); + int16x8_t v784_tmp = vqrdmulhq_n_s16(v783, 26108); + int16x8_t v784 = vmlaq_n_s16(v784_tmp, v783, 6); + int16x8_t v785 = vaddq_s16(v782, v784); + int16x8_t v786 = vsubq_s16(v61, v142); + int16x8_t v787 = vsubq_s16(v238, v328); + int16x8_t v788_tmp = vqrdmulhq_n_s16(v787, 12251); + int16x8_t v788 = vmlaq_n_s16(v788_tmp, v787, 20); + int16x8_t v789 = vaddq_s16(v786, v788); + int16x8_t v790 = vsubq_s16(v786, v788); + int16x8_t v791 = vsubq_s16(v782, v784); + int16x8_t v792 = vsubq_s16(v778, v780); + int16x8_t v793 = vsubq_s16(v774, v776); + int16x8_t v794 = vsubq_s16(v770, v772); + int16x8_t v795 = vsubq_s16(v766, v768); + int16x8_t v796 = vsubq_s16(v762, v764); + int16x8_t v797 = vsubq_s16(v758, v760); + int16x8_t v798 = vsubq_s16(v754, v756); + int16x8_t v799 = vsubq_s16(v750, v752); + int16x8_t v800 = vsubq_s16(v746, v748); + int16x8_t v801 = vsubq_s16(v742, v744); + int16x8_t v802 = vsubq_s16(v738, v740); + int16x8_t v803 = vsubq_s16(v734, v736); + int16x8_t v804 = vsubq_s16(v730, v732); + int16x8_t v805 = vsubq_s16(v726, v728); + int16x8_t v806 = vsubq_s16(v719, v724); + int16x8_t v807 = vsubq_s16(v709, v714); + int16x8_t v808 = vsubq_s16(v704, v700); + int16x8_t v809 = vsubq_s16(v688, v694); + int16x8_t v810 = vsubq_s16(v678, v683); + int16x8_t v811 = vsubq_s16(v667, v673); + int16x8_t v812 = vsubq_s16(v656, v662); + int16x8_t v813 = vsubq_s16(v646, v651); + int16x8_t v814 = vsubq_s16(v641, v631); + int16x8_t v815 = vsubq_s16(v608, v619); + int16x8_t v816 = vsubq_s16(v586, v597); + int16x8_t v817 = vsubq_s16(v562, v574); + int16x8_t v818 = vsubq_s16(v525, v551); + int16x8_t v819 = vsubq_s16(v471, v502); + int16x8_t v820 = vsubq_s16(v386, v445); + int16x8_t v821 = vsubq_s16(v143, v330); + vst1q_s16(out + out_stride * 0 + i, v331); + vst1q_s16(out + out_stride * 1 + i, v446); + vst1q_s16(out + out_stride * 2 + i, v503); + vst1q_s16(out + out_stride * 3 + i, v552); + vst1q_s16(out + out_stride * 4 + i, v575); + vst1q_s16(out + out_stride * 5 + i, v598); + vst1q_s16(out + out_stride * 6 + i, v620); + vst1q_s16(out + out_stride * 7 + i, v642); + vst1q_s16(out + out_stride * 8 + i, v652); + vst1q_s16(out + out_stride * 9 + i, v663); + vst1q_s16(out + out_stride * 10 + i, v674); + vst1q_s16(out + out_stride * 11 + i, v684); + vst1q_s16(out + out_stride * 12 + i, v695); + vst1q_s16(out + out_stride * 13 + i, v705); + vst1q_s16(out + out_stride * 14 + i, v715); + vst1q_s16(out + out_stride * 15 + i, v725); + vst1q_s16(out + out_stride * 16 + i, v729); + vst1q_s16(out + out_stride * 17 + i, v733); + vst1q_s16(out + out_stride * 18 + i, v737); + vst1q_s16(out + out_stride * 19 + i, v741); + vst1q_s16(out + out_stride * 20 + i, v745); + vst1q_s16(out + out_stride * 21 + i, v749); + vst1q_s16(out + out_stride * 22 + i, v753); + vst1q_s16(out + out_stride * 23 + i, v757); + vst1q_s16(out + out_stride * 24 + i, v761); + vst1q_s16(out + out_stride * 25 + i, v765); + vst1q_s16(out + out_stride * 26 + i, v769); + vst1q_s16(out + out_stride * 27 + i, v773); + vst1q_s16(out + out_stride * 28 + i, v777); + vst1q_s16(out + out_stride * 29 + i, v781); + vst1q_s16(out + out_stride * 30 + i, v785); + vst1q_s16(out + out_stride * 31 + i, v789); + vst1q_s16(out + out_stride * 32 + i, v790); + vst1q_s16(out + out_stride * 33 + i, v791); + vst1q_s16(out + out_stride * 34 + i, v792); + vst1q_s16(out + out_stride * 35 + i, v793); + vst1q_s16(out + out_stride * 36 + i, v794); + vst1q_s16(out + out_stride * 37 + i, v795); + vst1q_s16(out + out_stride * 38 + i, v796); + vst1q_s16(out + out_stride * 39 + i, v797); + vst1q_s16(out + out_stride * 40 + i, v798); + vst1q_s16(out + out_stride * 41 + i, v799); + vst1q_s16(out + out_stride * 42 + i, v800); + vst1q_s16(out + out_stride * 43 + i, v801); + vst1q_s16(out + out_stride * 44 + i, v802); + vst1q_s16(out + out_stride * 45 + i, v803); + vst1q_s16(out + out_stride * 46 + i, v804); + vst1q_s16(out + out_stride * 47 + i, v805); + vst1q_s16(out + out_stride * 48 + i, v806); + vst1q_s16(out + out_stride * 49 + i, v807); + vst1q_s16(out + out_stride * 50 + i, v808); + vst1q_s16(out + out_stride * 51 + i, v809); + vst1q_s16(out + out_stride * 52 + i, v810); + vst1q_s16(out + out_stride * 53 + i, v811); + vst1q_s16(out + out_stride * 54 + i, v812); + vst1q_s16(out + out_stride * 55 + i, v813); + vst1q_s16(out + out_stride * 56 + i, v814); + vst1q_s16(out + out_stride * 57 + i, v815); + vst1q_s16(out + out_stride * 58 + i, v816); + vst1q_s16(out + out_stride * 59 + i, v817); + vst1q_s16(out + out_stride * 60 + i, v818); + vst1q_s16(out + out_stride * 61 + i, v819); + vst1q_s16(out + out_stride * 62 + i, v820); + vst1q_s16(out + out_stride * 63 + i, v821); + } +} diff --git a/media/libjxl/src/lib/jxl/fast_dct8-inl.h b/media/libjxl/src/lib/jxl/fast_dct8-inl.h new file mode 100644 index 000000000..946ace4a0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct8-inl.h @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<8>) { return 1; } + +void FastIDCT(FastDCTTag<8>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17 = vqrdmulhq_n_s16(v16, 25080); + int16x8_t v18 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v16, v19); + int16x8_t v21 = vqrdmulhq_n_s16(v20, 17734); + int16x8_t v22 = vaddq_s16(v17, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vsubq_s16(v0, v1); + int16x8_t v27 = vsubq_s16(v4, v6); + int16x8_t v28_tmp = vqrdmulhq_n_s16(v27, 10045); + int16x8_t v28 = vaddq_s16(v28_tmp, v27); + int16x8_t v29 = vaddq_s16(v26, v28); + int16x8_t v30 = vsubq_s16(v11, v14); + int16x8_t v31 = vqrdmulhq_n_s16(v16, 17734); + int16x8_t v32_tmp = vqrdmulhq_n_s16(v19, 10045); + int16x8_t v32 = vaddq_s16(v32_tmp, v19); + int16x8_t v33 = vsubq_s16(v31, v32); + int16x8_t v34 = vaddq_s16(v30, v33); + int16x8_t v35 = vqrdmulhq_n_s16(v34, 19705); + int16x8_t v36 = vaddq_s16(v29, v35); + int16x8_t v37 = vsubq_s16(v26, v28); + int16x8_t v38 = vsubq_s16(v30, v33); + int16x8_t v39 = vqrdmulhq_n_s16(v38, 29490); + int16x8_t v40 = vaddq_s16(v37, v39); + int16x8_t v41 = vsubq_s16(v2, v8); + int16x8_t v42 = vsubq_s16(v15, v22); + int16x8_t v43_tmp = vqrdmulhq_n_s16(v42, 18446); + int16x8_t v43 = vmlaq_n_s16(v43_tmp, v42, 2); + int16x8_t v44 = vaddq_s16(v41, v43); + int16x8_t v45 = vsubq_s16(v41, v43); + int16x8_t v46 = vsubq_s16(v37, v39); + int16x8_t v47 = vsubq_s16(v29, v35); + int16x8_t v48 = vsubq_s16(v9, v24); + vst1q_s16(out + out_stride * 0 + i, v25); + vst1q_s16(out + out_stride * 1 + i, v36); + vst1q_s16(out + out_stride * 2 + i, v40); + vst1q_s16(out + out_stride * 3 + i, v44); + vst1q_s16(out + out_stride * 4 + i, v45); + vst1q_s16(out + out_stride * 5 + i, v46); + vst1q_s16(out + out_stride * 6 + i, v47); + vst1q_s16(out + out_stride * 7 + i, v48); + } +} diff --git a/media/libjxl/src/lib/jxl/fast_dct_test.cc b/media/libjxl/src/lib/jxl/fast_dct_test.cc new file mode 100644 index 000000000..d9d852f32 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_dct_test.cc @@ -0,0 +1,366 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_dct_test.cc" +#include + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/fast_dct-inl.h" +#include "lib/jxl/fast_dct.h" +#include "lib/jxl/transpose-inl.h" + +// Test utils +#include +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +template +HWY_NOINLINE void TestFastTranspose() { +#if HWY_TARGET == HWY_NEON + auto array_mem = hwy::AllocateAligned(N * M); + int16_t* array = array_mem.get(); + auto transposed_mem = hwy::AllocateAligned(N * M); + int16_t* transposed = transposed_mem.get(); + std::iota(array, array + N * M, 0); + for (size_t j = 0; j < 100000000 / (N * M); j++) { + FastTransposeBlock(array, M, N, M, transposed, N); + } + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + EXPECT_EQ(array[j * M + i], transposed[i * N + j]); + } + } +#endif +} + +template +HWY_NOINLINE void TestFloatTranspose() { + auto array_mem = hwy::AllocateAligned(N * M); + float* array = array_mem.get(); + auto transposed_mem = hwy::AllocateAligned(N * M); + float* transposed = transposed_mem.get(); + std::iota(array, array + N * M, 0); + for (size_t j = 0; j < 100000000 / (N * M); j++) { + Transpose::Run(DCTFrom(array, M), DCTTo(transposed, N)); + } + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + EXPECT_EQ(array[j * M + i], transposed[i * N + j]); + } + } +} + +// TODO(sboukortt): re-enable the FloatIDCT tests once we find out why they fail +// in ASAN mode in the CI runners and seemingly not locally. + +HWY_NOINLINE void TestFastTranspose8x8() { TestFastTranspose<8, 8>(); } +HWY_NOINLINE void TestFloatTranspose8x8() { TestFloatTranspose<8, 8>(); } +HWY_NOINLINE void TestFastIDCT8x8() { TestFastIDCT<8, 8>(); } +HWY_NOINLINE void TestFloatIDCT8x8() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<8, 8>(); +#endif +} +HWY_NOINLINE void TestFastTranspose8x16() { TestFastTranspose<8, 16>(); } +HWY_NOINLINE void TestFloatTranspose8x16() { TestFloatTranspose<8, 16>(); } +HWY_NOINLINE void TestFastIDCT8x16() { TestFastIDCT<8, 16>(); } +HWY_NOINLINE void TestFloatIDCT8x16() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<8, 16>(); +#endif +} +HWY_NOINLINE void TestFastTranspose8x32() { TestFastTranspose<8, 32>(); } +HWY_NOINLINE void TestFloatTranspose8x32() { TestFloatTranspose<8, 32>(); } +HWY_NOINLINE void TestFastIDCT8x32() { TestFastIDCT<8, 32>(); } +HWY_NOINLINE void TestFloatIDCT8x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<8, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose16x8() { TestFastTranspose<16, 8>(); } +HWY_NOINLINE void TestFloatTranspose16x8() { TestFloatTranspose<16, 8>(); } +HWY_NOINLINE void TestFastIDCT16x8() { TestFastIDCT<16, 8>(); } +HWY_NOINLINE void TestFloatIDCT16x8() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<16, 8>(); +#endif +} +HWY_NOINLINE void TestFastTranspose16x16() { TestFastTranspose<16, 16>(); } +HWY_NOINLINE void TestFloatTranspose16x16() { TestFloatTranspose<16, 16>(); } +HWY_NOINLINE void TestFastIDCT16x16() { TestFastIDCT<16, 16>(); } +HWY_NOINLINE void TestFloatIDCT16x16() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<16, 16>(); +#endif +} +HWY_NOINLINE void TestFastTranspose16x32() { TestFastTranspose<16, 32>(); } +HWY_NOINLINE void TestFloatTranspose16x32() { TestFloatTranspose<16, 32>(); } +HWY_NOINLINE void TestFastIDCT16x32() { TestFastIDCT<16, 32>(); } +HWY_NOINLINE void TestFloatIDCT16x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<16, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x8() { TestFastTranspose<32, 8>(); } +HWY_NOINLINE void TestFloatTranspose32x8() { TestFloatTranspose<32, 8>(); } +HWY_NOINLINE void TestFastIDCT32x8() { TestFastIDCT<32, 8>(); } +HWY_NOINLINE void TestFloatIDCT32x8() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 8>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x16() { TestFastTranspose<32, 16>(); } +HWY_NOINLINE void TestFloatTranspose32x16() { TestFloatTranspose<32, 16>(); } +HWY_NOINLINE void TestFastIDCT32x16() { TestFastIDCT<32, 16>(); } +HWY_NOINLINE void TestFloatIDCT32x16() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 16>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x32() { TestFastTranspose<32, 32>(); } +HWY_NOINLINE void TestFloatTranspose32x32() { TestFloatTranspose<32, 32>(); } +HWY_NOINLINE void TestFastIDCT32x32() { TestFastIDCT<32, 32>(); } +HWY_NOINLINE void TestFloatIDCT32x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x64() { TestFastTranspose<32, 64>(); } +HWY_NOINLINE void TestFloatTranspose32x64() { TestFloatTranspose<32, 64>(); } +HWY_NOINLINE void TestFastIDCT32x64() { TestFastIDCT<32, 64>(); } +HWY_NOINLINE void TestFloatIDCT32x64() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 64>(); +#endif +} +HWY_NOINLINE void TestFastTranspose64x32() { TestFastTranspose<64, 32>(); } +HWY_NOINLINE void TestFloatTranspose64x32() { TestFloatTranspose<64, 32>(); } +HWY_NOINLINE void TestFastIDCT64x32() { TestFastIDCT<64, 32>(); } +HWY_NOINLINE void TestFloatIDCT64x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<64, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose64x64() { TestFastTranspose<64, 64>(); } +HWY_NOINLINE void TestFloatTranspose64x64() { TestFloatTranspose<64, 64>(); } +HWY_NOINLINE void TestFastIDCT64x64() { TestFastIDCT<64, 64>(); } +HWY_NOINLINE void TestFloatIDCT64x64() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<64, 64>(); +#endif +} +HWY_NOINLINE void TestFastTranspose64x128() { TestFastTranspose<64, 128>(); } +HWY_NOINLINE void TestFloatTranspose64x128() { TestFloatTranspose<64, 128>(); } +HWY_NOINLINE void TestFastIDCT64x128() { TestFastIDCT<64, 128>(); } +HWY_NOINLINE void TestFloatIDCT64x128() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<64, 128>(); +#endif +} +HWY_NOINLINE void TestFastTranspose128x64() { TestFastTranspose<128, 64>(); } +HWY_NOINLINE void TestFloatTranspose128x64() { TestFloatTranspose<128, 64>(); } +HWY_NOINLINE void TestFastIDCT128x64() { TestFastIDCT<128, 64>(); } +HWY_NOINLINE void TestFloatIDCT128x64() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<128, 64>(); +#endif +} +HWY_NOINLINE void TestFastTranspose128x128() { TestFastTranspose<128, 128>(); } +HWY_NOINLINE void TestFloatTranspose128x128() { + TestFloatTranspose<128, 128>(); +} +HWY_NOINLINE void TestFastIDCT128x128() { TestFastIDCT<128, 128>(); } +HWY_NOINLINE void TestFloatIDCT128x128() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<128, 128>(); +#endif +} +HWY_NOINLINE void TestFastTranspose128x256() { TestFastTranspose<128, 256>(); } +HWY_NOINLINE void TestFloatTranspose128x256() { + TestFloatTranspose<128, 256>(); +} +HWY_NOINLINE void TestFastIDCT128x256() { TestFastIDCT<128, 256>(); } +HWY_NOINLINE void TestFloatIDCT128x256() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<128, 256>(); +#endif +} +HWY_NOINLINE void TestFastTranspose256x128() { TestFastTranspose<256, 128>(); } +HWY_NOINLINE void TestFloatTranspose256x128() { + TestFloatTranspose<256, 128>(); +} +HWY_NOINLINE void TestFastIDCT256x128() { TestFastIDCT<256, 128>(); } +HWY_NOINLINE void TestFloatIDCT256x128() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<256, 128>(); +#endif +} +HWY_NOINLINE void TestFastTranspose256x256() { TestFastTranspose<256, 256>(); } +HWY_NOINLINE void TestFloatTranspose256x256() { + TestFloatTranspose<256, 256>(); +} +HWY_NOINLINE void TestFastIDCT256x256() { TestFastIDCT<256, 256>(); } +HWY_NOINLINE void TestFloatIDCT256x256() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<256, 256>(); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class FastDCTTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(FastDCTTargetTest); + +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose64x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose64x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose256x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose256x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT64x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT64x64); +/* + * DCT-128 and above have very large errors just by rounding inputs. +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT256x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT256x256); +*/ + +TEST(FastDCTTest, TestWrapperFloat) { BenchmarkFloatIDCT32x32(); } +TEST(FastDCTTest, TestWrapperFast) { BenchmarkFastIDCT32x32(); } + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/fast_math-inl.h b/media/libjxl/src/lib/jxl/fast_math-inl.h new file mode 100644 index 000000000..5c4803429 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_math-inl.h @@ -0,0 +1,236 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast SIMD math ops (log2, encoder only, cos, erf for splines) + +#if defined(LIB_JXL_FAST_MATH_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_FAST_MATH_INL_H_ +#undef LIB_JXL_FAST_MATH_INL_H_ +#else +#define LIB_JXL_FAST_MATH_INL_H_ +#endif + +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/rational_polynomial-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::Floor; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Le; +using hwy::HWY_NAMESPACE::Min; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::NegMulAdd; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Xor; + +// Computes base-2 logarithm like std::log2. Undefined if negative / NaN. +// L1 error ~3.9E-6 +template +V FastLog2f(const DF df, V x) { + // 2,2 rational polynomial approximation of std::log1p(x) / std::log(2). + HWY_ALIGN const float p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06f), + HWY_REP4(1.4287160470083755E+00f), + HWY_REP4(7.4245873327820566E-01f)}; + HWY_ALIGN const float q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01f), + HWY_REP4(1.0096718572241148E+00f), + HWY_REP4(1.7409343003366853E-01f)}; + + const Rebind di; + const auto x_bits = BitCast(di, x); + + // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops + const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab)); // = 2/3 + // Shifted exponent = log2; also used to clear mantissa. + const auto exp_shifted = ShiftRight<23>(exp_bits); + const auto mantissa = BitCast(df, Sub(x_bits, ShiftLeft<23>(exp_shifted))); + const auto exp_val = ConvertTo(df, exp_shifted); + return Add(EvalRationalPolynomial(df, Sub(mantissa, Set(df, 1.0f)), p, q), + exp_val); +} + +// max relative error ~3e-7 +template +V FastPow2f(const DF df, V x) { + const Rebind di; + auto floorx = Floor(x); + auto exp = + BitCast(df, ShiftLeft<23>(Add(ConvertTo(di, floorx), Set(di, 127)))); + auto frac = Sub(x, floorx); + auto num = Add(frac, Set(df, 1.01749063e+01)); + num = MulAdd(num, frac, Set(df, 4.88687798e+01)); + num = MulAdd(num, frac, Set(df, 9.85506591e+01)); + num = Mul(num, exp); + auto den = MulAdd(frac, Set(df, 2.10242958e-01), Set(df, -2.22328856e-02)); + den = MulAdd(den, frac, Set(df, -1.94414990e+01)); + den = MulAdd(den, frac, Set(df, 9.85506633e+01)); + return Div(num, den); +} + +// max relative error ~3e-5 +template +V FastPowf(const DF df, V base, V exponent) { + return FastPow2f(df, Mul(FastLog2f(df, base), exponent)); +} + +// Computes cosine like std::cos. +// L1 error 7e-5. +template +V FastCosf(const DF df, V x) { + // Step 1: range reduction to [0, 2pi) + const auto pi2 = Set(df, kPi * 2.0f); + const auto pi2_inv = Set(df, 0.5f / kPi); + const auto npi2 = Mul(Floor(Mul(x, pi2_inv)), pi2); + const auto xmodpi2 = Sub(x, npi2); + // Step 2: range reduction to [0, pi] + const auto x_pi = Min(xmodpi2, Sub(pi2, xmodpi2)); + // Step 3: range reduction to [0, pi/2] + const auto above_pihalf = Ge(x_pi, Set(df, kPi / 2.0f)); + const auto x_pihalf = IfThenElse(above_pihalf, Sub(Set(df, kPi), x_pi), x_pi); + // Step 4: Taylor-like approximation, scaled by 2**0.75 to make angle + // duplication steps faster, on x/4. + const auto xs = Mul(x_pihalf, Set(df, 0.25f)); + const auto x2 = Mul(xs, xs); + const auto x4 = Mul(x2, x2); + const auto cosx_prescaling = + MulAdd(x4, Set(df, 0.06960438), + MulAdd(x2, Set(df, -0.84087373), Set(df, 1.68179268))); + // Step 5: angle duplication. + const auto cosx_scale1 = + MulAdd(cosx_prescaling, cosx_prescaling, Set(df, -1.414213562)); + const auto cosx_scale2 = MulAdd(cosx_scale1, cosx_scale1, Set(df, -1)); + // Step 6: change sign if needed. + const Rebind du; + auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, above_pihalf))); + return BitCast(df, Xor(signbit, BitCast(du, cosx_scale2))); +} + +// Computes the error function like std::erf. +// L1 error 7e-4. +template +V FastErff(const DF df, V x) { + // Formula from + // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations + // but constants have been recomputed. + const auto xle0 = Le(x, Zero(df)); + const auto absx = Abs(x); + // Compute 1 - 1 / ((((x * a + b) * x + c) * x + d) * x + 1)**4 + const auto denom1 = + MulAdd(absx, Set(df, 7.77394369e-02), Set(df, 2.05260015e-04)); + const auto denom2 = MulAdd(denom1, absx, Set(df, 2.32120216e-01)); + const auto denom3 = MulAdd(denom2, absx, Set(df, 2.77820801e-01)); + const auto denom4 = MulAdd(denom3, absx, Set(df, 1.0f)); + const auto denom5 = Mul(denom4, denom4); + const auto inv_denom5 = Div(Set(df, 1.0f), denom5); + const auto result = NegMulAdd(inv_denom5, inv_denom5, Set(df, 1.0f)); + // Change sign if needed. + const Rebind du; + auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, xle0))); + return BitCast(df, Xor(signbit, BitCast(du, result))); +} + +inline float FastLog2f(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastLog2f(D, Set(D, f))); +} + +inline float FastPow2f(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastPow2f(D, Set(D, f))); +} + +inline float FastPowf(float b, float e) { + HWY_CAPPED(float, 1) D; + return GetLane(FastPowf(D, Set(D, b), Set(D, e))); +} + +inline float FastCosf(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastCosf(D, Set(D, f))); +} + +inline float FastErff(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastErff(D, Set(D, f))); +} + +// Returns cbrt(x) + add with 6 ulp max error. +// Modified from vectormath_exp.h, Apache 2 license. +// https://www.agner.org/optimize/vectorclass.zip +template +V CubeRootAndAdd(const V x, const V add) { + const HWY_FULL(float) df; + const HWY_FULL(int32_t) di; + + const auto kExpBias = Set(di, 0x54800000); // cast(1.) + cast(1.) / 3 + const auto kExpMul = Set(di, 0x002AAAAA); // shifted 1/3 + const auto k1_3 = Set(df, 1.0f / 3); + const auto k4_3 = Set(df, 4.0f / 3); + + const auto xa = x; // assume inputs never negative + const auto xa_3 = Mul(k1_3, xa); + + // Multiply exponent by -1/3 + const auto m1 = BitCast(di, xa); + // Special case for 0. 0 is represented with an exponent of 0, so the + // "kExpBias - 1/3 * exp" below gives the wrong result. The IfThenZeroElse() + // sets those values as 0, which prevents having NaNs in the computations + // below. + // TODO(eustas): use fused op + const auto m2 = IfThenZeroElse( + Eq(m1, Zero(di)), Sub(kExpBias, Mul((ShiftRight<23>(m1)), kExpMul))); + auto r = BitCast(df, m2); + + // Newton-Raphson iterations + for (int i = 0; i < 3; i++) { + const auto r2 = Mul(r, r); + r = NegMulAdd(xa_3, Mul(r2, r2), Mul(k4_3, r)); + } + // Final iteration + auto r2 = Mul(r, r); + r = MulAdd(k1_3, NegMulAdd(xa, Mul(r2, r2), r), r); + r2 = Mul(r, r); + r = MulAdd(r2, x, add); + + return r; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_FAST_MATH_INL_H_ + +#if HWY_ONCE +#ifndef FAST_MATH_ONCE +#define FAST_MATH_ONCE + +namespace jxl { +inline float FastLog2f(float f) { return HWY_STATIC_DISPATCH(FastLog2f)(f); } +inline float FastPow2f(float f) { return HWY_STATIC_DISPATCH(FastPow2f)(f); } +inline float FastPowf(float b, float e) { + return HWY_STATIC_DISPATCH(FastPowf)(b, e); +} +inline float FastCosf(float f) { return HWY_STATIC_DISPATCH(FastCosf)(f); } +inline float FastErff(float f) { return HWY_STATIC_DISPATCH(FastErff)(f); } +} // namespace jxl + +#endif // FAST_MATH_ONCE +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/fast_math_test.cc b/media/libjxl/src/lib/jxl/fast_math_test.cc new file mode 100644 index 000000000..897aadc12 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fast_math_test.cc @@ -0,0 +1,288 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_math_test.cc" +#include + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/transfer_functions-inl.h" + +// Test utils +#include +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestFastLog2() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(1e-7f, 1e3f); + const auto actual_v = FastLog2f(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::log2(f) - actual); + EXPECT_LT(abs_err, 3.1E-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastPow2() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_rel_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(-100, 100); + const auto actual_v = FastPow2f(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float expected = std::pow(2, f); + const float rel_err = std::abs(expected - actual) / expected; + EXPECT_LT(rel_err, 3.1E-6) << "f = " << f; + max_rel_err = std::max(max_rel_err, rel_err); + } + printf("max rel err %e\n", static_cast(max_rel_err)); +} + +HWY_NOINLINE void TestFastPow() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_rel_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float b = rng.UniformF(1e-3f, 1e3f); + const float e = rng.UniformF(-10, 10); + const auto actual_v = FastPowf(d, Set(d, b), Set(d, e)); + const float actual = GetLane(actual_v); + const float expected = std::pow(b, e); + const float rel_err = std::abs(expected - actual) / expected; + EXPECT_LT(rel_err, 3E-5) << "b = " << b << " e = " << e; + max_rel_err = std::max(max_rel_err, rel_err); + } + printf("max rel err %e\n", static_cast(max_rel_err)); +} + +HWY_NOINLINE void TestFastCos() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(-1e3f, 1e3f); + const auto actual_v = FastCosf(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::cos(f) - actual); + EXPECT_LT(abs_err, 7E-5) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastErf() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(-5.f, 5.f); + const auto actual_v = FastErff(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::erf(f) - actual); + EXPECT_LT(abs_err, 7E-4) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestCubeRoot() { + const HWY_FULL(float) d; + for (uint64_t x5 = 0; x5 < 2000000; x5++) { + const float x = x5 * 1E-5f; + const float expected = cbrtf(x); + HWY_ALIGN float approx[MaxLanes(d)]; + Store(CubeRootAndAdd(Set(d, x), Zero(d)), d, approx); + + // All lanes are same + for (size_t i = 1; i < Lanes(d); ++i) { + EXPECT_NEAR(approx[0], approx[i], 5E-7f); + } + EXPECT_NEAR(approx[0], expected, 8E-7f); + } +} + +HWY_NOINLINE void TestFastSRGB() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const auto actual_v = FastLinearToSRGB(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float expected = GetLane(TF_SRGB().EncodedFromDisplay(d, Set(d, f))); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 1.2E-4) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastPQEFD() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(TF_PQ().EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_PQ().EncodedFromDisplay(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 7e-7) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastHLGEFD() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(TF_HLG().EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_HLG().EncodedFromDisplay(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 5e-7) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFast709EFD() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(TF_709().EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_709().EncodedFromDisplay(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 2e-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastPQDFE() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(TF_PQ().DisplayFromEncoded(d, Set(d, f))); + const float expected = TF_PQ().DisplayFromEncoded(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 3E-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastXYB() { + if (!HasFastXYBTosRGB8()) return; + ImageMetadata metadata; + ImageBundle ib(&metadata); + int scaling = 1; + int n = 256 * scaling; + float inv_scaling = 1.0f / scaling; + int kChunk = 32; + // The image is divided in chunks to reduce total memory usage. + for (int cr = 0; cr < n; cr += kChunk) { + for (int cg = 0; cg < n; cg += kChunk) { + for (int cb = 0; cb < n; cb += kChunk) { + Image3F chunk(kChunk * kChunk, kChunk); + for (int ir = 0; ir < kChunk; ir++) { + for (int ig = 0; ig < kChunk; ig++) { + for (int ib = 0; ib < kChunk; ib++) { + float r = (cr + ir) * inv_scaling; + float g = (cg + ig) * inv_scaling; + float b = (cb + ib) * inv_scaling; + chunk.PlaneRow(0, ir)[ig * kChunk + ib] = r * (1.0f / 255); + chunk.PlaneRow(1, ir)[ig * kChunk + ib] = g * (1.0f / 255); + chunk.PlaneRow(2, ir)[ig * kChunk + ib] = b * (1.0f / 255); + } + } + } + ib.SetFromImage(std::move(chunk), ColorEncoding::SRGB()); + Image3F xyb(kChunk * kChunk, kChunk); + std::vector roundtrip(kChunk * kChunk * kChunk * 3); + ToXYB(ib, nullptr, &xyb, GetJxlCms()); + for (int y = 0; y < kChunk; y++) { + const float* xyba[4] = {xyb.PlaneRow(0, y), xyb.PlaneRow(1, y), + xyb.PlaneRow(2, y), nullptr}; + jxl::HWY_NAMESPACE::FastXYBTosRGB8( + xyba, roundtrip.data() + 3 * xyb.xsize() * y, false, xyb.xsize()); + } + for (int ir = 0; ir < kChunk; ir++) { + for (int ig = 0; ig < kChunk; ig++) { + for (int ib = 0; ib < kChunk; ib++) { + float r = (cr + ir) * inv_scaling; + float g = (cg + ig) * inv_scaling; + float b = (cb + ib) * inv_scaling; + size_t idx = ir * kChunk * kChunk + ig * kChunk + ib; + int rr = roundtrip[3 * idx]; + int rg = roundtrip[3 * idx + 1]; + int rb = roundtrip[3 * idx + 2]; + EXPECT_LT(abs(r - rr), 2) << "expected " << r << " got " << rr; + EXPECT_LT(abs(g - rg), 2) << "expected " << g << " got " << rg; + EXPECT_LT(abs(b - rb), 2) << "expected " << b << " got " << rb; + } + } + } + } + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class FastMathTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(FastMathTargetTest); + +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastLog2); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow2); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastCos); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastErf); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestCubeRoot); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastSRGB); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPQDFE); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPQEFD); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastHLGEFD); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFast709EFD); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastXYB); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/field_encodings.h b/media/libjxl/src/lib/jxl/field_encodings.h new file mode 100644 index 000000000..5af749b2c --- /dev/null +++ b/media/libjxl/src/lib/jxl/field_encodings.h @@ -0,0 +1,134 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FIELD_ENCODINGS_H_ +#define LIB_JXL_FIELD_ENCODINGS_H_ + +// Constants needed to encode/decode fields; avoids including the full fields.h. + +#include +#include + +#include + +#include "hwy/base.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Macro to define the Fields' derived class Name when compiling with debug +// names. +#if JXL_IS_DEBUG_BUILD +#define JXL_FIELDS_NAME(X) \ + const char* Name() const override { return #X; } +#else +#define JXL_FIELDS_NAME(X) +#endif // JXL_IS_DEBUG_BUILD + +class Visitor; +class Fields { + public: + virtual ~Fields() = default; +#if JXL_IS_DEBUG_BUILD + virtual const char* Name() const = 0; +#endif // JXL_IS_DEBUG_BUILD + virtual Status VisitFields(Visitor* JXL_RESTRICT visitor) = 0; +}; + +// Distribution of U32 values for one particular selector. Represents either a +// power of two-sized range, or a single value. A separate type ensures this is +// only passed to the U32Enc ctor. +struct U32Distr { + // No need to validate - all `d` are legitimate. + constexpr explicit U32Distr(uint32_t d) : d(d) {} + + static constexpr uint32_t kDirect = 0x80000000u; + + constexpr bool IsDirect() const { return (d & kDirect) != 0; } + + // Only call if IsDirect(). + constexpr uint32_t Direct() const { return d & (kDirect - 1); } + + // Only call if !IsDirect(). + constexpr size_t ExtraBits() const { return (d & 0x1F) + 1; } + uint32_t Offset() const { return (d >> 5) & 0x3FFFFFF; } + + uint32_t d; +}; + +// A direct-coded 31-bit value occupying 2 bits in the bitstream. +constexpr U32Distr Val(uint32_t value) { + return U32Distr(value | U32Distr::kDirect); +} + +// Value - `offset` will be signaled in `bits` extra bits. +constexpr U32Distr BitsOffset(uint32_t bits, uint32_t offset) { + return U32Distr(((bits - 1) & 0x1F) + ((offset & 0x3FFFFFF) << 5)); +} + +// Value will be signaled in `bits` extra bits. +constexpr U32Distr Bits(uint32_t bits) { return BitsOffset(bits, 0); } + +// See U32Coder documentation in fields.h. +class U32Enc { + public: + constexpr U32Enc(const U32Distr d0, const U32Distr d1, const U32Distr d2, + const U32Distr d3) + : d_{d0, d1, d2, d3} {} + + // Returns the U32Distr at `selector` = 0..3, least-significant first. + U32Distr GetDistr(const uint32_t selector) const { + JXL_ASSERT(selector < 4); + return d_[selector]; + } + + private: + U32Distr d_[4]; +}; + +// Returns bit with the given `index` (0 = least significant). +template +static inline constexpr uint64_t MakeBit(T index) { + return 1ULL << static_cast(index); +} + +// Returns vector of all possible values of an Enum type. Relies on each Enum +// providing an overload of EnumBits() that returns a bit array of its values, +// which implies values must be in [0, 64). +template +std::vector Values() { + uint64_t bits = EnumBits(Enum()); + + std::vector values; + values.reserve(hwy::PopCount(bits)); + + // For each 1-bit in bits: add its index as value + while (bits != 0) { + const int index = Num0BitsBelowLS1Bit_Nonzero(bits); + values.push_back(static_cast(index)); + bits &= bits - 1; // clear least-significant bit + } + return values; +} + +// Returns true if value is one of Values(). +template +Status EnumValid(const Enum value) { + if (static_cast(value) >= 64) { + return JXL_FAILURE("Value %u too large for %s\n", + static_cast(value), EnumName(Enum())); + } + const uint64_t bit = MakeBit(value); + if ((EnumBits(Enum()) & bit) == 0) { + return JXL_FAILURE("Invalid value %u for %s\n", + static_cast(value), EnumName(Enum())); + } + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_FIELD_ENCODINGS_H_ diff --git a/media/libjxl/src/lib/jxl/fields.cc b/media/libjxl/src/lib/jxl/fields.cc new file mode 100644 index 000000000..e8d602585 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fields.cc @@ -0,0 +1,906 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/fields.h" + +#include + +#include +#include + +#include "hwy/base.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" + +namespace jxl { + +namespace { + +// A bundle can be in one of three states concerning extensions: not-begun, +// active, ended. Bundles may be nested, so we need a stack of states. +class ExtensionStates { + public: + void Push() { + // Initial state = not-begun. + begun_ <<= 1; + ended_ <<= 1; + } + + // Clears current state; caller must check IsEnded beforehand. + void Pop() { + begun_ >>= 1; + ended_ >>= 1; + } + + // Returns true if state == active || state == ended. + Status IsBegun() const { return (begun_ & 1) != 0; } + // Returns true if state != not-begun && state != active. + Status IsEnded() const { return (ended_ & 1) != 0; } + + void Begin() { + JXL_ASSERT(!IsBegun()); + JXL_ASSERT(!IsEnded()); + begun_ += 1; + } + + void End() { + JXL_ASSERT(IsBegun()); + JXL_ASSERT(!IsEnded()); + ended_ += 1; + } + + private: + // Current state := least-significant bit of begun_ and ended_. + uint64_t begun_ = 0; + uint64_t ended_ = 0; +}; + +// Visitors generate Init/AllDefault/Read/Write logic for all fields. Each +// bundle's VisitFields member function calls visitor->U32 etc. We do not +// overload operator() because a function name is easier to search for. + +class VisitorBase : public Visitor { + public: + explicit VisitorBase() {} + ~VisitorBase() override { JXL_ASSERT(depth_ == 0); } + + // This is the only call site of Fields::VisitFields. + // Ensures EndExtensions was called. + Status Visit(Fields* fields) override { + depth_ += 1; + JXL_ASSERT(depth_ <= Bundle::kMaxExtensions); + extension_states_.Push(); + + const Status ok = fields->VisitFields(this); + + if (ok) { + // If VisitFields called BeginExtensions, must also call + // EndExtensions. + JXL_ASSERT(!extension_states_.IsBegun() || extension_states_.IsEnded()); + } else { + // Failed, undefined state: don't care whether EndExtensions was + // called. + } + + extension_states_.Pop(); + JXL_ASSERT(depth_ != 0); + depth_ -= 1; + + return ok; + } + + // For visitors accepting a const Visitor, need to const-cast so we can call + // the non-const Visitor::VisitFields. NOTE: C is not modified except the + // `all_default` field by CanEncodeVisitor. + Status VisitConst(const Fields& t) { return Visit(const_cast(&t)); } + + // Derived types (overridden by InitVisitor because it is unsafe to read + // from *value there) + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + uint32_t bits = *value ? 1 : 0; + JXL_RETURN_IF_ERROR(Bits(1, static_cast(default_value), &bits)); + JXL_DASSERT(bits <= 1); + *value = bits == 1; + return true; + } + + // Overridden by ReadVisitor and WriteVisitor. + // Called before any conditional visit based on "extensions". + // Overridden by ReadVisitor, CanEncodeVisitor and WriteVisitor. + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_RETURN_IF_ERROR(U64(0, extensions)); + + extension_states_.Begin(); + return true; + } + + // Called after all extension fields (if any). Although non-extension + // fields could be visited afterward, we prefer the convention that + // extension fields are always the last to be visited. Overridden by + // ReadVisitor. + Status EndExtensions() override { + extension_states_.End(); + return true; + } + + private: + size_t depth_ = 0; // to check nesting + ExtensionStates extension_states_; +}; + +struct InitVisitor : public VisitorBase { + Status Bits(const size_t /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + // Always visit conditional fields to ensure they are initialized. + Status Conditional(bool /*condition*/) override { return true; } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + // Just initialize this field and don't skip initializing others. + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; + } + + Status VisitNested(Fields* /*fields*/) override { + // Avoid re-initializing nested bundles (their ctors already called + // Bundle::Init for their fields). + return true; + } +}; + +// Similar to InitVisitor, but also initializes nested fields. +struct SetDefaultVisitor : public VisitorBase { + Status Bits(const size_t /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + // Always visit conditional fields to ensure they are initialized. + Status Conditional(bool /*condition*/) override { return true; } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + // Just initialize this field and don't skip initializing others. + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; + } +}; + +class AllDefaultVisitor : public VisitorBase { + public: + explicit AllDefaultVisitor() : VisitorBase() {} + + Status Bits(const size_t bits, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + all_default_ &= *value == default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + all_default_ &= *value == default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + all_default_ &= *value == default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + all_default_ &= std::abs(*value - default_value) < 1E-6f; + return true; + } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT /*all_default*/) override { + // Visit all fields so we can compute the actual all_default_ value. + return false; + } + + bool AllDefault() const { return all_default_; } + + private: + bool all_default_ = true; +}; + +class ReadVisitor : public VisitorBase { + public: + explicit ReadVisitor(BitReader* reader) : VisitorBase(), reader_(reader) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + *value = BitsCoder::Read(bits, reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + Status U32(const U32Enc dist, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + *value = U32Coder::Read(dist, reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + *value = U64Coder::Read(reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + ok_ &= F16Coder::Read(reader_, value); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + void SetDefault(Fields* fields) override { Bundle::SetDefault(fields); } + + bool IsReading() const override { return true; } + + // This never fails because visitors are expected to keep reading until + // EndExtensions, see comment there. + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + if (*extensions == 0) return true; + + // For each nonzero bit, i.e. extension that is present: + for (uint64_t remaining_extensions = *extensions; remaining_extensions != 0; + remaining_extensions &= remaining_extensions - 1) { + const size_t idx_extension = + Num0BitsBelowLS1Bit_Nonzero(remaining_extensions); + // Read additional U64 (one per extension) indicating the number of bits + // (allows skipping individual extensions). + JXL_RETURN_IF_ERROR(U64(0, &extension_bits_[idx_extension])); + if (!SafeAdd(total_extension_bits_, extension_bits_[idx_extension], + total_extension_bits_)) { + return JXL_FAILURE("Extension bits overflowed, invalid codestream"); + } + } + // Used by EndExtensions to skip past any _remaining_ extensions. + pos_after_ext_size_ = reader_->TotalBitsConsumed(); + JXL_ASSERT(pos_after_ext_size_ != 0); + return true; + } + + Status EndExtensions() override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::EndExtensions()); + // Happens if extensions == 0: don't read size, done. + if (pos_after_ext_size_ == 0) return true; + + // Not enough bytes as set by BeginExtensions or earlier. Do not return + // this as an JXL_FAILURE or false (which can also propagate to error + // through e.g. JXL_RETURN_IF_ERROR), since this may be used while + // silently checking whether there are enough bytes. If this case must be + // treated as an error, reader_>Close() will do this, just like is already + // done for non-extension fields. + if (!enough_bytes_) return true; + + // Skip new fields this (old?) decoder didn't know about, if any. + const size_t bits_read = reader_->TotalBitsConsumed(); + uint64_t end; + if (!SafeAdd(pos_after_ext_size_, total_extension_bits_, end)) { + return JXL_FAILURE("Invalid extension size, caused overflow"); + } + if (bits_read > end) { + return JXL_FAILURE("Read more extension bits than budgeted"); + } + const size_t remaining_bits = end - bits_read; + if (remaining_bits != 0) { + JXL_WARNING("Skipping %" PRIuS "-bit extension(s)", remaining_bits); + reader_->SkipBits(remaining_bits); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + } + return true; + } + + Status OK() const { return ok_; } + + private: + // Whether any error other than not enough bytes occurred. + bool ok_ = true; + + // Whether there are enough input bytes to read from. + bool enough_bytes_ = true; + BitReader* const reader_; + // May be 0 even if the corresponding extension is present. + uint64_t extension_bits_[Bundle::kMaxExtensions] = {0}; + uint64_t total_extension_bits_ = 0; + size_t pos_after_ext_size_ = 0; // 0 iff extensions == 0. +}; + +class MaxBitsVisitor : public VisitorBase { + public: + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT /*value*/) override { + max_bits_ += BitsCoder::MaxEncodedBits(bits); + return true; + } + + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT /*value*/) override { + max_bits_ += U32Coder::MaxEncodedBits(enc); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT /*value*/) override { + max_bits_ += U64Coder::MaxEncodedBits(); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT /*value*/) override { + max_bits_ += F16Coder::MaxEncodedBits(); + return true; + } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; // For max bits, assume nothing is default + } + + // Always visit conditional fields to get a (loose) upper bound. + Status Conditional(bool /*condition*/) override { return true; } + + Status BeginExtensions(uint64_t* JXL_RESTRICT /*extensions*/) override { + // Skip - extensions are not included in "MaxBits" because their length + // is potentially unbounded. + return true; + } + + Status EndExtensions() override { return true; } + + size_t MaxBits() const { return max_bits_; } + + private: + size_t max_bits_ = 0; +}; + +class CanEncodeVisitor : public VisitorBase { + public: + explicit CanEncodeVisitor() : VisitorBase() {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= BitsCoder::CanEncode(bits, *value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= U32Coder::CanEncode(enc, *value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= U64Coder::CanEncode(*value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= F16Coder::CanEncode(*value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status AllDefault(const Fields& fields, + bool* JXL_RESTRICT all_default) override { + *all_default = Bundle::AllDefault(fields); + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return *all_default; + } + + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + extensions_ = *extensions; + if (*extensions != 0) { + JXL_ASSERT(pos_after_ext_ == 0); + pos_after_ext_ = encoded_bits_; + JXL_ASSERT(pos_after_ext_ != 0); // visited "extensions" + } + return true; + } + // EndExtensions = default. + + Status GetSizes(size_t* JXL_RESTRICT extension_bits, + size_t* JXL_RESTRICT total_bits) { + JXL_RETURN_IF_ERROR(ok_); + *extension_bits = 0; + *total_bits = encoded_bits_; + // Only if extension field was nonzero will we encode their sizes. + if (pos_after_ext_ != 0) { + JXL_ASSERT(encoded_bits_ >= pos_after_ext_); + *extension_bits = encoded_bits_ - pos_after_ext_; + // Also need to encode *extension_bits and bill it to *total_bits. + size_t encoded_bits = 0; + ok_ &= U64Coder::CanEncode(*extension_bits, &encoded_bits); + *total_bits += encoded_bits; + + // TODO(janwas): support encoding individual extension sizes. We + // currently ascribe all bits to the first and send zeros for the + // others. + for (size_t i = 1; i < hwy::PopCount(extensions_); ++i) { + encoded_bits = 0; + ok_ &= U64Coder::CanEncode(0, &encoded_bits); + *total_bits += encoded_bits; + } + } + return true; + } + + private: + bool ok_ = true; + size_t encoded_bits_ = 0; + uint64_t extensions_ = 0; + // Snapshot of encoded_bits_ after visiting the extension field, but NOT + // including the hidden extension sizes. + uint64_t pos_after_ext_ = 0; +}; + +class WriteVisitor : public VisitorBase { + public: + WriteVisitor(const size_t extension_bits, BitWriter* JXL_RESTRICT writer) + : extension_bits_(extension_bits), writer_(writer) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + ok_ &= BitsCoder::Write(bits, *value, writer_); + return true; + } + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + ok_ &= U32Coder::Write(enc, *value, writer_); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + ok_ &= U64Coder::Write(*value, writer_); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + ok_ &= F16Coder::Write(*value, writer_); + return true; + } + + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + if (*extensions == 0) { + JXL_ASSERT(extension_bits_ == 0); + return true; + } + // TODO(janwas): extend API to pass in array of extension_bits, one per + // extension. We currently ascribe all bits to the first extension, but + // this is only an encoder limitation. NOTE: extension_bits_ can be zero + // if an extension does not require any additional fields. + ok_ &= U64Coder::Write(extension_bits_, writer_); + // For each nonzero bit except the lowest/first (already written): + for (uint64_t remaining_extensions = *extensions & (*extensions - 1); + remaining_extensions != 0; + remaining_extensions &= remaining_extensions - 1) { + ok_ &= U64Coder::Write(0, writer_); + } + return true; + } + // EndExtensions = default. + + Status OK() const { return ok_; } + + private: + const size_t extension_bits_; + BitWriter* JXL_RESTRICT writer_; + bool ok_ = true; +}; + +} // namespace + +void Bundle::Init(Fields* fields) { + InitVisitor visitor; + if (!visitor.Visit(fields)) { + JXL_ABORT("Init should never fail"); + } +} +void Bundle::SetDefault(Fields* fields) { + SetDefaultVisitor visitor; + if (!visitor.Visit(fields)) { + JXL_ABORT("SetDefault should never fail"); + } +} +bool Bundle::AllDefault(const Fields& fields) { + AllDefaultVisitor visitor; + if (!visitor.VisitConst(fields)) { + JXL_ABORT("AllDefault should never fail"); + } + return visitor.AllDefault(); +} +size_t Bundle::MaxBits(const Fields& fields) { + MaxBitsVisitor visitor; +#if JXL_ENABLE_ASSERT + Status ret = +#else + (void) +#endif // JXL_ENABLE_ASSERT + visitor.VisitConst(fields); + JXL_ASSERT(ret); + return visitor.MaxBits(); +} +Status Bundle::CanEncode(const Fields& fields, size_t* extension_bits, + size_t* total_bits) { + CanEncodeVisitor visitor; + JXL_QUIET_RETURN_IF_ERROR(visitor.VisitConst(fields)); + JXL_QUIET_RETURN_IF_ERROR(visitor.GetSizes(extension_bits, total_bits)); + return true; +} +Status Bundle::Read(BitReader* reader, Fields* fields) { + ReadVisitor visitor(reader); + JXL_RETURN_IF_ERROR(visitor.Visit(fields)); + return visitor.OK(); +} +bool Bundle::CanRead(BitReader* reader, Fields* fields) { + ReadVisitor visitor(reader); + Status status = visitor.Visit(fields); + // We are only checking here whether there are enough bytes. We still return + // true for other errors because it means there are enough bytes to determine + // there's an error. Use Read() to determine which error it is. + return status.code() != StatusCode::kNotEnoughBytes; +} +Status Bundle::Write(const Fields& fields, BitWriter* writer, size_t layer, + AuxOut* aux_out) { + size_t extension_bits, total_bits; + JXL_RETURN_IF_ERROR(CanEncode(fields, &extension_bits, &total_bits)); + + BitWriter::Allotment allotment(writer, total_bits); + WriteVisitor visitor(extension_bits, writer); + JXL_RETURN_IF_ERROR(visitor.VisitConst(fields)); + JXL_RETURN_IF_ERROR(visitor.OK()); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return true; +} + +size_t U32Coder::MaxEncodedBits(const U32Enc enc) { + size_t extra_bits = 0; + for (uint32_t selector = 0; selector < 4; ++selector) { + const U32Distr d = enc.GetDistr(selector); + if (d.IsDirect()) { + continue; + } else { + extra_bits = std::max(extra_bits, d.ExtraBits()); + } + } + return 2 + extra_bits; +} + +Status U32Coder::CanEncode(const U32Enc enc, const uint32_t value, + size_t* JXL_RESTRICT encoded_bits) { + uint32_t selector; + size_t total_bits; + const Status ok = ChooseSelector(enc, value, &selector, &total_bits); + *encoded_bits = ok ? total_bits : 0; + return ok; +} + +uint32_t U32Coder::Read(const U32Enc enc, BitReader* JXL_RESTRICT reader) { + const uint32_t selector = reader->ReadFixedBits<2>(); + const U32Distr d = enc.GetDistr(selector); + if (d.IsDirect()) { + return d.Direct(); + } else { + return reader->ReadBits(d.ExtraBits()) + d.Offset(); + } +} + +// Returns false if the value is too large to encode. +Status U32Coder::Write(const U32Enc enc, const uint32_t value, + BitWriter* JXL_RESTRICT writer) { + uint32_t selector; + size_t total_bits; + JXL_RETURN_IF_ERROR(ChooseSelector(enc, value, &selector, &total_bits)); + + writer->Write(2, selector); + + const U32Distr d = enc.GetDistr(selector); + if (!d.IsDirect()) { // Nothing more to write for direct encoding + const uint32_t offset = d.Offset(); + JXL_ASSERT(value >= offset); + writer->Write(total_bits - 2, value - offset); + } + + return true; +} + +Status U32Coder::ChooseSelector(const U32Enc enc, const uint32_t value, + uint32_t* JXL_RESTRICT selector, + size_t* JXL_RESTRICT total_bits) { +#if JXL_ENABLE_ASSERT + const size_t bits_required = 32 - Num0BitsAboveMS1Bit(value); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(bits_required <= 32); + + *selector = 0; + *total_bits = 0; + + // It is difficult to verify whether Dist32Byte are sorted, so check all + // selectors and keep the one with the fewest total_bits. + *total_bits = 64; // more than any valid encoding + for (uint32_t s = 0; s < 4; ++s) { + const U32Distr d = enc.GetDistr(s); + if (d.IsDirect()) { + if (d.Direct() == value) { + *selector = s; + *total_bits = 2; + return true; // Done, direct is always the best possible. + } + continue; + } + const size_t extra_bits = d.ExtraBits(); + const uint32_t offset = d.Offset(); + if (value < offset || value >= offset + (1ULL << extra_bits)) continue; + + // Better than prior encoding, remember it: + if (2 + extra_bits < *total_bits) { + *selector = s; + *total_bits = 2 + extra_bits; + } + } + + if (*total_bits == 64) { + return JXL_FAILURE("No feasible selector for %u", value); + } + + return true; +} + +uint64_t U64Coder::Read(BitReader* JXL_RESTRICT reader) { + uint64_t selector = reader->ReadFixedBits<2>(); + if (selector == 0) { + return 0; + } + if (selector == 1) { + return 1 + reader->ReadFixedBits<4>(); + } + if (selector == 2) { + return 17 + reader->ReadFixedBits<8>(); + } + + // selector 3, varint, groups have first 12, then 8, and last 4 bits. + uint64_t result = reader->ReadFixedBits<12>(); + + uint64_t shift = 12; + while (reader->ReadFixedBits<1>()) { + if (shift == 60) { + result |= static_cast(reader->ReadFixedBits<4>()) << shift; + break; + } + result |= static_cast(reader->ReadFixedBits<8>()) << shift; + shift += 8; + } + + return result; +} + +// Returns false if the value is too large to encode. +Status U64Coder::Write(uint64_t value, BitWriter* JXL_RESTRICT writer) { + if (value == 0) { + // Selector: use 0 bits, value 0 + writer->Write(2, 0); + } else if (value <= 16) { + // Selector: use 4 bits, value 1..16 + writer->Write(2, 1); + writer->Write(4, value - 1); + } else if (value <= 272) { + // Selector: use 8 bits, value 17..272 + writer->Write(2, 2); + writer->Write(8, value - 17); + } else { + // Selector: varint, first a 12-bit group, after that per 8-bit group. + writer->Write(2, 3); + writer->Write(12, value & 4095); + value >>= 12; + int shift = 12; + while (value > 0 && shift < 60) { + // Indicate varint not done + writer->Write(1, 1); + writer->Write(8, value & 255); + value >>= 8; + shift += 8; + } + if (value > 0) { + // This only could happen if shift == N - 4. + writer->Write(1, 1); + writer->Write(4, value & 15); + // Implicitly closed sequence, no extra stop bit is required. + } else { + // Indicate end of varint + writer->Write(1, 0); + } + } + + return true; +} + +// Can always encode, but useful because it also returns bit size. +Status U64Coder::CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits) { + if (value == 0) { + *encoded_bits = 2; // 2 selector bits + } else if (value <= 16) { + *encoded_bits = 2 + 4; // 2 selector bits + 4 payload bits + } else if (value <= 272) { + *encoded_bits = 2 + 8; // 2 selector bits + 8 payload bits + } else { + *encoded_bits = 2 + 12; // 2 selector bits + 12 payload bits + value >>= 12; + int shift = 12; + while (value > 0 && shift < 60) { + *encoded_bits += 1 + 8; // 1 continuation bit + 8 payload bits + value >>= 8; + shift += 8; + } + if (value > 0) { + // This only could happen if shift == N - 4. + *encoded_bits += 1 + 4; // 1 continuation bit + 4 payload bits + } else { + *encoded_bits += 1; // 1 stop bit + } + } + + return true; +} + +Status F16Coder::Read(BitReader* JXL_RESTRICT reader, + float* JXL_RESTRICT value) { + const uint32_t bits16 = reader->ReadFixedBits<16>(); + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + if (JXL_UNLIKELY(biased_exp == 31)) { + return JXL_FAILURE("F16 infinity or NaN are not supported"); + } + + // Subnormal or zero + if (JXL_UNLIKELY(biased_exp == 0)) { + *value = (1.0f / 16384) * (mantissa * (1.0f / 1024)); + if (sign) *value = -*value; + return true; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + memcpy(value, &bits32, sizeof(bits32)); + return true; +} + +Status F16Coder::Write(float value, BitWriter* JXL_RESTRICT writer) { + uint32_t bits32; + memcpy(&bits32, &value, sizeof(bits32)); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = static_cast(biased_exp32) - 127; + if (JXL_UNLIKELY(exp > 15)) { + return JXL_FAILURE("Too big to encode, CanEncode should return false"); + } + + // Tiny or zero => zero. + if (exp < -24) { + writer->Write(16, 0); + return true; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (JXL_UNLIKELY(exp < -14)) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + JXL_ASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = (1 << (10 - sub_exp)) + (mantissa32 >> (13 + sub_exp)); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + JXL_ASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + JXL_ASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + JXL_ASSERT(bits16 < 0x10000); + writer->Write(16, bits16); + return true; +} + +Status F16Coder::CanEncode(float value, size_t* JXL_RESTRICT encoded_bits) { + *encoded_bits = MaxEncodedBits(); + if (std::isnan(value) || std::isinf(value)) { + return JXL_FAILURE("Should not attempt to store NaN and infinity"); + } + return std::abs(value) <= 65504.0f; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/fields.h b/media/libjxl/src/lib/jxl/fields.h new file mode 100644 index 000000000..18a57cfca --- /dev/null +++ b/media/libjxl/src/lib/jxl/fields.h @@ -0,0 +1,290 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FIELDS_H_ +#define LIB_JXL_FIELDS_H_ + +// Forward/backward-compatible 'bundles' with auto-serialized 'fields'. + +#include +#include +#include +#include +#include +#include + +#include +#include // abs +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// Integer coders: BitsCoder (raw), U32Coder (table), U64Coder (varint). + +// Reads/writes a given (fixed) number of bits <= 32. +class BitsCoder { + public: + static size_t MaxEncodedBits(const size_t bits) { return bits; } + + static Status CanEncode(const size_t bits, const uint32_t value, + size_t* JXL_RESTRICT encoded_bits) { + *encoded_bits = bits; + if (value >= (1ULL << bits)) { + return JXL_FAILURE("Value %u too large for %" PRIu64 " bits", value, + static_cast(bits)); + } + return true; + } + + static uint32_t Read(const size_t bits, BitReader* JXL_RESTRICT reader) { + return reader->ReadBits(bits); + } + + // Returns false if the value is too large to encode. + static Status Write(const size_t bits, const uint32_t value, + BitWriter* JXL_RESTRICT writer) { + if (value >= (1ULL << bits)) { + return JXL_FAILURE("Value %d too large to encode in %" PRIu64 " bits", + value, static_cast(bits)); + } + writer->Write(bits, value); + return true; + } +}; + +// Encodes u32 using a lookup table and/or extra bits, governed by a per-field +// encoding `enc` which consists of four distributions `d` chosen via a 2-bit +// selector (least significant = 0). Each d may have two modes: +// - direct: if d.IsDirect(), the value is d.Direct(); +// - offset: the value is derived from d.ExtraBits() extra bits plus d.Offset(); +// This encoding is denser than Exp-Golomb or Gamma codes when both small and +// large values occur. +// +// Examples: +// Direct: U32Enc(Val(8), Val(16), Val(32), Bits(6)), value 32 => 10b. +// Offset: U32Enc(Val(0), BitsOffset(1, 1), BitsOffset(2, 3), BitsOffset(8, 8)) +// defines the following prefix code: +// 00 -> 0 +// 01x -> 1..2 +// 10xx -> 3..7 +// 11xxxxxxxx -> 8..263 +class U32Coder { + public: + static size_t MaxEncodedBits(U32Enc enc); + static Status CanEncode(U32Enc enc, uint32_t value, + size_t* JXL_RESTRICT encoded_bits); + static uint32_t Read(U32Enc enc, BitReader* JXL_RESTRICT reader); + + // Returns false if the value is too large to encode. + static Status Write(U32Enc enc, uint32_t value, + BitWriter* JXL_RESTRICT writer); + + private: + static Status ChooseSelector(U32Enc enc, uint32_t value, + uint32_t* JXL_RESTRICT selector, + size_t* JXL_RESTRICT total_bits); +}; + +// Encodes 64-bit unsigned integers with a fixed distribution, taking 2 bits +// to encode 0, 6 bits to encode 1 to 16, 10 bits to encode 17 to 272, 15 bits +// to encode up to 4095, and on the order of log2(value) * 1.125 bits for +// larger values. +class U64Coder { + public: + static constexpr size_t MaxEncodedBits() { + return 2 + 12 + 6 * (8 + 1) + (4 + 1); + } + + static uint64_t Read(BitReader* JXL_RESTRICT reader); + + // Returns false if the value is too large to encode. + static Status Write(uint64_t value, BitWriter* JXL_RESTRICT writer); + + // Can always encode, but useful because it also returns bit size. + static Status CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits); +}; + +// IEEE 754 half-precision (binary16). Refuses to read/write NaN/Inf. +class F16Coder { + public: + static constexpr size_t MaxEncodedBits() { return 16; } + + // Returns false if the bit representation is NaN or infinity + static Status Read(BitReader* JXL_RESTRICT reader, float* JXL_RESTRICT value); + + // Returns false if the value is too large to encode. + static Status Write(float value, BitWriter* JXL_RESTRICT writer); + static Status CanEncode(float value, size_t* JXL_RESTRICT encoded_bits); +}; + +// A "bundle" is a forward- and backward compatible collection of fields. +// They are used for SizeHeader/FrameHeader/GroupHeader. Bundles can be +// extended by appending(!) fields. Optional fields may be omitted from the +// bitstream by conditionally visiting them. When reading new bitstreams with +// old code, we skip unknown fields at the end of the bundle. This requires +// storing the amount of extra appended bits, and that fields are visited in +// chronological order of being added to the format, because old decoders +// cannot skip some future fields and resume reading old fields. Similarly, +// new readers query bits in an "extensions" field to skip (groups of) fields +// not present in old bitstreams. Note that each bundle must include an +// "extensions" field prior to freezing the format, otherwise it cannot be +// extended. +// +// To ensure interoperability, there will be no opaque fields. +// +// HOWTO: +// - basic usage: define a struct with member variables ("fields") and a +// VisitFields(v) member function that calls v->U32/Bool etc. for each +// field, specifying their default values. The ctor must call +// Bundle::Init(this). +// +// - print a trace of visitors: ensure each bundle has a static Name() member +// function, and change Bundle::Print* to return true. +// +// - optional fields: in VisitFields, add if (v->Conditional(your_condition)) +// { v->Bool(default, &field); }. This prevents reading/writing field +// if !your_condition, which is typically computed from a prior field. +// WARNING: to ensure all fields are initialized, do not add an else branch; +// instead add another if (v->Conditional(!your_condition)). +// +// - repeated fields: for dynamic sizes, use e.g. std::vector and in +// VisitFields, if (v->IsReading()) field.resize(size) before accessing field. +// For static or bounded sizes, use an array or std::array. In all cases, +// simply visit each array element as if it were a normal field. +// +// - nested bundles: add a bundle as a normal field and in VisitFields call +// JXL_RETURN_IF_ERROR(v->VisitNested(&nested)); +// +// - allow future extensions: define a "uint64_t extensions" field and call +// v->BeginExtensions(&extensions) after visiting all non-extension fields, +// and `return v->EndExtensions();` after the last extension field. +// +// - encode an entire bundle in one bit if ALL its fields equal their default +// values: add a "mutable bool all_default" field and as the first visitor: +// if (v->AllDefault(*this, &all_default)) { +// // Overwrite all serialized fields, but not any nonserialized_*. +// v->SetDefault(this); +// return true; +// } +// Note: if extensions are present, AllDefault() == false. + +class Bundle { + public: + static constexpr size_t kMaxExtensions = 64; // bits in u64 + + // Initializes fields to the default values. It is not recursive to nested + // fields, this function is intended to be called in the constructors so + // each nested field will already Init itself. + static void Init(Fields* JXL_RESTRICT fields); + + // Similar to Init, but recursive to nested fields. + static void SetDefault(Fields* JXL_RESTRICT fields); + + // Returns whether ALL fields (including `extensions`, if present) are equal + // to their default value. + static bool AllDefault(const Fields& fields); + + // Returns max number of bits required to encode a T. + static size_t MaxBits(const Fields& fields); + + // Returns whether a header's fields can all be encoded, i.e. they have a + // valid representation. If so, "*total_bits" is the exact number of bits + // required. Called by Write. + static Status CanEncode(const Fields& fields, + size_t* JXL_RESTRICT extension_bits, + size_t* JXL_RESTRICT total_bits); + + static Status Read(BitReader* reader, Fields* JXL_RESTRICT fields); + + // Returns whether enough bits are available to fully read this bundle using + // Read. Also returns true in case of a codestream error (other than not being + // large enough): that means enough bits are available to determine there's an + // error, use Read to get such error status. + // NOTE: this advances the BitReader, a different one pointing back at the + // original bit position in the codestream must be created to use Read after + // this. + static bool CanRead(BitReader* reader, Fields* JXL_RESTRICT fields); + + static Status Write(const Fields& fields, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out); + + private: +}; + +// Different subclasses of Visitor are passed to implementations of Fields +// throughout their lifetime. Templates used to be used for this but dynamic +// polymorphism produces more compact executables than template reification did. +class Visitor { + public: + virtual ~Visitor() = default; + virtual Status Visit(Fields* fields) = 0; + + virtual Status Bool(bool default_value, bool* JXL_RESTRICT value) = 0; + virtual Status U32(U32Enc, uint32_t, uint32_t*) = 0; + + // Helper to construct U32Enc from U32Distr. + Status U32(const U32Distr d0, const U32Distr d1, const U32Distr d2, + const U32Distr d3, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) { + return U32(U32Enc(d0, d1, d2, d3), default_value, value); + } + + template + Status Enum(const EnumT default_value, EnumT* JXL_RESTRICT value) { + uint32_t u32 = static_cast(*value); + // 00 -> 0 + // 01 -> 1 + // 10xxxx -> 2..17 + // 11yyyyyy -> 18..81 + JXL_RETURN_IF_ERROR(U32(Val(0), Val(1), BitsOffset(4, 2), BitsOffset(6, 18), + static_cast(default_value), &u32)); + *value = static_cast(u32); + return EnumValid(*value); + } + + virtual Status Bits(size_t bits, uint32_t default_value, + uint32_t* JXL_RESTRICT value) = 0; + virtual Status U64(uint64_t default_value, uint64_t* JXL_RESTRICT value) = 0; + virtual Status F16(float default_value, float* JXL_RESTRICT value) = 0; + + // Returns whether VisitFields should visit some subsequent fields. + // "condition" is typically from prior fields, e.g. flags. + // Overridden by InitVisitor and MaxBitsVisitor. + virtual Status Conditional(bool condition) { return condition; } + + // Overridden by InitVisitor, AllDefaultVisitor and CanEncodeVisitor. + virtual Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) { + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return *all_default; + } + + virtual void SetDefault(Fields* /*fields*/) { + // Do nothing by default, this is overridden by ReadVisitor. + } + + // Returns the result of visiting a nested Bundle. + // Overridden by InitVisitor. + virtual Status VisitNested(Fields* fields) { return Visit(fields); } + + // Overridden by ReadVisitor. Enables dynamically-sized fields. + virtual bool IsReading() const { return false; } + + virtual Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) = 0; + virtual Status EndExtensions() = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_FIELDS_H_ diff --git a/media/libjxl/src/lib/jxl/fields_test.cc b/media/libjxl/src/lib/jxl/fields_test.cc new file mode 100644 index 000000000..c11b05230 --- /dev/null +++ b/media/libjxl/src/lib/jxl/fields_test.cc @@ -0,0 +1,434 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/fields.h" + +#include +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/common.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" + +namespace jxl { +namespace { + +// Ensures `value` round-trips and in exactly `expected_bits_written`. +void TestU32Coder(const uint32_t value, const size_t expected_bits_written) { + U32Coder coder; + const U32Enc enc(Val(0), Bits(4), Val(0x7FFFFFFF), Bits(32)); + + BitWriter writer; + BitWriter::Allotment allotment( + &writer, RoundUpBitsToByteMultiple(U32Coder::MaxEncodedBits(enc))); + + size_t precheck_pos; + EXPECT_TRUE(coder.CanEncode(enc, value, &precheck_pos)); + EXPECT_EQ(expected_bits_written, precheck_pos); + + EXPECT_TRUE(coder.Write(enc, value, &writer)); + EXPECT_EQ(expected_bits_written, writer.BitsWritten()); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + + BitReader reader(writer.GetSpan()); + const uint32_t decoded_value = coder.Read(enc, &reader); + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); +} + +TEST(FieldsTest, U32CoderTest) { + TestU32Coder(0, 2); + TestU32Coder(1, 6); + TestU32Coder(15, 6); + TestU32Coder(0x7FFFFFFF, 2); + TestU32Coder(128, 34); + TestU32Coder(0x7FFFFFFEu, 34); + TestU32Coder(0x80000000u, 34); + TestU32Coder(0xFFFFFFFFu, 34); +} + +void TestU64Coder(const uint64_t value, const size_t expected_bits_written) { + U64Coder coder; + + BitWriter writer; + BitWriter::Allotment allotment( + &writer, RoundUpBitsToByteMultiple(U64Coder::MaxEncodedBits())); + + size_t precheck_pos; + EXPECT_TRUE(coder.CanEncode(value, &precheck_pos)); + EXPECT_EQ(expected_bits_written, precheck_pos); + + EXPECT_TRUE(coder.Write(value, &writer)); + EXPECT_EQ(expected_bits_written, writer.BitsWritten()); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + + BitReader reader(writer.GetSpan()); + const uint64_t decoded_value = coder.Read(&reader); + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); +} + +TEST(FieldsTest, U64CoderTest) { + // Values that should take 2 bits (selector 00): 0 + TestU64Coder(0, 2); + + // Values that should take 6 bits (2 for selector, 4 for value): 1..16 + TestU64Coder(1, 6); + TestU64Coder(2, 6); + TestU64Coder(8, 6); + TestU64Coder(15, 6); + TestU64Coder(16, 6); + + // Values that should take 10 bits (2 for selector, 8 for value): 17..272 + TestU64Coder(17, 10); + TestU64Coder(18, 10); + TestU64Coder(100, 10); + TestU64Coder(271, 10); + TestU64Coder(272, 10); + + // Values that should take 15 bits (2 for selector, 12 for value, 1 for varint + // end): (0)..273..4095 + TestU64Coder(273, 15); + TestU64Coder(274, 15); + TestU64Coder(1000, 15); + TestU64Coder(4094, 15); + TestU64Coder(4095, 15); + + // Take 24 bits (of which 20 actual value): (0)..4096..1048575 + TestU64Coder(4096, 24); + TestU64Coder(4097, 24); + TestU64Coder(10000, 24); + TestU64Coder(1048574, 24); + TestU64Coder(1048575, 24); + + // Take 33 bits (of which 28 actual value): (0)..1048576..268435455 + TestU64Coder(1048576, 33); + TestU64Coder(1048577, 33); + TestU64Coder(10000000, 33); + TestU64Coder(268435454, 33); + TestU64Coder(268435455, 33); + + // Take 42 bits (of which 36 actual value): (0)..268435456..68719476735 + TestU64Coder(268435456ull, 42); + TestU64Coder(268435457ull, 42); + TestU64Coder(1000000000ull, 42); + TestU64Coder(68719476734ull, 42); + TestU64Coder(68719476735ull, 42); + + // Take 51 bits (of which 44 actual value): (0)..68719476736..17592186044415 + TestU64Coder(68719476736ull, 51); + TestU64Coder(68719476737ull, 51); + TestU64Coder(1000000000000ull, 51); + TestU64Coder(17592186044414ull, 51); + TestU64Coder(17592186044415ull, 51); + + // Take 60 bits (of which 52 actual value): + // (0)..17592186044416..4503599627370495 + TestU64Coder(17592186044416ull, 60); + TestU64Coder(17592186044417ull, 60); + TestU64Coder(100000000000000ull, 60); + TestU64Coder(4503599627370494ull, 60); + TestU64Coder(4503599627370495ull, 60); + + // Take 69 bits (of which 60 actual value): + // (0)..4503599627370496..1152921504606846975 + TestU64Coder(4503599627370496ull, 69); + TestU64Coder(4503599627370497ull, 69); + TestU64Coder(10000000000000000ull, 69); + TestU64Coder(1152921504606846974ull, 69); + TestU64Coder(1152921504606846975ull, 69); + + // Take 73 bits (of which 64 actual value): + // (0)..1152921504606846976..18446744073709551615 + TestU64Coder(1152921504606846976ull, 73); + TestU64Coder(1152921504606846977ull, 73); + TestU64Coder(10000000000000000000ull, 73); + TestU64Coder(18446744073709551614ull, 73); + TestU64Coder(18446744073709551615ull, 73); +} + +Status TestF16Coder(const float value) { + F16Coder coder; + + size_t max_encoded_bits; + // It is not a fatal error if it can't be encoded. + if (!coder.CanEncode(value, &max_encoded_bits)) return false; + EXPECT_EQ(F16Coder::MaxEncodedBits(), max_encoded_bits); + + BitWriter writer; + BitWriter::Allotment allotment(&writer, + RoundUpBitsToByteMultiple(max_encoded_bits)); + + EXPECT_TRUE(coder.Write(value, &writer)); + EXPECT_EQ(F16Coder::MaxEncodedBits(), writer.BitsWritten()); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + + BitReader reader(writer.GetSpan()); + float decoded_value; + EXPECT_TRUE(coder.Read(&reader, &decoded_value)); + // All values we test can be represented exactly. + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); + return true; +} + +TEST(FieldsTest, F16CoderTest) { + for (float sign : {-1.0f, 1.0f}) { + // (anything less than 1E-3 are subnormals) + for (float mag : {0.0f, 0.5f, 1.0f, 2.0f, 2.5f, 16.015625f, 1.0f / 4096, + 1.0f / 16384, 65504.0f}) { + EXPECT_TRUE(TestF16Coder(sign * mag)); + } + } + + // Out of range + EXPECT_FALSE(TestF16Coder(65504.01f)); + EXPECT_FALSE(TestF16Coder(-65505.0f)); +} + +// Ensures Read(Write()) returns the same fields. +TEST(FieldsTest, TestRoundtripSize) { + for (int i = 0; i < 8; i++) { + SizeHeader size; + ASSERT_TRUE(size.Set(123 + 77 * i, 7 + i)); + + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_TRUE(Bundle::CanEncode(size, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + + BitWriter writer; + ASSERT_TRUE(WriteSizeHeader(size, &writer, 0, nullptr)); + EXPECT_EQ(total_bits, writer.BitsWritten()); + writer.ZeroPadToByte(); + + SizeHeader size2; + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadSizeHeader(&reader, &size2)); + EXPECT_EQ(total_bits, reader.TotalBitsConsumed()); + EXPECT_TRUE(reader.Close()); + + EXPECT_EQ(size.xsize(), size2.xsize()); + EXPECT_EQ(size.ysize(), size2.ysize()); + } +} + +// Ensure all values can be reached by the encoding. +TEST(FieldsTest, TestCropRect) { + CodecMetadata metadata; + for (int32_t i = -999; i < 19000; ++i) { + FrameHeader f(&metadata); + f.custom_size_or_origin = true; + f.frame_origin.x0 = i; + f.frame_origin.y0 = i; + f.frame_size.xsize = 1000 + i; + f.frame_size.ysize = 1000 + i; + size_t extension_bits = 0, total_bits = 0; + ASSERT_TRUE(Bundle::CanEncode(f, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + EXPECT_GE(total_bits, 9u); + } +} +TEST(FieldsTest, TestPreview) { + // (div8 cannot represent 4360, but !div8 can go a little higher) + for (uint32_t i = 1; i < 4360; ++i) { + PreviewHeader p; + ASSERT_TRUE(p.Set(i, i)); + size_t extension_bits = 0, total_bits = 0; + ASSERT_TRUE(Bundle::CanEncode(p, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + EXPECT_GE(total_bits, 6u); + } +} + +// Ensures Read(Write()) returns the same fields. +TEST(FieldsTest, TestRoundtripFrame) { + CodecMetadata metadata; + FrameHeader h(&metadata); + h.extensions = 0x800; + + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_TRUE(Bundle::CanEncode(h, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + BitWriter writer; + ASSERT_TRUE(WriteFrameHeader(h, &writer, nullptr)); + EXPECT_EQ(total_bits, writer.BitsWritten()); + writer.ZeroPadToByte(); + + FrameHeader h2(&metadata); + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadFrameHeader(&reader, &h2)); + EXPECT_EQ(total_bits, reader.TotalBitsConsumed()); + EXPECT_TRUE(reader.Close()); + + EXPECT_EQ(h.extensions, h2.extensions); + EXPECT_EQ(h.flags, h2.flags); +} + +#ifndef JXL_CRASH_ON_ERROR +// Ensure out-of-bounds values cause an error. +TEST(FieldsTest, TestOutOfRange) { + SizeHeader h; + ASSERT_TRUE(h.Set(0xFFFFFFFFull, 0xFFFFFFFFull)); + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_FALSE(Bundle::CanEncode(h, &extension_bits, &total_bits)); +} +#endif + +struct OldBundle : public Fields { + OldBundle() { Bundle::Init(this); } + JXL_FIELDS_NAME(OldBundle) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Bits(2), Bits(3), Bits(4), 1, &old_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.125f, &old_f)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(7), Bits(12), Bits(16), Bits(32), 0, &old_large)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + return visitor->EndExtensions(); + } + + uint32_t old_small; + float old_f; + uint32_t old_large; + uint64_t extensions; +}; + +struct NewBundle : public Fields { + NewBundle() { Bundle::Init(this); } + JXL_FIELDS_NAME(NewBundle) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Bits(2), Bits(3), Bits(4), 1, &old_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.125f, &old_f)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(7), Bits(12), Bits(16), Bits(32), 0, &old_large)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + if (visitor->Conditional(extensions & 1)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(2), Bits(2), Bits(3), Bits(4), 2, &new_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(-2.0f, &new_f)); + } + if (visitor->Conditional(extensions & 2)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(9), Bits(12), Bits(16), Bits(32), 0, &new_large)); + } + return visitor->EndExtensions(); + } + + uint32_t old_small; + float old_f; + uint32_t old_large; + uint64_t extensions; + + // If extensions & 1 + uint32_t new_small = 2; + float new_f = -2.0f; + // If extensions & 2 + uint32_t new_large = 0; +}; + +TEST(FieldsTest, TestNewDecoderOldData) { + OldBundle old_bundle; + old_bundle.old_large = 123; + old_bundle.old_f = 3.75f; + old_bundle.extensions = 0; + + // Write to bit stream + const size_t kMaxOutBytes = 999; + BitWriter writer; + // Make sure values are initialized by code under test. + size_t extension_bits = 12345, total_bits = 12345; + ASSERT_TRUE(Bundle::CanEncode(old_bundle, &extension_bits, &total_bits)); + ASSERT_LE(total_bits, kMaxOutBytes * kBitsPerByte); + EXPECT_EQ(0u, extension_bits); + AuxOut aux_out; + ASSERT_TRUE(Bundle::Write(old_bundle, &writer, kLayerHeader, &aux_out)); + + BitWriter::Allotment allotment(&writer, + kMaxOutBytes * kBitsPerByte - total_bits); + writer.Write(20, 0xA55A); // sentinel + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, kLayerHeader, nullptr); + + ASSERT_LE(writer.GetSpan().size(), kMaxOutBytes); + BitReader reader(writer.GetSpan()); + NewBundle new_bundle; + ASSERT_TRUE(Bundle::Read(&reader, &new_bundle)); + EXPECT_EQ(reader.TotalBitsConsumed(), + aux_out.layers[kLayerHeader].total_bits); + EXPECT_EQ(reader.ReadBits(20), 0xA55Au); + EXPECT_TRUE(reader.Close()); + + // Old fields are the same in both + EXPECT_EQ(old_bundle.extensions, new_bundle.extensions); + EXPECT_EQ(old_bundle.old_small, new_bundle.old_small); + EXPECT_EQ(old_bundle.old_f, new_bundle.old_f); + EXPECT_EQ(old_bundle.old_large, new_bundle.old_large); + // New fields match their defaults + EXPECT_EQ(2u, new_bundle.new_small); + EXPECT_EQ(-2.0f, new_bundle.new_f); + EXPECT_EQ(0u, new_bundle.new_large); +} + +TEST(FieldsTest, TestOldDecoderNewData) { + NewBundle new_bundle; + new_bundle.old_large = 123; + new_bundle.extensions = 3; + new_bundle.new_f = 999.0f; + new_bundle.new_large = 456; + + // Write to bit stream + constexpr size_t kMaxOutBytes = 999; + BitWriter writer; + // Make sure values are initialized by code under test. + size_t extension_bits = 12345, total_bits = 12345; + ASSERT_TRUE(Bundle::CanEncode(new_bundle, &extension_bits, &total_bits)); + EXPECT_NE(0u, extension_bits); + AuxOut aux_out; + ASSERT_TRUE(Bundle::Write(new_bundle, &writer, kLayerHeader, &aux_out)); + ASSERT_LE(aux_out.layers[kLayerHeader].total_bits, + kMaxOutBytes * kBitsPerByte); + + BitWriter::Allotment allotment( + &writer, + kMaxOutBytes * kBitsPerByte - aux_out.layers[kLayerHeader].total_bits); + // Ensure Read skips the additional fields + writer.Write(20, 0xA55A); // sentinel + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, kLayerHeader, nullptr); + + BitReader reader(writer.GetSpan()); + OldBundle old_bundle; + ASSERT_TRUE(Bundle::Read(&reader, &old_bundle)); + EXPECT_EQ(reader.TotalBitsConsumed(), + aux_out.layers[kLayerHeader].total_bits); + EXPECT_EQ(reader.ReadBits(20), 0xA55Au); + EXPECT_TRUE(reader.Close()); + + // Old fields are the same in both + EXPECT_EQ(new_bundle.extensions, old_bundle.extensions); + EXPECT_EQ(new_bundle.old_small, old_bundle.old_small); + EXPECT_EQ(new_bundle.old_f, old_bundle.old_f); + EXPECT_EQ(new_bundle.old_large, old_bundle.old_large); + // (Can't check new fields because old decoder doesn't know about them) +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/frame_header.cc b/media/libjxl/src/lib/jxl/frame_header.cc new file mode 100644 index 000000000..e69a12c51 --- /dev/null +++ b/media/libjxl/src/lib/jxl/frame_header.cc @@ -0,0 +1,470 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/frame_header.h" + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +constexpr uint8_t YCbCrChromaSubsampling::kHShift[]; +constexpr uint8_t YCbCrChromaSubsampling::kVShift[]; + +static Status VisitBlendMode(Visitor* JXL_RESTRICT visitor, + BlendMode default_value, BlendMode* blend_mode) { + uint32_t encoded = static_cast(*blend_mode); + + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(static_cast(BlendMode::kReplace)), + Val(static_cast(BlendMode::kAdd)), + Val(static_cast(BlendMode::kBlend)), BitsOffset(2, 3), + static_cast(default_value), &encoded)); + if (encoded > 4) { + return JXL_FAILURE("Invalid blend_mode"); + } + *blend_mode = static_cast(encoded); + return true; +} + +static Status VisitFrameType(Visitor* JXL_RESTRICT visitor, + FrameType default_value, FrameType* frame_type) { + uint32_t encoded = static_cast(*frame_type); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(static_cast(FrameType::kRegularFrame)), + Val(static_cast(FrameType::kDCFrame)), + Val(static_cast(FrameType::kReferenceOnly)), + Val(static_cast(FrameType::kSkipProgressive)), + static_cast(default_value), &encoded)); + *frame_type = static_cast(encoded); + return true; +} + +BlendingInfo::BlendingInfo() { Bundle::Init(this); } + +Status BlendingInfo::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR( + VisitBlendMode(visitor, BlendMode::kReplace, &mode)); + if (visitor->Conditional(nonserialized_num_extra_channels > 0 && + (mode == BlendMode::kBlend || + mode == BlendMode::kAlphaWeightedAdd))) { + // Up to 11 alpha channels for blending. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(0), Val(1), Val(2), BitsOffset(3, 3), 0, &alpha_channel)); + if (visitor->IsReading() && + alpha_channel >= nonserialized_num_extra_channels) { + return JXL_FAILURE("Invalid alpha channel for blending"); + } + } + if (visitor->Conditional((nonserialized_num_extra_channels > 0 && + (mode == BlendMode::kBlend || + mode == BlendMode::kAlphaWeightedAdd)) || + mode == BlendMode::kMul)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &clamp)); + } + // 'old' frame for blending. Only necessary if this is not a full frame, or + // blending is not kReplace. + if (visitor->Conditional(mode != BlendMode::kReplace || + nonserialized_is_partial_frame)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Val(3), 0, &source)); + } + return true; +} + +AnimationFrame::AnimationFrame(const CodecMetadata* metadata) + : nonserialized_metadata(metadata) { + Bundle::Init(this); +} +Status AnimationFrame::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->Conditional(nonserialized_metadata != nullptr && + nonserialized_metadata->m.have_animation)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Bits(8), Bits(32), 0, &duration)); + } + + if (visitor->Conditional( + nonserialized_metadata != nullptr && + nonserialized_metadata->m.animation.have_timecodes)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(32, 0, &timecode)); + } + return true; +} + +YCbCrChromaSubsampling::YCbCrChromaSubsampling() { Bundle::Init(this); } +Passes::Passes() { Bundle::Init(this); } +Status Passes::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(3, 4), 1, &num_passes)); + JXL_ASSERT(num_passes <= kMaxNumPasses); // Cannot happen when reading + + if (visitor->Conditional(num_passes != 1)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(0), Val(1), Val(2), BitsOffset(1, 3), 0, &num_downsample)); + JXL_ASSERT(num_downsample <= 4); // 1,2,4,8 + if (num_downsample > num_passes) { + return JXL_FAILURE("num_downsample %u > num_passes %u", num_downsample, + num_passes); + } + + for (uint32_t i = 0; i < num_passes - 1; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 0, &shift[i])); + } + shift[num_passes - 1] = 0; + + for (uint32_t i = 0; i < num_downsample; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &downsample[i])); + if (i > 0 && downsample[i] >= downsample[i - 1]) { + return JXL_FAILURE("downsample sequence should be decreasing"); + } + } + for (uint32_t i = 0; i < num_downsample; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Bits(3), 0, &last_pass[i])); + if (i > 0 && last_pass[i] <= last_pass[i - 1]) { + return JXL_FAILURE("last_pass sequence should be increasing"); + } + if (last_pass[i] >= num_passes) { + return JXL_FAILURE("last_pass %u >= num_passes %u", last_pass[i], + num_passes); + } + } + } + + return true; +} + +std::string Passes::DebugString() const { + std::ostringstream os; + os << "p=" << num_passes; + if (num_downsample) { + os << ",ds="; + for (uint32_t i = 0; i < num_downsample; ++i) { + os << last_pass[i] << ":" << downsample[i]; + if (i + 1 < num_downsample) os << ";"; + } + } + bool have_shifts = false; + for (uint32_t i = 0; i < num_passes; ++i) { + if (shift[i]) have_shifts = true; + } + if (have_shifts) { + os << ",shifts="; + for (uint32_t i = 0; i < num_passes; ++i) { + os << shift[i]; + if (i + 1 < num_passes) os << ";"; + } + } + return os.str(); +} + +FrameHeader::FrameHeader(const CodecMetadata* metadata) + : animation_frame(metadata), nonserialized_metadata(metadata) { + Bundle::Init(this); +} + +Status ReadFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame) { + return Bundle::Read(reader, frame); +} + +Status WriteFrameHeader(const FrameHeader& frame, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { + return Bundle::Write(frame, writer, kLayerHeader, aux_out); +} + +Status FrameHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR( + VisitFrameType(visitor, FrameType::kRegularFrame, &frame_type)); + if (visitor->IsReading() && nonserialized_is_preview && + frame_type != kRegularFrame) { + return JXL_FAILURE("Only regular frame could be a preview"); + } + + // FrameEncoding. + bool is_modular = (encoding == FrameEncoding::kModular); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &is_modular)); + encoding = (is_modular ? FrameEncoding::kModular : FrameEncoding::kVarDCT); + + // Flags + JXL_QUIET_RETURN_IF_ERROR(visitor->U64(0, &flags)); + + // Color transform + bool xyb_encoded = nonserialized_metadata == nullptr || + nonserialized_metadata->m.xyb_encoded; + + if (xyb_encoded) { + color_transform = ColorTransform::kXYB; + } else { + // Alternate if kYCbCr. + bool alternate = color_transform == ColorTransform::kYCbCr; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &alternate)); + color_transform = + (alternate ? ColorTransform::kYCbCr : ColorTransform::kNone); + } + + // Chroma subsampling for YCbCr, if no DC frame is used. + if (visitor->Conditional(color_transform == ColorTransform::kYCbCr && + ((flags & kUseDcFrame) == 0))) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&chroma_subsampling)); + } + + size_t num_extra_channels = + nonserialized_metadata != nullptr + ? nonserialized_metadata->m.extra_channel_info.size() + : 0; + + // Upsampling + if (visitor->Conditional((flags & kUseDcFrame) == 0)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &upsampling)); + if (nonserialized_metadata != nullptr && + visitor->Conditional(num_extra_channels != 0)) { + const std::vector& extra_channels = + nonserialized_metadata->m.extra_channel_info; + extra_channel_upsampling.resize(extra_channels.size(), 1); + for (size_t i = 0; i < extra_channels.size(); ++i) { + uint32_t dim_shift = + nonserialized_metadata->m.extra_channel_info[i].dim_shift; + uint32_t& ec_upsampling = extra_channel_upsampling[i]; + ec_upsampling >>= dim_shift; + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &ec_upsampling)); + ec_upsampling <<= dim_shift; + if (ec_upsampling < upsampling) { + return JXL_FAILURE( + "EC upsampling (%u) < color upsampling (%u), which is invalid.", + ec_upsampling, upsampling); + } + if (ec_upsampling > 8) { + return JXL_FAILURE("EC upsampling too large (%u)", ec_upsampling); + } + } + } else { + extra_channel_upsampling.clear(); + } + } + + // Modular- or VarDCT-specific data. + if (visitor->Conditional(encoding == FrameEncoding::kModular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 1, &group_size_shift)); + } + if (visitor->Conditional(encoding == FrameEncoding::kVarDCT && + color_transform == ColorTransform::kXYB)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 3, &x_qm_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 2, &b_qm_scale)); + } else { + x_qm_scale = b_qm_scale = 2; // noop + } + + // Not useful for kPatchSource + if (visitor->Conditional(frame_type != FrameType::kReferenceOnly)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&passes)); + } + + if (visitor->Conditional(frame_type == FrameType::kDCFrame)) { + // Up to 4 pyramid levels - for up to 16384x downsampling. + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 1, &dc_level)); + } + if (frame_type != FrameType::kDCFrame) { + dc_level = 0; + } + + bool is_partial_frame = false; + if (visitor->Conditional(frame_type != FrameType::kDCFrame)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &custom_size_or_origin)); + if (visitor->Conditional(custom_size_or_origin)) { + const U32Enc enc(Bits(8), BitsOffset(11, 256), BitsOffset(14, 2304), + BitsOffset(30, 18688)); + // Frame offset, only if kRegularFrame or kSkipProgressive. + if (visitor->Conditional(frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive)) { + uint32_t ux0 = PackSigned(frame_origin.x0); + uint32_t uy0 = PackSigned(frame_origin.y0); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &ux0)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &uy0)); + frame_origin.x0 = UnpackSigned(ux0); + frame_origin.y0 = UnpackSigned(uy0); + } + // Frame size + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &frame_size.xsize)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &frame_size.ysize)); + if (custom_size_or_origin && + (frame_size.xsize == 0 || frame_size.ysize == 0)) { + return JXL_FAILURE( + "Invalid crop dimensions for frame: zero width or height"); + } + int32_t image_xsize = default_xsize(); + int32_t image_ysize = default_ysize(); + if (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive) { + is_partial_frame |= frame_origin.x0 > 0; + is_partial_frame |= frame_origin.y0 > 0; + is_partial_frame |= (static_cast(frame_size.xsize) + + frame_origin.x0) < image_xsize; + is_partial_frame |= (static_cast(frame_size.ysize) + + frame_origin.y0) < image_ysize; + } + } + } + + // Blending info, animation info and whether this is the last frame or not. + if (visitor->Conditional(frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive)) { + blending_info.nonserialized_num_extra_channels = num_extra_channels; + blending_info.nonserialized_is_partial_frame = is_partial_frame; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&blending_info)); + bool replace_all = (blending_info.mode == BlendMode::kReplace); + extra_channel_blending_info.resize(num_extra_channels); + for (size_t i = 0; i < num_extra_channels; i++) { + auto& ec_blending_info = extra_channel_blending_info[i]; + ec_blending_info.nonserialized_is_partial_frame = is_partial_frame; + ec_blending_info.nonserialized_num_extra_channels = num_extra_channels; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&ec_blending_info)); + replace_all &= (ec_blending_info.mode == BlendMode::kReplace); + } + if (visitor->IsReading() && nonserialized_is_preview) { + if (!replace_all || custom_size_or_origin) { + return JXL_FAILURE("Preview is not compatible with blending"); + } + } + if (visitor->Conditional(nonserialized_metadata != nullptr && + nonserialized_metadata->m.have_animation)) { + animation_frame.nonserialized_metadata = nonserialized_metadata; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&animation_frame)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &is_last)); + } + if (frame_type != FrameType::kRegularFrame) { + is_last = false; + } + + // ID of that can be used to refer to this frame. 0 for a non-zero-duration + // frame means that it will not be referenced. Not necessary for the last + // frame. + if (visitor->Conditional(frame_type != kDCFrame && !is_last)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Val(3), 0, &save_as_reference)); + } + + // If this frame is not blended on another frame post-color-transform, it may + // be stored for being referenced either before or after the color transform. + // If it is blended post-color-transform, it must be blended after. It must + // also be blended after if this is a kRegular frame that does not cover the + // full frame, as samples outside the partial region are from a + // post-color-transform frame. + if (frame_type != FrameType::kDCFrame) { + if (visitor->Conditional(CanBeReferenced() && + blending_info.mode == BlendMode::kReplace && + !is_partial_frame && + (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive))) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(false, &save_before_color_transform)); + } else if (visitor->Conditional(frame_type == FrameType::kReferenceOnly)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(true, &save_before_color_transform)); + if (!save_before_color_transform && + (frame_size.xsize < nonserialized_metadata->xsize() || + frame_size.ysize < nonserialized_metadata->ysize() || + frame_origin.x0 != 0 || frame_origin.y0 != 0)) { + return JXL_FAILURE( + "non-patch reference frame with invalid crop: %" PRIuS "x%" PRIuS + "%+d%+d", + static_cast(frame_size.xsize), + static_cast(frame_size.ysize), + static_cast(frame_origin.x0), + static_cast(frame_origin.y0)); + } + } + } else { + save_before_color_transform = true; + } + + JXL_QUIET_RETURN_IF_ERROR(VisitNameString(visitor, &name)); + + loop_filter.nonserialized_is_modular = is_modular; + JXL_RETURN_IF_ERROR(visitor->VisitNested(&loop_filter)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +std::string FrameHeader::DebugString() const { + std::ostringstream os; + os << (encoding == FrameEncoding::kVarDCT ? "VarDCT" : "Modular"); + os << ","; + os << (frame_type == FrameType::kRegularFrame ? "Regular" + : frame_type == FrameType::kDCFrame ? "DC" + : frame_type == FrameType::kReferenceOnly ? "Reference" + : "SkipProgressive"); + if (frame_type == FrameType::kDCFrame) { + os << "(lv" << dc_level << ")"; + } + + if (flags) { + os << ","; + uint32_t remaining = flags; + +#define TEST_FLAG(name) \ + if (flags & Flags::k##name) { \ + remaining &= ~Flags::k##name; \ + os << #name; \ + if (remaining) os << "|"; \ + } + TEST_FLAG(Noise); + TEST_FLAG(Patches); + TEST_FLAG(Splines); + TEST_FLAG(UseDcFrame); + TEST_FLAG(SkipAdaptiveDCSmoothing); +#undef TEST_FLAG + } + + os << ","; + os << (color_transform == ColorTransform::kXYB ? "XYB" + : color_transform == ColorTransform::kYCbCr ? "YCbCr" + : "None"); + + if (encoding == FrameEncoding::kModular) { + os << ",shift=" << group_size_shift; + } else if (color_transform == ColorTransform::kXYB) { + os << ",qm=" << x_qm_scale << ";" << b_qm_scale; + } + if (frame_type != FrameType::kReferenceOnly) { + os << "," << passes.DebugString(); + } + if (custom_size_or_origin) { + os << ",xs=" << frame_size.xsize; + os << ",ys=" << frame_size.ysize; + if (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive) { + os << ",x0=" << frame_origin.x0; + os << ",y0=" << frame_origin.y0; + } + } + if (upsampling > 1) os << ",up=" << upsampling; + if (loop_filter.gab) os << ",Gaborish"; + if (loop_filter.epf_iters > 0) os << ",epf=" << loop_filter.epf_iters; + if (animation_frame.duration > 0) os << ",dur=" << animation_frame.duration; + if (save_as_reference > 0) os << ",ref=" << save_as_reference; + if (is_last) os << ",last"; + return os.str(); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/frame_header.h b/media/libjxl/src/lib/jxl/frame_header.h new file mode 100644 index 000000000..7eb2f3578 --- /dev/null +++ b/media/libjxl/src/lib/jxl/frame_header.h @@ -0,0 +1,506 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FRAME_HEADER_H_ +#define LIB_JXL_FRAME_HEADER_H_ + +// Frame header with backward and forward-compatible extension capability and +// compressed integer fields. + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/loop_filter.h" + +namespace jxl { + +// Also used by extra channel names. +static inline Status VisitNameString(Visitor* JXL_RESTRICT visitor, + std::string* name) { + uint32_t name_length = static_cast(name->length()); + // Allows layer name lengths up to 1071 bytes + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Bits(4), BitsOffset(5, 16), + BitsOffset(10, 48), 0, &name_length)); + if (visitor->IsReading()) { + name->resize(name_length); + } + for (size_t i = 0; i < name_length; i++) { + uint32_t c = static_cast((*name)[i]); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(8, 0, &c)); + (*name)[i] = static_cast(c); + } + return true; +} + +enum class FrameEncoding : uint32_t { + kVarDCT, + kModular, +}; + +enum class ColorTransform : uint32_t { + kXYB, // Values are encoded with XYB. May only be used if + // ImageBundle::xyb_encoded. + kNone, // Values are encoded according to the attached color profile. May + // only be used if !ImageBundle::xyb_encoded. + kYCbCr, // Values are encoded according to the attached color profile, but + // transformed to YCbCr. May only be used if + // !ImageBundle::xyb_encoded. +}; + +inline std::array JpegOrder(ColorTransform ct, bool is_gray) { + if (is_gray) { + return {{0, 0, 0}}; + } + JXL_ASSERT(ct != ColorTransform::kXYB); + if (ct == ColorTransform::kYCbCr) { + return {{1, 0, 2}}; + } else { + return {{0, 1, 2}}; + } +} + +struct YCbCrChromaSubsampling : public Fields { + YCbCrChromaSubsampling(); + JXL_FIELDS_NAME(YCbCrChromaSubsampling) + size_t HShift(size_t c) const { return maxhs_ - kHShift[channel_mode_[c]]; } + size_t VShift(size_t c) const { return maxvs_ - kVShift[channel_mode_[c]]; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + // TODO(veluca): consider allowing 4x downsamples + for (size_t i = 0; i < 3; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 0, &channel_mode_[i])); + } + Recompute(); + return true; + } + + uint8_t MaxHShift() const { return maxhs_; } + uint8_t MaxVShift() const { return maxvs_; } + + uint8_t RawHShift(size_t c) const { return kHShift[channel_mode_[c]]; } + uint8_t RawVShift(size_t c) const { return kVShift[channel_mode_[c]]; } + + // Uses JPEG channel order (Y, Cb, Cr). + Status Set(const uint8_t* hsample, const uint8_t* vsample) { + for (size_t c = 0; c < 3; c++) { + size_t cjpeg = c < 2 ? c ^ 1 : c; + size_t i = 0; + for (; i < 4; i++) { + if (1 << kHShift[i] == hsample[cjpeg] && + 1 << kVShift[i] == vsample[cjpeg]) { + channel_mode_[c] = i; + break; + } + } + if (i == 4) { + return JXL_FAILURE("Invalid subsample mode"); + } + } + Recompute(); + return true; + } + + bool Is444() const { + for (size_t c : {0, 2}) { + if (channel_mode_[c] != channel_mode_[1]) { + return false; + } + } + return true; + } + + bool Is420() const { + return channel_mode_[0] == 1 && channel_mode_[1] == 0 && + channel_mode_[2] == 1; + } + + bool Is422() const { + for (size_t c : {0, 2}) { + if (kHShift[channel_mode_[c]] == kHShift[channel_mode_[1]] + 1 && + kVShift[channel_mode_[c]] == kVShift[channel_mode_[1]]) { + return false; + } + } + return true; + } + + bool Is440() const { + for (size_t c : {0, 2}) { + if (kHShift[channel_mode_[c]] == kHShift[channel_mode_[1]] && + kVShift[channel_mode_[c]] == kVShift[channel_mode_[1]] + 1) { + return false; + } + } + return true; + } + + private: + void Recompute() { + maxhs_ = 0; + maxvs_ = 0; + for (size_t i = 0; i < 3; i++) { + maxhs_ = std::max(maxhs_, kHShift[channel_mode_[i]]); + maxvs_ = std::max(maxvs_, kVShift[channel_mode_[i]]); + } + } + static constexpr uint8_t kHShift[4] = {0, 1, 1, 0}; + static constexpr uint8_t kVShift[4] = {0, 1, 0, 1}; + uint32_t channel_mode_[3]; + uint8_t maxhs_; + uint8_t maxvs_; +}; + +// Indicates how to combine the current frame with a previously-saved one. Can +// be independently controlled for color and extra channels. Formulas are +// indicative and treat alpha as if it is in range 0.0-1.0. In descriptions +// below, alpha channel is the extra channel of type alpha used for blending +// according to the blend_channel, or fully opaque if there is no alpha channel. +// The blending specified here is used for performing blending *after* color +// transforms - in linear sRGB if blending a XYB-encoded frame on another +// XYB-encoded frame, in sRGB if blending a frame with kColorSpace == kSRGB, or +// in the original colorspace otherwise. Blending in XYB or YCbCr is done by +// using patches. +enum class BlendMode { + // The new values (in the crop) replace the old ones: sample = new + kReplace = 0, + // The new values (in the crop) get added to the old ones: sample = old + new + kAdd = 1, + // The new values (in the crop) replace the old ones if alpha>0: + // For the alpha channel that is used as source: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + kBlend = 2, + // The new values (in the crop) are added to the old ones if alpha>0: + // For the alpha channel that is used as source: + // sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + kAlphaWeightedAdd = 3, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // The range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + // If using kMul as a blend mode for color channels, no color transform is + // performed on the current frame. + kMul = 4, +}; + +struct BlendingInfo : public Fields { + BlendingInfo(); + JXL_FIELDS_NAME(BlendingInfo) + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + BlendMode mode; + // Which extra channel to use as alpha channel for blending, only encoded + // for blend modes that involve alpha and if there are more than 1 extra + // channels. + uint32_t alpha_channel; + // Clamp alpha or channel values to 0-1 range. + bool clamp; + // Frame ID to copy from (0-3). Only encoded if blend_mode is not kReplace. + uint32_t source; + + size_t nonserialized_num_extra_channels = 0; + bool nonserialized_is_partial_frame = false; +}; + +// Origin of the current frame. Not present for frames of type +// kOnlyPatches. +struct FrameOrigin { + int32_t x0, y0; // can be negative. +}; + +// Size of the current frame. +struct FrameSize { + uint32_t xsize, ysize; +}; + +// AnimationFrame defines duration of animation frames. +struct AnimationFrame : public Fields { + explicit AnimationFrame(const CodecMetadata* metadata); + JXL_FIELDS_NAME(AnimationFrame) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // How long to wait [in ticks, see Animation{}] after rendering. + // May be 0 if the current frame serves as a foundation for another frame. + uint32_t duration; + + uint32_t timecode; // 0xHHMMSSFF + + // Must be set to the one ImageMetadata acting as the full codestream header, + // with correct xyb_encoded, list of extra channels, etc... + const CodecMetadata* nonserialized_metadata = nullptr; +}; + +// For decoding to lower resolutions. Only used for kRegular frames. +struct Passes : public Fields { + Passes(); + JXL_FIELDS_NAME(Passes) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + void GetDownsamplingBracket(size_t pass, int& minShift, int& maxShift) const { + maxShift = 2; + minShift = 3; + for (size_t i = 0;; i++) { + for (uint32_t j = 0; j < num_downsample; ++j) { + if (i == last_pass[j]) { + if (downsample[j] == 8) minShift = 3; + if (downsample[j] == 4) minShift = 2; + if (downsample[j] == 2) minShift = 1; + if (downsample[j] == 1) minShift = 0; + } + } + if (i == num_passes - 1) minShift = 0; + if (i == pass) return; + maxShift = minShift - 1; + } + } + + uint32_t GetDownsamplingTargetForCompletedPasses(uint32_t num_p) const { + if (num_p >= num_passes) return 1; + uint32_t retval = 8; + for (uint32_t i = 0; i < num_downsample; ++i) { + if (num_p > last_pass[i]) { + retval = std::min(retval, downsample[i]); + } + } + return retval; + } + + std::string DebugString() const; + + uint32_t num_passes; // <= kMaxNumPasses + uint32_t num_downsample; // <= num_passes + + // Array of num_downsample pairs. downsample=1/last_pass=num_passes-1 and + // downsample=8/last_pass=0 need not be specified; they are implicit. + uint32_t downsample[kMaxNumPasses]; + uint32_t last_pass[kMaxNumPasses]; + // Array of shift values for each pass. It is implicitly assumed to be 0 for + // the last pass. + uint32_t shift[kMaxNumPasses]; +}; + +enum FrameType { + // A "regular" frame: might be a crop, and will be blended on a previous + // frame, if any, and displayed or blended in future frames. + kRegularFrame = 0, + // A DC frame: this frame is downsampled and will be *only* used as the DC of + // a future frame and, possibly, for previews. Cannot be cropped, blended, or + // referenced by patches or blending modes. Frames that *use* a DC frame + // cannot have non-default sizes either. + kDCFrame = 1, + // A PatchesSource frame: this frame will be only used as a source frame for + // taking patches. Can be cropped, but cannot have non-(0, 0) x0 and y0. + kReferenceOnly = 2, + // Same as kRegularFrame, but not used for progressive rendering. This also + // implies no early display of DC. + kSkipProgressive = 3, +}; + +// Image/frame := one of more of these, where the last has is_last = true. +// Starts at a byte-aligned address "a"; the next pass starts at "a + size". +struct FrameHeader : public Fields { + // Optional postprocessing steps. These flags are the source of truth; + // Override must set/clear them rather than change their meaning. Values + // chosen such that typical flags == 0 (encoded in only two bits). + enum Flags { + // Often but not always off => low bit value: + + // Inject noise into decoded output. + kNoise = 1, + + // Overlay patches. + kPatches = 2, + + // 4, 8 = reserved for future sometimes-off + + // Overlay splines. + kSplines = 16, + + kUseDcFrame = 32, // Implies kSkipAdaptiveDCSmoothing. + + // 64 = reserved for future often-off + + // Almost always on => negated: + + kSkipAdaptiveDCSmoothing = 128, + }; + + explicit FrameHeader(const CodecMetadata* metadata); + JXL_FIELDS_NAME(FrameHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Sets/clears `flag` based upon `condition`. + void UpdateFlag(const bool condition, const uint64_t flag) { + if (condition) { + flags |= flag; + } else { + flags &= ~flag; + } + } + + // Returns true if this frame is supposed to be saved for future usage by + // other frames. + bool CanBeReferenced() const { + // DC frames cannot be referenced. The last frame cannot be referenced. A + // duration 0 frame makes little sense if it is not referenced. A + // non-duration 0 frame may or may not be referenced. + return !is_last && frame_type != FrameType::kDCFrame && + (animation_frame.duration == 0 || save_as_reference != 0); + } + + mutable bool all_default; + + // Always present + FrameEncoding encoding; + // Some versions of UBSAN complain in VisitFrameType if not initialized. + FrameType frame_type = FrameType::kRegularFrame; + + uint64_t flags; + + ColorTransform color_transform; + YCbCrChromaSubsampling chroma_subsampling; + + uint32_t group_size_shift; // only if encoding == kModular; + + uint32_t x_qm_scale; // only if VarDCT and color_transform == kXYB + uint32_t b_qm_scale; // only if VarDCT and color_transform == kXYB + + std::string name; + + // Skipped for kReferenceOnly. + Passes passes; + + // Skipped for kDCFrame + bool custom_size_or_origin; + FrameSize frame_size; + + // upsampling factors for color and extra channels. + // Upsampling is always performed before applying any inverse color transform. + // Skipped (1) if kUseDCFrame + uint32_t upsampling; + std::vector extra_channel_upsampling; + + // Only for kRegular frames. + FrameOrigin frame_origin; + + BlendingInfo blending_info; + std::vector extra_channel_blending_info; + + // Animation info for this frame. + AnimationFrame animation_frame; + + // This is the last frame. + bool is_last; + + // ID to refer to this frame with. 0-3, not present if kDCFrame. + // 0 has a special meaning for kRegular frames of nonzero duration: it defines + // a frame that will not be referenced in the future. + uint32_t save_as_reference; + + // Whether to save this frame before or after the color transform. A frame + // that is saved before the color tansform can only be used for blending + // through patches. On the contrary, a frame that is saved after the color + // transform can only be used for blending through blending modes. + // Irrelevant for extra channel blending. Can only be true if + // blending_info.mode == kReplace and this is not a partial kRegularFrame; if + // this is a DC frame, it is always true. + bool save_before_color_transform; + + uint32_t dc_level; // 1-4 if kDCFrame (0 otherwise). + + // Must be set to the one ImageMetadata acting as the full codestream header, + // with correct xyb_encoded, list of extra channels, etc... + const CodecMetadata* nonserialized_metadata = nullptr; + + // NOTE: This is ignored by AllDefault. + LoopFilter loop_filter; + + bool nonserialized_is_preview = false; + + size_t default_xsize() const { + if (!nonserialized_metadata) return 0; + if (nonserialized_is_preview) { + return nonserialized_metadata->m.preview_size.xsize(); + } + return nonserialized_metadata->xsize(); + } + + size_t default_ysize() const { + if (!nonserialized_metadata) return 0; + if (nonserialized_is_preview) { + return nonserialized_metadata->m.preview_size.ysize(); + } + return nonserialized_metadata->ysize(); + } + + FrameDimensions ToFrameDimensions() const { + size_t xsize = default_xsize(); + size_t ysize = default_ysize(); + + xsize = frame_size.xsize ? frame_size.xsize : xsize; + ysize = frame_size.ysize ? frame_size.ysize : ysize; + + if (dc_level != 0) { + xsize = DivCeil(xsize, 1 << (3 * dc_level)); + ysize = DivCeil(ysize, 1 << (3 * dc_level)); + } + + FrameDimensions frame_dim; + frame_dim.Set(xsize, ysize, group_size_shift, + chroma_subsampling.MaxHShift(), + chroma_subsampling.MaxVShift(), + encoding == FrameEncoding::kModular, upsampling); + return frame_dim; + } + + // True if a color transform should be applied to this frame. + bool needs_color_transform() const { + return !save_before_color_transform || + frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive; + } + + std::string DebugString() const; + + uint64_t extensions; +}; + +Status ReadFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame); + +Status WriteFrameHeader(const FrameHeader& frame, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); + +// Shared by enc/dec. 5F and 13 are by far the most common for d1/2/4/8, 0 +// ensures low overhead for small images. +static constexpr U32Enc kOrderEnc = + U32Enc(Val(0x5F), Val(0x13), Val(0), Bits(kNumOrders)); + +} // namespace jxl + +#endif // LIB_JXL_FRAME_HEADER_H_ diff --git a/media/libjxl/src/lib/jxl/gaborish.cc b/media/libjxl/src/lib/jxl/gaborish.cc new file mode 100644 index 000000000..6a187c46e --- /dev/null +++ b/media/libjxl/src/lib/jxl/gaborish.cc @@ -0,0 +1,70 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/gaborish.h" + +#include + +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +void GaborishInverse(Image3F* in_out, float mul, ThreadPool* pool) { + JXL_ASSERT(mul >= 0.0f); + + // Only an approximation. One or even two 3x3, and rank-1 (separable) 5x5 + // are insufficient. + constexpr float kGaborish[5] = { + -0.092359145662814029f, -0.039253623634014627f, 0.016176494530216929f, + 0.00083458437774987476f, 0.004512465323949319f, + }; + /* + better would be: + 1.0 - mul * (4 * (kGaborish[0] + kGaborish[1] + + kGaborish[2] + kGaborish[4]) + + 8 * (kGaborish[3])); + */ + WeightsSymmetric5 weights = {{HWY_REP4(1.0f)}, + {HWY_REP4(mul * kGaborish[0])}, + {HWY_REP4(mul * kGaborish[2])}, + {HWY_REP4(mul * kGaborish[1])}, + {HWY_REP4(mul * kGaborish[4])}, + {HWY_REP4(mul * kGaborish[3])}}; + double sum = static_cast(weights.c[0]); + sum += 4 * weights.r[0]; + sum += 4 * weights.R[0]; + sum += 4 * weights.d[0]; + sum += 4 * weights.D[0]; + sum += 8 * weights.L[0]; + const float normalize = static_cast(1.0 / sum); + for (size_t i = 0; i < 4; ++i) { + weights.c[i] *= normalize; + weights.r[i] *= normalize; + weights.R[i] *= normalize; + weights.d[i] *= normalize; + weights.D[i] *= normalize; + weights.L[i] *= normalize; + } + + // Reduce memory footprint by only allocating a single plane and swapping it + // into the output Image3F. Better still would be tiling. + // Note that we cannot *allocate* a plane, as doing so might cause Image3F to + // have planes of different stride. Instead, we copy one plane in a temporary + // image and reuse the existing planes of the in/out image. + ImageF temp = CopyImage(in_out->Plane(2)); + Symmetric5(in_out->Plane(0), Rect(*in_out), weights, pool, &in_out->Plane(2)); + Symmetric5(in_out->Plane(1), Rect(*in_out), weights, pool, &in_out->Plane(0)); + Symmetric5(temp, Rect(*in_out), weights, pool, &in_out->Plane(1)); + // Now planes are 1, 2, 0. + in_out->Plane(0).Swap(in_out->Plane(1)); + // 2 1 0 + in_out->Plane(0).Swap(in_out->Plane(2)); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/gaborish.h b/media/libjxl/src/lib/jxl/gaborish.h new file mode 100644 index 000000000..e43411dd9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/gaborish.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_GABORISH_H_ +#define LIB_JXL_GABORISH_H_ + +// Linear smoothing (3x3 convolution) for deblocking without too much blur. + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Used in encoder to reduce the impact of the decoder's smoothing. +// This is not exact. Works in-place to reduce memory use. +// The input is typically in XYB space. +void GaborishInverse(Image3F* in_out, float mul, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_GABORISH_H_ diff --git a/media/libjxl/src/lib/jxl/gaborish_test.cc b/media/libjxl/src/lib/jxl/gaborish_test.cc new file mode 100644 index 000000000..55b17a060 --- /dev/null +++ b/media/libjxl/src/lib/jxl/gaborish_test.cc @@ -0,0 +1,71 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/gaborish.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +// weight1,2 need not be normalized. +WeightsSymmetric3 GaborishKernel(float weight1, float weight2) { + constexpr float weight0 = 1.0f; + + // Normalize + const float mul = 1.0f / (weight0 + 4 * (weight1 + weight2)); + const float w0 = weight0 * mul; + const float w1 = weight1 * mul; + const float w2 = weight2 * mul; + + const WeightsSymmetric3 w = {{HWY_REP4(w0)}, {HWY_REP4(w1)}, {HWY_REP4(w2)}}; + return w; +} + +void ConvolveGaborish(const ImageF& in, float weight1, float weight2, + ThreadPool* pool, ImageF* JXL_RESTRICT out) { + JXL_CHECK(SameSize(in, *out)); + Symmetric3(in, Rect(in), GaborishKernel(weight1, weight2), pool, out); +} + +void TestRoundTrip(const Image3F& in, float max_l1) { + Image3F fwd(in.xsize(), in.ysize()); + ThreadPool* null_pool = nullptr; + ConvolveGaborish(in.Plane(0), 0, 0, null_pool, &fwd.Plane(0)); + ConvolveGaborish(in.Plane(1), 0, 0, null_pool, &fwd.Plane(1)); + ConvolveGaborish(in.Plane(2), 0, 0, null_pool, &fwd.Plane(2)); + GaborishInverse(&fwd, 0.92718927264540152f, null_pool); + VerifyRelativeError(in, fwd, max_l1, 1E-4f); +} + +TEST(GaborishTest, TestZero) { + Image3F in(20, 20); + ZeroFillImage(&in); + TestRoundTrip(in, 0.0f); +} + +// Disabled: large difference. +#if 0 +TEST(GaborishTest, TestDirac) { + Image3F in(20, 20); + ZeroFillImage(&in); + in.PlaneRow(1, 10)[10] = 10.0f; + TestRoundTrip(in, 0.26f); +} +#endif + +TEST(GaborishTest, TestFlat) { + Image3F in(20, 20); + FillImage(1.0f, &in); + TestRoundTrip(in, 1E-5f); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/gamma_correct_test.cc b/media/libjxl/src/lib/jxl/gamma_correct_test.cc new file mode 100644 index 000000000..d17ce899b --- /dev/null +++ b/media/libjxl/src/lib/jxl/gamma_correct_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/enc_gamma_correct.h" + +namespace jxl { +namespace { + +TEST(GammaCorrectTest, TestLinearToSrgbEdgeCases) { + EXPECT_EQ(0, LinearToSrgb8Direct(0.0)); + EXPECT_NEAR(0, LinearToSrgb8Direct(1E-6f), 2E-5); + EXPECT_EQ(0, LinearToSrgb8Direct(-1E-6f)); + EXPECT_EQ(0, LinearToSrgb8Direct(-1E6)); + EXPECT_NEAR(1, LinearToSrgb8Direct(1 - 1E-6f), 1E-5); + EXPECT_EQ(1, LinearToSrgb8Direct(1 + 1E-6f)); + EXPECT_EQ(1, LinearToSrgb8Direct(1E6)); +} + +TEST(GammaCorrectTest, TestRoundTrip) { + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (double linear = 0.0; linear <= 1.0; linear += 1E-7) { + const double srgb = LinearToSrgb8Direct(linear); + const double linear2 = Srgb8ToLinearDirect(srgb); + ASSERT_LT(std::abs(linear - linear2), 2E-13) + << "linear = " << linear << ", linear2 = " << linear2; + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/gauss_blur.cc b/media/libjxl/src/lib/jxl/gauss_blur.cc new file mode 100644 index 000000000..930ffb4a3 --- /dev/null +++ b/media/libjxl/src/lib/jxl/gauss_blur.cc @@ -0,0 +1,623 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/gauss_blur.h" + +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/gauss_blur.cc" +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/linalg.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Broadcast; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::NegMulSub; +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::ShiftLeftLanes; +#endif +using hwy::HWY_NAMESPACE::Vec; + +void FastGaussian1D(const hwy::AlignedUniquePtr& rg, + const float* JXL_RESTRICT in, intptr_t width, + float* JXL_RESTRICT out) { + // Although the current output depends on the previous output, we can unroll + // up to 4x by precomputing up to fourth powers of the constants. Beyond that, + // numerical precision might become a problem. Macro because this is tested + // in #if alongside HWY_TARGET. +#define JXL_GAUSS_MAX_LANES 4 + using D = HWY_CAPPED(float, JXL_GAUSS_MAX_LANES); + using V = Vec; + const D d; + const V mul_in_1 = Load(d, rg->mul_in + 0 * 4); + const V mul_in_3 = Load(d, rg->mul_in + 1 * 4); + const V mul_in_5 = Load(d, rg->mul_in + 2 * 4); + const V mul_prev_1 = Load(d, rg->mul_prev + 0 * 4); + const V mul_prev_3 = Load(d, rg->mul_prev + 1 * 4); + const V mul_prev_5 = Load(d, rg->mul_prev + 2 * 4); + const V mul_prev2_1 = Load(d, rg->mul_prev2 + 0 * 4); + const V mul_prev2_3 = Load(d, rg->mul_prev2 + 1 * 4); + const V mul_prev2_5 = Load(d, rg->mul_prev2 + 2 * 4); + V prev_1 = Zero(d); + V prev_3 = Zero(d); + V prev_5 = Zero(d); + V prev2_1 = Zero(d); + V prev2_3 = Zero(d); + V prev2_5 = Zero(d); + + const intptr_t N = rg->radius; + + intptr_t n = -N + 1; + // Left side with bounds checks and only write output after n >= 0. + const intptr_t first_aligned = RoundUpTo(N + 1, Lanes(d)); + for (; n < std::min(first_aligned, width); ++n) { + const intptr_t left = n - N - 1; + const intptr_t right = n + N - 1; + const float left_val = left >= 0 ? in[left] : 0.0f; + const float right_val = right < width ? in[right] : 0.0f; + const V sum = Set(d, left_val + right_val); + + // (Only processing a single lane here, no need to broadcast) + V out_1 = Mul(sum, mul_in_1); + V out_3 = Mul(sum, mul_in_3); + V out_5 = Mul(sum, mul_in_5); + + out_1 = MulAdd(mul_prev2_1, prev2_1, out_1); + out_3 = MulAdd(mul_prev2_3, prev2_3, out_3); + out_5 = MulAdd(mul_prev2_5, prev2_5, out_5); + prev2_1 = prev_1; + prev2_3 = prev_3; + prev2_5 = prev_5; + + out_1 = MulAdd(mul_prev_1, prev_1, out_1); + out_3 = MulAdd(mul_prev_3, prev_3, out_3); + out_5 = MulAdd(mul_prev_5, prev_5, out_5); + prev_1 = out_1; + prev_3 = out_3; + prev_5 = out_5; + + if (n >= 0) { + out[n] = GetLane(Add(out_1, Add(out_3, out_5))); + } + } + + // The above loop is effectively scalar but it is convenient to use the same + // prev/prev2 variables, so broadcast to each lane before the unrolled loop. +#if HWY_TARGET != HWY_SCALAR && JXL_GAUSS_MAX_LANES > 1 + prev2_1 = Broadcast<0>(prev2_1); + prev2_3 = Broadcast<0>(prev2_3); + prev2_5 = Broadcast<0>(prev2_5); + prev_1 = Broadcast<0>(prev_1); + prev_3 = Broadcast<0>(prev_3); + prev_5 = Broadcast<0>(prev_5); +#endif + + // Unrolled, no bounds checking needed. + for (; n < width - N + 1 - (JXL_GAUSS_MAX_LANES - 1); n += Lanes(d)) { + const V sum = Add(LoadU(d, in + n - N - 1), LoadU(d, in + n + N - 1)); + + // To get a vector of output(s), we multiply broadcasted vectors (of each + // input plus the two previous outputs) and add them all together. + // Incremental broadcasting and shifting is expected to be cheaper than + // horizontal adds or transposing 4x4 values because they run on a different + // port, concurrently with the FMA. + const V in0 = Broadcast<0>(sum); + V out_1 = Mul(in0, mul_in_1); + V out_3 = Mul(in0, mul_in_3); + V out_5 = Mul(in0, mul_in_5); + +#if HWY_TARGET != HWY_SCALAR && JXL_GAUSS_MAX_LANES >= 2 + const V in1 = Broadcast<1>(sum); + out_1 = MulAdd(ShiftLeftLanes<1>(mul_in_1), in1, out_1); + out_3 = MulAdd(ShiftLeftLanes<1>(mul_in_3), in1, out_3); + out_5 = MulAdd(ShiftLeftLanes<1>(mul_in_5), in1, out_5); + +#if JXL_GAUSS_MAX_LANES >= 4 + const V in2 = Broadcast<2>(sum); + out_1 = MulAdd(ShiftLeftLanes<2>(mul_in_1), in2, out_1); + out_3 = MulAdd(ShiftLeftLanes<2>(mul_in_3), in2, out_3); + out_5 = MulAdd(ShiftLeftLanes<2>(mul_in_5), in2, out_5); + + const V in3 = Broadcast<3>(sum); + out_1 = MulAdd(ShiftLeftLanes<3>(mul_in_1), in3, out_1); + out_3 = MulAdd(ShiftLeftLanes<3>(mul_in_3), in3, out_3); + out_5 = MulAdd(ShiftLeftLanes<3>(mul_in_5), in3, out_5); +#endif +#endif + + out_1 = MulAdd(mul_prev2_1, prev2_1, out_1); + out_3 = MulAdd(mul_prev2_3, prev2_3, out_3); + out_5 = MulAdd(mul_prev2_5, prev2_5, out_5); + + out_1 = MulAdd(mul_prev_1, prev_1, out_1); + out_3 = MulAdd(mul_prev_3, prev_3, out_3); + out_5 = MulAdd(mul_prev_5, prev_5, out_5); +#if HWY_TARGET == HWY_SCALAR || JXL_GAUSS_MAX_LANES == 1 + prev2_1 = prev_1; + prev2_3 = prev_3; + prev2_5 = prev_5; + prev_1 = out_1; + prev_3 = out_3; + prev_5 = out_5; +#else + prev2_1 = Broadcast(out_1); + prev2_3 = Broadcast(out_3); + prev2_5 = Broadcast(out_5); + prev_1 = Broadcast(out_1); + prev_3 = Broadcast(out_3); + prev_5 = Broadcast(out_5); +#endif + + Store(Add(out_1, Add(out_3, out_5)), d, out + n); + } + + // Remainder handling with bounds checks + for (; n < width; ++n) { + const intptr_t left = n - N - 1; + const intptr_t right = n + N - 1; + const float left_val = left >= 0 ? in[left] : 0.0f; + const float right_val = right < width ? in[right] : 0.0f; + const V sum = Set(d, left_val + right_val); + + // (Only processing a single lane here, no need to broadcast) + V out_1 = Mul(sum, mul_in_1); + V out_3 = Mul(sum, mul_in_3); + V out_5 = Mul(sum, mul_in_5); + + out_1 = MulAdd(mul_prev2_1, prev2_1, out_1); + out_3 = MulAdd(mul_prev2_3, prev2_3, out_3); + out_5 = MulAdd(mul_prev2_5, prev2_5, out_5); + prev2_1 = prev_1; + prev2_3 = prev_3; + prev2_5 = prev_5; + + out_1 = MulAdd(mul_prev_1, prev_1, out_1); + out_3 = MulAdd(mul_prev_3, prev_3, out_3); + out_5 = MulAdd(mul_prev_5, prev_5, out_5); + prev_1 = out_1; + prev_3 = out_3; + prev_5 = out_5; + + out[n] = GetLane(Add(out_1, Add(out_3, out_5))); + } +} + +// Ring buffer is for n, n-1, n-2; round up to 4 for faster modulo. +constexpr size_t kMod = 4; + +// Avoids an unnecessary store during warmup. +struct OutputNone { + template + void operator()(const V& /*unused*/, float* JXL_RESTRICT /*pos*/, + ptrdiff_t /*offset*/) const {} +}; + +// Common case: write output vectors in all VerticalBlock except warmup. +struct OutputStore { + template + void operator()(const V& out, float* JXL_RESTRICT pos, + ptrdiff_t offset) const { + // Stream helps for large images but is slower for images that fit in cache. + Store(out, HWY_FULL(float)(), pos + offset); + } +}; + +// At top/bottom borders, we don't have two inputs to load, so avoid addition. +// pos may even point to all zeros if the row is outside the input image. +class SingleInput { + public: + explicit SingleInput(const float* pos) : pos_(pos) {} + Vec operator()(const size_t offset) const { + return Load(HWY_FULL(float)(), pos_ + offset); + } + const float* pos_; +}; + +// In the middle of the image, we need to load from a row above and below, and +// return the sum. +class TwoInputs { + public: + TwoInputs(const float* pos1, const float* pos2) : pos1_(pos1), pos2_(pos2) {} + Vec operator()(const size_t offset) const { + const auto in1 = Load(HWY_FULL(float)(), pos1_ + offset); + const auto in2 = Load(HWY_FULL(float)(), pos2_ + offset); + return Add(in1, in2); + } + + private: + const float* pos1_; + const float* pos2_; +}; + +// Block := kVectors consecutive full vectors (one cache line except on the +// right boundary, where we can only rely on having one vector). Unrolling to +// the cache line size improves cache utilization. +template +void VerticalBlock(const V& d1_1, const V& d1_3, const V& d1_5, const V& n2_1, + const V& n2_3, const V& n2_5, const Input& input, + size_t& ctr, float* ring_buffer, const Output output, + float* JXL_RESTRICT out_pos) { + const HWY_FULL(float) d; + constexpr size_t kVN = MaxLanes(d); + // More cache-friendly to process an entirely cache line at a time + constexpr size_t kLanes = kVectors * kVN; + + float* JXL_RESTRICT y_1 = ring_buffer + 0 * kLanes * kMod; + float* JXL_RESTRICT y_3 = ring_buffer + 1 * kLanes * kMod; + float* JXL_RESTRICT y_5 = ring_buffer + 2 * kLanes * kMod; + + const size_t n_0 = (++ctr) % kMod; + const size_t n_1 = (ctr - 1) % kMod; + const size_t n_2 = (ctr - 2) % kMod; + + for (size_t idx_vec = 0; idx_vec < kVectors; ++idx_vec) { + const V sum = input(idx_vec * kVN); + + const V y_n1_1 = Load(d, y_1 + kLanes * n_1 + idx_vec * kVN); + const V y_n1_3 = Load(d, y_3 + kLanes * n_1 + idx_vec * kVN); + const V y_n1_5 = Load(d, y_5 + kLanes * n_1 + idx_vec * kVN); + const V y_n2_1 = Load(d, y_1 + kLanes * n_2 + idx_vec * kVN); + const V y_n2_3 = Load(d, y_3 + kLanes * n_2 + idx_vec * kVN); + const V y_n2_5 = Load(d, y_5 + kLanes * n_2 + idx_vec * kVN); + // (35) + const V y1 = MulAdd(n2_1, sum, NegMulSub(d1_1, y_n1_1, y_n2_1)); + const V y3 = MulAdd(n2_3, sum, NegMulSub(d1_3, y_n1_3, y_n2_3)); + const V y5 = MulAdd(n2_5, sum, NegMulSub(d1_5, y_n1_5, y_n2_5)); + Store(y1, d, y_1 + kLanes * n_0 + idx_vec * kVN); + Store(y3, d, y_3 + kLanes * n_0 + idx_vec * kVN); + Store(y5, d, y_5 + kLanes * n_0 + idx_vec * kVN); + output(Add(y1, Add(y3, y5)), out_pos, idx_vec * kVN); + } + // NOTE: flushing cache line out_pos hurts performance - less so with + // clflushopt than clflush but still a significant slowdown. +} + +// Reads/writes one block (kVectors full vectors) in each row. +template +void VerticalStrip(const hwy::AlignedUniquePtr& rg, + const ImageF& in, const size_t x, ImageF* JXL_RESTRICT out) { + // We're iterating vertically, so use multiple full-length vectors (each lane + // is one column of row n). + using D = HWY_FULL(float); + using V = Vec; + const D d; + constexpr size_t kVN = MaxLanes(d); + // More cache-friendly to process an entirely cache line at a time + constexpr size_t kLanes = kVectors * kVN; +#if HWY_TARGET == HWY_SCALAR + const V d1_1 = Set(d, rg->d1[0 * 4]); + const V d1_3 = Set(d, rg->d1[1 * 4]); + const V d1_5 = Set(d, rg->d1[2 * 4]); + const V n2_1 = Set(d, rg->n2[0 * 4]); + const V n2_3 = Set(d, rg->n2[1 * 4]); + const V n2_5 = Set(d, rg->n2[2 * 4]); +#else + const V d1_1 = LoadDup128(d, rg->d1 + 0 * 4); + const V d1_3 = LoadDup128(d, rg->d1 + 1 * 4); + const V d1_5 = LoadDup128(d, rg->d1 + 2 * 4); + const V n2_1 = LoadDup128(d, rg->n2 + 0 * 4); + const V n2_3 = LoadDup128(d, rg->n2 + 1 * 4); + const V n2_5 = LoadDup128(d, rg->n2 + 2 * 4); +#endif + + const size_t N = rg->radius; + const size_t ysize = in.ysize(); + + size_t ctr = 0; + HWY_ALIGN float ring_buffer[3 * kLanes * kMod] = {0}; + HWY_ALIGN static constexpr float zero[kLanes] = {0}; + + // Warmup: top is out of bounds (zero padded), bottom is usually in-bounds. + ssize_t n = -static_cast(N) + 1; + for (; n < 0; ++n) { + // bottom is always non-negative since n is initialized in -N + 1. + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + SingleInput(bottom < ysize ? in.ConstRow(bottom) + x : zero), ctr, + ring_buffer, OutputNone(), nullptr); + } + JXL_DASSERT(n >= 0); + + // Start producing output; top is still out of bounds. + for (; static_cast(n) < std::min(N + 1, ysize); ++n) { + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + SingleInput(bottom < ysize ? in.ConstRow(bottom) + x : zero), ctr, + ring_buffer, OutputStore(), out->Row(n) + x); + } + + // Interior outputs with prefetching and without bounds checks. + constexpr size_t kPrefetchRows = 8; + for (; n < static_cast(ysize - N + 1 - kPrefetchRows); ++n) { + const size_t top = n - N - 1; + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + TwoInputs(in.ConstRow(top) + x, in.ConstRow(bottom) + x), ctr, + ring_buffer, OutputStore(), out->Row(n) + x); + hwy::Prefetch(in.ConstRow(top + kPrefetchRows) + x); + hwy::Prefetch(in.ConstRow(bottom + kPrefetchRows) + x); + } + + // Bottom border without prefetching and with bounds checks. + for (; static_cast(n) < ysize; ++n) { + const size_t top = n - N - 1; + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + TwoInputs(in.ConstRow(top) + x, + bottom < ysize ? in.ConstRow(bottom) + x : zero), + ctr, ring_buffer, OutputStore(), out->Row(n) + x); + } +} + +// Apply 1D vertical scan to multiple columns (one per vector lane). +// Not yet parallelized. +void FastGaussianVertical(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* /*pool*/, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + JXL_CHECK(SameSize(in, *out)); + + constexpr size_t kCacheLineLanes = 64 / sizeof(float); + constexpr size_t kVN = MaxLanes(HWY_FULL(float)()); + constexpr size_t kCacheLineVectors = + (kVN < kCacheLineLanes) ? (kCacheLineLanes / kVN) : 4; + constexpr size_t kFastPace = kCacheLineVectors * kVN; + + size_t x = 0; + for (; x + kFastPace <= in.xsize(); x += kFastPace) { + VerticalStrip(rg, in, x, out); + } + for (; x < in.xsize(); x += kVN) { + VerticalStrip<1>(rg, in, x, out); + } +} + +// TODO(veluca): consider replacing with FastGaussian. +ImageF ConvolveXSampleAndTranspose(const ImageF& in, + const std::vector& kernel, + const size_t res) { + JXL_ASSERT(kernel.size() % 2 == 1); + JXL_ASSERT(in.xsize() % res == 0); + const size_t offset = res / 2; + const size_t out_xsize = in.xsize() / res; + ImageF out(in.ysize(), out_xsize); + const int r = kernel.size() / 2; + HWY_FULL(float) df; + std::vector row_tmp(in.xsize() + 2 * r + Lanes(df)); + float* const JXL_RESTRICT rowp = &row_tmp[r]; + std::vector padded_k = kernel; + padded_k.resize(padded_k.size() + Lanes(df)); + const float* const kernelp = &padded_k[r]; + for (size_t y = 0; y < in.ysize(); ++y) { + ExtrapolateBorders(in.Row(y), rowp, in.xsize(), r); + size_t x = offset, ox = 0; + for (; x < static_cast(r) && x < in.xsize(); x += res, ++ox) { + float sum = 0.0f; + for (int i = -r; i <= r; ++i) { + sum += rowp[std::max( + 0, std::min(static_cast(x) + i, in.xsize()))] * + kernelp[i]; + } + out.Row(ox)[y] = sum; + } + for (; x + r < in.xsize(); x += res, ++ox) { + auto sum = Zero(df); + for (int i = -r; i <= r; i += Lanes(df)) { + sum = MulAdd(LoadU(df, rowp + x + i), LoadU(df, kernelp + i), sum); + } + out.Row(ox)[y] = GetLane(SumOfLanes(df, sum)); + } + for (; x < in.xsize(); x += res, ++ox) { + float sum = 0.0f; + for (int i = -r; i <= r; ++i) { + sum += rowp[std::max( + 0, std::min(static_cast(x) + i, in.xsize()))] * + kernelp[i]; + } + out.Row(ox)[y] = sum; + } + } + return out; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FastGaussian1D); +HWY_EXPORT(ConvolveXSampleAndTranspose); +void FastGaussian1D(const hwy::AlignedUniquePtr& rg, + const float* JXL_RESTRICT in, intptr_t width, + float* JXL_RESTRICT out) { + return HWY_DYNAMIC_DISPATCH(FastGaussian1D)(rg, in, width, out); +} + +HWY_EXPORT(FastGaussianVertical); // Local function. + +void ExtrapolateBorders(const float* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const int xsize, + const int radius) { + const int lastcol = xsize - 1; + for (int x = 1; x <= radius; ++x) { + row_out[-x] = row_in[std::min(x, xsize - 1)]; + } + memcpy(row_out, row_in, xsize * sizeof(row_out[0])); + for (int x = 1; x <= radius; ++x) { + row_out[lastcol + x] = row_in[std::max(0, lastcol - x)]; + } +} + +ImageF ConvolveXSampleAndTranspose(const ImageF& in, + const std::vector& kernel, + const size_t res) { + return HWY_DYNAMIC_DISPATCH(ConvolveXSampleAndTranspose)(in, kernel, res); +} + +Image3F ConvolveXSampleAndTranspose(const Image3F& in, + const std::vector& kernel, + const size_t res) { + return Image3F(ConvolveXSampleAndTranspose(in.Plane(0), kernel, res), + ConvolveXSampleAndTranspose(in.Plane(1), kernel, res), + ConvolveXSampleAndTranspose(in.Plane(2), kernel, res)); +} + +ImageF ConvolveAndSample(const ImageF& in, const std::vector& kernel, + const size_t res) { + ImageF tmp = ConvolveXSampleAndTranspose(in, kernel, res); + return ConvolveXSampleAndTranspose(tmp, kernel, res); +} + +// Implements "Recursive Implementation of the Gaussian Filter Using Truncated +// Cosine Functions" by Charalampidis [2016]. +hwy::AlignedUniquePtr CreateRecursiveGaussian(double sigma) { + PROFILER_FUNC; + auto rg = hwy::MakeUniqueAligned(); + constexpr double kPi = 3.141592653589793238; + + const double radius = roundf(3.2795 * sigma + 0.2546); // (57), "N" + + // Table I, first row + const double pi_div_2r = kPi / (2.0 * radius); + const double omega[3] = {pi_div_2r, 3.0 * pi_div_2r, 5.0 * pi_div_2r}; + + // (37), k={1,3,5} + const double p_1 = +1.0 / std::tan(0.5 * omega[0]); + const double p_3 = -1.0 / std::tan(0.5 * omega[1]); + const double p_5 = +1.0 / std::tan(0.5 * omega[2]); + + // (44), k={1,3,5} + const double r_1 = +p_1 * p_1 / std::sin(omega[0]); + const double r_3 = -p_3 * p_3 / std::sin(omega[1]); + const double r_5 = +p_5 * p_5 / std::sin(omega[2]); + + // (50), k={1,3,5} + const double neg_half_sigma2 = -0.5 * sigma * sigma; + const double recip_radius = 1.0 / radius; + double rho[3]; + for (size_t i = 0; i < 3; ++i) { + rho[i] = std::exp(neg_half_sigma2 * omega[i] * omega[i]) * recip_radius; + } + + // second part of (52), k1,k2 = 1,3; 3,5; 5,1 + const double D_13 = p_1 * r_3 - r_1 * p_3; + const double D_35 = p_3 * r_5 - r_3 * p_5; + const double D_51 = p_5 * r_1 - r_5 * p_1; + + // (52), k=5 + const double recip_d13 = 1.0 / D_13; + const double zeta_15 = D_35 * recip_d13; + const double zeta_35 = D_51 * recip_d13; + + double A[9] = {p_1, p_3, p_5, // + r_1, r_3, r_5, // (56) + zeta_15, zeta_35, 1}; + JXL_CHECK(Inv3x3Matrix(A)); + const double gamma[3] = {1, radius * radius - sigma * sigma, // (55) + zeta_15 * rho[0] + zeta_35 * rho[1] + rho[2]}; + double beta[3]; + MatMul(A, gamma, 3, 3, 1, beta); // (53) + + // Sanity check: correctly solved for beta (IIR filter weights are normalized) + const double sum = beta[0] * p_1 + beta[1] * p_3 + beta[2] * p_5; // (39) + JXL_ASSERT(std::abs(sum - 1) < 1E-12); + (void)sum; + + rg->radius = static_cast(radius); + + double n2[3]; + double d1[3]; + for (size_t i = 0; i < 3; ++i) { + n2[i] = -beta[i] * std::cos(omega[i] * (radius + 1.0)); // (33) + d1[i] = -2.0 * std::cos(omega[i]); // (33) + + for (size_t lane = 0; lane < 4; ++lane) { + rg->n2[4 * i + lane] = static_cast(n2[i]); + rg->d1[4 * i + lane] = static_cast(d1[i]); + } + + const double d_2 = d1[i] * d1[i]; + + // Obtained by expanding (35) for four consecutive outputs via sympy: + // n, d, p, pp = symbols('n d p pp') + // i0, i1, i2, i3 = symbols('i0 i1 i2 i3') + // o0, o1, o2, o3 = symbols('o0 o1 o2 o3') + // o0 = n*i0 - d*p - pp + // o1 = n*i1 - d*o0 - p + // o2 = n*i2 - d*o1 - o0 + // o3 = n*i3 - d*o2 - o1 + // Then expand(o3) and gather terms for p(prev), pp(prev2) etc. + rg->mul_prev[4 * i + 0] = -d1[i]; + rg->mul_prev[4 * i + 1] = d_2 - 1.0; + rg->mul_prev[4 * i + 2] = -d_2 * d1[i] + 2.0 * d1[i]; + rg->mul_prev[4 * i + 3] = d_2 * d_2 - 3.0 * d_2 + 1.0; + rg->mul_prev2[4 * i + 0] = -1.0; + rg->mul_prev2[4 * i + 1] = d1[i]; + rg->mul_prev2[4 * i + 2] = -d_2 + 1.0; + rg->mul_prev2[4 * i + 3] = d_2 * d1[i] - 2.0 * d1[i]; + rg->mul_in[4 * i + 0] = n2[i]; + rg->mul_in[4 * i + 1] = -d1[i] * n2[i]; + rg->mul_in[4 * i + 2] = d_2 * n2[i] - n2[i]; + rg->mul_in[4 * i + 3] = -d_2 * d1[i] * n2[i] + 2.0 * d1[i] * n2[i]; + } + return rg; +} + +namespace { + +// Apply 1D horizontal scan to each row. +void FastGaussianHorizontal(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + JXL_CHECK(SameSize(in, *out)); + + const intptr_t xsize = in.xsize(); + JXL_CHECK(RunOnPool( + pool, 0, in.ysize(), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + const float* row_in = in.ConstRow(y); + float* JXL_RESTRICT row_out = out->Row(y); + FastGaussian1D(rg, row_in, xsize, row_out); + }, + "FastGaussianHorizontal")); +} + +} // namespace + +void FastGaussian(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* pool, ImageF* JXL_RESTRICT temp, + ImageF* JXL_RESTRICT out) { + FastGaussianHorizontal(rg, in, pool, temp); + HWY_DYNAMIC_DISPATCH(FastGaussianVertical)(rg, *temp, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/gauss_blur.h b/media/libjxl/src/lib/jxl/gauss_blur.h new file mode 100644 index 000000000..fb4741f03 --- /dev/null +++ b/media/libjxl/src/lib/jxl/gauss_blur.h @@ -0,0 +1,94 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_GAUSS_BLUR_H_ +#define LIB_JXL_GAUSS_BLUR_H_ + +#include + +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +std::vector GaussianKernel(int radius, T sigma) { + JXL_ASSERT(sigma > 0.0); + std::vector kernel(2 * radius + 1); + const T scaler = -1.0 / (2 * sigma * sigma); + double sum = 0.0; + for (int i = -radius; i <= radius; ++i) { + const T val = std::exp(scaler * i * i); + kernel[i + radius] = val; + sum += val; + } + for (size_t i = 0; i < kernel.size(); ++i) { + kernel[i] /= sum; + } + return kernel; +} + +// All convolution functions below apply mirroring of the input on the borders +// in the following way: +// +// input: [a0 a1 a2 ... aN] +// mirrored input: [aR ... a1 | a0 a1 a2 .... aN | aN-1 ... aN-R] +// +// where R is the radius of the kernel (i.e. kernel size is 2*R+1). + +// REQUIRES: in.xsize() and in.ysize() are integer multiples of res. +ImageF ConvolveAndSample(const ImageF& in, const std::vector& kernel, + const size_t res); + +// Private, used by test. +void ExtrapolateBorders(const float* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const int xsize, + const int radius); + +// Only for use by CreateRecursiveGaussian and FastGaussian*. +#pragma pack(push, 1) +struct RecursiveGaussian { + // For k={1,3,5} in that order, each broadcasted 4x for LoadDup128. Used only + // for vertical passes. + float n2[3 * 4]; + float d1[3 * 4]; + + // We unroll horizontal passes 4x - one output per lane. These are each lane's + // multiplier for the previous output (relative to the first of the four + // outputs). Indexing: 4 * 0..2 (for {1,3,5}) + 0..3 for the lane index. + float mul_prev[3 * 4]; + // Ditto for the second to last output. + float mul_prev2[3 * 4]; + + // We multiply a vector of inputs 0..3 by a vector shifted from this array. + // in=0 uses all 4 (nonzero) terms; for in=3, the lower three lanes are 0. + float mul_in[3 * 4]; + + size_t radius; +}; +#pragma pack(pop) + +// Precomputation for FastGaussian*; users may use the same pointer/storage in +// subsequent calls to FastGaussian* with the same sigma. +hwy::AlignedUniquePtr CreateRecursiveGaussian(double sigma); + +// 1D Gaussian with zero-pad boundary handling and runtime independent of sigma. +void FastGaussian1D(const hwy::AlignedUniquePtr& rg, + const float* JXL_RESTRICT in, intptr_t width, + float* JXL_RESTRICT out); + +// 2D Gaussian with zero-pad boundary handling and runtime independent of sigma. +void FastGaussian(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* pool, ImageF* JXL_RESTRICT temp, + ImageF* JXL_RESTRICT out); + +} // namespace jxl + +#endif // LIB_JXL_GAUSS_BLUR_H_ diff --git a/media/libjxl/src/lib/jxl/gauss_blur_gbench.cc b/media/libjxl/src/lib/jxl/gauss_blur_gbench.cc new file mode 100644 index 000000000..b1bb64abc --- /dev/null +++ b/media/libjxl/src/lib/jxl/gauss_blur_gbench.cc @@ -0,0 +1,126 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "benchmark/benchmark.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/gauss_blur.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +JXL_MAYBE_UNUSED ImageF Convolve(const ImageF& in, + const std::vector& kernel) { + return ConvolveAndSample(in, kernel, 1); +} + +void BM_GaussBlur1d(benchmark::State& state) { + // Uncomment to disable SIMD and force and scalar implementation + // hwy::DisableTargets(~HWY_SCALAR); + // Uncomment to run AVX2 + // hwy::DisableTargets(HWY_AVX3); + + const size_t length = state.range(); + const double sigma = 7.0; // (from Butteraugli application) + ImageF in(length, 1); + const float expected = length; + FillImage(expected, &in); + + ImageF temp(length, 1); + ImageF out(length, 1); + const auto rg = CreateRecursiveGaussian(sigma); + for (auto _ : state) { + FastGaussian1D(rg, in.Row(0), length, out.Row(0)); + // Prevent optimizing out + JXL_ASSERT(std::abs(out.ConstRow(0)[length / 2] - expected) / expected < + 9E-5); + } + state.SetItemsProcessed(length * state.iterations()); +} + +void BM_GaussBlur2d(benchmark::State& state) { + // See GaussBlur1d for SIMD changes. + + const size_t xsize = state.range(); + const size_t ysize = xsize; + const double sigma = 7.0; // (from Butteraugli application) + ImageF in(xsize, ysize); + const float expected = xsize + ysize; + FillImage(expected, &in); + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + ThreadPool* null_pool = nullptr; + const auto rg = CreateRecursiveGaussian(sigma); + for (auto _ : state) { + FastGaussian(rg, in, null_pool, &temp, &out); + // Prevent optimizing out + JXL_ASSERT(std::abs(out.ConstRow(ysize / 2)[xsize / 2] - expected) / + expected < + 9E-5); + } + state.SetItemsProcessed(xsize * ysize * state.iterations()); +} + +void BM_GaussBlurFir(benchmark::State& state) { + // See GaussBlur1d for SIMD changes. + + const size_t xsize = state.range(); + const size_t ysize = xsize; + const double sigma = 7.0; // (from Butteraugli application) + ImageF in(xsize, ysize); + const float expected = xsize + ysize; + FillImage(expected, &in); + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + for (auto _ : state) { + // Prevent optimizing out + JXL_ASSERT(std::abs(Convolve(in, kernel).ConstRow(ysize / 2)[xsize / 2] - + expected) / + expected < + 9E-5); + } + state.SetItemsProcessed(xsize * ysize * state.iterations()); +} + +void BM_GaussBlurSep7(benchmark::State& state) { + // See GaussBlur1d for SIMD changes. + + const size_t xsize = state.range(); + const size_t ysize = xsize; + ImageF in(xsize, ysize); + const float expected = xsize + ysize; + FillImage(expected, &in); + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + ThreadPool* null_pool = nullptr; + // Gaussian with sigma 1 + const WeightsSeparable7 weights = {{HWY_REP4(0.383103f), HWY_REP4(0.241843f), + HWY_REP4(0.060626f), HWY_REP4(0.00598f)}, + {HWY_REP4(0.383103f), HWY_REP4(0.241843f), + HWY_REP4(0.060626f), HWY_REP4(0.00598f)}}; + for (auto _ : state) { + Separable7(in, Rect(in), weights, null_pool, &out); + // Prevent optimizing out + JXL_ASSERT(std::abs(out.ConstRow(ysize / 2)[xsize / 2] - expected) / + expected < + 9E-5); + } + state.SetItemsProcessed(xsize * ysize * state.iterations()); +} + +BENCHMARK(BM_GaussBlur1d)->Range(1 << 8, 1 << 14); +BENCHMARK(BM_GaussBlur2d)->Range(1 << 7, 1 << 10); +BENCHMARK(BM_GaussBlurFir)->Range(1 << 7, 1 << 10); +BENCHMARK(BM_GaussBlurSep7)->Range(1 << 7, 1 << 10); + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/gauss_blur_test.cc b/media/libjxl/src/lib/jxl/gauss_blur_test.cc new file mode 100644 index 000000000..2aa94f786 --- /dev/null +++ b/media/libjxl/src/lib/jxl/gauss_blur_test.cc @@ -0,0 +1,452 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/gauss_blur.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/time.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { + +bool NearEdge(const int64_t width, const int64_t peak) { + // When around 3*sigma from the edge, there is negligible truncation. + return peak < 10 || peak > width - 10; +} + +// Follow the curve downwards by scanning right from `peak` and verifying +// identical values at the same offset to the left. +void VerifySymmetric(const int64_t width, const int64_t peak, + const float* out) { + const double tolerance = NearEdge(width, peak) ? 0.015 : 6E-7; + for (int64_t i = 1;; ++i) { + // Stop if we passed either end of the array + if (peak - i < 0 || peak + i >= width) break; + EXPECT_GT(out[peak + i - 1] + tolerance, out[peak + i]); // descending + EXPECT_NEAR(out[peak - i], out[peak + i], tolerance); // symmetric + } +} + +void TestImpulseResponse(size_t width, size_t peak) { + const auto rg3 = CreateRecursiveGaussian(3.0); + const auto rg4 = CreateRecursiveGaussian(4.0); + const auto rg5 = CreateRecursiveGaussian(5.0); + + // Extra padding for 4x unrolling + auto in = hwy::AllocateAligned(width + 3); + memset(in.get(), 0, sizeof(float) * (width + 3)); + in[peak] = 1.0f; + + auto out3 = hwy::AllocateAligned(width + 3); + auto out4 = hwy::AllocateAligned(width + 3); + auto out5 = hwy::AllocateAligned(width + 3); + FastGaussian1D(rg3, in.get(), width, out3.get()); + FastGaussian1D(rg4, out3.get(), width, out4.get()); + FastGaussian1D(rg5, in.get(), width, out5.get()); + + VerifySymmetric(width, peak, out3.get()); + VerifySymmetric(width, peak, out4.get()); + VerifySymmetric(width, peak, out5.get()); + + // Wider kernel has flatter peak + EXPECT_LT(out5[peak] + 0.05, out3[peak]); + + // Gauss3 o Gauss4 ~= Gauss5 + const double tolerance = NearEdge(width, peak) ? 0.04 : 0.01; + for (size_t i = 0; i < width; ++i) { + EXPECT_NEAR(out4[i], out5[i], tolerance); + } +} + +void TestImpulseResponseForWidth(size_t width) { + for (size_t i = 0; i < width; ++i) { + TestImpulseResponse(width, i); + } +} + +TEST(GaussBlurTest, ImpulseResponse) { + TestImpulseResponseForWidth(10); // tiny even + TestImpulseResponseForWidth(15); // small odd + TestImpulseResponseForWidth(32); // power of two + TestImpulseResponseForWidth(31); // power of two - 1 + TestImpulseResponseForWidth(33); // power of two + 1 +} + +ImageF Convolve(const ImageF& in, const std::vector& kernel) { + return ConvolveAndSample(in, kernel, 1); +} + +// Higher-precision version for accuracy test. +ImageF ConvolveAndTransposeF64(const ImageF& in, + const std::vector& kernel) { + JXL_ASSERT(kernel.size() % 2 == 1); + ImageF out(in.ysize(), in.xsize()); + const int r = kernel.size() / 2; + std::vector row_tmp(in.xsize() + 2 * r); + float* const JXL_RESTRICT rowp = &row_tmp[r]; + const double* const kernelp = &kernel[r]; + for (size_t y = 0; y < in.ysize(); ++y) { + ExtrapolateBorders(in.Row(y), rowp, in.xsize(), r); + for (size_t x = 0, ox = 0; x < in.xsize(); ++x, ++ox) { + double sum = 0.0; + for (int i = -r; i <= r; ++i) { + sum += rowp[std::max( + 0, std::min(static_cast(x) + i, in.xsize()))] * + kernelp[i]; + } + out.Row(ox)[y] = static_cast(sum); + } + } + return out; +} + +ImageF ConvolveF64(const ImageF& in, const std::vector& kernel) { + ImageF tmp = ConvolveAndTransposeF64(in, kernel); + return ConvolveAndTransposeF64(tmp, kernel); +} + +void TestDirac2D(size_t xsize, size_t ysize, double sigma) { + ImageF in(xsize, ysize); + ZeroFillImage(&in); + // We anyway ignore the border below, so might as well choose the middle. + in.Row(ysize / 2)[xsize / 2] = 1.0f; + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, &temp, &out); + + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + const ImageF expected = Convolve(in, kernel); + + const double max_l1 = sigma < 1.5 ? 5E-3 : 6E-4; + const size_t border = 2 * sigma; + VerifyRelativeError(expected, out, max_l1, 1E-8, border); +} + +TEST(GaussBlurTest, Test2D) { + const std::vector dimensions{6, 15, 17, 64, 50, 49}; + for (int xsize : dimensions) { + for (int ysize : dimensions) { + for (double sigma : {1.0, 2.5, 3.6, 7.0}) { + TestDirac2D(static_cast(xsize), static_cast(ysize), + sigma); + } + } + } +} + +// Slow (44 sec). To run, remove the disabled prefix. +TEST(GaussBlurTest, DISABLED_SlowTestDirac1D) { + const double sigma = 7.0; + const auto rg = CreateRecursiveGaussian(sigma); + + // IPOL accuracy test uses 10^-15 tolerance, this is 2*10^-11. + const size_t radius = static_cast(7 * sigma); + const std::vector kernel = GaussianKernel(radius, sigma); + + const size_t length = 16384; + ImageF inputs(length, 1); + ZeroFillImage(&inputs); + + auto outputs = hwy::AllocateAligned(length); + + // One per center position + auto sum_abs_err = hwy::AllocateAligned(length); + std::fill(sum_abs_err.get(), sum_abs_err.get() + length, 0.0); + + for (size_t center = radius; center < length - radius; ++center) { + inputs.Row(0)[center - 1] = 0.0f; // reset last peak, entire array now 0 + inputs.Row(0)[center] = 1.0f; + FastGaussian1D(rg, inputs.Row(0), length, outputs.get()); + + const ImageF outputs_fir = ConvolveF64(inputs, kernel); + + for (size_t i = 0; i < length; ++i) { + const float abs_err = std::abs(outputs[i] - outputs_fir.Row(0)[i]); + sum_abs_err[i] += static_cast(abs_err); + } + } + + const double max_abs_err = + *std::max_element(sum_abs_err.get(), sum_abs_err.get() + length); + printf("Max abs err: %.8e\n", max_abs_err); +} + +void TestRandom(size_t xsize, size_t ysize, float min, float max, double sigma, + double max_l1, double max_rel) { + printf("%4" PRIuS " x %4" PRIuS " %4.1f %4.1f sigma %.1f\n", xsize, ysize, + min, max, sigma); + ImageF in(xsize, ysize); + RandomFillImage(&in, min, max, 65537 + xsize * 129 + ysize); + // FastGaussian/Convolve handle borders differently, so keep those pixels 0. + const size_t border = 4 * sigma; + SetBorder(border, 0.0f, &in); + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, &temp, &out); + + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + const ImageF expected = Convolve(in, kernel); + + VerifyRelativeError(expected, out, max_l1, max_rel, border); +} + +void TestRandomForSizes(float min, float max, double sigma) { + double max_l1 = 6E-3; + double max_rel = 3E-3; + TestRandom(128, 1, min, max, sigma, max_l1, max_rel); + TestRandom(1, 128, min, max, sigma, max_l1, max_rel); + TestRandom(30, 201, min, max, sigma, max_l1 * 1.6, max_rel * 1.2); + TestRandom(201, 30, min, max, sigma, max_l1 * 1.6, max_rel * 1.2); + TestRandom(201, 201, min, max, sigma, max_l1 * 2.0, max_rel * 1.2); +} + +TEST(GaussBlurTest, TestRandom) { + // small non-negative + TestRandomForSizes(0.0f, 10.0f, 3.0f); + TestRandomForSizes(0.0f, 10.0f, 7.0f); + + // small negative + TestRandomForSizes(-4.0f, -1.0f, 3.0f); + TestRandomForSizes(-4.0f, -1.0f, 7.0f); + + // mixed positive/negative + TestRandomForSizes(-6.0f, 6.0f, 3.0f); + TestRandomForSizes(-6.0f, 6.0f, 7.0f); +} + +TEST(GaussBlurTest, TestSign) { + const size_t xsize = 500; + const size_t ysize = 606; + ImageF in(xsize, ysize); + + ZeroFillImage(&in); + const float center[33 * 33] = { + -0.128445f, -0.098473f, -0.121883f, -0.093601f, 0.095665f, -0.271332f, + -0.705475f, -1.324005f, -2.020741f, -1.329464f, 1.834064f, 4.787300f, + 5.834560f, 5.272720f, 3.967960f, 3.547935f, 3.432732f, 3.383015f, + 3.239326f, 3.290806f, 3.298954f, 3.397808f, 3.359730f, 3.533844f, + 3.511856f, 3.436787f, 3.428310f, 3.460209f, 3.550011f, 3.590942f, + 3.593109f, 3.560005f, 3.443165f, 0.089741f, 0.179230f, -0.032997f, + -0.182610f, 0.005669f, -0.244759f, -0.395123f, -0.514961f, -1.003529f, + -1.798656f, -2.377975f, 0.222191f, 3.957664f, 5.946804f, 5.543129f, + 4.290096f, 3.621010f, 3.407257f, 3.392494f, 3.345367f, 3.391903f, + 3.441605f, 3.429260f, 3.444969f, 3.507130f, 3.518612f, 3.443111f, + 3.475948f, 3.536148f, 3.470333f, 3.628311f, 3.600243f, 3.292892f, + -0.226730f, -0.573616f, -0.762165f, -0.398739f, -0.189842f, -0.275921f, + -0.446739f, -0.550037f, -0.461033f, -0.724792f, -1.448349f, -1.814064f, + -0.491032f, 2.817703f, 5.213242f, 5.675629f, 4.864548f, 3.876324f, + 3.535587f, 3.530312f, 3.413765f, 3.386261f, 3.404854f, 3.383472f, + 3.420830f, 3.326496f, 3.257877f, 3.362152f, 3.489609f, 3.619587f, + 3.555805f, 3.423164f, 3.309708f, -0.483940f, -0.502926f, -0.592983f, + -0.492527f, -0.413616f, -0.482555f, -0.475506f, -0.447990f, -0.338120f, + -0.189072f, -0.376427f, -0.910828f, -1.878044f, -1.937927f, 1.423218f, + 4.871609f, 5.767548f, 5.103741f, 3.983868f, 3.633003f, 3.458263f, + 3.507309f, 3.247021f, 3.220612f, 3.326061f, 3.352814f, 3.291061f, + 3.322739f, 3.444302f, 3.506207f, 3.556839f, 3.529575f, 3.457024f, + -0.408161f, -0.431343f, -0.454369f, -0.356419f, -0.380924f, -0.399452f, + -0.439476f, -0.412189f, -0.306816f, -0.008213f, -0.325813f, -0.537842f, + -0.984100f, -1.805332f, -2.028198f, 0.773205f, 4.423046f, 5.604839f, + 5.231617f, 4.080299f, 3.603008f, 3.498741f, 3.517010f, 3.333897f, + 3.381336f, 3.342617f, 3.369686f, 3.434155f, 3.490452f, 3.607029f, + 3.555298f, 3.702297f, 3.618679f, -0.503609f, -0.578564f, -0.419014f, + -0.239883f, 0.269836f, 0.022984f, -0.455067f, -0.621777f, -0.304176f, + -0.163792f, -0.490250f, -0.466637f, -0.391792f, -0.657940f, -1.498035f, + -1.895836f, 0.036537f, 3.462456f, 5.586445f, 5.658791f, 4.434784f, + 3.423435f, 3.318848f, 3.202328f, 3.532764f, 3.436687f, 3.354881f, + 3.356941f, 3.382645f, 3.503902f, 3.512867f, 3.632366f, 3.537312f, + -0.274734f, -0.658829f, -0.726532f, -0.281254f, 0.053196f, -0.064991f, + -0.608517f, -0.720966f, -0.070602f, -0.111320f, -0.440956f, -0.492180f, + -0.488762f, -0.569283f, -1.012741f, -1.582779f, -2.101479f, -1.392380f, + 2.451153f, 5.555855f, 6.096313f, 5.230045f, 4.068172f, 3.404274f, + 3.392586f, 3.326065f, 3.156670f, 3.284828f, 3.347012f, 3.319252f, + 3.352310f, 3.610790f, 3.499847f, -0.150600f, -0.314445f, -0.093575f, + -0.057384f, 0.053688f, -0.189255f, -0.263515f, -0.318653f, 0.053246f, + 0.080627f, -0.119553f, -0.152454f, -0.305420f, -0.404869f, -0.385944f, + -0.689949f, -1.204914f, -1.985748f, -1.711361f, 1.260658f, 4.626896f, + 5.888351f, 5.450989f, 4.070587f, 3.539200f, 3.383492f, 3.296318f, + 3.267334f, 3.436028f, 3.463005f, 3.502625f, 3.522282f, 3.403763f, + -0.348049f, -0.302303f, -0.137016f, -0.041737f, -0.164001f, -0.358849f, + -0.469627f, -0.428291f, -0.375797f, -0.246346f, -0.118950f, -0.084229f, + -0.205681f, -0.241199f, -0.391796f, -0.323151f, -0.241211f, -0.834137f, + -1.684219f, -1.972137f, 0.448399f, 4.019985f, 5.648144f, 5.647846f, + 4.295094f, 3.641884f, 3.374790f, 3.197342f, 3.425545f, 3.507481f, + 3.478065f, 3.430889f, 3.341900f, -1.016304f, -0.959221f, -0.909466f, + -0.810715f, -0.590729f, -0.594467f, -0.646721f, -0.629364f, -0.528561f, + -0.551819f, -0.301086f, -0.149101f, -0.060146f, -0.162220f, -0.326210f, + -0.156548f, -0.036293f, -0.426098f, -1.145470f, -1.628998f, -2.003052f, + -1.142891f, 2.885162f, 5.652863f, 5.718426f, 4.911140f, 3.234222f, + 3.473373f, 3.577183f, 3.271603f, 3.410435f, 3.505489f, 3.434032f, + -0.508911f, -0.438797f, -0.437450f, -0.627426f, -0.511745f, -0.304874f, + -0.274246f, -0.261841f, -0.228466f, -0.342491f, -0.528206f, -0.490082f, + -0.516350f, -0.361694f, -0.398514f, -0.276020f, -0.210369f, -0.355938f, + -0.402622f, -0.538864f, -1.249573f, -2.100105f, -0.996178f, 1.886410f, + 4.929745f, 5.630871f, 5.444199f, 4.042740f, 3.739189f, 3.691399f, + 3.391956f, 3.469696f, 3.431232f, 0.204849f, 0.205433f, -0.131927f, + -0.367908f, -0.374378f, -0.126820f, -0.186951f, -0.228565f, -0.081776f, + -0.143143f, -0.379230f, -0.598701f, -0.458019f, -0.295586f, -0.407730f, + -0.245853f, -0.043140f, 0.024242f, -0.038998f, -0.044151f, -0.425991f, + -1.240753f, -1.943146f, -2.174755f, 0.523415f, 4.376751f, 5.956558f, + 5.850082f, 4.403152f, 3.517399f, 3.560753f, 3.554836f, 3.471985f, + -0.508503f, -0.109783f, 0.057747f, 0.190079f, -0.257153f, -0.591980f, + -0.666771f, -0.525391f, -0.293060f, -0.489731f, -0.304855f, -0.259644f, + -0.367825f, -0.346977f, -0.292889f, -0.215652f, -0.120705f, -0.176010f, + -0.422905f, -0.114647f, -0.289749f, -0.374203f, -0.606754f, -1.127949f, + -1.994583f, -0.588058f, 3.415840f, 5.603470f, 5.811581f, 4.959423f, + 3.721760f, 3.710499f, 3.785461f, -0.554588f, -0.565517f, -0.434578f, + -0.012482f, -0.284660f, -0.699795f, -0.957535f, -0.755135f, -0.382034f, + -0.321552f, -0.287571f, -0.279537f, -0.314972f, -0.256287f, -0.372818f, + -0.316017f, -0.287975f, -0.365639f, -0.512589f, -0.420692f, -0.436485f, + -0.295353f, -0.451958f, -0.755459f, -1.272358f, -2.301353f, -1.776161f, + 1.572483f, 4.826286f, 5.741898f, 5.162853f, 4.028049f, 3.686325f, + -0.495590f, -0.664413f, -0.760044f, -0.152634f, -0.286480f, -0.340462f, + 0.076477f, 0.187706f, -0.068787f, -0.293491f, -0.361145f, -0.292515f, + -0.140671f, -0.190723f, -0.333302f, -0.368168f, -0.192581f, -0.154499f, + -0.236544f, -0.124405f, -0.208321f, -0.465607f, -0.883080f, -1.104813f, + -1.210567f, -1.415665f, -1.924683f, -1.634758f, 0.601017f, 4.276672f, + 5.501350f, 5.331257f, 3.809288f, -0.727722f, -0.533619f, -0.511524f, + -0.470688f, -0.610710f, -0.575130f, -0.311115f, -0.090420f, -0.297676f, + -0.646118f, -0.742805f, -0.485050f, -0.330910f, -0.275417f, -0.357037f, + -0.425598f, -0.481876f, -0.488941f, -0.393551f, -0.051105f, -0.090755f, + -0.328674f, -0.536369f, -0.533684f, -0.336960f, -0.689194f, -1.187195f, + -1.860954f, -2.290253f, -0.424774f, 3.050060f, 5.083332f, 5.291920f, + -0.343605f, -0.190975f, -0.303692f, -0.456512f, -0.681820f, -0.690693f, + -0.416729f, -0.286446f, -0.442055f, -0.709148f, -0.569160f, -0.382423f, + -0.402321f, -0.383362f, -0.366413f, -0.290718f, -0.110069f, -0.220280f, + -0.279018f, -0.255424f, -0.262081f, -0.487556f, -0.444492f, -0.250500f, + -0.119583f, -0.291557f, -0.537781f, -1.104073f, -1.737091f, -1.697441f, + -0.323456f, 2.042049f, 4.605103f, -0.310631f, -0.279568f, -0.012695f, + -0.160130f, -0.358746f, -0.421101f, -0.559677f, -0.474136f, -0.416565f, + -0.561817f, -0.534672f, -0.519157f, -0.767197f, -0.605831f, -0.186523f, + 0.219872f, 0.264984f, -0.193432f, -0.363182f, -0.467472f, -0.462009f, + -0.571053f, -0.522476f, -0.315903f, -0.237427f, -0.147320f, -0.100201f, + -0.237568f, -0.763435f, -1.242043f, -2.135159f, -1.409485f, 1.236370f, + -0.474247f, -0.517906f, -0.410217f, -0.542244f, -0.795986f, -0.590004f, + -0.388863f, -0.462921f, -0.810627f, -0.778637f, -0.512486f, -0.718025f, + -0.710854f, -0.482513f, -0.318233f, -0.194962f, -0.220116f, -0.421673f, + -0.534233f, -0.403339f, -0.389332f, -0.407303f, -0.437355f, -0.469730f, + -0.359600f, -0.352745f, -0.466755f, -0.414585f, -0.430756f, -0.656822f, + -1.237038f, -2.046097f, -1.574898f, -0.593815f, -0.582165f, -0.336098f, + -0.372612f, -0.554386f, -0.410603f, -0.428276f, -0.647644f, -0.640720f, + -0.582207f, -0.414112f, -0.435547f, -0.435505f, -0.332561f, -0.248116f, + -0.340221f, -0.277855f, -0.352699f, -0.377319f, -0.230850f, -0.313267f, + -0.446270f, -0.346237f, -0.420422f, -0.530781f, -0.400341f, -0.463661f, + -0.209091f, -0.056705f, -0.011772f, -0.169388f, -0.736275f, -1.463017f, + -0.752701f, -0.668865f, -0.329765f, -0.299347f, -0.245667f, -0.286999f, + -0.520420f, -0.675438f, -0.255753f, 0.141357f, -0.079639f, -0.419476f, + -0.374069f, -0.046253f, 0.116116f, -0.145847f, -0.380371f, -0.563412f, + -0.638634f, -0.310116f, -0.260914f, -0.508404f, -0.465508f, -0.527824f, + -0.370979f, -0.305595f, -0.244694f, -0.254490f, 0.009968f, -0.050201f, + -0.331219f, -0.614960f, -0.788208f, -0.483242f, -0.367516f, -0.186951f, + -0.180031f, 0.129711f, -0.127811f, -0.384750f, -0.499542f, -0.418613f, + -0.121635f, 0.203197f, -0.167290f, -0.397270f, -0.355461f, -0.218746f, + -0.376785f, -0.521698f, -0.721581f, -0.845741f, -0.535439f, -0.220882f, + -0.309067f, -0.555248f, -0.690342f, -0.664948f, -0.390102f, 0.020355f, + -0.130447f, -0.173252f, -0.170059f, -0.633663f, -0.956001f, -0.621696f, + -0.388302f, -0.342262f, -0.244370f, -0.386948f, -0.401421f, -0.172979f, + -0.206163f, -0.450058f, -0.525789f, -0.549274f, -0.349251f, -0.474613f, + -0.667976f, -0.435600f, -0.175369f, -0.196877f, -0.202976f, -0.242481f, + -0.258369f, -0.189133f, -0.395397f, -0.765499f, -0.944016f, -0.850967f, + -0.631561f, -0.152493f, -0.046432f, -0.262066f, -0.195919f, 0.048218f, + 0.084972f, 0.039902f, 0.000618f, -0.404430f, -0.447456f, -0.418076f, + -0.631935f, -0.717415f, -0.502888f, -0.530514f, -0.747826f, -0.704041f, + -0.674969f, -0.516853f, -0.418446f, -0.327740f, -0.308815f, -0.481636f, + -0.440083f, -0.481720f, -0.341053f, -0.283897f, -0.324368f, -0.352829f, + -0.434349f, -0.545589f, -0.533104f, -0.472755f, -0.570496f, -0.557735f, + -0.708176f, -0.493332f, -0.194416f, -0.186249f, -0.256710f, -0.271835f, + -0.304752f, -0.431267f, -0.422398f, -0.646725f, -0.680801f, -0.249031f, + -0.058567f, -0.213890f, -0.383949f, -0.540291f, -0.549877f, -0.225567f, + -0.037174f, -0.499874f, -0.641010f, -0.628044f, -0.390549f, -0.311497f, + -0.542313f, -0.569565f, -0.473408f, -0.331245f, -0.357197f, -0.285599f, + -0.200157f, -0.201866f, -0.124428f, -0.346016f, -0.392311f, -0.264496f, + -0.285370f, -0.436974f, -0.523483f, -0.410461f, -0.267925f, -0.055016f, + -0.382458f, -0.319771f, -0.049927f, 0.124329f, 0.266102f, -0.106606f, + -0.773647f, -0.973053f, -0.708206f, -0.486137f, -0.319923f, -0.493900f, + -0.490860f, -0.324986f, -0.147346f, -0.146088f, -0.161758f, -0.084396f, + -0.379494f, 0.041626f, -0.113361f, -0.277767f, 0.083366f, 0.126476f, + 0.139057f, 0.038040f, 0.038162f, -0.242126f, -0.411736f, -0.370049f, + -0.455357f, -0.039257f, 0.264442f, -0.271492f, -0.425346f, -0.514847f, + -0.448650f, -0.580399f, -0.652603f, -0.774803f, -0.692524f, -0.579578f, + -0.465206f, -0.386265f, -0.458012f, -0.446594f, -0.284893f, -0.345448f, + -0.350876f, -0.440350f, -0.360378f, -0.270428f, 0.237213f, -0.063602f, + -0.364529f, -0.179867f, 0.078197f, 0.117947f, -0.093410f, -0.359119f, + -0.480961f, -0.540638f, -0.436287f, -0.598576f, -0.253735f, -0.060093f, + -0.549145f, -0.808327f, -0.698593f, -0.595764f, -0.582508f, -0.497353f, + -0.480892f, -0.584240f, -0.665791f, -0.690903f, -0.743446f, -0.796677f, + -0.782391f, -0.649010f, -0.628139f, -0.880848f, -0.829361f, -0.373272f, + -0.223667f, 0.174572f, -0.348743f, -0.798901f, -0.692307f, -0.607609f, + -0.401455f, -0.480919f, -0.450798f, -0.435413f, -0.322338f, -0.228382f, + -0.450466f, -0.504440f, -0.477402f, -0.662224f, -0.583397f, -0.217445f, + -0.157459f, -0.079584f, -0.226168f, -0.488720f, -0.669624f, -0.666878f, + -0.565311f, -0.549625f, -0.364601f, -0.497627f, -0.736897f, -0.763023f, + -0.741020f, -0.404503f, 0.184814f, -0.075315f, -0.281513f, -0.532906f, + -0.405800f, -0.313438f, -0.536652f, -0.403381f, 0.011967f, 0.103310f, + -0.269848f, -0.508656f, -0.445923f, -0.644859f, -0.617870f, -0.500927f, + -0.371559f, -0.125580f, 0.028625f, -0.154713f, -0.442024f, -0.492764f, + -0.199371f, 0.236305f, 0.225925f, 0.075577f, -0.285812f, -0.437145f, + -0.374260f, -0.156693f, -0.129635f, -0.243206f, -0.123058f, 0.162148f, + -0.313152f, -0.337982f, -0.358421f, 0.040070f, 0.038925f, -0.333313f, + -0.351662f, 0.023014f, 0.091362f, -0.282890f, -0.373253f, -0.389050f, + -0.532707f, -0.423347f, -0.349968f, -0.287045f, -0.202442f, -0.308430f, + -0.222801f, -0.106323f, -0.056358f, 0.027222f, 0.390732f, 0.033558f, + -0.160088f, -0.382217f, -0.535282f, -0.515900f, -0.022736f, 0.165665f, + -0.111408f, -0.233784f, -0.312357f, -0.541885f, -0.480022f, -0.482513f, + -0.246254f, 0.132244f, 0.090134f, 0.234634f, -0.089249f, -0.460854f, + -0.515457f, -0.450874f, -0.311031f, -0.387680f, -0.360554f, -0.179241f, + -0.283817f, -0.475815f, -0.246399f, -0.388958f, -0.551140f, -0.496239f, + -0.559879f, -0.379761f, -0.254288f, -0.395111f, -0.613018f, -0.459427f, + -0.263580f, -0.268929f, 0.080826f, 0.115616f, -0.097324f, -0.325310f, + -0.480450f, -0.313286f, -0.310371f, -0.517361f, -0.288288f, -0.112679f, + -0.173241f, -0.221664f, -0.039452f, -0.107578f, -0.089630f, -0.483768f, + -0.571087f, -0.497108f, -0.321533f, -0.375492f, -0.540363f, -0.406815f, + -0.388512f, -0.514561f, -0.540192f, -0.402412f, -0.232246f, -0.304749f, + -0.383724f, -0.679596f, -0.685463f, -0.694538f, -0.642937f, -0.425789f, + 0.103271f, -0.194862f, -0.487999f, -0.717281f, -0.681850f, -0.709286f, + -0.615398f, -0.554245f, -0.254681f, -0.049950f, -0.002914f, -0.095383f, + -0.370911f, -0.564224f, -0.242714f}; + const size_t xtest = xsize / 2; + const size_t ytest = ysize / 2; + + for (intptr_t dy = -16; dy <= 16; ++dy) { + float* row = in.Row(ytest + dy); + for (intptr_t dx = -16; dx <= 16; ++dx) + row[xtest + dx] = center[(dy + 16) * 33 + (dx + 16)]; + } + + const double sigma = 7.155933; + + ImageF temp(xsize, ysize); + ImageF out_rg(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, &temp, &out_rg); + + ImageF out_old; + { + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + printf("old kernel size %" PRIuS "\n", kernel.size()); + out_old = Convolve(in, kernel); + } + + printf("rg %.4f old %.4f\n", out_rg.Row(ytest)[xtest], + out_old.Row(ytest)[xtest]); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/gradient_test.cc b/media/libjxl/src/lib/jxl/gradient_test.cc new file mode 100644 index 000000000..0351904d3 --- /dev/null +++ b/media/libjxl/src/lib/jxl/gradient_test.cc @@ -0,0 +1,205 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +// Returns distance of point p to line p0..p1, the result is signed and is not +// normalized. +double PointLineDist(double x0, double y0, double x1, double y1, double x, + double y) { + return (y1 - y0) * x - (x1 - x0) * y + x1 * y0 - y1 * x0; +} + +// Generates a test image with a gradient from one color to another. +// Angle in degrees, colors can be given in hex as 0xRRGGBB. The angle is the +// angle in which the change direction happens. +Image3F GenerateTestGradient(uint32_t color0, uint32_t color1, double angle, + size_t xsize, size_t ysize) { + Image3F image(xsize, ysize); + + double x0 = xsize / 2; + double y0 = ysize / 2; + double x1 = x0 + std::sin(angle / 360.0 * 2.0 * kPi); + double y1 = y0 + std::cos(angle / 360.0 * 2.0 * kPi); + + double maxdist = + std::max(fabs(PointLineDist(x0, y0, x1, y1, 0, 0)), + fabs(PointLineDist(x0, y0, x1, y1, xsize, 0))); + + for (size_t c = 0; c < 3; ++c) { + float c0 = ((color0 >> (8 * (2 - c))) & 255); + float c1 = ((color1 >> (8 * (2 - c))) & 255); + for (size_t y = 0; y < ysize; ++y) { + float* row = image.PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + double dist = PointLineDist(x0, y0, x1, y1, x, y); + double v = ((dist / maxdist) + 1.0) / 2.0; + float color = c0 * (1.0 - v) + c1 * v; + row[x] = color; + } + } + } + + return image; +} + +// Computes the max of the horizontal and vertical second derivative for each +// pixel, where second derivative means absolute value of difference of left +// delta and right delta (top/bottom for vertical direction). +// The radius over which the derivative is computed is only 1 pixel and it only +// checks two angles (hor and ver), but this approximation works well enough. +static ImageF Gradient2(const ImageF& image) { + size_t xsize = image.xsize(); + size_t ysize = image.ysize(); + ImageF image2(image.xsize(), image.ysize()); + for (size_t y = 1; y + 1 < ysize; y++) { + const auto* JXL_RESTRICT row0 = image.Row(y - 1); + const auto* JXL_RESTRICT row1 = image.Row(y); + const auto* JXL_RESTRICT row2 = image.Row(y + 1); + auto* row_out = image2.Row(y); + for (size_t x = 1; x + 1 < xsize; x++) { + float ddx = (row1[x] - row1[x - 1]) - (row1[x + 1] - row1[x]); + float ddy = (row1[x] - row0[x]) - (row2[x] - row1[x]); + row_out[x] = std::max(fabsf(ddx), fabsf(ddy)); + } + } + // Copy to the borders + if (ysize > 2) { + auto* JXL_RESTRICT row0 = image2.Row(0); + const auto* JXL_RESTRICT row1 = image2.Row(1); + const auto* JXL_RESTRICT row2 = image2.Row(ysize - 2); + auto* JXL_RESTRICT row3 = image2.Row(ysize - 1); + for (size_t x = 1; x + 1 < xsize; x++) { + row0[x] = row1[x]; + row3[x] = row2[x]; + } + } else { + const auto* row0_in = image.Row(0); + const auto* row1_in = image.Row(ysize - 1); + auto* row0_out = image2.Row(0); + auto* row1_out = image2.Row(ysize - 1); + for (size_t x = 1; x + 1 < xsize; x++) { + // Image too narrow, take first derivative instead + row0_out[x] = row1_out[x] = fabsf(row0_in[x] - row1_in[x]); + } + } + if (xsize > 2) { + for (size_t y = 0; y < ysize; y++) { + auto* row = image2.Row(y); + row[0] = row[1]; + row[xsize - 1] = row[xsize - 2]; + } + } else { + for (size_t y = 0; y < ysize; y++) { + const auto* JXL_RESTRICT row_in = image.Row(y); + auto* row_out = image2.Row(y); + // Image too narrow, take first derivative instead + row_out[0] = row_out[xsize - 1] = fabsf(row_in[0] - row_in[xsize - 1]); + } + } + return image2; +} + +static Image3F Gradient2(const Image3F& image) { + return Image3F(Gradient2(image.Plane(0)), Gradient2(image.Plane(1)), + Gradient2(image.Plane(2))); +} + +/* +Tests if roundtrip with jxl on a gradient image doesn't cause banding. +Only tests if use_gradient is true. Set to false for debugging to see the +distance values. +Angle in degrees, colors can be given in hex as 0xRRGGBB. +*/ +void TestGradient(ThreadPool* pool, uint32_t color0, uint32_t color1, + size_t xsize, size_t ysize, float angle, bool fast_mode, + float butteraugli_distance, bool use_gradient = true) { + CompressParams cparams; + cparams.butteraugli_distance = butteraugli_distance; + if (fast_mode) { + cparams.speed_tier = SpeedTier::kSquirrel; + } + Image3F gradient = GenerateTestGradient(color0, color1, angle, xsize, ysize); + + CodecInOut io; + io.metadata.m.SetUintSamples(8); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(std::move(gradient), io.metadata.m.color_encoding); + + CodecInOut io2; + + PaddedBytes compressed; + AuxOut* aux_out = nullptr; + PassesEncoderState enc_state; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + EXPECT_TRUE( + io2.Main().TransformTo(io2.metadata.m.color_encoding, GetJxlCms(), pool)); + + if (use_gradient) { + // Test that the gradient map worked. For that, we take a second derivative + // of the image with Gradient2 to measure how linear the change is in x and + // y direction. For a well handled gradient, we expect max values around + // 0.1, while if there is noticeable banding, which means the gradient map + // failed, the values are around 0.5-1.0 (regardless of + // butteraugli_distance). + Image3F gradient2 = Gradient2(*io2.Main().color()); + + std::array image_max; + Image3Max(gradient2, &image_max); + + // TODO(jyrki): These values used to work with 0.2, 0.2, 0.2. + EXPECT_LE(image_max[0], 3.15); + EXPECT_LE(image_max[1], 1.72); + EXPECT_LE(image_max[2], 5.05); + } +} + +static constexpr bool fast_mode = true; + +TEST(GradientTest, SteepGradient) { + ThreadPoolInternal pool(8); + // Relatively steep gradients, colors from the sky of stp.png + TestGradient(&pool, 0xd99d58, 0x889ab1, 512, 512, 90, fast_mode, 3.0); +} + +TEST(GradientTest, SubtleGradient) { + ThreadPoolInternal pool(8); + // Very subtle gradient + TestGradient(&pool, 0xb89b7b, 0xa89b8d, 512, 512, 90, fast_mode, 4.0); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/headers.cc b/media/libjxl/src/lib/jxl/headers.cc new file mode 100644 index 000000000..7c560e52a --- /dev/null +++ b/media/libjxl/src/lib/jxl/headers.cc @@ -0,0 +1,200 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/headers.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { + +struct Rational { + constexpr explicit Rational(uint32_t num, uint32_t den) + : num(num), den(den) {} + + // Returns floor(multiplicand * rational). + constexpr uint32_t MulTruncate(uint32_t multiplicand) const { + return uint64_t(multiplicand) * num / den; + } + + uint32_t num; + uint32_t den; +}; + +Rational FixedAspectRatios(uint32_t ratio) { + JXL_ASSERT(0 != ratio && ratio < 8); + // Other candidates: 5/4, 7/5, 14/9, 16/10, 5/3, 21/9, 12/5 + constexpr Rational kRatios[7] = {Rational(1, 1), // square + Rational(12, 10), // + Rational(4, 3), // camera + Rational(3, 2), // mobile camera + Rational(16, 9), // camera/display + Rational(5, 4), // + Rational(2, 1)}; // + return kRatios[ratio - 1]; +} + +uint32_t FindAspectRatio(uint32_t xsize, uint32_t ysize) { + for (uint32_t r = 1; r < 8; ++r) { + if (xsize == FixedAspectRatios(r).MulTruncate(ysize)) { + return r; + } + } + return 0; // Must send xsize instead +} + +} // namespace + +size_t SizeHeader::xsize() const { + if (ratio_ != 0) { + return FixedAspectRatios(ratio_).MulTruncate( + static_cast(ysize())); + } + return small_ ? ((xsize_div8_minus_1_ + 1) * 8) : xsize_; +} + +Status SizeHeader::Set(size_t xsize64, size_t ysize64) { + if (xsize64 > 0xFFFFFFFFull || ysize64 > 0xFFFFFFFFull) { + return JXL_FAILURE("Image too large"); + } + const uint32_t xsize32 = static_cast(xsize64); + const uint32_t ysize32 = static_cast(ysize64); + if (xsize64 == 0 || ysize64 == 0) return JXL_FAILURE("Empty image"); + ratio_ = FindAspectRatio(xsize32, ysize32); + small_ = ysize64 <= 256 && (ysize64 % kBlockDim) == 0 && + (ratio_ != 0 || (xsize64 <= 256 && (xsize64 % kBlockDim) == 0)); + if (small_) { + ysize_div8_minus_1_ = ysize32 / 8 - 1; + } else { + ysize_ = ysize32; + } + + if (ratio_ == 0) { + if (small_) { + xsize_div8_minus_1_ = xsize32 / 8 - 1; + } else { + xsize_ = xsize32; + } + } + JXL_ASSERT(xsize() == xsize64); + JXL_ASSERT(ysize() == ysize64); + return true; +} + +Status PreviewHeader::Set(size_t xsize64, size_t ysize64) { + const uint32_t xsize32 = static_cast(xsize64); + const uint32_t ysize32 = static_cast(ysize64); + if (xsize64 == 0 || ysize64 == 0) return JXL_FAILURE("Empty preview"); + div8_ = (xsize64 % kBlockDim) == 0 && (ysize64 % kBlockDim) == 0; + if (div8_) { + ysize_div8_ = ysize32 / 8; + } else { + ysize_ = ysize32; + } + + ratio_ = FindAspectRatio(xsize32, ysize32); + if (ratio_ == 0) { + if (div8_) { + xsize_div8_ = xsize32 / 8; + } else { + xsize_ = xsize32; + } + } + JXL_ASSERT(xsize() == xsize64); + JXL_ASSERT(ysize() == ysize64); + return true; +} + +size_t PreviewHeader::xsize() const { + if (ratio_ != 0) { + return FixedAspectRatios(ratio_).MulTruncate( + static_cast(ysize())); + } + return div8_ ? (xsize_div8_ * 8) : xsize_; +} + +SizeHeader::SizeHeader() { Bundle::Init(this); } +Status SizeHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &small_)); + + if (visitor->Conditional(small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, 0, &ysize_div8_minus_1_)); + } + if (visitor->Conditional(!small_)) { + // (Could still be small, but non-multiple of 8.) + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(9, 1), BitsOffset(13, 1), + BitsOffset(18, 1), BitsOffset(30, 1), + 1, &ysize_)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &ratio_)); + if (visitor->Conditional(ratio_ == 0 && small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, 0, &xsize_div8_minus_1_)); + } + if (visitor->Conditional(ratio_ == 0 && !small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(9, 1), BitsOffset(13, 1), + BitsOffset(18, 1), BitsOffset(30, 1), + 1, &xsize_)); + } + + return true; +} + +PreviewHeader::PreviewHeader() { Bundle::Init(this); } +Status PreviewHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &div8_)); + + if (visitor->Conditional(div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), Val(32), BitsOffset(5, 1), + BitsOffset(9, 33), 1, &ysize_div8_)); + } + if (visitor->Conditional(!div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(6, 1), BitsOffset(8, 65), + BitsOffset(10, 321), + BitsOffset(12, 1345), 1, &ysize_)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &ratio_)); + if (visitor->Conditional(ratio_ == 0 && div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), Val(32), BitsOffset(5, 1), + BitsOffset(9, 33), 1, &xsize_div8_)); + } + if (visitor->Conditional(ratio_ == 0 && !div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(6, 1), BitsOffset(8, 65), + BitsOffset(10, 321), + BitsOffset(12, 1345), 1, &xsize_)); + } + + return true; +} + +AnimationHeader::AnimationHeader() { Bundle::Init(this); } +Status AnimationHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(100), Val(1000), BitsOffset(10, 1), + BitsOffset(30, 1), 1, &tps_numerator)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(1), Val(1001), BitsOffset(8, 1), + BitsOffset(10, 1), 1, + &tps_denominator)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Bits(3), Bits(16), Bits(32), 0, &num_loops)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_timecodes)); + return true; +} + +Status ReadSizeHeader(BitReader* JXL_RESTRICT reader, + SizeHeader* JXL_RESTRICT size) { + return Bundle::Read(reader, size); +} + +Status WriteSizeHeader(const SizeHeader& size, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out) { + return Bundle::Write(size, writer, layer, aux_out); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/headers.h b/media/libjxl/src/lib/jxl/headers.h new file mode 100644 index 000000000..a9be252c2 --- /dev/null +++ b/media/libjxl/src/lib/jxl/headers.h @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_HEADERS_H_ +#define LIB_JXL_HEADERS_H_ + +// Codestream headers, also stored in CodecInOut. + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// Reserved by ISO/IEC 10918-1. LF causes files opened in text mode to be +// rejected because the marker changes to 0x0D instead. The 0xFF prefix also +// ensures there were no 7-bit transmission limitations. +static constexpr uint8_t kCodestreamMarker = 0x0A; + +// Compact representation of image dimensions (best case: 9 bits) so decoders +// can preallocate early. +class SizeHeader : public Fields { + public: + SizeHeader(); + JXL_FIELDS_NAME(SizeHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + Status Set(size_t xsize, size_t ysize); + + size_t xsize() const; + size_t ysize() const { + return small_ ? ((ysize_div8_minus_1_ + 1) * 8) : ysize_; + } + + private: + bool small_; // xsize and ysize <= 256 and divisible by 8. + + uint32_t ysize_div8_minus_1_; + uint32_t ysize_; + + uint32_t ratio_; + uint32_t xsize_div8_minus_1_; + uint32_t xsize_; +}; + +// (Similar to SizeHeader but different encoding because previews are smaller) +class PreviewHeader : public Fields { + public: + PreviewHeader(); + JXL_FIELDS_NAME(PreviewHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + Status Set(size_t xsize, size_t ysize); + + size_t xsize() const; + size_t ysize() const { return div8_ ? (ysize_div8_ * 8) : ysize_; } + + private: + bool div8_; // xsize and ysize divisible by 8. + + uint32_t ysize_div8_; + uint32_t ysize_; + + uint32_t ratio_; + uint32_t xsize_div8_; + uint32_t xsize_; +}; + +struct AnimationHeader : public Fields { + AnimationHeader(); + JXL_FIELDS_NAME(AnimationHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Ticks per second (expressed as rational number to support NTSC) + uint32_t tps_numerator; + uint32_t tps_denominator; + + uint32_t num_loops; // 0 means to repeat infinitely. + + bool have_timecodes; +}; + +Status ReadSizeHeader(BitReader* JXL_RESTRICT reader, + SizeHeader* JXL_RESTRICT size); + +Status WriteSizeHeader(const SizeHeader& size, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_HEADERS_H_ diff --git a/media/libjxl/src/lib/jxl/huffman_table.cc b/media/libjxl/src/lib/jxl/huffman_table.cc new file mode 100644 index 000000000..9ae7865af --- /dev/null +++ b/media/libjxl/src/lib/jxl/huffman_table.cc @@ -0,0 +1,161 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/huffman_table.h" + +#include /* for memcpy */ +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/dec_huffman.h" + +namespace jxl { + +/* Returns reverse(reverse(key, len) + 1, len), where reverse(key, len) is the + bit-wise reversal of the len least significant bits of key. */ +static inline int GetNextKey(int key, int len) { + int step = 1u << (len - 1); + while (key & step) { + step >>= 1; + } + return (key & (step - 1)) + step; +} + +/* Stores code in table[0], table[step], table[2*step], ..., table[end] */ +/* Assumes that end is an integer multiple of step */ +static inline void ReplicateValue(HuffmanCode* table, int step, int end, + HuffmanCode code) { + do { + end -= step; + table[end] = code; + } while (end > 0); +} + +/* Returns the table width of the next 2nd level table. count is the histogram + of bit lengths for the remaining symbols, len is the code length of the next + processed symbol */ +static inline size_t NextTableBitSize(const uint16_t* const count, size_t len, + int root_bits) { + size_t left = 1u << (len - root_bits); + while (len < PREFIX_MAX_BITS) { + if (left <= count[len]) break; + left -= count[len]; + ++len; + left <<= 1; + } + return len - root_bits; +} + +uint32_t BuildHuffmanTable(HuffmanCode* root_table, int root_bits, + const uint8_t* const code_lengths, + size_t code_lengths_size, uint16_t* count) { + HuffmanCode code; /* current table entry */ + HuffmanCode* table; /* next available space in table */ + size_t len; /* current code length */ + size_t symbol; /* symbol index in original or sorted table */ + int key; /* reversed prefix code */ + int step; /* step size to replicate values in current table */ + int low; /* low bits for current root entry */ + int mask; /* mask for low bits */ + size_t table_bits; /* key length of current table */ + int table_size; /* size of current table */ + int total_size; /* sum of root table size and 2nd level table sizes */ + /* offsets in sorted table for each length */ + uint16_t offset[PREFIX_MAX_BITS + 1]; + size_t max_length = 1; + + if (code_lengths_size > 1u << PREFIX_MAX_BITS) return 0; + + /* symbols sorted by code length */ + std::vector sorted_storage(code_lengths_size); + uint16_t* sorted = sorted_storage.data(); + + /* generate offsets into sorted symbol table by code length */ + { + uint16_t sum = 0; + for (len = 1; len <= PREFIX_MAX_BITS; len++) { + offset[len] = sum; + if (count[len]) { + sum = static_cast(sum + count[len]); + max_length = len; + } + } + } + + /* sort symbols by length, by symbol order within each length */ + for (symbol = 0; symbol < code_lengths_size; symbol++) { + if (code_lengths[symbol] != 0) { + sorted[offset[code_lengths[symbol]]++] = symbol; + } + } + + table = root_table; + table_bits = root_bits; + table_size = 1u << table_bits; + total_size = table_size; + + /* special case code with only one value */ + if (offset[PREFIX_MAX_BITS] == 1) { + code.bits = 0; + code.value = static_cast(sorted[0]); + for (key = 0; key < total_size; ++key) { + table[key] = code; + } + return total_size; + } + + /* fill in root table */ + /* let's reduce the table size to a smaller size if possible, and */ + /* create the repetitions by memcpy if possible in the coming loop */ + if (table_bits > max_length) { + table_bits = max_length; + table_size = 1u << table_bits; + } + key = 0; + symbol = 0; + code.bits = 1; + step = 2; + do { + for (; count[code.bits] != 0; --count[code.bits]) { + code.value = static_cast(sorted[symbol++]); + ReplicateValue(&table[key], step, table_size, code); + key = GetNextKey(key, code.bits); + } + step <<= 1; + } while (++code.bits <= table_bits); + + /* if root_bits != table_bits we only created one fraction of the */ + /* table, and we need to replicate it now. */ + while (total_size != table_size) { + memcpy(&table[table_size], &table[0], table_size * sizeof(table[0])); + table_size <<= 1; + } + + /* fill in 2nd level tables and add pointers to root table */ + mask = total_size - 1; + low = -1; + for (len = root_bits + 1, step = 2; len <= max_length; ++len, step <<= 1) { + for (; count[len] != 0; --count[len]) { + if ((key & mask) != low) { + table += table_size; + table_bits = NextTableBitSize(count, len, root_bits); + table_size = 1u << table_bits; + total_size += table_size; + low = key & mask; + root_table[low].bits = static_cast(table_bits + root_bits); + root_table[low].value = + static_cast((table - root_table) - low); + } + code.bits = static_cast(len - root_bits); + code.value = static_cast(sorted[symbol++]); + ReplicateValue(&table[key >> root_bits], step, table_size, code); + key = GetNextKey(key, len); + } + } + + return total_size; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/huffman_table.h b/media/libjxl/src/lib/jxl/huffman_table.h new file mode 100644 index 000000000..11cdb2fc4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/huffman_table.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_HUFFMAN_TABLE_H_ +#define LIB_JXL_HUFFMAN_TABLE_H_ + +#include +#include + +namespace jxl { + +struct HuffmanCode { + uint8_t bits; /* number of bits used for this symbol */ + uint16_t value; /* symbol value or table offset */ +}; + +/* Builds Huffman lookup table assuming code lengths are in symbol order. */ +/* Returns 0 in case of error (invalid tree or memory error), otherwise + populated size of table. */ +uint32_t BuildHuffmanTable(HuffmanCode* root_table, int root_bits, + const uint8_t* code_lengths, + size_t code_lengths_size, uint16_t* count); + +} // namespace jxl + +#endif // LIB_JXL_HUFFMAN_TABLE_H_ diff --git a/media/libjxl/src/lib/jxl/huffman_tree.cc b/media/libjxl/src/lib/jxl/huffman_tree.cc new file mode 100644 index 000000000..77107b08d --- /dev/null +++ b/media/libjxl/src/lib/jxl/huffman_tree.cc @@ -0,0 +1,328 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/huffman_tree.h" + +#include +#include +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +void SetDepth(const HuffmanTree& p, HuffmanTree* pool, uint8_t* depth, + uint8_t level) { + if (p.index_left >= 0) { + ++level; + SetDepth(pool[p.index_left], pool, depth, level); + SetDepth(pool[p.index_right_or_value], pool, depth, level); + } else { + depth[p.index_right_or_value] = level; + } +} + +// Sort the root nodes, least popular first. +static JXL_INLINE bool Compare(const HuffmanTree& v0, const HuffmanTree& v1) { + return v0.total_count < v1.total_count; +} + +// This function will create a Huffman tree. +// +// The catch here is that the tree cannot be arbitrarily deep. +// Brotli specifies a maximum depth of 15 bits for "code trees" +// and 7 bits for "code length code trees." +// +// count_limit is the value that is to be faked as the minimum value +// and this minimum value is raised until the tree matches the +// maximum length requirement. +// +// This algorithm is not of excellent performance for very long data blocks, +// especially when population counts are longer than 2**tree_limit, but +// we are not planning to use this with extremely long blocks. +// +// See http://en.wikipedia.org/wiki/Huffman_coding +void CreateHuffmanTree(const uint32_t* data, const size_t length, + const int tree_limit, uint8_t* depth) { + // For block sizes below 64 kB, we never need to do a second iteration + // of this loop. Probably all of our block sizes will be smaller than + // that, so this loop is mostly of academic interest. If we actually + // would need this, we would be better off with the Katajainen algorithm. + for (uint32_t count_limit = 1;; count_limit *= 2) { + std::vector tree; + tree.reserve(2 * length + 1); + + for (size_t i = length; i != 0;) { + --i; + if (data[i]) { + const uint32_t count = std::max(data[i], count_limit - 1); + tree.emplace_back(count, -1, static_cast(i)); + } + } + + const size_t n = tree.size(); + if (n == 1) { + // Fake value; will be fixed on upper level. + depth[tree[0].index_right_or_value] = 1; + break; + } + + std::stable_sort(tree.begin(), tree.end(), Compare); + + // The nodes are: + // [0, n): the sorted leaf nodes that we start with. + // [n]: we add a sentinel here. + // [n + 1, 2n): new parent nodes are added here, starting from + // (n+1). These are naturally in ascending order. + // [2n]: we add a sentinel at the end as well. + // There will be (2n+1) elements at the end. + const HuffmanTree sentinel(std::numeric_limits::max(), -1, -1); + tree.push_back(sentinel); + tree.push_back(sentinel); + + size_t i = 0; // Points to the next leaf node. + size_t j = n + 1; // Points to the next non-leaf node. + for (size_t k = n - 1; k != 0; --k) { + size_t left, right; + if (tree[i].total_count <= tree[j].total_count) { + left = i; + ++i; + } else { + left = j; + ++j; + } + if (tree[i].total_count <= tree[j].total_count) { + right = i; + ++i; + } else { + right = j; + ++j; + } + + // The sentinel node becomes the parent node. + size_t j_end = tree.size() - 1; + tree[j_end].total_count = + tree[left].total_count + tree[right].total_count; + tree[j_end].index_left = static_cast(left); + tree[j_end].index_right_or_value = static_cast(right); + + // Add back the last sentinel node. + tree.push_back(sentinel); + } + JXL_DASSERT(tree.size() == 2 * n + 1); + SetDepth(tree[2 * n - 1], &tree[0], depth, 0); + + // We need to pack the Huffman tree in tree_limit bits. + // If this was not successful, add fake entities to the lowest values + // and retry. + if (*std::max_element(&depth[0], &depth[length]) <= tree_limit) { + break; + } + } +} + +void Reverse(uint8_t* v, size_t start, size_t end) { + --end; + while (start < end) { + uint8_t tmp = v[start]; + v[start] = v[end]; + v[end] = tmp; + ++start; + --end; + } +} + +void WriteHuffmanTreeRepetitions(const uint8_t previous_value, + const uint8_t value, size_t repetitions, + size_t* tree_size, uint8_t* tree, + uint8_t* extra_bits_data) { + JXL_DASSERT(repetitions > 0); + if (previous_value != value) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions == 7) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions < 3) { + for (size_t i = 0; i < repetitions; ++i) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + } + } else { + repetitions -= 3; + size_t start = *tree_size; + while (true) { + tree[*tree_size] = 16; + extra_bits_data[*tree_size] = repetitions & 0x3; + ++(*tree_size); + repetitions >>= 2; + if (repetitions == 0) { + break; + } + --repetitions; + } + Reverse(tree, start, *tree_size); + Reverse(extra_bits_data, start, *tree_size); + } +} + +void WriteHuffmanTreeRepetitionsZeros(size_t repetitions, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data) { + if (repetitions == 11) { + tree[*tree_size] = 0; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions < 3) { + for (size_t i = 0; i < repetitions; ++i) { + tree[*tree_size] = 0; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + } + } else { + repetitions -= 3; + size_t start = *tree_size; + while (true) { + tree[*tree_size] = 17; + extra_bits_data[*tree_size] = repetitions & 0x7; + ++(*tree_size); + repetitions >>= 3; + if (repetitions == 0) { + break; + } + --repetitions; + } + Reverse(tree, start, *tree_size); + Reverse(extra_bits_data, start, *tree_size); + } +} + +static void DecideOverRleUse(const uint8_t* depth, const size_t length, + bool* use_rle_for_non_zero, + bool* use_rle_for_zero) { + size_t total_reps_zero = 0; + size_t total_reps_non_zero = 0; + size_t count_reps_zero = 1; + size_t count_reps_non_zero = 1; + for (size_t i = 0; i < length;) { + const uint8_t value = depth[i]; + size_t reps = 1; + for (size_t k = i + 1; k < length && depth[k] == value; ++k) { + ++reps; + } + if (reps >= 3 && value == 0) { + total_reps_zero += reps; + ++count_reps_zero; + } + if (reps >= 4 && value != 0) { + total_reps_non_zero += reps; + ++count_reps_non_zero; + } + i += reps; + } + *use_rle_for_non_zero = total_reps_non_zero > count_reps_non_zero * 2; + *use_rle_for_zero = total_reps_zero > count_reps_zero * 2; +} + +void WriteHuffmanTree(const uint8_t* depth, size_t length, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data) { + uint8_t previous_value = 8; + + // Throw away trailing zeros. + size_t new_length = length; + for (size_t i = 0; i < length; ++i) { + if (depth[length - i - 1] == 0) { + --new_length; + } else { + break; + } + } + + // First gather statistics on if it is a good idea to do rle. + bool use_rle_for_non_zero = false; + bool use_rle_for_zero = false; + if (length > 50) { + // Find rle coding for longer codes. + // Shorter codes seem not to benefit from rle. + DecideOverRleUse(depth, new_length, &use_rle_for_non_zero, + &use_rle_for_zero); + } + + // Actual rle coding. + for (size_t i = 0; i < new_length;) { + const uint8_t value = depth[i]; + size_t reps = 1; + if ((value != 0 && use_rle_for_non_zero) || + (value == 0 && use_rle_for_zero)) { + for (size_t k = i + 1; k < new_length && depth[k] == value; ++k) { + ++reps; + } + } + if (value == 0) { + WriteHuffmanTreeRepetitionsZeros(reps, tree_size, tree, extra_bits_data); + } else { + WriteHuffmanTreeRepetitions(previous_value, value, reps, tree_size, tree, + extra_bits_data); + previous_value = value; + } + i += reps; + } +} + +namespace { + +uint16_t ReverseBits(int num_bits, uint16_t bits) { + static const size_t kLut[16] = {// Pre-reversed 4-bit values. + 0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe, + 0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf}; + size_t retval = kLut[bits & 0xf]; + for (int i = 4; i < num_bits; i += 4) { + retval <<= 4; + bits = static_cast(bits >> 4); + retval |= kLut[bits & 0xf]; + } + retval >>= (-num_bits & 0x3); + return static_cast(retval); +} + +} // namespace + +void ConvertBitDepthsToSymbols(const uint8_t* depth, size_t len, + uint16_t* bits) { + // In Brotli, all bit depths are [1..15] + // 0 bit depth means that the symbol does not exist. + const int kMaxBits = 16; // 0..15 are values for bits + uint16_t bl_count[kMaxBits] = {0}; + { + for (size_t i = 0; i < len; ++i) { + ++bl_count[depth[i]]; + } + bl_count[0] = 0; + } + uint16_t next_code[kMaxBits]; + next_code[0] = 0; + { + int code = 0; + for (size_t i = 1; i < kMaxBits; ++i) { + code = (code + bl_count[i - 1]) << 1; + next_code[i] = static_cast(code); + } + } + for (size_t i = 0; i < len; ++i) { + if (depth[i]) { + bits[i] = ReverseBits(depth[i], next_code[depth[i]]++); + } + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/huffman_tree.h b/media/libjxl/src/lib/jxl/huffman_tree.h new file mode 100644 index 000000000..e4ccac49b --- /dev/null +++ b/media/libjxl/src/lib/jxl/huffman_tree.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Library for creating Huffman codes from population counts. + +#ifndef LIB_JXL_HUFFMAN_TREE_H_ +#define LIB_JXL_HUFFMAN_TREE_H_ + +#include +#include + +namespace jxl { + +// A node of a Huffman tree. +struct HuffmanTree { + HuffmanTree(uint32_t count, int16_t left, int16_t right) + : total_count(count), index_left(left), index_right_or_value(right) {} + uint32_t total_count; + int16_t index_left; + int16_t index_right_or_value; +}; + +void SetDepth(const HuffmanTree& p, HuffmanTree* pool, uint8_t* depth, + uint8_t level); + +// This function will create a Huffman tree. +// +// The (data,length) contains the population counts. +// The tree_limit is the maximum bit depth of the Huffman codes. +// +// The depth contains the tree, i.e., how many bits are used for +// the symbol. +// +// See http://en.wikipedia.org/wiki/Huffman_coding +void CreateHuffmanTree(const uint32_t* data, const size_t length, + const int tree_limit, uint8_t* depth); + +// Write a Huffman tree from bit depths into the bitstream representation +// of a Huffman tree. The generated Huffman tree is to be compressed once +// more using a Huffman tree +void WriteHuffmanTree(const uint8_t* depth, size_t length, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data); + +// Get the actual bit values for a tree of bit depths. +void ConvertBitDepthsToSymbols(const uint8_t* depth, size_t len, + uint16_t* bits); + +} // namespace jxl + +#endif // LIB_JXL_HUFFMAN_TREE_H_ diff --git a/media/libjxl/src/lib/jxl/iaca_test.cc b/media/libjxl/src/lib/jxl/iaca_test.cc new file mode 100644 index 000000000..9b2e8ea25 --- /dev/null +++ b/media/libjxl/src/lib/jxl/iaca_test.cc @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/iaca.h" + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(IacaTest, MarkersDefaultToDisabledAndDoNotCrash) { + BeginIACA(); + EndIACA(); +} + +TEST(IacaTest, ScopeDefaultToDisabledAndDoNotCrash) { ScopeIACA iaca; } + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/icc_codec.cc b/media/libjxl/src/lib/jxl/icc_codec.cc new file mode 100644 index 000000000..dd83fbe83 --- /dev/null +++ b/media/libjxl/src/lib/jxl/icc_codec.cc @@ -0,0 +1,391 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/icc_codec.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/icc_codec_common.h" + +namespace jxl { +namespace { + +// Shuffles or interleaves bytes, for example with width 2, turns "ABCDabcd" +// into "AaBbCcDc". Transposes a matrix of ceil(size / width) columns and +// width rows. There are size elements, size may be < width * height, if so the +// last elements of the rightmost column are missing, the missing spots are +// transposed along with the filled spots, and the result has the missing +// elements at the end of the bottom row. The input is the input matrix in +// scanline order but with missing elements skipped (which may occur in multiple +// locations), the output is the result matrix in scanline order (with +// no need to skip missing elements as they are past the end of the data). +void Shuffle(uint8_t* data, size_t size, size_t width) { + size_t height = (size + width - 1) / width; // amount of rows of output + PaddedBytes result(size); + // i = output index, j input index + size_t s = 0, j = 0; + for (size_t i = 0; i < size; i++) { + result[i] = data[j]; + j += height; + if (j >= size) j = ++s; + } + + for (size_t i = 0; i < size; i++) { + data[i] = result[i]; + } +} + +// TODO(eustas): should be 20, or even 18, once DecodeVarInt is improved; +// currently DecodeVarInt does not signal the errors, and marks +// 11 bytes as used even if only 10 are used (and 9 is enough for +// 63-bit values). +constexpr const size_t kPreambleSize = 22; // enough for reading 2 VarInts + +} // namespace + +// Mimics the beginning of UnpredictICC for quick validity check. +// At least kPreambleSize bytes of data should be valid at invocation time. +Status CheckPreamble(const PaddedBytes& data, size_t enc_size, + size_t output_limit) { + const uint8_t* enc = data.data(); + size_t size = data.size(); + size_t pos = 0; + uint64_t osize = DecodeVarInt(enc, size, &pos); + JXL_RETURN_IF_ERROR(CheckIs32Bit(osize)); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t csize = DecodeVarInt(enc, size, &pos); + JXL_RETURN_IF_ERROR(CheckIs32Bit(csize)); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size)); + // We expect that UnpredictICC inflates input, not the other way round. + if (osize + 65536 < enc_size) return JXL_FAILURE("Malformed ICC"); + if (output_limit && osize > output_limit) { + return JXL_FAILURE("Decoded ICC is too large"); + } + return true; +} + +// Decodes the result of PredictICC back to a valid ICC profile. +Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { + if (!result->empty()) return JXL_FAILURE("result must be empty initially"); + size_t pos = 0; + // TODO(lode): technically speaking we need to check that the entire varint + // decoding never goes out of bounds, not just the first byte. This requires + // a DecodeVarInt function that returns an error code. It is safe to use + // DecodeVarInt with out of bounds values, it silently returns, but the + // specification requires an error. Idem for all DecodeVarInt below. + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t osize = DecodeVarInt(enc, size, &pos); // Output size + JXL_RETURN_IF_ERROR(CheckIs32Bit(osize)); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t csize = DecodeVarInt(enc, size, &pos); // Commands size + // Every command is translated to at least on byte. + JXL_RETURN_IF_ERROR(CheckIs32Bit(csize)); + size_t cpos = pos; // pos in commands stream + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size)); + size_t commands_end = cpos + csize; + pos = commands_end; // pos in data stream + + // Header + PaddedBytes header = ICCInitialHeaderPrediction(); + EncodeUint32(0, osize, &header); + for (size_t i = 0; i <= kICCHeaderSize; i++) { + if (result->size() == osize) { + if (cpos != commands_end) return JXL_FAILURE("Not all commands used"); + if (pos != size) return JXL_FAILURE("Not all data used"); + return true; // Valid end + } + if (i == kICCHeaderSize) break; // Done + ICCPredictHeader(result->data(), result->size(), header.data(), i); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + result->push_back(enc[pos++] + header[i]); + } + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + + // Tag list + uint64_t numtags = DecodeVarInt(enc, size, &cpos); + + if (numtags != 0) { + numtags--; + JXL_RETURN_IF_ERROR(CheckIs32Bit(numtags)); + AppendUint32(numtags, result); + uint64_t prevtagstart = kICCHeaderSize + numtags * 12; + uint64_t prevtagsize = 0; + for (;;) { + if (result->size() > osize) return JXL_FAILURE("Invalid result size"); + if (cpos > commands_end) return JXL_FAILURE("Out of bounds"); + if (cpos == commands_end) break; // Valid end + uint8_t command = enc[cpos++]; + uint8_t tagcode = command & 63; + Tag tag; + if (tagcode == 0) { + break; + } else if (tagcode == kCommandTagUnknown) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 4, size)); + tag = DecodeKeyword(enc, size, pos); + pos += 4; + } else if (tagcode == kCommandTagTRC) { + tag = kRtrcTag; + } else if (tagcode == kCommandTagXYZ) { + tag = kRxyzTag; + } else { + if (tagcode - kCommandTagStringFirst >= kNumTagStrings) { + return JXL_FAILURE("Unknown tagcode"); + } + tag = *kTagStrings[tagcode - kCommandTagStringFirst]; + } + AppendKeyword(tag, result); + + uint64_t tagstart; + uint64_t tagsize = prevtagsize; + if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || + tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || + tag == kLumiTag) { + tagsize = 20; + } + + if (command & kFlagBitOffset) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + tagstart = DecodeVarInt(enc, size, &cpos); + } else { + JXL_RETURN_IF_ERROR(CheckIs32Bit(prevtagstart)); + tagstart = prevtagstart + prevtagsize; + } + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart)); + AppendUint32(tagstart, result); + if (command & kFlagBitSize) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + tagsize = DecodeVarInt(enc, size, &cpos); + } + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagsize)); + AppendUint32(tagsize, result); + prevtagstart = tagstart; + prevtagsize = tagsize; + + if (tagcode == kCommandTagTRC) { + AppendKeyword(kGtrcTag, result); + AppendUint32(tagstart, result); + AppendUint32(tagsize, result); + AppendKeyword(kBtrcTag, result); + AppendUint32(tagstart, result); + AppendUint32(tagsize, result); + } + + if (tagcode == kCommandTagXYZ) { + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart + tagsize * 2)); + AppendKeyword(kGxyzTag, result); + AppendUint32(tagstart + tagsize, result); + AppendUint32(tagsize, result); + AppendKeyword(kBxyzTag, result); + AppendUint32(tagstart + tagsize * 2, result); + AppendUint32(tagsize, result); + } + } + } + + // Main Content + for (;;) { + if (result->size() > osize) return JXL_FAILURE("Invalid result size"); + if (cpos > commands_end) return JXL_FAILURE("Out of bounds"); + if (cpos == commands_end) break; // Valid end + uint8_t command = enc[cpos++]; + if (command == kCommandInsert) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + for (size_t i = 0; i < num; i++) { + result->push_back(enc[pos++]); + } + } else if (command == kCommandShuffle2 || command == kCommandShuffle4) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + PaddedBytes shuffled(num); + for (size_t i = 0; i < num; i++) { + shuffled[i] = enc[pos + i]; + } + if (command == kCommandShuffle2) { + Shuffle(shuffled.data(), num, 2); + } else if (command == kCommandShuffle4) { + Shuffle(shuffled.data(), num, 4); + } + for (size_t i = 0; i < num; i++) { + result->push_back(shuffled[i]); + pos++; + } + } else if (command == kCommandPredict) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(cpos, 2, commands_end)); + uint8_t flags = enc[cpos++]; + + size_t width = (flags & 3) + 1; + if (width == 3) return JXL_FAILURE("Invalid width"); + + int order = (flags & 12) >> 2; + if (order == 3) return JXL_FAILURE("Invalid order"); + + uint64_t stride = width; + if (flags & 16) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + stride = DecodeVarInt(enc, size, &cpos); + if (stride < width) { + return JXL_FAILURE("Invalid stride"); + } + } + // If stride * 4 >= result->size(), return failure. The check + // "size == 0 || ((size - 1) >> 2) < stride" corresponds to + // "stride * 4 >= size", but does not suffer from integer overflow. + // This check is more strict than necessary but follows the specification + // and the encoder should ensure this is followed. + if (result->empty() || ((result->size() - 1u) >> 2u) < stride) { + return JXL_FAILURE("Invalid stride"); + } + + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); // in bytes + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + + PaddedBytes shuffled(num); + for (size_t i = 0; i < num; i++) { + shuffled[i] = enc[pos + i]; + } + if (width > 1) Shuffle(shuffled.data(), num, width); + + size_t start = result->size(); + for (size_t i = 0; i < num; i++) { + uint8_t predicted = LinearPredictICCValue(result->data(), start, i, + stride, width, order); + result->push_back(predicted + shuffled[i]); + } + pos += num; + } else if (command == kCommandXYZ) { + AppendKeyword(kXyz_Tag, result); + for (int i = 0; i < 4; i++) result->push_back(0); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 12, size)); + for (size_t i = 0; i < 12; i++) { + result->push_back(enc[pos++]); + } + } else if (command >= kCommandTypeStartFirst && + command < kCommandTypeStartFirst + kNumTypeStrings) { + AppendKeyword(*kTypeStrings[command - kCommandTypeStartFirst], result); + for (size_t i = 0; i < 4; i++) { + result->push_back(0); + } + } else { + return JXL_FAILURE("Unknown command"); + } + } + + if (pos != size) return JXL_FAILURE("Not all data used"); + if (result->size() != osize) return JXL_FAILURE("Invalid result size"); + + return true; +} + +Status ICCReader::Init(BitReader* reader, size_t output_limit) { + JXL_RETURN_IF_ERROR(CheckEOI(reader)); + used_bits_base_ = reader->TotalBitsConsumed(); + if (bits_to_skip_ == 0) { + enc_size_ = U64Coder::Read(reader); + if (enc_size_ > 268435456) { + // Avoid too large memory allocation for invalid file. + return JXL_FAILURE("Too large encoded profile"); + } + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, kNumICCContexts, &code_, &context_map_)); + ans_reader_ = ANSSymbolReader(&code_, reader); + i_ = 0; + decompressed_.resize(std::min(i_ + 0x400, enc_size_)); + for (; i_ < std::min(2, enc_size_); i_++) { + decompressed_[i_] = ans_reader_.ReadHybridUint( + ICCANSContext(i_, i_ > 0 ? decompressed_[i_ - 1] : 0, + i_ > 1 ? decompressed_[i_ - 2] : 0), + reader, context_map_); + } + if (enc_size_ > kPreambleSize) { + for (; i_ < kPreambleSize; i_++) { + decompressed_[i_] = ans_reader_.ReadHybridUint( + ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), + reader, context_map_); + } + JXL_RETURN_IF_ERROR(CheckEOI(reader)); + JXL_RETURN_IF_ERROR( + CheckPreamble(decompressed_, enc_size_, output_limit)); + } + bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_; + } else { + reader->SkipBits(bits_to_skip_); + } + return true; +} + +Status ICCReader::Process(BitReader* reader, PaddedBytes* icc) { + ANSSymbolReader::Checkpoint checkpoint; + size_t saved_i = 0; + auto save = [&]() { + ans_reader_.Save(&checkpoint); + bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_; + saved_i = i_; + }; + save(); + auto check_and_restore = [&]() { + Status status = CheckEOI(reader); + if (!status) { + // not enough bytes. + ans_reader_.Restore(checkpoint); + i_ = saved_i; + return status; + } + return Status(true); + }; + for (; i_ < enc_size_; i_++) { + if (i_ % ANSSymbolReader::kMaxCheckpointInterval == 0 && i_ > 0) { + JXL_RETURN_IF_ERROR(check_and_restore()); + save(); + if ((i_ > 0) && (((i_ & 0xFFFF) == 0))) { + float used_bytes = + (reader->TotalBitsConsumed() - used_bits_base_) / 8.0f; + if (i_ > used_bytes * 256) return JXL_FAILURE("Corrupted stream"); + } + decompressed_.resize(std::min(i_ + 0x400, enc_size_)); + } + JXL_DASSERT(i_ >= 2); + decompressed_[i_] = ans_reader_.ReadHybridUint( + ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), reader, + context_map_); + } + JXL_RETURN_IF_ERROR(check_and_restore()); + bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_; + if (!ans_reader_.CheckANSFinalState()) { + return JXL_FAILURE("Corrupted ICC profile"); + } + + icc->clear(); + return UnpredictICC(decompressed_.data(), decompressed_.size(), icc); +} + +Status ICCReader::CheckEOI(BitReader* reader) { + if (reader->AllReadsWithinBounds()) return true; + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for reading ICC profile"); +} + +Status ReadICC(BitReader* JXL_RESTRICT reader, PaddedBytes* JXL_RESTRICT icc, + size_t output_limit) { + ICCReader icc_reader; + JXL_RETURN_IF_ERROR(icc_reader.Init(reader, output_limit)); + JXL_RETURN_IF_ERROR(icc_reader.Process(reader, icc)); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/icc_codec.h b/media/libjxl/src/lib/jxl/icc_codec.h new file mode 100644 index 000000000..d55b31695 --- /dev/null +++ b/media/libjxl/src/lib/jxl/icc_codec.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ICC_CODEC_H_ +#define LIB_JXL_ICC_CODEC_H_ + +// Compressed representation of ICC profiles. + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Should still be called if `icc.empty()` - if so, writes only 1 bit. +Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +struct ICCReader { + Status Init(BitReader* reader, size_t output_limit); + Status Process(BitReader* reader, PaddedBytes* icc); + void Reset() { + bits_to_skip_ = 0; + decompressed_.clear(); + } + + private: + Status CheckEOI(BitReader* reader); + size_t i_ = 0; + size_t bits_to_skip_ = 0; + size_t used_bits_base_ = 0; + uint64_t enc_size_ = 0; + std::vector context_map_; + ANSCode code_; + ANSSymbolReader ans_reader_; + PaddedBytes decompressed_; +}; + +// `icc` may be empty afterwards - if so, call CreateProfile. Does not append, +// clears any original data that was in icc. +// If `output_limit` is not 0, then returns error if resulting profile would be +// longer than `output_limit` +Status ReadICC(BitReader* JXL_RESTRICT reader, PaddedBytes* JXL_RESTRICT icc, + size_t output_limit = 0); + +// Exposed only for testing +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result); + +// Exposed only for testing +Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result); + +} // namespace jxl + +#endif // LIB_JXL_ICC_CODEC_H_ diff --git a/media/libjxl/src/lib/jxl/icc_codec_common.cc b/media/libjxl/src/lib/jxl/icc_codec_common.cc new file mode 100644 index 000000000..3e6004849 --- /dev/null +++ b/media/libjxl/src/lib/jxl/icc_codec_common.cc @@ -0,0 +1,192 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/icc_codec_common.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { +static uint8_t ByteKind1(uint8_t b) { + if ('a' <= b && b <= 'z') return 0; + if ('A' <= b && b <= 'Z') return 0; + if ('0' <= b && b <= '9') return 1; + if (b == '.' || b == ',') return 1; + if (b == 0) return 2; + if (b == 1) return 3; + if (b < 16) return 4; + if (b == 255) return 6; + if (b > 240) return 5; + return 7; +} + +static uint8_t ByteKind2(uint8_t b) { + if ('a' <= b && b <= 'z') return 0; + if ('A' <= b && b <= 'Z') return 0; + if ('0' <= b && b <= '9') return 1; + if (b == '.' || b == ',') return 1; + if (b < 16) return 2; + if (b > 240) return 3; + return 4; +} + +template +T PredictValue(T p1, T p2, T p3, int order) { + if (order == 0) return p1; + if (order == 1) return 2 * p1 - p2; + if (order == 2) return 3 * p1 - 3 * p2 + p3; + return 0; +} +} // namespace + +uint32_t DecodeUint32(const uint8_t* data, size_t size, size_t pos) { + return pos + 4 > size ? 0 : LoadBE32(data + pos); +} + +void EncodeUint32(size_t pos, uint32_t value, PaddedBytes* data) { + if (pos + 4 > data->size()) return; + StoreBE32(value, data->data() + pos); +} + +void AppendUint32(uint32_t value, PaddedBytes* data) { + data->resize(data->size() + 4); + EncodeUint32(data->size() - 4, value, data); +} + +typedef std::array Tag; + +Tag DecodeKeyword(const uint8_t* data, size_t size, size_t pos) { + if (pos + 4 > size) return {{' ', ' ', ' ', ' '}}; + return {{data[pos], data[pos + 1], data[pos + 2], data[pos + 3]}}; +} + +void EncodeKeyword(const Tag& keyword, uint8_t* data, size_t size, size_t pos) { + if (keyword.size() != 4 || pos + 3 >= size) return; + for (size_t i = 0; i < 4; ++i) data[pos + i] = keyword[i]; +} + +void AppendKeyword(const Tag& keyword, PaddedBytes* data) { + JXL_ASSERT(keyword.size() == 4); + data->append(keyword); +} + +// Checks if a + b > size, taking possible integer overflow into account. +Status CheckOutOfBounds(size_t a, size_t b, size_t size) { + size_t pos = a + b; + if (pos > size) return JXL_FAILURE("Out of bounds"); + if (pos < a) return JXL_FAILURE("Out of bounds"); // overflow happened + return true; +} + +Status CheckIs32Bit(uint64_t v) { + static constexpr const uint64_t kUpper32 = ~static_cast(0xFFFFFFFF); + if ((v & kUpper32) != 0) return JXL_FAILURE("32-bit value expected"); + return true; +} + +PaddedBytes ICCInitialHeaderPrediction() { + PaddedBytes result(kICCHeaderSize); + for (size_t i = 0; i < kICCHeaderSize; i++) { + result[i] = 0; + } + result[8] = 4; + EncodeKeyword(kMntrTag, result.data(), result.size(), 12); + EncodeKeyword(kRgb_Tag, result.data(), result.size(), 16); + EncodeKeyword(kXyz_Tag, result.data(), result.size(), 20); + EncodeKeyword(kAcspTag, result.data(), result.size(), 36); + result[68] = 0; + result[69] = 0; + result[70] = 246; + result[71] = 214; + result[72] = 0; + result[73] = 1; + result[74] = 0; + result[75] = 0; + result[76] = 0; + result[77] = 0; + result[78] = 211; + result[79] = 45; + return result; +} + +void ICCPredictHeader(const uint8_t* icc, size_t size, uint8_t* header, + size_t pos) { + if (pos == 8 && size >= 8) { + header[80] = icc[4]; + header[81] = icc[5]; + header[82] = icc[6]; + header[83] = icc[7]; + } + if (pos == 41 && size >= 41) { + if (icc[40] == 'A') { + header[41] = 'P'; + header[42] = 'P'; + header[43] = 'L'; + } + if (icc[40] == 'M') { + header[41] = 'S'; + header[42] = 'F'; + header[43] = 'T'; + } + } + if (pos == 42 && size >= 42) { + if (icc[40] == 'S' && icc[41] == 'G') { + header[42] = 'I'; + header[43] = ' '; + } + if (icc[40] == 'S' && icc[41] == 'U') { + header[42] = 'N'; + header[43] = 'W'; + } + } +} + +// Predicts a value with linear prediction of given order (0-2), for integers +// with width bytes and given stride in bytes between values. +// The start position is at start + i, and the relevant modulus of i describes +// which byte of the multi-byte integer is being handled. +// The value start + i must be at least stride * 4. +uint8_t LinearPredictICCValue(const uint8_t* data, size_t start, size_t i, + size_t stride, size_t width, int order) { + size_t pos = start + i; + if (width == 1) { + uint8_t p1 = data[pos - stride]; + uint8_t p2 = data[pos - stride * 2]; + uint8_t p3 = data[pos - stride * 3]; + return PredictValue(p1, p2, p3, order); + } else if (width == 2) { + size_t p = start + (i & ~1); + uint16_t p1 = (data[p - stride * 1] << 8) + data[p - stride * 1 + 1]; + uint16_t p2 = (data[p - stride * 2] << 8) + data[p - stride * 2 + 1]; + uint16_t p3 = (data[p - stride * 3] << 8) + data[p - stride * 3 + 1]; + uint16_t pred = PredictValue(p1, p2, p3, order); + return (i & 1) ? (pred & 255) : ((pred >> 8) & 255); + } else { + size_t p = start + (i & ~3); + uint32_t p1 = DecodeUint32(data, pos, p - stride); + uint32_t p2 = DecodeUint32(data, pos, p - stride * 2); + uint32_t p3 = DecodeUint32(data, pos, p - stride * 3); + uint32_t pred = PredictValue(p1, p2, p3, order); + unsigned shiftbytes = 3 - (i & 3); + return (pred >> (shiftbytes * 8)) & 255; + } +} + +size_t ICCANSContext(size_t i, size_t b1, size_t b2) { + if (i <= 128) return 0; + return 1 + ByteKind1(b1) + ByteKind2(b2) * 8; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/icc_codec_common.h b/media/libjxl/src/lib/jxl/icc_codec_common.h new file mode 100644 index 000000000..e91e90866 --- /dev/null +++ b/media/libjxl/src/lib/jxl/icc_codec_common.h @@ -0,0 +1,106 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ICC_CODEC_COMMON_H_ +#define LIB_JXL_ICC_CODEC_COMMON_H_ + +// Compressed representation of ICC profiles. + +#include +#include + +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +static constexpr size_t kICCHeaderSize = 128; + +typedef std::array Tag; + +static const Tag kAcspTag = {{'a', 'c', 's', 'p'}}; +static const Tag kBkptTag = {{'b', 'k', 'p', 't'}}; +static const Tag kBtrcTag = {{'b', 'T', 'R', 'C'}}; +static const Tag kBxyzTag = {{'b', 'X', 'Y', 'Z'}}; +static const Tag kChadTag = {{'c', 'h', 'a', 'd'}}; +static const Tag kChrmTag = {{'c', 'h', 'r', 'm'}}; +static const Tag kCprtTag = {{'c', 'p', 'r', 't'}}; +static const Tag kCurvTag = {{'c', 'u', 'r', 'v'}}; +static const Tag kDescTag = {{'d', 'e', 's', 'c'}}; +static const Tag kDmddTag = {{'d', 'm', 'd', 'd'}}; +static const Tag kDmndTag = {{'d', 'm', 'n', 'd'}}; +static const Tag kGbd_Tag = {{'g', 'b', 'd', ' '}}; +static const Tag kGtrcTag = {{'g', 'T', 'R', 'C'}}; +static const Tag kGxyzTag = {{'g', 'X', 'Y', 'Z'}}; +static const Tag kKtrcTag = {{'k', 'T', 'R', 'C'}}; +static const Tag kKxyzTag = {{'k', 'X', 'Y', 'Z'}}; +static const Tag kLumiTag = {{'l', 'u', 'm', 'i'}}; +static const Tag kMab_Tag = {{'m', 'A', 'B', ' '}}; +static const Tag kMba_Tag = {{'m', 'B', 'A', ' '}}; +static const Tag kMlucTag = {{'m', 'l', 'u', 'c'}}; +static const Tag kMntrTag = {{'m', 'n', 't', 'r'}}; +static const Tag kParaTag = {{'p', 'a', 'r', 'a'}}; +static const Tag kRgb_Tag = {{'R', 'G', 'B', ' '}}; +static const Tag kRtrcTag = {{'r', 'T', 'R', 'C'}}; +static const Tag kRxyzTag = {{'r', 'X', 'Y', 'Z'}}; +static const Tag kSf32Tag = {{'s', 'f', '3', '2'}}; +static const Tag kTextTag = {{'t', 'e', 'x', 't'}}; +static const Tag kVcgtTag = {{'v', 'c', 'g', 't'}}; +static const Tag kWtptTag = {{'w', 't', 'p', 't'}}; +static const Tag kXyz_Tag = {{'X', 'Y', 'Z', ' '}}; + +// Tag names focused on RGB and GRAY monitor profiles +static constexpr size_t kNumTagStrings = 17; +static constexpr const Tag* kTagStrings[kNumTagStrings] = { + &kCprtTag, &kWtptTag, &kBkptTag, &kRxyzTag, &kGxyzTag, &kBxyzTag, + &kKxyzTag, &kRtrcTag, &kGtrcTag, &kBtrcTag, &kKtrcTag, &kChadTag, + &kDescTag, &kChrmTag, &kDmndTag, &kDmddTag, &kLumiTag}; + +static constexpr size_t kCommandTagUnknown = 1; +static constexpr size_t kCommandTagTRC = 2; +static constexpr size_t kCommandTagXYZ = 3; +static constexpr size_t kCommandTagStringFirst = 4; + +// Tag types focused on RGB and GRAY monitor profiles +static constexpr size_t kNumTypeStrings = 8; +static constexpr const Tag* kTypeStrings[kNumTypeStrings] = { + &kXyz_Tag, &kDescTag, &kTextTag, &kMlucTag, + &kParaTag, &kCurvTag, &kSf32Tag, &kGbd_Tag}; + +static constexpr size_t kCommandInsert = 1; +static constexpr size_t kCommandShuffle2 = 2; +static constexpr size_t kCommandShuffle4 = 3; +static constexpr size_t kCommandPredict = 4; +static constexpr size_t kCommandXYZ = 10; +static constexpr size_t kCommandTypeStartFirst = 16; + +static constexpr size_t kFlagBitOffset = 64; +static constexpr size_t kFlagBitSize = 128; + +static constexpr size_t kNumICCContexts = 41; + +uint32_t DecodeUint32(const uint8_t* data, size_t size, size_t pos); +void EncodeUint32(size_t pos, uint32_t value, PaddedBytes* data); +void AppendUint32(uint32_t value, PaddedBytes* data); +Tag DecodeKeyword(const uint8_t* data, size_t size, size_t pos); +void EncodeKeyword(const Tag& keyword, uint8_t* data, size_t size, size_t pos); +void AppendKeyword(const Tag& keyword, PaddedBytes* data); + +// Checks if a + b > size, taking possible integer overflow into account. +Status CheckOutOfBounds(size_t a, size_t b, size_t size); +Status CheckIs32Bit(uint64_t v); + +PaddedBytes ICCInitialHeaderPrediction(); +void ICCPredictHeader(const uint8_t* icc, size_t size, uint8_t* header, + size_t pos); +uint8_t LinearPredictICCValue(const uint8_t* data, size_t start, size_t i, + size_t stride, size_t width, int order); +size_t ICCANSContext(size_t i, size_t b1, size_t b2); + +} // namespace jxl + +#endif // LIB_JXL_ICC_CODEC_COMMON_H_ diff --git a/media/libjxl/src/lib/jxl/icc_codec_test.cc b/media/libjxl/src/lib/jxl/icc_codec_test.cc new file mode 100644 index 000000000..d365471af --- /dev/null +++ b/media/libjxl/src/lib/jxl/icc_codec_test.cc @@ -0,0 +1,207 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/icc_codec.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/enc_icc_codec.h" + +namespace jxl { +namespace { + +void TestProfile(const PaddedBytes& icc) { + BitWriter writer; + ASSERT_TRUE(WriteICC(icc, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + PaddedBytes dec; + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadICC(&reader, &dec)); + ASSERT_TRUE(reader.Close()); + EXPECT_EQ(icc.size(), dec.size()); + if (icc.size() == dec.size()) { + for (size_t i = 0; i < icc.size(); i++) { + EXPECT_EQ(icc[i], dec[i]); + if (icc[i] != dec[i]) break; // One output is enough + } + } +} + +void TestProfile(const std::string& icc) { + PaddedBytes bytes(icc.size()); + for (size_t i = 0; i < icc.size(); i++) { + bytes[i] = icc[i]; + } + TestProfile(bytes); +} + +// Valid profile from one of the images output by the decoder. +static const unsigned char kTestProfile[] = { + 0x00, 0x00, 0x03, 0x80, 0x6c, 0x63, 0x6d, 0x73, 0x04, 0x30, 0x00, 0x00, + 0x6d, 0x6e, 0x74, 0x72, 0x52, 0x47, 0x42, 0x20, 0x58, 0x59, 0x5a, 0x20, + 0x07, 0xe3, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x0f, 0x00, 0x32, 0x00, 0x2e, + 0x61, 0x63, 0x73, 0x70, 0x41, 0x50, 0x50, 0x4c, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xf6, 0xd6, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d, 0x6c, 0x63, 0x6d, 0x73, + 0x5f, 0x07, 0x0d, 0x3e, 0x4d, 0x32, 0xf2, 0x6e, 0x5d, 0x77, 0x26, 0xcc, + 0x23, 0xb0, 0x6a, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, + 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x01, 0x20, 0x00, 0x00, 0x00, 0x42, + 0x63, 0x70, 0x72, 0x74, 0x00, 0x00, 0x01, 0x64, 0x00, 0x00, 0x01, 0x00, + 0x77, 0x74, 0x70, 0x74, 0x00, 0x00, 0x02, 0x64, 0x00, 0x00, 0x00, 0x14, + 0x63, 0x68, 0x61, 0x64, 0x00, 0x00, 0x02, 0x78, 0x00, 0x00, 0x00, 0x2c, + 0x72, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xa4, 0x00, 0x00, 0x00, 0x14, + 0x62, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xb8, 0x00, 0x00, 0x00, 0x14, + 0x67, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xcc, 0x00, 0x00, 0x00, 0x14, + 0x72, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x67, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x62, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x63, 0x68, 0x72, 0x6d, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x24, + 0x64, 0x6d, 0x6e, 0x64, 0x00, 0x00, 0x03, 0x24, 0x00, 0x00, 0x00, 0x28, + 0x64, 0x6d, 0x64, 0x64, 0x00, 0x00, 0x03, 0x4c, 0x00, 0x00, 0x00, 0x32, + 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0x26, + 0x00, 0x00, 0x00, 0x1c, 0x00, 0x52, 0x00, 0x47, 0x00, 0x42, 0x00, 0x5f, + 0x00, 0x44, 0x00, 0x36, 0x00, 0x35, 0x00, 0x5f, 0x00, 0x53, 0x00, 0x52, + 0x00, 0x47, 0x00, 0x5f, 0x00, 0x52, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x5f, + 0x00, 0x37, 0x00, 0x30, 0x00, 0x39, 0x00, 0x00, 0x6d, 0x6c, 0x75, 0x63, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, + 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x00, 0x00, 0x1c, + 0x00, 0x43, 0x00, 0x6f, 0x00, 0x70, 0x00, 0x79, 0x00, 0x72, 0x00, 0x69, + 0x00, 0x67, 0x00, 0x68, 0x00, 0x74, 0x00, 0x20, 0x00, 0x32, 0x00, 0x30, + 0x00, 0x31, 0x00, 0x38, 0x00, 0x20, 0x00, 0x47, 0x00, 0x6f, 0x00, 0x6f, + 0x00, 0x67, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x20, 0x00, 0x4c, 0x00, 0x4c, + 0x00, 0x43, 0x00, 0x2c, 0x00, 0x20, 0x00, 0x43, 0x00, 0x43, 0x00, 0x2d, + 0x00, 0x42, 0x00, 0x59, 0x00, 0x2d, 0x00, 0x53, 0x00, 0x41, 0x00, 0x20, + 0x00, 0x33, 0x00, 0x2e, 0x00, 0x30, 0x00, 0x20, 0x00, 0x55, 0x00, 0x6e, + 0x00, 0x70, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x74, 0x00, 0x65, 0x00, 0x64, + 0x00, 0x20, 0x00, 0x6c, 0x00, 0x69, 0x00, 0x63, 0x00, 0x65, 0x00, 0x6e, + 0x00, 0x73, 0x00, 0x65, 0x00, 0x28, 0x00, 0x68, 0x00, 0x74, 0x00, 0x74, + 0x00, 0x70, 0x00, 0x73, 0x00, 0x3a, 0x00, 0x2f, 0x00, 0x2f, 0x00, 0x63, + 0x00, 0x72, 0x00, 0x65, 0x00, 0x61, 0x00, 0x74, 0x00, 0x69, 0x00, 0x76, + 0x00, 0x65, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x6d, 0x00, 0x6f, + 0x00, 0x6e, 0x00, 0x73, 0x00, 0x2e, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x67, + 0x00, 0x2f, 0x00, 0x6c, 0x00, 0x69, 0x00, 0x63, 0x00, 0x65, 0x00, 0x6e, + 0x00, 0x73, 0x00, 0x65, 0x00, 0x73, 0x00, 0x2f, 0x00, 0x62, 0x00, 0x79, + 0x00, 0x2d, 0x00, 0x73, 0x00, 0x61, 0x00, 0x2f, 0x00, 0x33, 0x00, 0x2e, + 0x00, 0x30, 0x00, 0x2f, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x67, 0x00, 0x61, + 0x00, 0x6c, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x64, 0x00, 0x65, 0x00, 0x29, + 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf6, 0xd6, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d, 0x73, 0x66, 0x33, 0x32, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x0c, 0x42, 0x00, 0x00, 0x05, 0xde, + 0xff, 0xff, 0xf3, 0x25, 0x00, 0x00, 0x07, 0x93, 0x00, 0x00, 0xfd, 0x90, + 0xff, 0xff, 0xfb, 0xa1, 0xff, 0xff, 0xfd, 0xa2, 0x00, 0x00, 0x03, 0xdc, + 0x00, 0x00, 0xc0, 0x6e, 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x6f, 0xa0, 0x00, 0x00, 0x38, 0xf5, 0x00, 0x00, 0x03, 0x90, + 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x24, 0x9f, + 0x00, 0x00, 0x0f, 0x84, 0x00, 0x00, 0xb6, 0xc4, 0x58, 0x59, 0x5a, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0x97, 0x00, 0x00, 0xb7, 0x87, + 0x00, 0x00, 0x18, 0xd9, 0x70, 0x61, 0x72, 0x61, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x38, 0xe4, 0x00, 0x00, 0xe8, 0xf0, + 0x00, 0x00, 0x17, 0x10, 0x00, 0x00, 0x38, 0xe4, 0x00, 0x00, 0x14, 0xbc, + 0x63, 0x68, 0x72, 0x6d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x00, 0x00, 0xa3, 0xd7, 0x00, 0x00, 0x54, 0x7c, 0x00, 0x00, 0x4c, 0xcd, + 0x00, 0x00, 0x99, 0x9a, 0x00, 0x00, 0x26, 0x67, 0x00, 0x00, 0x0f, 0x5c, + 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0x0c, + 0x00, 0x00, 0x00, 0x1c, 0x00, 0x47, 0x00, 0x6f, 0x00, 0x6f, 0x00, 0x67, + 0x00, 0x6c, 0x00, 0x65, 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x49, 0x00, 0x6d, + 0x00, 0x61, 0x00, 0x67, 0x00, 0x65, 0x00, 0x20, 0x00, 0x63, 0x00, 0x6f, + 0x00, 0x64, 0x00, 0x65, 0x00, 0x63, 0x00, 0x00, +}; + +} // namespace + +TEST(IccCodecTest, Icc) { + // Empty string cannot be tested, encoder checks against writing it. + TestProfile("a"); + TestProfile("ab"); + TestProfile("aaaa"); + + { + // Exactly the ICC header size + PaddedBytes profile(128); + for (size_t i = 0; i < 128; i++) { + profile[i] = 0; + } + TestProfile(profile); + } + + { + PaddedBytes profile; + profile.append(kTestProfile, kTestProfile + sizeof(kTestProfile)); + TestProfile(profile); + } + + // Test substrings of full profile + { + PaddedBytes profile; + for (size_t i = 0; i <= 256; i++) { + profile.push_back(kTestProfile[i]); + TestProfile(profile); + } + } +} + +// kTestProfile after encoding with the ICC codec +static const unsigned char kEncodedTestProfile[] = { + 0x1f, 0x8b, 0x1, 0x13, 0x10, 0x0, 0x0, 0x0, 0x20, 0x4c, 0xcc, 0x3, + 0xe7, 0xa0, 0xa5, 0xa2, 0x90, 0xa4, 0x27, 0xe8, 0x79, 0x1d, 0xe3, 0x26, + 0x57, 0x54, 0xef, 0x0, 0xe8, 0x97, 0x2, 0xce, 0xa1, 0xd7, 0x85, 0x16, + 0xb4, 0x29, 0x94, 0x58, 0xf2, 0x56, 0xc0, 0x76, 0xea, 0x23, 0xec, 0x7c, + 0x73, 0x51, 0x41, 0x40, 0x23, 0x21, 0x95, 0x4, 0x75, 0x12, 0xc9, 0xcc, + 0x16, 0xbd, 0xb6, 0x99, 0xad, 0xf8, 0x75, 0x35, 0xb6, 0x42, 0xae, 0xae, + 0xae, 0x86, 0x56, 0xf8, 0xcc, 0x16, 0x30, 0xb3, 0x45, 0xad, 0xd, 0x40, + 0xd6, 0xd1, 0xd6, 0x99, 0x40, 0xbe, 0xe2, 0xdc, 0x31, 0x7, 0xa6, 0xb9, + 0x27, 0x92, 0x38, 0x0, 0x3, 0x5e, 0x2c, 0xbe, 0xe6, 0xfb, 0x19, 0xbf, + 0xf3, 0x6d, 0xbc, 0x4d, 0x64, 0xe5, 0xba, 0x76, 0xde, 0x31, 0x65, 0x66, + 0x14, 0xa6, 0x3a, 0xc5, 0x8f, 0xb1, 0xb4, 0xba, 0x1f, 0xb1, 0xb8, 0xd4, + 0x75, 0xba, 0x18, 0x86, 0x95, 0x3c, 0x26, 0xf6, 0x25, 0x62, 0x53, 0xfd, + 0x9c, 0x94, 0x76, 0xf6, 0x95, 0x2c, 0xb1, 0xfd, 0xdc, 0xc0, 0xe4, 0x3f, + 0xb3, 0xff, 0x67, 0xde, 0xd5, 0x94, 0xcc, 0xb0, 0x83, 0x2f, 0x28, 0x93, + 0x92, 0x3, 0xa1, 0x41, 0x64, 0x60, 0x62, 0x70, 0x80, 0x87, 0xaf, 0xe7, + 0x60, 0x4a, 0x20, 0x23, 0xb3, 0x11, 0x7, 0x38, 0x38, 0xd4, 0xa, 0x66, + 0xb5, 0x93, 0x41, 0x90, 0x19, 0x17, 0x18, 0x60, 0xa5, 0xb, 0x7a, 0x24, + 0xaa, 0x20, 0x81, 0xac, 0xa9, 0xa1, 0x70, 0xa6, 0x12, 0x8a, 0x4a, 0xa3, + 0xa0, 0xf9, 0x9a, 0x97, 0xe7, 0xa8, 0xac, 0x8, 0xa8, 0xc4, 0x2a, 0x86, + 0xa7, 0x69, 0x1e, 0x67, 0xe6, 0xbe, 0xa4, 0xd3, 0xff, 0x91, 0x61, 0xf6, + 0x8a, 0xe6, 0xb5, 0xb3, 0x61, 0x9f, 0x19, 0x17, 0x98, 0x27, 0x6b, 0xe9, + 0x8, 0x98, 0xe1, 0x21, 0x4a, 0x9, 0xb5, 0xd7, 0xca, 0xfa, 0x94, 0xd0, + 0x69, 0x1a, 0xeb, 0x52, 0x1, 0x4e, 0xf5, 0xf6, 0xdf, 0x7f, 0xe7, 0x29, + 0x70, 0xee, 0x4, 0xda, 0x2f, 0xa4, 0xff, 0xfe, 0xbb, 0x6f, 0xa8, 0xff, + 0xfe, 0xdb, 0xaf, 0x8, 0xf6, 0x72, 0xa1, 0x40, 0x5d, 0xf0, 0x2d, 0x8, + 0x82, 0x5b, 0x87, 0xbd, 0x10, 0x8, 0xe9, 0x7, 0xee, 0x4b, 0x80, 0xda, + 0x4a, 0x4, 0xc5, 0x5e, 0xa0, 0xb7, 0x1e, 0x60, 0xb0, 0x59, 0x76, 0x60, + 0xb, 0x2e, 0x19, 0x8a, 0x2e, 0x1c, 0xe6, 0x6, 0x20, 0xb8, 0x64, 0x18, + 0x2a, 0xcf, 0x51, 0x94, 0xd4, 0xee, 0xc3, 0xfe, 0x39, 0x74, 0xd4, 0x2b, + 0x48, 0xc9, 0x83, 0x4c, 0x9b, 0xd0, 0x4c, 0x35, 0x10, 0xe3, 0x9, 0xf7, + 0x72, 0xf0, 0x7a, 0xe, 0xbf, 0x7d, 0x36, 0x2e, 0x19, 0x7e, 0x3f, 0xc, + 0xf7, 0x93, 0xe7, 0xf4, 0x1d, 0x32, 0xc6, 0xb0, 0x89, 0xad, 0xe0, 0x28, + 0xc1, 0xa7, 0x59, 0xe3, 0x0, +}; + +// Tests that the decoded kEncodedTestProfile matches kTestProfile. +TEST(IccCodecTest, EncodedIccProfile) { + jxl::BitReader reader(jxl::Span(kEncodedTestProfile, + sizeof(kEncodedTestProfile))); + jxl::PaddedBytes dec; + ASSERT_TRUE(ReadICC(&reader, &dec)); + ASSERT_TRUE(reader.Close()); + EXPECT_EQ(sizeof(kTestProfile), dec.size()); + if (sizeof(kTestProfile) == dec.size()) { + for (size_t i = 0; i < dec.size(); i++) { + EXPECT_EQ(kTestProfile[i], dec[i]); + if (kTestProfile[i] != dec[i]) break; // One output is enough + } + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/image.cc b/media/libjxl/src/lib/jxl/image.cc new file mode 100644 index 000000000..34b315d85 --- /dev/null +++ b/media/libjxl/src/lib/jxl/image.cc @@ -0,0 +1,247 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image.h" + +#include // swap + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/image.cc" +#include +#include + +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { + +namespace HWY_NAMESPACE { +size_t GetVectorSize() { return HWY_LANES(uint8_t); } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(GetVectorSize); // Local function. + +size_t VectorSize() { + static size_t bytes = HWY_DYNAMIC_DISPATCH(GetVectorSize)(); + return bytes; +} + +// Returns distance [bytes] between the start of two consecutive rows, a +// multiple of vector/cache line size but NOT CacheAligned::kAlias - see below. +size_t BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 0) { + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = std::max(vec_size, CacheAligned::kAlignment); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % CacheAligned::kAlias == 0) { + bytes_per_row += align; + } + + JXL_ASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +} // namespace + +PlaneBase::PlaneBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + orig_xsize_(static_cast(xsize)), + orig_ysize_(static_cast(ysize)) { + // (Can't profile CacheAligned itself because it is used by profiler.h) + PROFILER_FUNC; + + JXL_CHECK(xsize == xsize_); + JXL_CHECK(ysize == ysize_); + + JXL_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + bytes_ = AllocateArray(bytes_per_row_ * ysize); + JXL_CHECK(bytes_.get()); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +void PlaneBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if defined(MEMORY_SANITIZER) || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); + if (vec_size == 0) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* JXL_RESTRICT row = static_cast(VoidRow(y)); +#if defined(__clang__) && (__clang_major__ <= 6) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + std::fill(row, msan::kSanitizerSentinelByte, initialize_size); +#else + memset(row + valid_size, msan::kSanitizerSentinelByte, + initialize_size - valid_size); +#endif // clang6 + } +#endif // MEMORY_SANITIZER +} + +void PlaneBase::Swap(PlaneBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(orig_xsize_, other.orig_xsize_); + std::swap(orig_ysize_, other.orig_ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +Image3F PadImageMirror(const Image3F& in, const size_t xborder, + const size_t yborder) { + size_t xsize = in.xsize(); + size_t ysize = in.ysize(); + Image3F out(xsize + 2 * xborder, ysize + 2 * yborder); + if (xborder > xsize || yborder > ysize) { + for (size_t c = 0; c < 3; c++) { + for (int32_t y = 0; y < static_cast(out.ysize()); y++) { + float* row_out = out.PlaneRow(c, y); + const float* row_in = in.PlaneRow( + c, Mirror(y - static_cast(yborder), in.ysize())); + for (int32_t x = 0; x < static_cast(out.xsize()); x++) { + int32_t xin = Mirror(x - static_cast(xborder), in.xsize()); + row_out[x] = row_in[xin]; + } + } + } + return out; + } + CopyImageTo(in, Rect(xborder, yborder, xsize, ysize), &out); + for (size_t c = 0; c < 3; c++) { + // Horizontal pad. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xborder; x++) { + out.PlaneRow(c, y + yborder)[x] = + in.ConstPlaneRow(c, y)[xborder - x - 1]; + out.PlaneRow(c, y + yborder)[x + xsize + xborder] = + in.ConstPlaneRow(c, y)[xsize - 1 - x]; + } + } + // Vertical pad. + for (size_t y = 0; y < yborder; y++) { + memcpy(out.PlaneRow(c, y), out.ConstPlaneRow(c, 2 * yborder - 1 - y), + out.xsize() * sizeof(float)); + memcpy(out.PlaneRow(c, y + ysize + yborder), + out.ConstPlaneRow(c, ysize + yborder - 1 - y), + out.xsize() * sizeof(float)); + } + } + return out; +} + +void PadImageToBlockMultipleInPlace(Image3F* JXL_RESTRICT in) { + PROFILER_FUNC; + const size_t xsize_orig = in->xsize(); + const size_t ysize_orig = in->ysize(); + const size_t xsize = RoundUpToBlockDim(xsize_orig); + const size_t ysize = RoundUpToBlockDim(ysize_orig); + // Expands image size to the originally-allocated size. + in->ShrinkTo(xsize, ysize); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize_orig; y++) { + float* JXL_RESTRICT row = in->PlaneRow(c, y); + for (size_t x = xsize_orig; x < xsize; x++) { + row[x] = row[xsize_orig - 1]; + } + } + const float* JXL_RESTRICT row_src = in->ConstPlaneRow(c, ysize_orig - 1); + for (size_t y = ysize_orig; y < ysize; y++) { + memcpy(in->PlaneRow(c, y), row_src, xsize * sizeof(float)); + } + } +} + +static void DownsampleImage(const ImageF& input, size_t factor, + ImageF* output) { + JXL_ASSERT(factor != 1); + output->ShrinkTo(DivCeil(input.xsize(), factor), + DivCeil(input.ysize(), factor)); + size_t in_stride = input.PixelsPerRow(); + for (size_t y = 0; y < output->ysize(); y++) { + float* row_out = output->Row(y); + const float* row_in = input.Row(factor * y); + for (size_t x = 0; x < output->xsize(); x++) { + size_t cnt = 0; + float sum = 0; + for (size_t iy = 0; iy < factor && iy + factor * y < input.ysize(); + iy++) { + for (size_t ix = 0; ix < factor && ix + factor * x < input.xsize(); + ix++) { + sum += row_in[iy * in_stride + x * factor + ix]; + cnt++; + } + } + row_out[x] = sum / cnt; + } + } +} + +void DownsampleImage(ImageF* image, size_t factor) { + // Allocate extra space to avoid a reallocation when padding. + ImageF downsampled(DivCeil(image->xsize(), factor) + kBlockDim, + DivCeil(image->ysize(), factor) + kBlockDim); + DownsampleImage(*image, factor, &downsampled); + *image = std::move(downsampled); +} + +void DownsampleImage(Image3F* opsin, size_t factor) { + JXL_ASSERT(factor != 1); + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), factor) + kBlockDim, + DivCeil(opsin->ysize(), factor) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + for (size_t c = 0; c < 3; c++) { + DownsampleImage(opsin->Plane(c), factor, &downsampled.Plane(c)); + } + *opsin = std::move(downsampled); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/image.h b/media/libjxl/src/lib/jxl/image.h new file mode 100644 index 000000000..5fe2c558c --- /dev/null +++ b/media/libjxl/src/lib/jxl/image.h @@ -0,0 +1,494 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_H_ +#define LIB_JXL_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include +#include +#include +#include + +#include +#include +#include // std::move + +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" + +namespace jxl { + +// Type-independent parts of Plane<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct PlaneBase { + PlaneBase() + : xsize_(0), + ysize_(0), + orig_xsize_(0), + orig_ysize_(0), + bytes_per_row_(0), + bytes_(nullptr) {} + PlaneBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + PlaneBase(const PlaneBase& other) = delete; + PlaneBase& operator=(const PlaneBase& other) = delete; + + // Move constructor (required for returning Image from function) + PlaneBase(PlaneBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + PlaneBase& operator=(PlaneBase&& other) noexcept = default; + + void Swap(PlaneBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. May also be used to + // un-shrink the image. Caller is responsible for ensuring xsize/ysize are <= + // the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + JXL_CHECK(xsize <= orig_xsize_); + JXL_CHECK(ysize <= orig_ysize_); + xsize_ = static_cast(xsize); + ysize_ = static_cast(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + JXL_INLINE size_t xsize() const { return xsize_; } + JXL_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + JXL_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + JXL_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast(JXL_ASSUME_ALIGNED(p, 64)); + } + JXL_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast(JXL_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + JXL_INLINE void* VoidRow(const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (y >= ysize_) { + JXL_ABORT("Row(%" PRIu64 ") in (%u x %u) image\n", (uint64_t)y, xsize_, + ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return JXL_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x = xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + uint32_t orig_xsize_; + uint32_t orig_ysize_; + size_t bytes_per_row_; // Includes padding. + CacheAlignedUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template +class Plane : public PlaneBase { + public: + using T = ComponentType; + static constexpr size_t kNumPlanes = 1; + + Plane() = default; + Plane(const size_t xsize, const size_t ysize) + : PlaneBase(xsize, ysize, sizeof(T)) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + JXL_INLINE T* Row(const size_t y) { return static_cast(VoidRow(y)); } + + // Returns pointer to const (see above). + JXL_INLINE const T* Row(const size_t y) const { + return static_cast(VoidRow(y)); + } + + // Documents that the access is const. + JXL_INLINE const T* ConstRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + JXL_INLINE intptr_t PixelsPerRow() const { + return static_cast(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageSB = Plane; +using ImageB = Plane; +using ImageS = Plane; // signed integer or half-float +using ImageU = Plane; +using ImageI = Plane; +using ImageF = Plane; +using ImageD = Plane; + +// Also works for Image3 and mixed argument types. +template +bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +template +class Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions (e.g. color transform and quantization field). +// Can compare using SameSize(rect1, rect2). +template +class RectT { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr RectT(T xbegin, T ybegin, size_t xsize_max, size_t ysize_max, + T xend, T yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr RectT(T xbegin, T ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image/plane/ImageBundle etc. + template + explicit RectT(const ImageT& image) + : RectT(0, 0, image.xsize(), image.ysize()) {} + + RectT() : RectT(0, 0, 0, 0) {} + + RectT(const RectT&) = default; + RectT& operator=(const RectT&) = default; + + // Construct a subrect that resides in an image/plane/ImageBundle etc. + template + RectT Crop(const ImageT& image) const { + return Intersection(RectT(image)); + } + + // Construct a subrect that resides in the [0, ysize) x [0, xsize) region of + // the current rect. + RectT Crop(size_t area_xsize, size_t area_ysize) const { + return Intersection(RectT(0, 0, area_xsize, area_ysize)); + } + + // Returns a rect that only contains `num` lines with offset `y` from `y0()`. + RectT Lines(size_t y, size_t num) const { + JXL_DASSERT(y + num <= ysize_); + return RectT(x0_, y0_ + y, xsize_, num); + } + + RectT Line(size_t y) const { return Lines(y, 1); } + + JXL_MUST_USE_RESULT RectT Intersection(const RectT& other) const { + return RectT(std::max(x0_, other.x0_), std::max(y0_, other.y0_), xsize_, + ysize_, std::min(x1(), other.x1()), + std::min(y1(), other.y1())); + } + + JXL_MUST_USE_RESULT RectT Translate(int64_t x_offset, + int64_t y_offset) const { + return RectT(x0_ + x_offset, y0_ + y_offset, xsize_, ysize_); + } + + template + V* Row(Plane* image, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image->Row(y + y0_) + x0_; + } + + template + const V* Row(const Plane* image, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image->Row(y + y0_) + x0_; + } + + template + V* PlaneRow(Image3* image, const size_t c, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image->PlaneRow(c, y + y0_) + x0_; + } + + template + const V* ConstRow(const Plane& image, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image.ConstRow(y + y0_) + x0_; + } + + template + const V* ConstPlaneRow(const Image3& image, size_t c, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + bool IsInside(const RectT& other) const { + return x0_ >= other.x0() && x1() <= other.x1() && y0_ >= other.y0() && + y1() <= other.y1(); + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Plane or Image3; however if ImageT is Rect, results are nonsensical. + template + bool IsInside(const ImageT& image) const { + return IsInside(RectT(image)); + } + + T x0() const { return x0_; } + T y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + T x1() const { return x0_ + xsize_; } + T y1() const { return y0_ + ysize_; } + + RectT ShiftLeft(size_t shiftx, size_t shifty) const { + return RectT(x0_ * (1 << shiftx), y0_ * (1 << shifty), xsize_ << shiftx, + ysize_ << shifty); + } + RectT ShiftLeft(size_t shift) const { return ShiftLeft(shift, shift); } + + // Requires x0(), y0() to be multiples of 1< CeilShiftRight(size_t shiftx, size_t shifty) const { + JXL_ASSERT(x0_ % (1 << shiftx) == 0); + JXL_ASSERT(y0_ % (1 << shifty) == 0); + return RectT(x0_ / (1 << shiftx), y0_ / (1 << shifty), + DivCeil(xsize_, T{1} << shiftx), + DivCeil(ysize_, T{1} << shifty)); + } + RectT CeilShiftRight(std::pair shift) const { + return CeilShiftRight(shift.first, shift.second); + } + RectT CeilShiftRight(size_t shift) const { + return CeilShiftRight(shift, shift); + } + + template + RectT As() const { + return RectT(U(x0_), U(y0_), U(xsize_), U(ysize_)); + } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(T begin, size_t size_max, T end) { + return (static_cast(begin + size_max) <= end) + ? size_max + : (end > begin ? end - begin : 0); + } + + T x0_; + T y0_; + + size_t xsize_; + size_t ysize_; +}; + +template +std::string Description(RectT r) { + std::ostringstream os; + os << "[" << r.x0() << ".." << r.x1() << ")x" + << "[" << r.y0() << ".." << r.y1() << ")"; + return os.str(); +} + +using Rect = RectT; + +// Currently, we abuse Image to either refer to an image that owns its storage +// or one that doesn't. In similar vein, we abuse Image* function parameters to +// either mean "assign to me" or "fill the provided image with data". +// Hopefully, the "assign to me" meaning will go away and most images in the +// codebase will not be backed by own storage. When this happens we can redesign +// Image to be a non-storage-holding view class and introduce BackedImage in +// those places that actually need it. + +// NOTE: we can't use Image as a view because invariants are violated +// (alignment and the presence of padding before/after each "row"). + +// A bundle of 3 same-sized images. Typically constructed by moving from three +// rvalue references to Image. To overwrite an existing Image3 using +// single-channel producers, we also need access to Image*. Constructing +// temporary non-owning Image pointing to one plane of an existing Image3 risks +// dangling references, especially if the wrapper is moved. Therefore, we +// store an array of Image (which are compact enough that size is not a concern) +// and provide Plane+Row accessors. +template +class Image3 { + public: + using T = ComponentType; + using PlaneT = jxl::Plane; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{PlaneT(), PlaneT(), PlaneT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{PlaneT(xsize, ysize), PlaneT(xsize, ysize), + PlaneT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(PlaneT&& plane0, PlaneT&& plane1, PlaneT&& plane2) { + JXL_CHECK(SameSize(plane0, plane1)); + JXL_CHECK(SameSize(plane0, plane2)); + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + // Returns row pointer; usage: PlaneRow(idx_plane, y)[x] = val. + JXL_INLINE T* PlaneRow(const size_t c, const size_t y) { + // Custom implementation instead of calling planes_[c].Row ensures only a + // single multiplication is needed for PlaneRow(0..2, y). + PlaneRowBoundsCheck(c, y); + const size_t row_offset = y * planes_[0].bytes_per_row(); + void* row = planes_[c].bytes() + row_offset; + return static_cast(JXL_ASSUME_ALIGNED(row, 64)); + } + + // Returns const row pointer; usage: val = PlaneRow(idx_plane, y)[x]. + JXL_INLINE const T* PlaneRow(const size_t c, const size_t y) const { + PlaneRowBoundsCheck(c, y); + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast(JXL_ASSUME_ALIGNED(row, 64)); + } + + // Returns const row pointer, even if called from a non-const Image3. + JXL_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + PlaneRowBoundsCheck(c, y); + return PlaneRow(c, y); + } + + JXL_INLINE const PlaneT& Plane(size_t idx) const { return planes_[idx]; } + + JXL_INLINE PlaneT& Plane(size_t idx) { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. May also be used to + // un-shrink the image. Caller is responsible for ensuring xsize/ysize are <= + // the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (PlaneT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + JXL_INLINE size_t xsize() const { return planes_[0].xsize(); } + JXL_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + JXL_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + JXL_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + void PlaneRowBoundsCheck(const size_t c, const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (c >= kNumPlanes || y >= ysize()) { + JXL_ABORT("PlaneRow(%" PRIu64 ", %" PRIu64 ") in (%" PRIu64 " x %" PRIu64 + ") image\n", + static_cast(c), static_cast(y), + static_cast(xsize()), static_cast(ysize())); + } +#endif + } + + private: + PlaneT planes_[kNumPlanes]; +}; + +using Image3B = Image3; +using Image3S = Image3; +using Image3U = Image3; +using Image3I = Image3; +using Image3F = Image3; +using Image3D = Image3; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_H_ diff --git a/media/libjxl/src/lib/jxl/image_bundle.cc b/media/libjxl/src/lib/jxl/image_bundle.cc new file mode 100644 index 000000000..dfbc02ddb --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_bundle.cc @@ -0,0 +1,128 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_bundle.h" + +#include +#include + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/luminance.h" + +namespace jxl { + +void ImageBundle::ShrinkTo(size_t xsize, size_t ysize) { + if (HasColor()) color_.ShrinkTo(xsize, ysize); + for (ImageF& ec : extra_channels_) { + ec.ShrinkTo(xsize, ysize); + } +} + +// Called by all other SetFrom*. +void ImageBundle::SetFromImage(Image3F&& color, + const ColorEncoding& c_current) { + JXL_CHECK(color.xsize() != 0 && color.ysize() != 0); + JXL_CHECK(metadata_->color_encoding.IsGray() == c_current.IsGray()); + color_ = std::move(color); + c_current_ = c_current; + VerifySizes(); +} + +void ImageBundle::VerifyMetadata() const { + JXL_CHECK(!c_current_.ICC().empty()); + JXL_CHECK(metadata_->color_encoding.IsGray() == IsGray()); + + if (metadata_->HasAlpha() && alpha().xsize() == 0) { + JXL_ABORT("MD alpha_bits %u IB alpha %" PRIuS " x %" PRIuS "\n", + metadata_->GetAlphaBits(), alpha().xsize(), alpha().ysize()); + } + const uint32_t alpha_bits = metadata_->GetAlphaBits(); + JXL_CHECK(alpha_bits <= 32); + + // metadata_->num_extra_channels may temporarily differ from + // extra_channels_.size(), e.g. after SetAlpha. They are synced by the next + // call to VisitFields. +} + +void ImageBundle::VerifySizes() const { + const size_t xs = xsize(); + const size_t ys = ysize(); + + if (HasExtraChannels()) { + JXL_CHECK(xs != 0 && ys != 0); + for (const ImageF& ec : extra_channels_) { + JXL_CHECK(ec.xsize() == xs); + JXL_CHECK(ec.ysize() == ys); + } + } +} + +size_t ImageBundle::DetectRealBitdepth() const { + return metadata_->bit_depth.bits_per_sample; + + // TODO(lode): let this function return lower bit depth if possible, e.g. + // return 8 bits in case the original image came from a 16-bit PNG that + // was in fact representable as 8-bit PNG. Ensure that the implementation + // returns 16 if e.g. two consecutive 16-bit values appeared in the original + // image (such as 32768 and 32769), take into account that e.g. the values + // 3-bit can represent is not a superset of the values 2-bit can represent, + // and there may be slight imprecisions in the floating point image. +} + +const ImageF& ImageBundle::black() const { + JXL_ASSERT(HasBlack()); + const size_t ec = metadata_->Find(ExtraChannel::kBlack) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} +const ImageF& ImageBundle::alpha() const { + JXL_ASSERT(HasAlpha()); + const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} +ImageF* ImageBundle::alpha() { + JXL_ASSERT(HasAlpha()); + const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return &extra_channels_[ec]; +} + +void ImageBundle::SetAlpha(ImageF&& alpha, bool alpha_is_premultiplied) { + const ExtraChannelInfo* eci = metadata_->Find(ExtraChannel::kAlpha); + // Must call SetAlphaBits first, otherwise we don't know which channel index + JXL_CHECK(eci != nullptr); + JXL_CHECK(alpha.xsize() != 0 && alpha.ysize() != 0); + JXL_CHECK(eci->alpha_associated == alpha_is_premultiplied); + if (extra_channels_.size() < metadata_->extra_channel_info.size()) { + // TODO(jon): get rid of this case + extra_channels_.insert( + extra_channels_.begin() + (eci - metadata_->extra_channel_info.data()), + std::move(alpha)); + } else { + extra_channels_[eci - metadata_->extra_channel_info.data()] = + std::move(alpha); + } + // num_extra_channels is automatically set in visitor + VerifySizes(); +} + +void ImageBundle::SetExtraChannels(std::vector&& extra_channels) { + for (const ImageF& plane : extra_channels) { + JXL_CHECK(plane.xsize() != 0 && plane.ysize() != 0); + } + extra_channels_ = std::move(extra_channels); + VerifySizes(); +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/image_bundle.h b/media/libjxl/src/lib/jxl/image_bundle.h new file mode 100644 index 000000000..d233abbbc --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_bundle.h @@ -0,0 +1,256 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_BUNDLE_H_ +#define LIB_JXL_IMAGE_BUNDLE_H_ + +// The main image or frame consists of a bundle of associated images. + +#include +#include + +#include + +#include "jxl/cms_interface.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +// A bundle of color/alpha/depth/plane images. +class ImageBundle { + public: + // Uninitialized state for use as output parameter. + ImageBundle() : metadata_(nullptr) {} + // Caller is responsible for setting metadata before calling Set*. + explicit ImageBundle(const ImageMetadata* metadata) : metadata_(metadata) {} + + // Move-only (allows storing in std::vector). + ImageBundle(ImageBundle&&) = default; + ImageBundle& operator=(ImageBundle&&) = default; + + ImageBundle Copy() const { + ImageBundle copy(metadata_); + copy.color_ = CopyImage(color_); + copy.c_current_ = c_current_; + copy.extra_channels_.reserve(extra_channels_.size()); + for (const ImageF& plane : extra_channels_) { + copy.extra_channels_.emplace_back(CopyImage(plane)); + } + + copy.jpeg_data = + jpeg_data ? make_unique(*jpeg_data) : nullptr; + copy.color_transform = color_transform; + copy.chroma_subsampling = chroma_subsampling; + + return copy; + } + + // -- SIZE + + size_t xsize() const { + if (IsJPEG()) return jpeg_data->width; + if (color_.xsize() != 0) return color_.xsize(); + return extra_channels_.empty() ? 0 : extra_channels_[0].xsize(); + } + size_t ysize() const { + if (IsJPEG()) return jpeg_data->height; + if (color_.ysize() != 0) return color_.ysize(); + return extra_channels_.empty() ? 0 : extra_channels_[0].ysize(); + } + void ShrinkTo(size_t xsize, size_t ysize); + + // sizes taking orientation into account + size_t oriented_xsize() const { + if (static_cast(metadata_->GetOrientation()) > 4) { + return ysize(); + } else { + return xsize(); + } + } + size_t oriented_ysize() const { + if (static_cast(metadata_->GetOrientation()) > 4) { + return xsize(); + } else { + return ysize(); + } + } + + // -- COLOR + + // Whether color() is valid/usable. Returns true in most cases. Even images + // with spot colors (one example of when !planes().empty()) typically have a + // part that can be converted to RGB. + bool HasColor() const { return color_.xsize() != 0; } + + // For resetting the size when switching from a reference to main frame. + void RemoveColor() { color_ = Image3F(); } + + // Do not use if !HasColor(). + const Image3F& color() const { + // If this fails, Set* was not called - perhaps because decoding failed? + JXL_DASSERT(HasColor()); + return color_; + } + + // Do not use if !HasColor(). + Image3F* color() { + JXL_DASSERT(HasColor()); + return &color_; + } + + // If c_current.IsGray(), all planes must be identical. NOTE: c_current is + // independent of metadata()->color_encoding, which is the original, whereas + // a decoder might return pixels in a different c_current. + // This only sets the color channels, you must also make extra channels + // match the amount that is in the metadata. + void SetFromImage(Image3F&& color, const ColorEncoding& c_current); + + // -- COLOR ENCODING + + const ColorEncoding& c_current() const { return c_current_; } + + // Returns whether the color image has identical planes. Once established by + // Set*, remains unchanged until a subsequent Set* or TransformTo. + bool IsGray() const { return c_current_.IsGray(); } + + bool IsSRGB() const { return c_current_.IsSRGB(); } + bool IsLinearSRGB() const { + return c_current_.white_point == WhitePoint::kD65 && + c_current_.primaries == Primaries::kSRGB && c_current_.tf.IsLinear(); + } + + // Set the c_current profile without doing any transformation, e.g. if the + // transformation was already applied. + void OverrideProfile(const ColorEncoding& new_c_current) { + c_current_ = new_c_current; + } + + // TODO(lode): TransformTo and CopyTo are implemented in enc_image_bundle.cc, + // move these functions out of this header file and class, to + // enc_image_bundle.h. + + // Transforms color to c_desired and sets c_current to c_desired. Alpha and + // metadata remains unchanged. + Status TransformTo(const ColorEncoding& c_desired, const JxlCmsInterface& cms, + ThreadPool* pool = nullptr); + // Copies this:rect, converts to c_desired, and allocates+fills out. + Status CopyTo(const Rect& rect, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, Image3F* out, + ThreadPool* pool = nullptr) const; + + // Detect 'real' bit depth, which can be lower than nominal bit depth + // (this is common in PNG), returns 'real' bit depth + size_t DetectRealBitdepth() const; + + // -- ALPHA + + void SetAlpha(ImageF&& alpha, bool alpha_is_premultiplied); + bool HasAlpha() const { + return metadata_->Find(ExtraChannel::kAlpha) != nullptr; + } + bool AlphaIsPremultiplied() const { + const ExtraChannelInfo* eci = metadata_->Find(ExtraChannel::kAlpha); + return (eci == nullptr) ? false : eci->alpha_associated; + } + const ImageF& alpha() const; + ImageF* alpha(); + + // -- EXTRA CHANNELS + bool HasBlack() const { + return metadata_->Find(ExtraChannel::kBlack) != nullptr; + } + const ImageF& black() const; + + // Extra channels of unknown interpretation (e.g. spot colors). + void SetExtraChannels(std::vector&& extra_channels); + void ClearExtraChannels() { extra_channels_.clear(); } + bool HasExtraChannels() const { return !extra_channels_.empty(); } + const std::vector& extra_channels() const { return extra_channels_; } + std::vector& extra_channels() { return extra_channels_; } + + const ImageMetadata* metadata() const { return metadata_; } + + void VerifyMetadata() const; + + void SetDecodedBytes(size_t decoded_bytes) { decoded_bytes_ = decoded_bytes; } + size_t decoded_bytes() const { return decoded_bytes_; } + + // -- JPEG transcoding: + + // Returns true if image does or will represent quantized DCT-8 coefficients, + // stored in 8x8 pixel regions. + bool IsJPEG() const { +#if JPEGXL_ENABLE_TRANSCODE_JPEG + return jpeg_data != nullptr; +#else // JPEGXL_ENABLE_TRANSCODE_JPEG + return false; +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + } + + std::unique_ptr jpeg_data; + // these fields are used to signal the input JPEG color space + // NOTE: JPEG doesn't actually provide a way to determine whether YCbCr was + // applied or not. + ColorTransform color_transform = ColorTransform::kNone; + YCbCrChromaSubsampling chroma_subsampling; + + FrameOrigin origin{0, 0}; + + // Animation-related information, corresponding to the timecode and duration + // fields of the jxl::AnimationFrame of the jxl::FrameHeader. + // TODO(lode): ImageBundle is used here to carry the information from + // jxl::FrameHeader, consider instead passing a jxl::FrameHeader directly to + // EncodeFrame or having a field of that type here. + uint32_t duration = 0; + uint32_t timecode = 0; + + // TODO(lode): these fields do not match the JXL frame header, it should be + // possible to specify up to 4 (3 if nonzero duration) slots to save this + // frame as reference (see save_as_reference). + bool use_for_next_frame = false; + bool blend = false; + BlendMode blendmode = BlendMode::kBlend; + + std::string name; + + private: + // Called after any Set* to ensure their sizes are compatible. + void VerifySizes() const; + + // Required for TransformTo so that an ImageBundle is self-sufficient. Always + // points to the same thing, but cannot be const-pointer because that prevents + // the compiler from generating a move ctor. + const ImageMetadata* metadata_; + + // Initialized by Set*: + Image3F color_; // If empty, planes_ is not; all planes equal if IsGray(). + ColorEncoding c_current_; // of color_ + + // Initialized by SetPlanes; size = ImageMetadata.num_extra_channels + std::vector extra_channels_; + + // How many bytes of the input were actually read. + size_t decoded_bytes_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_BUNDLE_H_ diff --git a/media/libjxl/src/lib/jxl/image_bundle_test.cc b/media/libjxl/src/lib/jxl/image_bundle_test.cc new file mode 100644 index 000000000..6de2e49db --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_bundle_test.cc @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_bundle.h" + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" + +namespace jxl { +namespace { + +TEST(ImageBundleTest, ExtraChannelName) { + AuxOut aux_out; + BitWriter writer; + BitWriter::Allotment allotment(&writer, 99); + + ImageMetadata metadata; + ExtraChannelInfo eci; + eci.type = ExtraChannel::kBlack; + eci.name = "testK"; + metadata.extra_channel_info.push_back(std::move(eci)); + ASSERT_TRUE(WriteImageMetadata(metadata, &writer, /*layer=*/0, &aux_out)); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, /*layer=*/0, &aux_out); + + BitReader reader(writer.GetSpan()); + ImageMetadata metadata_out; + ASSERT_TRUE(ReadImageMetadata(&reader, &metadata_out)); + EXPECT_TRUE(reader.Close()); + EXPECT_EQ("testK", metadata_out.Find(ExtraChannel::kBlack)->name); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/image_metadata.cc b/media/libjxl/src/lib/jxl/image_metadata.cc new file mode 100644 index 000000000..7a1ee1c6b --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_metadata.cc @@ -0,0 +1,477 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_metadata.h" + +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/fields.h" + +namespace jxl { +BitDepth::BitDepth() { Bundle::Init(this); } +Status BitDepth::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &floating_point_sample)); + // The same fields (bits_per_sample and exponent_bits_per_sample) are read + // in a different way depending on floating_point_sample's value. It's still + // default-initialized correctly so using visitor->Conditional is not + // required. + if (!floating_point_sample) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(8), Val(10), Val(12), BitsOffset(6, 1), 8, &bits_per_sample)); + exponent_bits_per_sample = 0; + } else { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(32), Val(16), Val(24), BitsOffset(6, 1), 32, &bits_per_sample)); + // The encoded value is exponent_bits_per_sample - 1, encoded in 3 bits + // so the value can be in range [1, 8]. + const uint32_t offset = 1; + exponent_bits_per_sample -= offset; + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, 8 - offset, &exponent_bits_per_sample)); + exponent_bits_per_sample += offset; + } + + // Error-checking for floating point ranges. + if (floating_point_sample) { + if (exponent_bits_per_sample < 2 || exponent_bits_per_sample > 8) { + return JXL_FAILURE("Invalid exponent_bits_per_sample: %u", + exponent_bits_per_sample); + } + int mantissa_bits = + static_cast(bits_per_sample) - exponent_bits_per_sample - 1; + if (mantissa_bits < 2 || mantissa_bits > 23) { + return JXL_FAILURE("Invalid bits_per_sample: %u", bits_per_sample); + } + } else { + if (bits_per_sample > 31) { + return JXL_FAILURE("Invalid bits_per_sample: %u", bits_per_sample); + } + } + return true; +} + +std::string BitDepth::DebugString() const { + std::ostringstream os; + os << (floating_point_sample ? "F" : "U"); + os << bits_per_sample; + if (floating_point_sample) os << "." << exponent_bits_per_sample; + return os.str(); +} + +CustomTransformData::CustomTransformData() { Bundle::Init(this); } +Status CustomTransformData::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + if (visitor->Conditional(nonserialized_xyb_encoded)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&opsin_inverse_matrix)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &custom_weights_mask)); + if (visitor->Conditional((custom_weights_mask & 0x1) != 0)) { + // 4 5x5 kernels, but all of them can be obtained by symmetry from one, + // which is symmetric along its main diagonal. The top-left kernel is + // defined by + // + // 0 1 2 3 4 + // 1 5 6 7 8 + // 2 6 9 10 11 + // 3 7 10 12 13 + // 4 8 11 13 14 + float constexpr kWeights2[15] = { + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, + 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, 0.56661550f, + 0.03777607f, -0.01986694f, -0.03144731f, -0.01185068f, -0.00213539f}; + for (size_t i = 0; i < 15; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights2[i], &upsampling2_weights[i])); + } + } + if (visitor->Conditional((custom_weights_mask & 0x2) != 0)) { + // 16 5x5 kernels, but all of them can be obtained by symmetry from + // three, two of which are symmetric along their main diagonals. The top + // left 4 kernels are defined by + // + // 0 1 2 3 4 5 6 7 8 9 + // 1 10 11 12 13 14 15 16 17 18 + // 2 11 19 20 21 22 23 24 25 26 + // 3 12 20 27 28 29 30 31 32 33 + // 4 13 21 28 34 35 36 37 38 39 + // + // 5 14 22 29 35 40 41 42 43 44 + // 6 15 23 30 36 41 45 46 47 48 + // 7 16 24 31 37 42 46 49 50 51 + // 8 17 25 32 38 43 47 50 52 53 + // 9 18 26 33 39 44 48 51 53 54 + constexpr float kWeights4[55] = { + -0.02419067f, -0.03491987f, -0.03693351f, -0.03094285f, -0.00529785f, + -0.01663432f, -0.03556863f, -0.03888905f, -0.03516850f, -0.00989469f, + 0.23651958f, 0.33392945f, -0.01073543f, -0.01313181f, -0.03556694f, + 0.13048175f, 0.40103025f, 0.03951150f, -0.02077584f, 0.46914198f, + -0.00209270f, -0.01484589f, -0.04064806f, 0.18942530f, 0.56279892f, + 0.06674400f, -0.02335494f, -0.03551682f, -0.00754830f, -0.02267919f, + -0.02363578f, 0.00315804f, -0.03399098f, -0.01359519f, -0.00091653f, + -0.00335467f, -0.01163294f, -0.01610294f, -0.00974088f, -0.00191622f, + -0.01095446f, -0.03198464f, -0.04455121f, -0.02799790f, -0.00645912f, + 0.06390599f, 0.22963888f, 0.00630981f, -0.01897349f, 0.67537268f, + 0.08483369f, -0.02534994f, -0.02205197f, -0.01667999f, -0.00384443f}; + for (size_t i = 0; i < 55; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights4[i], &upsampling4_weights[i])); + } + } + if (visitor->Conditional((custom_weights_mask & 0x4) != 0)) { + // 64 5x5 kernels, all of them can be obtained by symmetry from + // 10, 4 of which are symmetric along their main diagonals. The top + // left 16 kernels are defined by + // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 + // 1 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 21 22 23 24 25 26 + // 2 15 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 + // 3 16 28 39 3a 3b 3c 3d 3e 3f 40 41 42 43 44 45 46 47 48 49 + // 4 17 29 3a 4a 4b 4c 4d 4e 4f 50 51 52 53 54 55 56 57 58 59 + + // 5 18 2a 3b 4b 5a 5b 5c 5d 5e 5f 60 61 62 63 64 65 66 67 68 + // 6 19 2b 3c 4c 5b 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 + // 7 1a 2c 3d 4d 5c 6a 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 + // 8 1b 2d 3e 4e 5d 6b 78 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f + // 9 1c 2e 3f 4f 5e 6c 79 85 90 91 92 93 94 95 96 97 98 99 9a + + // a 1d 2f 40 50 5f 6d 7a 86 91 9b 9c 9d 9e 9f a0 a1 a2 a3 a4 + // b 1e 30 41 51 60 6e 7b 87 92 9c a5 a6 a7 a8 a9 aa ab ac ad + // c 1f 31 42 52 61 6f 7c 88 93 9d a6 ae af b0 b1 b2 b3 b4 b5 + // d 20 32 43 53 62 70 7d 89 94 9e a7 af b6 b7 b8 b9 ba bb bc + // e 21 33 44 54 63 71 7e 8a 95 9f a8 b0 b7 bd be bf c0 c1 c2 + + // f 22 34 45 55 64 72 7f 8b 96 a0 a9 b1 b8 be c3 c4 c5 c6 c7 + // 10 23 35 46 56 65 73 80 8c 97 a1 aa b2 b9 bf c4 c8 c9 ca cb + // 11 24 36 47 57 66 74 81 8d 98 a2 ab b3 ba c0 c5 c9 cc cd ce + // 12 25 37 48 58 67 75 82 8e 99 a3 ac b4 bb c1 c6 ca cd cf d0 + // 13 26 38 49 59 68 76 83 8f 9a a4 ad b5 bc c2 c7 cb ce d0 d1 + constexpr float kWeights8[210] = { + -0.02928613f, -0.03706353f, -0.03783812f, -0.03324558f, -0.00447632f, + -0.02519406f, -0.03752601f, -0.03901508f, -0.03663285f, -0.00646649f, + -0.02066407f, -0.03838633f, -0.04002101f, -0.03900035f, -0.00901973f, + -0.01626393f, -0.03954148f, -0.04046620f, -0.03979621f, -0.01224485f, + 0.29895328f, 0.35757708f, -0.02447552f, -0.01081748f, -0.04314594f, + 0.23903219f, 0.41119301f, -0.00573046f, -0.01450239f, -0.04246845f, + 0.17567618f, 0.45220643f, 0.02287757f, -0.01936783f, -0.03583255f, + 0.11572472f, 0.47416733f, 0.06284440f, -0.02685066f, 0.42720050f, + -0.02248939f, -0.01155273f, -0.04562755f, 0.28689496f, 0.49093869f, + -0.00007891f, -0.01545926f, -0.04562659f, 0.21238920f, 0.53980934f, + 0.03369474f, -0.02070211f, -0.03866988f, 0.14229550f, 0.56593398f, + 0.08045181f, -0.02888298f, -0.03680918f, -0.00542229f, -0.02920477f, + -0.02788574f, -0.02118180f, -0.03942402f, -0.00775547f, -0.02433614f, + -0.03193943f, -0.02030828f, -0.04044014f, -0.01074016f, -0.01930822f, + -0.03620399f, -0.01974125f, -0.03919545f, -0.01456093f, -0.00045072f, + -0.00360110f, -0.01020207f, -0.01231907f, -0.00638988f, -0.00071592f, + -0.00279122f, -0.00957115f, -0.01288327f, -0.00730937f, -0.00107783f, + -0.00210156f, -0.00890705f, -0.01317668f, -0.00813895f, -0.00153491f, + -0.02128481f, -0.04173044f, -0.04831487f, -0.03293190f, -0.00525260f, + -0.01720322f, -0.04052736f, -0.05045706f, -0.03607317f, -0.00738030f, + -0.01341764f, -0.03965629f, -0.05151616f, -0.03814886f, -0.01005819f, + 0.18968273f, 0.33063684f, -0.01300105f, -0.01372950f, -0.04017465f, + 0.13727832f, 0.36402234f, 0.01027890f, -0.01832107f, -0.03365072f, + 0.08734506f, 0.38194295f, 0.04338228f, -0.02525993f, 0.56408126f, + 0.00458352f, -0.01648227f, -0.04887868f, 0.24585519f, 0.62026135f, + 0.04314807f, -0.02213737f, -0.04158014f, 0.16637289f, 0.65027023f, + 0.09621636f, -0.03101388f, -0.04082742f, -0.00904519f, -0.02790922f, + -0.02117818f, 0.00798662f, -0.03995711f, -0.01243427f, -0.02231705f, + -0.02946266f, 0.00992055f, -0.03600283f, -0.01684920f, -0.00111684f, + -0.00411204f, -0.01297130f, -0.01723725f, -0.01022545f, -0.00165306f, + -0.00313110f, -0.01218016f, -0.01763266f, -0.01125620f, -0.00231663f, + -0.01374149f, -0.03797620f, -0.05142937f, -0.03117307f, -0.00581914f, + -0.01064003f, -0.03608089f, -0.05272168f, -0.03375670f, -0.00795586f, + 0.09628104f, 0.27129991f, -0.00353779f, -0.01734151f, -0.03153981f, + 0.05686230f, 0.28500998f, 0.02230594f, -0.02374955f, 0.68214326f, + 0.05018048f, -0.02320852f, -0.04383616f, 0.18459474f, 0.71517975f, + 0.10805613f, -0.03263677f, -0.03637639f, -0.01394373f, -0.02511203f, + -0.01728636f, 0.05407331f, -0.02867568f, -0.01893131f, -0.00240854f, + -0.00446511f, -0.01636187f, -0.02377053f, -0.01522848f, -0.00333334f, + -0.00819975f, -0.02964169f, -0.04499287f, -0.02745350f, -0.00612408f, + 0.02727416f, 0.19446600f, 0.00159832f, -0.02232473f, 0.74982506f, + 0.11452620f, -0.03348048f, -0.01605681f, -0.02070339f, -0.00458223f}; + for (size_t i = 0; i < 210; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights8[i], &upsampling8_weights[i])); + } + } + return true; +} + +ExtraChannelInfo::ExtraChannelInfo() { Bundle::Init(this); } +Status ExtraChannelInfo::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + // General + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(ExtraChannel::kAlpha, &type)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&bit_depth)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(3), Val(4), BitsOffset(3, 1), 0, &dim_shift)); + if ((1U << dim_shift) > 8) { + return JXL_FAILURE("dim_shift %u too large", dim_shift); + } + + JXL_QUIET_RETURN_IF_ERROR(VisitNameString(visitor, &name)); + + // Conditional + if (visitor->Conditional(type == ExtraChannel::kAlpha)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &alpha_associated)); + } + if (visitor->Conditional(type == ExtraChannel::kSpotColor)) { + for (float& c : spot_color) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0, &c)); + } + } + if (visitor->Conditional(type == ExtraChannel::kCFA)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(1), Bits(2), BitsOffset(4, 3), + BitsOffset(8, 19), 1, &cfa_channel)); + } + + if (type == ExtraChannel::kUnknown || + (int(ExtraChannel::kReserved0) <= int(type) && + int(type) <= int(ExtraChannel::kReserved7))) { + return JXL_FAILURE("Unknown extra channel (bits %u, shift %u, name '%s')\n", + bit_depth.bits_per_sample, dim_shift, name.c_str()); + } + return true; +} + +std::string ExtraChannelInfo::DebugString() const { + std::ostringstream os; + os << (type == ExtraChannel::kAlpha ? "Alpha" + : type == ExtraChannel::kDepth ? "Depth" + : type == ExtraChannel::kSpotColor ? "Spot" + : type == ExtraChannel::kSelectionMask ? "Mask" + : type == ExtraChannel::kBlack ? "Black" + : type == ExtraChannel::kCFA ? "CFA" + : type == ExtraChannel::kThermal ? "Thermal" + : "Unknown"); + if (type == ExtraChannel::kAlpha && alpha_associated) os << "(premul)"; + os << " " << bit_depth.DebugString(); + os << " shift: " << dim_shift; + return os.str(); +} + +ImageMetadata::ImageMetadata() { Bundle::Init(this); } +Status ImageMetadata::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + // Bundle::AllDefault does not allow usage when reading (it may abort the + // program when a codestream has invalid values), but when reading we + // overwrite the extra_fields value, so do not need to call AllDefault. + bool tone_mapping_default = + visitor->IsReading() ? false : Bundle::AllDefault(tone_mapping); + + bool extra_fields = (orientation != 1 || have_preview || have_animation || + have_intrinsic_size || !tone_mapping_default); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &extra_fields)); + if (visitor->Conditional(extra_fields)) { + orientation--; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &orientation)); + orientation++; + // (No need for bounds checking because we read exactly 3 bits) + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_intrinsic_size)); + if (visitor->Conditional(have_intrinsic_size)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&intrinsic_size)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_preview)); + if (visitor->Conditional(have_preview)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&preview_size)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_animation)); + if (visitor->Conditional(have_animation)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&animation)); + } + } else { + orientation = 1; // identity + have_intrinsic_size = false; + have_preview = false; + have_animation = false; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&bit_depth)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(true, &modular_16_bit_buffer_sufficient)); + + num_extra_channels = extra_channel_info.size(); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(12, 1), 0, + &num_extra_channels)); + + if (visitor->Conditional(num_extra_channels != 0)) { + if (visitor->IsReading()) { + extra_channel_info.resize(num_extra_channels); + } + for (ExtraChannelInfo& eci : extra_channel_info) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&eci)); + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &xyb_encoded)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&color_encoding)); + if (visitor->Conditional(extra_fields)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&tone_mapping)); + } + + // Treat as if only the fields up to extra channels exist. + if (visitor->IsReading() && nonserialized_only_parse_basic_info) { + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +OpsinInverseMatrix::OpsinInverseMatrix() { Bundle::Init(this); } +Status OpsinInverseMatrix::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + for (int i = 0; i < 9; ++i) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16( + DefaultInverseOpsinAbsorbanceMatrix()[i], &inverse_matrix[i])); + } + for (int i = 0; i < 3; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kNegOpsinAbsorbanceBiasRGB[i], &opsin_biases[i])); + } + for (int i = 0; i < 4; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kDefaultQuantBias[i], &quant_biases[i])); + } + return true; +} + +ToneMapping::ToneMapping() { Bundle::Init(this); } +Status ToneMapping::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kDefaultIntensityTarget, &intensity_target)); + if (intensity_target <= 0.f) { + return JXL_FAILURE("invalid intensity target"); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.0f, &min_nits)); + if (min_nits < 0.f || min_nits > intensity_target) { + return JXL_FAILURE("invalid min %f vs max %f", min_nits, intensity_target); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &relative_to_max_display)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.0f, &linear_below)); + if (linear_below < 0 || (relative_to_max_display && linear_below > 1.0f)) { + return JXL_FAILURE("invalid linear_below %f (%s)", linear_below, + relative_to_max_display ? "relative" : "absolute"); + } + + return true; +} + +Status ReadImageMetadata(BitReader* JXL_RESTRICT reader, + ImageMetadata* JXL_RESTRICT metadata) { + return Bundle::Read(reader, metadata); +} + +Status WriteImageMetadata(const ImageMetadata& metadata, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out) { + return Bundle::Write(metadata, writer, layer, aux_out); +} + +void ImageMetadata::SetAlphaBits(uint32_t bits, bool alpha_is_premultiplied) { + std::vector& eciv = extra_channel_info; + ExtraChannelInfo* alpha = Find(ExtraChannel::kAlpha); + if (bits == 0) { + if (alpha != nullptr) { + // Remove the alpha channel from the extra channel info. It's + // theoretically possible that there are multiple, remove all in that + // case. This ensure a next HasAlpha() will return false. + const auto is_alpha = [](const ExtraChannelInfo& eci) { + return eci.type == ExtraChannel::kAlpha; + }; + eciv.erase(std::remove_if(eciv.begin(), eciv.end(), is_alpha), + eciv.end()); + } + } else { + if (alpha == nullptr) { + ExtraChannelInfo info; + info.type = ExtraChannel::kAlpha; + info.bit_depth.bits_per_sample = bits; + info.dim_shift = 0; + info.alpha_associated = alpha_is_premultiplied; + // Prepend rather than append: in case there already are other extra + // channels, prefer alpha channel to be listed first. + eciv.insert(eciv.begin(), info); + } else { + // Ignores potential extra alpha channels, only sets to first one. + alpha->bit_depth.bits_per_sample = bits; + alpha->bit_depth.floating_point_sample = false; + alpha->bit_depth.exponent_bits_per_sample = 0; + alpha->alpha_associated = alpha_is_premultiplied; + } + } + num_extra_channels = extra_channel_info.size(); + if (bits > 12) modular_16_bit_buffer_sufficient = false; +} + +std::string ImageMetadata::DebugString() const { + std::ostringstream os; + os << bit_depth.DebugString(); + if (modular_16_bit_buffer_sufficient) { + os << " (modular 16)"; + } + os << (xyb_encoded ? " xyb encoded" : " orig profile"); + os << " " << Description(color_encoding); + if (num_extra_channels > 0) { + os << " extra channels:"; + for (size_t i = 0; i < num_extra_channels; ++i) { + os << " (" << extra_channel_info[i].DebugString() << ")"; + if (i + 1 < num_extra_channels) os << ","; + } + } + if (have_preview) { + os << " preview: " << preview_size.xsize() << "x" << preview_size.ysize(); + } + if (orientation != 1) { + os << " orientation: " << orientation; + } + return os.str(); +} + +std::string CodecMetadata::DebugString() const { + std::ostringstream os; + os << size.xsize() << "x" << size.ysize(); + os << " " << m.DebugString(); + return os.str(); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/image_metadata.h b/media/libjxl/src/lib/jxl/image_metadata.h new file mode 100644 index 000000000..9008e42f6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_metadata.h @@ -0,0 +1,423 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Main codestream header bundles, the metadata that applies to all frames. +// Enums must align with the C API definitions in codestream_header.h. + +#ifndef LIB_JXL_IMAGE_METADATA_H_ +#define LIB_JXL_IMAGE_METADATA_H_ + +#include +#include + +#include +#include + +#include "jxl/codestream_header.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +// EXIF orientation of the image. This field overrides any field present in +// actual EXIF metadata. The value tells which transformation the decoder must +// apply after decoding to display the image with the correct orientation. +enum class Orientation : uint32_t { + // Values 1..8 match the EXIF definitions. + kIdentity = JXL_ORIENT_IDENTITY, + kFlipHorizontal = JXL_ORIENT_FLIP_HORIZONTAL, + kRotate180 = JXL_ORIENT_ROTATE_180, + kFlipVertical = JXL_ORIENT_FLIP_VERTICAL, + kTranspose = JXL_ORIENT_TRANSPOSE, + kRotate90 = JXL_ORIENT_ROTATE_90_CW, + kAntiTranspose = JXL_ORIENT_ANTI_TRANSPOSE, + kRotate270 = JXL_ORIENT_ROTATE_90_CCW, +}; +// Don't need an EnumBits because Orientation is not read via Enum(). + +enum class ExtraChannel : uint32_t { + // First two enumerators (most common) are cheaper to encode + kAlpha = JXL_CHANNEL_ALPHA, + kDepth = JXL_CHANNEL_DEPTH, + + kSpotColor = JXL_CHANNEL_SPOT_COLOR, + kSelectionMask = JXL_CHANNEL_SELECTION_MASK, + kBlack = JXL_CHANNEL_BLACK, // for CMYK + kCFA = JXL_CHANNEL_CFA, // Bayer channel + kThermal = JXL_CHANNEL_THERMAL, + kReserved0 = JXL_CHANNEL_RESERVED0, + kReserved1 = JXL_CHANNEL_RESERVED1, + kReserved2 = JXL_CHANNEL_RESERVED2, + kReserved3 = JXL_CHANNEL_RESERVED3, + kReserved4 = JXL_CHANNEL_RESERVED4, + kReserved5 = JXL_CHANNEL_RESERVED5, + kReserved6 = JXL_CHANNEL_RESERVED6, + kReserved7 = JXL_CHANNEL_RESERVED7, + // disambiguated via name string, raise warning if unsupported + kUnknown = JXL_CHANNEL_UNKNOWN, + // like kUnknown but can silently be ignored + kOptional = JXL_CHANNEL_OPTIONAL +}; +static inline const char* EnumName(ExtraChannel /*unused*/) { + return "ExtraChannel"; +} +static inline constexpr uint64_t EnumBits(ExtraChannel /*unused*/) { + using EC = ExtraChannel; + return MakeBit(EC::kAlpha) | MakeBit(EC::kDepth) | MakeBit(EC::kSpotColor) | + MakeBit(EC::kSelectionMask) | MakeBit(EC::kBlack) | MakeBit(EC::kCFA) | + MakeBit(EC::kThermal) | MakeBit(EC::kUnknown) | MakeBit(EC::kOptional); +} + +// Used in ImageMetadata and ExtraChannelInfo. +struct BitDepth : public Fields { + BitDepth(); + JXL_FIELDS_NAME(BitDepth) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + std::string DebugString() const; + + // Whether the original (uncompressed) samples are floating point or + // unsigned integer. + bool floating_point_sample; + + // Bit depth of the original (uncompressed) image samples. Must be in the + // range [1, 32]. + uint32_t bits_per_sample; + + // Floating point exponent bits of the original (uncompressed) image samples, + // only used if floating_point_sample is true. + // If used, the samples are floating point with: + // - 1 sign bit + // - exponent_bits_per_sample exponent bits + // - (bits_per_sample - exponent_bits_per_sample - 1) mantissa bits + // If used, exponent_bits_per_sample must be in the range + // [2, 8] and amount of mantissa bits must be in the range [2, 23]. + // NOTE: exponent_bits_per_sample is 8 for single precision binary32 + // point, 5 for half precision binary16, 7 for fp24. + uint32_t exponent_bits_per_sample; +}; + +// Describes one extra channel. +struct ExtraChannelInfo : public Fields { + ExtraChannelInfo(); + JXL_FIELDS_NAME(ExtraChannelInfo) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + std::string DebugString() const; + + mutable bool all_default; + + ExtraChannel type; + BitDepth bit_depth; + uint32_t dim_shift; // downsampled by 2^dim_shift on each axis + + std::string name; // UTF-8 + + // Conditional: + bool alpha_associated; // i.e. premultiplied + float spot_color[4]; // spot color in linear RGBA + uint32_t cfa_channel; +}; + +struct OpsinInverseMatrix : public Fields { + OpsinInverseMatrix(); + JXL_FIELDS_NAME(OpsinInverseMatrix) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + mutable bool all_default; + + float inverse_matrix[9]; + float opsin_biases[3]; + float quant_biases[4]; +}; + +// Information useful for mapping HDR images to lower dynamic range displays. +struct ToneMapping : public Fields { + ToneMapping(); + JXL_FIELDS_NAME(ToneMapping) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + mutable bool all_default; + + // Upper bound on the intensity level present in the image. For unsigned + // integer pixel encodings, this is the brightness of the largest + // representable value. The image does not necessarily contain a pixel + // actually this bright. An encoder is allowed to set 255 for SDR images + // without computing a histogram. + float intensity_target; // [nits] + + // Lower bound on the intensity level present in the image. This may be + // loose, i.e. lower than the actual darkest pixel. When tone mapping, a + // decoder will map [min_nits, intensity_target] to the display range. + float min_nits; + + bool relative_to_max_display; // see below + // The tone mapping will leave unchanged (linear mapping) any pixels whose + // brightness is strictly below this. The interpretation depends on + // relative_to_max_display. If true, this is a ratio [0, 1] of the maximum + // display brightness [nits], otherwise an absolute brightness [nits]. + float linear_below; +}; + +// Contains weights to customize some trasnforms - in particular, XYB and +// upsampling. +struct CustomTransformData : public Fields { + CustomTransformData(); + JXL_FIELDS_NAME(CustomTransformData) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Must be set before calling VisitFields. Must equal xyb_encoded of + // ImageMetadata, should be set by ImageMetadata during VisitFields. + bool nonserialized_xyb_encoded = false; + + mutable bool all_default; + + OpsinInverseMatrix opsin_inverse_matrix; + + uint32_t custom_weights_mask; + float upsampling2_weights[15]; + float upsampling4_weights[55]; + float upsampling8_weights[210]; +}; + +// Properties of the original image bundle. This enables Encode(Decode()) to +// re-create an equivalent image without user input. +struct ImageMetadata : public Fields { + ImageMetadata(); + JXL_FIELDS_NAME(ImageMetadata) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Returns bit depth of the JPEG XL compressed alpha channel, or 0 if no alpha + // channel present. In the theoretical case that there are multiple alpha + // channels, returns the bit depht of the first. + uint32_t GetAlphaBits() const { + const ExtraChannelInfo* alpha = Find(ExtraChannel::kAlpha); + if (alpha == nullptr) return 0; + JXL_ASSERT(alpha->bit_depth.bits_per_sample != 0); + return alpha->bit_depth.bits_per_sample; + } + + // Sets bit depth of alpha channel, adding extra channel if needed, or + // removing all alpha channels if bits is 0. + // Assumes integer alpha channel and not designed to support multiple + // alpha channels (it's possible to use those features by manipulating + // extra_channel_info directly). + // + // Callers must insert the actual channel image at the same index before any + // further modifications to extra_channel_info. + void SetAlphaBits(uint32_t bits, bool alpha_is_premultiplied = false); + + bool HasAlpha() const { return GetAlphaBits() != 0; } + + // Sets the original bit depth fields to indicate unsigned integer of the + // given bit depth. + // TODO(lode): move function to BitDepth + void SetUintSamples(uint32_t bits) { + bit_depth.bits_per_sample = bits; + bit_depth.exponent_bits_per_sample = 0; + bit_depth.floating_point_sample = false; + // RCT / Squeeze may add one bit each, and this is about int16_t, + // so uint13 should still be OK but limiting it to 12 seems safer. + // TODO(jon): figure out a better way to set this header field. + // (in particular, if modular mode is not used it doesn't matter, + // and if transforms are restricted, up to 15-bit could be done) + if (bits > 12) modular_16_bit_buffer_sufficient = false; + } + // Sets the original bit depth fields to indicate single precision floating + // point. + // TODO(lode): move function to BitDepth + void SetFloat32Samples() { + bit_depth.bits_per_sample = 32; + bit_depth.exponent_bits_per_sample = 8; + bit_depth.floating_point_sample = true; + modular_16_bit_buffer_sufficient = false; + } + + void SetFloat16Samples() { + bit_depth.bits_per_sample = 16; + bit_depth.exponent_bits_per_sample = 5; + bit_depth.floating_point_sample = true; + modular_16_bit_buffer_sufficient = false; + } + + void SetIntensityTarget(float intensity_target) { + tone_mapping.intensity_target = intensity_target; + } + float IntensityTarget() const { + JXL_ASSERT(tone_mapping.intensity_target != 0); + return tone_mapping.intensity_target; + } + + // Returns first ExtraChannelInfo of the given type, or nullptr if none. + const ExtraChannelInfo* Find(ExtraChannel type) const { + for (const ExtraChannelInfo& eci : extra_channel_info) { + if (eci.type == type) return &eci; + } + return nullptr; + } + + // Returns first ExtraChannelInfo of the given type, or nullptr if none. + ExtraChannelInfo* Find(ExtraChannel type) { + for (ExtraChannelInfo& eci : extra_channel_info) { + if (eci.type == type) return &eci; + } + return nullptr; + } + + Orientation GetOrientation() const { + return static_cast(orientation); + } + + bool ExtraFieldsDefault() const; + + std::string DebugString() const; + + mutable bool all_default; + + BitDepth bit_depth; + bool modular_16_bit_buffer_sufficient; // otherwise 32 is. + + // Whether the colors values of the pixels of frames are encoded in the + // codestream using the absolute XYB color space, or the using values that + // follow the color space defined by the ColorEncoding or ICC profile. This + // determines when or whether a CMS (Color Management System) is needed to get + // the pixels in a desired color space. In one case, the pixels have one known + // color space and a CMS is needed to convert them to the original image's + // color space, in the other case the pixels have the color space of the + // original image and a CMS is required if a different display space, or a + // single known consistent color space for multiple decoded images, is + // desired. In all cases, the color space of all frames from a single image is + // the same, both VarDCT and modular frames. + // + // If true: then frames can be decoded to XYB (which can also be converted to + // linear and non-linear sRGB with the built in conversion without CMS). The + // attached ColorEncoding or ICC profile has no effect on the meaning of the + // pixel's color values, but instead indicates what the color profile of the + // original image was, and what color profile one should convert to when + // decoding to integers to prevent clipping and precision loss. To do that + // conversion requires a CMS. + // + // If false: then the color values of decoded frames are in the space defined + // by the attached ColorEncoding or ICC profile. To instead get the pixels in + // a chosen known color space, such as sRGB, requires a CMS, since the + // attached ColorEncoding or ICC profile could be any arbitrary color space. + // This mode is typically used for lossless images encoded as integers. + // Frames can also use YCbCr encoding, some frames may and some may not, but + // this is not a different color space but a certain encoding of the RGB + // values. + // + // Note: if !xyb_encoded, but the attached color profile indicates XYB (which + // can happen either if it's a ColorEncoding with color_space_ == + // ColorSpace::kXYB, or if it's an ICC Profile that has been crafted to + // represent XYB), then the frames still may not use ColorEncoding kXYB, they + // must still use kNone (or kYCbCr, which would mean applying the YCbCr + // transform to the 3-channel XYB data), since with !xyb_encoded, the 3 + // channels are stored as-is, no matter what meaning the color profile assigns + // to them. To use ColorEncoding::kXYB, xyb_encoded must be true. + // + // This value is defined in image metadata because this is the global + // codestream header. This value does not affect the image itself, so is not + // image metadata per se, it only affects the encoding, and what color space + // the decoder can receive the pixels in without needing a CMS. + bool xyb_encoded; + + ColorEncoding color_encoding; + + // These values are initialized to defaults such that the 'extra_fields' + // condition in VisitFields uses correctly initialized values. + uint32_t orientation = 1; + bool have_preview = false; + bool have_animation = false; + bool have_intrinsic_size = false; + + // If present, the stored image has the dimensions of the first SizeHeader, + // but decoders are advised to resample or display per `intrinsic_size`. + SizeHeader intrinsic_size; // only if have_intrinsic_size + + ToneMapping tone_mapping; + + // When reading: deserialized. When writing: automatically set from vector. + uint32_t num_extra_channels; + std::vector extra_channel_info; + + // Only present if m.have_preview. + PreviewHeader preview_size; + // Only present if m.have_animation. + AnimationHeader animation; + + uint64_t extensions; + + // Option to stop parsing after basic info, and treat as if the later + // fields do not participate. Use to parse only basic image information + // excluding the final larger or variable sized data. + bool nonserialized_only_parse_basic_info = false; +}; + +Status ReadImageMetadata(BitReader* JXL_RESTRICT reader, + ImageMetadata* JXL_RESTRICT metadata); + +Status WriteImageMetadata(const ImageMetadata& metadata, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out); + +// All metadata applicable to the entire codestream (dimensions, extra channels, +// ...) +struct CodecMetadata { + // TODO(lode): use the preview and animation fields too, in place of the + // nonserialized_ ones in ImageMetadata. + ImageMetadata m; + // The size of the codestream: this is the nominal size applicable to all + // frames, although some frames can have a different effective size through + // crop, dc_level or representing a the preview. + SizeHeader size; + // Often default. + CustomTransformData transform_data; + + size_t xsize() const { return size.xsize(); } + size_t ysize() const { return size.ysize(); } + size_t oriented_xsize(bool keep_orientation) const { + if (static_cast(m.GetOrientation()) > 4 && !keep_orientation) { + return ysize(); + } else { + return xsize(); + } + } + size_t oriented_preview_xsize(bool keep_orientation) const { + if (static_cast(m.GetOrientation()) > 4 && !keep_orientation) { + return m.preview_size.ysize(); + } else { + return m.preview_size.xsize(); + } + } + size_t oriented_ysize(bool keep_orientation) const { + if (static_cast(m.GetOrientation()) > 4 && !keep_orientation) { + return xsize(); + } else { + return ysize(); + } + } + size_t oriented_preview_ysize(bool keep_orientation) const { + if (static_cast(m.GetOrientation()) > 4 && !keep_orientation) { + return m.preview_size.xsize(); + } else { + return m.preview_size.ysize(); + } + } + + std::string DebugString() const; +}; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_METADATA_H_ diff --git a/media/libjxl/src/lib/jxl/image_ops.h b/media/libjxl/src/lib/jxl/image_ops.h new file mode 100644 index 000000000..63fc08749 --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_ops.h @@ -0,0 +1,806 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_OPS_H_ +#define LIB_JXL_IMAGE_OPS_H_ + +// Operations on images. + +#include +#include +#include +#include + +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +void CopyImageTo(const Plane& from, Plane* JXL_RESTRICT to) { + PROFILER_ZONE("CopyImage1"); + JXL_ASSERT(SameSize(from, *to)); + if (from.ysize() == 0 || from.xsize() == 0) return; + for (size_t y = 0; y < from.ysize(); ++y) { + const T* JXL_RESTRICT row_from = from.ConstRow(y); + T* JXL_RESTRICT row_to = to->Row(y); + memcpy(row_to, row_from, from.xsize() * sizeof(T)); + } +} + +// DEPRECATED - prefer to preallocate result. +template +Plane CopyImage(const Plane& from) { + Plane to(from.xsize(), from.ysize()); + CopyImageTo(from, &to); + return to; +} + +// Copies `from:rect_from` to `to:rect_to`. +template +void CopyImageTo(const Rect& rect_from, const Plane& from, + const Rect& rect_to, Plane* JXL_RESTRICT to) { + PROFILER_ZONE("CopyImageR"); + JXL_DASSERT(SameSize(rect_from, rect_to)); + JXL_DASSERT(rect_from.IsInside(from)); + JXL_DASSERT(rect_to.IsInside(*to)); + if (rect_from.xsize() == 0) return; + for (size_t y = 0; y < rect_from.ysize(); ++y) { + const T* JXL_RESTRICT row_from = rect_from.ConstRow(from, y); + T* JXL_RESTRICT row_to = rect_to.Row(to, y); + memcpy(row_to, row_from, rect_from.xsize() * sizeof(T)); + } +} + +// DEPRECATED - Returns a copy of the "image" pixels that lie in "rect". +template +Plane CopyImage(const Rect& rect, const Plane& image) { + Plane copy(rect.xsize(), rect.ysize()); + CopyImageTo(rect, image, ©); + return copy; +} + +// Copies `from:rect_from` to `to:rect_to`. +template +void CopyImageTo(const Rect& rect_from, const Image3& from, + const Rect& rect_to, Image3* JXL_RESTRICT to) { + PROFILER_ZONE("CopyImageR"); + JXL_ASSERT(SameSize(rect_from, rect_to)); + for (size_t c = 0; c < 3; c++) { + CopyImageTo(rect_from, from.Plane(c), rect_to, &to->Plane(c)); + } +} + +template +void ConvertPlaneAndClamp(const Rect& rect_from, const Plane& from, + const Rect& rect_to, Plane* JXL_RESTRICT to) { + PROFILER_ZONE("ConvertPlane"); + JXL_ASSERT(SameSize(rect_from, rect_to)); + using M = decltype(T() + U()); + for (size_t y = 0; y < rect_to.ysize(); ++y) { + const T* JXL_RESTRICT row_from = rect_from.ConstRow(from, y); + U* JXL_RESTRICT row_to = rect_to.Row(to, y); + for (size_t x = 0; x < rect_to.xsize(); ++x) { + row_to[x] = + std::min(std::max(row_from[x], std::numeric_limits::min()), + std::numeric_limits::max()); + } + } +} + +// Copies `from` to `to`. +template +void CopyImageTo(const T& from, T* JXL_RESTRICT to) { + return CopyImageTo(Rect(from), from, Rect(*to), to); +} + +// Copies `from:rect_from` to `to`. +template +void CopyImageTo(const Rect& rect_from, const T& from, T* JXL_RESTRICT to) { + return CopyImageTo(rect_from, from, Rect(*to), to); +} + +// Copies `from` to `to:rect_to`. +template +void CopyImageTo(const T& from, const Rect& rect_to, T* JXL_RESTRICT to) { + return CopyImageTo(Rect(from), from, rect_to, to); +} + +// Copies `from:rect_from` to `to:rect_to`; also copies `padding` pixels of +// border around `from:rect_from`, in all directions, whenever they are inside +// the first image. +template +void CopyImageToWithPadding(const Rect& from_rect, const T& from, + size_t padding, const Rect& to_rect, T* to) { + size_t xextra0 = std::min(padding, from_rect.x0()); + size_t xextra1 = + std::min(padding, from.xsize() - from_rect.x0() - from_rect.xsize()); + size_t yextra0 = std::min(padding, from_rect.y0()); + size_t yextra1 = + std::min(padding, from.ysize() - from_rect.y0() - from_rect.ysize()); + JXL_DASSERT(to_rect.x0() >= xextra0); + JXL_DASSERT(to_rect.y0() >= yextra0); + + return CopyImageTo(Rect(from_rect.x0() - xextra0, from_rect.y0() - yextra0, + from_rect.xsize() + xextra0 + xextra1, + from_rect.ysize() + yextra0 + yextra1), + from, + Rect(to_rect.x0() - xextra0, to_rect.y0() - yextra0, + to_rect.xsize() + xextra0 + xextra1, + to_rect.ysize() + yextra0 + yextra1), + to); +} + +// DEPRECATED - prefer to preallocate result. +template +Image3 CopyImage(const Image3& from) { + Image3 copy(from.xsize(), from.ysize()); + CopyImageTo(from, ©); + return copy; +} + +// DEPRECATED - prefer to preallocate result. +template +Image3 CopyImage(const Rect& rect, const Image3& from) { + Image3 to(rect.xsize(), rect.ysize()); + CopyImageTo(rect, from.Plane(0), to.Plane(0)); + CopyImageTo(rect, from.Plane(1), to.Plane(1)); + CopyImageTo(rect, from.Plane(2), to.Plane(2)); + return to; +} + +// Sets "thickness" pixels on each border to "value". This is faster than +// initializing the entire image and overwriting valid/interior pixels. +template +void SetBorder(const size_t thickness, const T value, Image3* image) { + const size_t xsize = image->xsize(); + const size_t ysize = image->ysize(); + // Top: fill entire row + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < std::min(thickness, ysize); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + std::fill(row, row + xsize, value); + } + + // Bottom: fill entire row + for (size_t y = ysize - thickness; y < ysize; ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + std::fill(row, row + xsize, value); + } + + // Left/right: fill the 'columns' on either side, but only if the image is + // big enough that they don't already belong to the top/bottom rows. + if (ysize >= 2 * thickness) { + for (size_t y = thickness; y < ysize - thickness; ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + std::fill(row, row + thickness, value); + std::fill(row + xsize - thickness, row + xsize, value); + } + } + } +} + +template +void Subtract(const ImageIn& image1, const ImageIn& image2, ImageOut* out) { + using T = typename ImageIn::T; + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] - row2[x]; + } + } +} + +// In-place. +template +void SubtractFrom(const Plane& what, Plane* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstRow(y); + Tout* JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] -= row_what[x]; + } + } +} + +// In-place. +template +void AddTo(const Plane& what, Plane* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstRow(y); + Tout* JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } +} + +template +void AddTo(Rect rectFrom, const Plane& what, Rect rectTo, + Plane* to) { + JXL_ASSERT(SameSize(rectFrom, rectTo)); + const size_t xsize = rectTo.xsize(); + const size_t ysize = rectTo.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = rectFrom.ConstRow(what, y); + Tout* JXL_RESTRICT row_to = rectTo.Row(to, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } +} + +// Returns linear combination of two grayscale images. +template +Plane LinComb(const T lambda1, const Plane& image1, const T lambda2, + const Plane& image2) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + Plane out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = lambda1 * row1[x] + lambda2 * row2[x]; + } + } + return out; +} + +// Returns a pixel-by-pixel multiplication of image by lambda. +template +Plane ScaleImage(const T lambda, const Plane& image) { + Plane out(image.xsize(), image.ysize()); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* const JXL_RESTRICT row = image.Row(y); + T* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = lambda * row[x]; + } + } + return out; +} + +// Multiplies image by lambda in-place +template +void ScaleImage(const T lambda, Plane* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = lambda * row[x]; + } + } +} + +template +Plane Product(const Plane& a, const Plane& b) { + Plane c(a.xsize(), a.ysize()); + for (size_t y = 0; y < a.ysize(); ++y) { + const T* const JXL_RESTRICT row_a = a.Row(y); + const T* const JXL_RESTRICT row_b = b.Row(y); + T* const JXL_RESTRICT row_c = c.Row(y); + for (size_t x = 0; x < a.xsize(); ++x) { + row_c[x] = row_a[x] * row_b[x]; + } + } + return c; +} + +float DotProduct(const ImageF& a, const ImageF& b); + +template +void FillImage(const T value, Plane* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } +} + +template +void ZeroFillImage(Plane* image) { + if (image->xsize() == 0) return; + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + memset(row, 0, image->xsize() * sizeof(T)); + } +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static inline int64_t Mirror(int64_t x, const int64_t xsize) { + JXL_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return x; +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + JXL_INLINE int64_t operator()(const int64_t coord, const int64_t size) const { + return Mirror(coord, size); + } +}; + +// Returns the same coordinate: required for TFNode with Border(), or useful +// when we know "coord" is already valid (e.g. interior of an image). +struct WrapUnchanged { + JXL_INLINE int64_t operator()(const int64_t coord, int64_t /*size*/) const { + return coord; + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template + WrapRowMirror(const ImageOrView& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const JXL_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const JXL_RESTRICT first_row_; + const float* const JXL_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + JXL_INLINE const float* operator()(const float* const JXL_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +// Sets "thickness" pixels on each border to "value". This is faster than +// initializing the entire image and overwriting valid/interior pixels. +template +void SetBorder(const size_t thickness, const T value, Plane* image) { + const size_t xsize = image->xsize(); + const size_t ysize = image->ysize(); + // Top: fill entire row + for (size_t y = 0; y < std::min(thickness, ysize); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + xsize, value); + } + + // Bottom: fill entire row + for (size_t y = ysize - thickness; y < ysize; ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + xsize, value); + } + + // Left/right: fill the 'columns' on either side, but only if the image is + // big enough that they don't already belong to the top/bottom rows. + if (ysize >= 2 * thickness) { + for (size_t y = thickness; y < ysize - thickness; ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + thickness, value); + std::fill(row + xsize - thickness, row + xsize, value); + } + } +} + +// Computes the minimum and maximum pixel value. +template +void ImageMinMax(const Plane& image, T* const JXL_RESTRICT min, + T* const JXL_RESTRICT max) { + *min = std::numeric_limits::max(); + *max = std::numeric_limits::lowest(); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* const JXL_RESTRICT row = image.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + *min = std::min(*min, row[x]); + *max = std::max(*max, row[x]); + } + } +} + +// Copies pixels, scaling their value relative to the "from" min/max by +// "to_range". Example: U8 [0, 255] := [0.0, 1.0], to_range = 1.0 => +// outputs [0.0, 1.0]. +template +void ImageConvert(const Plane& from, const float to_range, + Plane* const JXL_RESTRICT to) { + JXL_ASSERT(SameSize(from, *to)); + FromType min_from, max_from; + ImageMinMax(from, &min_from, &max_from); + const float scale = to_range / (max_from - min_from); + for (size_t y = 0; y < from.ysize(); ++y) { + const FromType* const JXL_RESTRICT row_from = from.Row(y); + ToType* const JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < from.xsize(); ++x) { + row_to[x] = static_cast((row_from[x] - min_from) * scale); + } + } +} + +template +Plane ConvertToFloat(const Plane& from) { + float factor = 1.0f / std::numeric_limits::max(); + if (std::is_same::value || std::is_same::value) { + factor = 1.0f; + } + Plane to(from.xsize(), from.ysize()); + for (size_t y = 0; y < from.ysize(); ++y) { + const From* const JXL_RESTRICT row_from = from.Row(y); + float* const JXL_RESTRICT row_to = to.Row(y); + for (size_t x = 0; x < from.xsize(); ++x) { + row_to[x] = row_from[x] * factor; + } + } + return to; +} + +template +Plane ImageFromPacked(const std::vector& packed, const size_t xsize, + const size_t ysize) { + Plane out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + T* const JXL_RESTRICT row = out.Row(y); + const T* const JXL_RESTRICT packed_row = &packed[y * xsize]; + memcpy(row, packed_row, xsize * sizeof(T)); + } + return out; +} + +// Computes independent minimum and maximum values for each plane. +template +void Image3MinMax(const Image3& image, const Rect& rect, + std::array* out_min, std::array* out_max) { + for (size_t c = 0; c < 3; ++c) { + T min = std::numeric_limits::max(); + T max = std::numeric_limits::min(); + for (size_t y = 0; y < rect.ysize(); ++y) { + const T* JXL_RESTRICT row = rect.ConstPlaneRow(image, c, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + min = std::min(min, row[x]); + max = std::max(max, row[x]); + } + } + (*out_min)[c] = min; + (*out_max)[c] = max; + } +} + +// Computes independent minimum and maximum values for each plane. +template +void Image3MinMax(const Image3& image, std::array* out_min, + std::array* out_max) { + Image3MinMax(image, Rect(image), out_min, out_max); +} + +template +void Image3Max(const Image3& image, std::array* out_max) { + for (size_t c = 0; c < 3; ++c) { + T max = std::numeric_limits::min(); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row = image.ConstPlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + max = std::max(max, row[x]); + } + } + (*out_max)[c] = max; + } +} + +// Computes the sum of the pixels in `rect`. +template +T ImageSum(const Plane& image, const Rect& rect) { + T result = 0; + for (size_t y = 0; y < rect.ysize(); ++y) { + const T* JXL_RESTRICT row = rect.ConstRow(image, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + result += row[x]; + } + } + return result; +} + +template +T ImageSum(const Plane& image) { + return ImageSum(image, Rect(image)); +} + +template +std::array Image3Sum(const Image3& image, const Rect& rect) { + std::array out_sum = 0; + for (size_t c = 0; c < 3; ++c) { + (out_sum)[c] = ImageSum(image.Plane(c), rect); + } + return out_sum; +} + +template +std::array Image3Sum(const Image3& image) { + return Image3Sum(image, Rect(image)); +} + +template +std::vector PackedFromImage(const Plane& image, const Rect& rect) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + std::vector packed(xsize * ysize); + for (size_t y = 0; y < rect.ysize(); ++y) { + memcpy(&packed[y * xsize], rect.ConstRow(image, y), xsize * sizeof(T)); + } + return packed; +} + +template +std::vector PackedFromImage(const Plane& image) { + return PackedFromImage(image, Rect(image)); +} + +// Computes the median pixel value. +template +T ImageMedian(const Plane& image, const Rect& rect) { + std::vector pixels = PackedFromImage(image, rect); + return Median(&pixels); +} + +template +T ImageMedian(const Plane& image) { + return ImageMedian(image, Rect(image)); +} + +template +std::array Image3Median(const Image3& image, const Rect& rect) { + std::array out_median; + for (size_t c = 0; c < 3; ++c) { + (out_median)[c] = ImageMedian(image.Plane(c), rect); + } + return out_median; +} + +template +std::array Image3Median(const Image3& image) { + return Image3Median(image, Rect(image)); +} + +template +void Image3Convert(const Image3& from, const float to_range, + Image3* const JXL_RESTRICT to) { + JXL_ASSERT(SameSize(from, *to)); + std::array min_from, max_from; + Image3MinMax(from, &min_from, &max_from); + float scales[3]; + for (size_t c = 0; c < 3; ++c) { + scales[c] = to_range / (max_from[c] - min_from[c]); + } + float scale = std::min(scales[0], std::min(scales[1], scales[2])); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < from.ysize(); ++y) { + const FromType* JXL_RESTRICT row_from = from.ConstPlaneRow(c, y); + ToType* JXL_RESTRICT row_to = to->PlaneRow(c, y); + for (size_t x = 0; x < from.xsize(); ++x) { + const float to = (row_from[x] - min_from[c]) * scale; + row_to[x] = static_cast(to); + } + } + } +} + +template +Image3F ConvertToFloat(const Image3& from) { + return Image3F(ConvertToFloat(from.Plane(0)), ConvertToFloat(from.Plane(1)), + ConvertToFloat(from.Plane(2))); +} + +template +void Subtract(const Image3& image1, const Image3& image2, + Image3* out) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* const JXL_RESTRICT row1 = image1.ConstPlaneRow(c, y); + const Tin* const JXL_RESTRICT row2 = image2.ConstPlaneRow(c, y); + Tout* const JXL_RESTRICT row_out = out->PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] - row2[x]; + } + } + } +} + +template +void SubtractFrom(const Image3& what, Image3* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstPlaneRow(c, y); + Tout* JXL_RESTRICT row_to = to->PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] -= row_what[x]; + } + } + } +} + +template +void AddTo(const Image3& what, Image3* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstPlaneRow(c, y); + Tout* JXL_RESTRICT row_to = to->PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } + } +} + +// Adds `what` of the size of `rect` to `to` in the position of `rect`. +template +void AddTo(const Rect& rect, const Image3& what, Image3* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + JXL_ASSERT(xsize == rect.xsize()); + JXL_ASSERT(ysize == rect.ysize()); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstPlaneRow(c, y); + Tout* JXL_RESTRICT row_to = rect.PlaneRow(to, c, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } + } +} + +template +Image3 ScaleImage(const T lambda, const Image3& image) { + Image3 out(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row = image.ConstPlaneRow(c, y); + T* JXL_RESTRICT row_out = out.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = lambda * row[x]; + } + } + } + return out; +} + +// Multiplies image by lambda in-place +template +void ScaleImage(const T lambda, Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->PlaneRow(c, y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = lambda * row[x]; + } + } + } +} + +// Initializes all planes to the same "value". +template +void FillImage(const T value, Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } + } +} + +template +void FillPlane(const T value, Plane* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } +} + +template +void FillImage(const T value, Image3* image, Rect rect) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.PlaneRow(image, c, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row[x] = value; + } + } + } +} + +template +void FillPlane(const T value, Plane* image, Rect rect) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.Row(image, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row[x] = value; + } + } +} + +template +void ZeroFillImage(Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + if (image->xsize() != 0) memset(row, 0, image->xsize() * sizeof(T)); + } + } +} + +template +void ZeroFillPlane(Plane* image, Rect rect) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.Row(image, y); + memset(row, 0, rect.xsize() * sizeof(T)); + } +} + +// Pad an image with xborder columns on each vertical side and yboder rows +// above and below, mirroring the image. +Image3F PadImageMirror(const Image3F& in, size_t xborder, size_t yborder); + +// Same as above, but operates in-place. Assumes that the `in` image was +// allocated large enough. +void PadImageToBlockMultipleInPlace(Image3F* JXL_RESTRICT in); + +// Downsamples an image by a given factor. +void DownsampleImage(Image3F* opsin, size_t factor); +void DownsampleImage(ImageF* image, size_t factor); + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_OPS_H_ diff --git a/media/libjxl/src/lib/jxl/image_ops_test.cc b/media/libjxl/src/lib/jxl/image_ops_test.cc new file mode 100644 index 000000000..8937364e8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_ops_test.cc @@ -0,0 +1,164 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_ops.h" + +#include +#include +#include + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +template +void TestPacked(const size_t xsize, const size_t ysize) { + Plane image1(xsize, ysize); + RandomFillImage(&image1); + const std::vector& packed = PackedFromImage(image1); + const Plane& image2 = ImageFromPacked(packed, xsize, ysize); + EXPECT_TRUE(SamePixels(image1, image2)); +} + +TEST(ImageTest, TestPacked) { + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); + + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); + + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); + + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); +} + +// Ensure entire payload is readable/writable for various size/offset combos. +TEST(ImageTest, TestAllocator) { + Rng rng(0); + const size_t k32 = 32; + const size_t kAlign = CacheAligned::kAlignment; + for (size_t size : {k32 * 1, k32 * 2, k32 * 3, k32 * 4, k32 * 5, + CacheAligned::kAlias, 2 * CacheAligned::kAlias + 4}) { + for (size_t offset = 0; offset <= CacheAligned::kAlias; offset += kAlign) { + uint8_t* bytes = + static_cast(CacheAligned::Allocate(size, offset)); + JXL_CHECK(reinterpret_cast(bytes) % kAlign == 0); + // Ensure we can write/read the last byte. Use RNG to fool the compiler + // into thinking the write is necessary. + memset(bytes, 0, size); + bytes[size - 1] = 1; // greatest element + uint32_t pos = rng.UniformU(0, size - 1); // random but != greatest + JXL_CHECK(bytes[pos] < bytes[size - 1]); + + CacheAligned::Free(bytes); + } + } +} + +template +void TestFillImpl(Image3* img, const char* layout) { + FillImage(T(1), img); + for (size_t y = 0; y < img->ysize(); ++y) { + for (size_t c = 0; c < 3; ++c) { + T* JXL_RESTRICT row = img->PlaneRow(c, y); + for (size_t x = 0; x < img->xsize(); ++x) { + if (row[x] != T(1)) { + printf("Not 1 at c=%" PRIuS " %" PRIuS ", %" PRIuS " (%" PRIuS + " x %" PRIuS ") (%s)\n", + c, x, y, img->xsize(), img->ysize(), layout); + abort(); + } + row[x] = T(2); + } + } + } + + // Same for ZeroFillImage and swapped c/y loop ordering. + ZeroFillImage(img); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < img->ysize(); ++y) { + T* JXL_RESTRICT row = img->PlaneRow(c, y); + for (size_t x = 0; x < img->xsize(); ++x) { + if (row[x] != T(0)) { + printf("Not 0 at c=%" PRIuS " %" PRIuS ", %" PRIuS " (%" PRIuS + " x %" PRIuS ") (%s)\n", + c, x, y, img->xsize(), img->ysize(), layout); + abort(); + } + row[x] = T(3); + } + } + } +} + +template +void TestFillT() { + for (uint32_t xsize : {0, 1, 15, 16, 31, 32}) { + for (uint32_t ysize : {0, 1, 15, 16, 31, 32}) { + Image3 image(xsize, ysize); + TestFillImpl(&image, "size ctor"); + + Image3 planar(Plane(xsize, ysize), Plane(xsize, ysize), + Plane(xsize, ysize)); + TestFillImpl(&planar, "planar"); + } + } +} + +// Ensure y/c/x and c/y/x loops visit pixels no more than once. +TEST(ImageTest, TestFill) { + TestFillT(); + TestFillT(); + TestFillT(); + TestFillT(); +} + +TEST(ImageTest, CopyImageToWithPaddingTest) { + Plane src(100, 61); + for (size_t y = 0; y < src.ysize(); y++) { + for (size_t x = 0; x < src.xsize(); x++) { + src.Row(y)[x] = x * 1000 + y; + } + } + Rect src_rect(10, 20, 30, 40); + EXPECT_TRUE(src_rect.IsInside(src)); + + Plane dst(60, 50); + FillImage(0u, &dst); + Rect dst_rect(20, 5, 30, 40); + EXPECT_TRUE(dst_rect.IsInside(dst)); + + CopyImageToWithPadding(src_rect, src, /*padding=*/2, dst_rect, &dst); + + // ysize is + 3 instead of + 4 because we are at the y image boundary on the + // source image. + Rect padded_dst_rect(20 - 2, 5 - 2, 30 + 4, 40 + 3); + for (size_t y = 0; y < dst.ysize(); y++) { + for (size_t x = 0; x < dst.xsize(); x++) { + if (Rect(x, y, 1, 1).IsInside(padded_dst_rect)) { + EXPECT_EQ((x - dst_rect.x0() + src_rect.x0()) * 1000 + + (y - dst_rect.y0() + src_rect.y0()), + dst.Row(y)[x]); + } else { + EXPECT_EQ(0u, dst.Row(y)[x]); + } + } + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/image_test_utils.h b/media/libjxl/src/lib/jxl/image_test_utils.h new file mode 100644 index 000000000..4549c194b --- /dev/null +++ b/media/libjxl/src/lib/jxl/image_test_utils.h @@ -0,0 +1,268 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_TEST_UTILS_H_ +#define LIB_JXL_IMAGE_TEST_UTILS_H_ + +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +void VerifyEqual(const Plane& expected, const Plane& actual) { + JXL_CHECK(SameSize(expected, actual)); + for (size_t y = 0; y < expected.ysize(); ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (size_t x = 0; x < expected.xsize(); ++x) { + ASSERT_EQ(row_expected[x], row_actual[x]) << x << " " << y; + } + } +} + +template +void VerifyEqual(const Image3& expected, const Image3& actual) { + for (size_t c = 0; c < 3; ++c) { + VerifyEqual(expected.Plane(c), actual.Plane(c)); + } +} + +template +bool SamePixels(const Plane& image1, const Plane& image2, + const Rect rect) { + if (!rect.IsInside(image1) || !rect.IsInside(image2)) { + ADD_FAILURE() << "requested rectangle is not fully inside the image"; + return false; + } + size_t mismatches = 0; + for (size_t y = rect.y0(); y < rect.ysize(); ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + for (size_t x = rect.x0(); x < rect.xsize(); ++x) { + if (row1[x] != row2[x]) { + ADD_FAILURE() << "pixel mismatch" << x << ", " << y << ": " + << double(row1[x]) << " != " << double(row2[x]); + if (++mismatches > 4) { + return false; + } + } + } + } + return mismatches == 0; +} + +template +bool SamePixels(const Plane& image1, const Plane& image2) { + JXL_CHECK(SameSize(image1, image2)); + return SamePixels(image1, image2, Rect(image1)); +} + +template +bool SamePixels(const Image3& image1, const Image3& image2) { + JXL_CHECK(SameSize(image1, image2)); + for (size_t c = 0; c < 3; ++c) { + if (!SamePixels(image1.Plane(c), image2.Plane(c))) { + return false; + } + } + return true; +} + +// Use for floating-point images with fairly large numbers; tolerates small +// absolute errors and/or small relative errors. +template +void VerifyRelativeError(const Plane& expected, const Plane& actual, + const double threshold_l1, + const double threshold_relative, + const intptr_t border = 0, const size_t c = 0) { + JXL_CHECK(SameSize(expected, actual)); + const intptr_t xsize = expected.xsize(); + const intptr_t ysize = expected.ysize(); + + // Max over current scanline to give a better idea whether there are + // systematic errors or just one outlier. Invalid if negative. + double max_l1 = -1; + double max_relative = -1; + bool any_bad = false; + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + // Cannot compute relative, only check/update L1. + if (std::abs(row_expected[x]) < 1E-10) { + if (l1 > threshold_l1) { + any_bad = true; + max_l1 = std::max(max_l1, l1); + } + } else { + const double relative = l1 / std::abs(double(row_expected[x])); + if (l1 > threshold_l1 && relative > threshold_relative) { + // Fails both tolerances => will exit below, update max_*. + any_bad = true; + max_l1 = std::max(max_l1, l1); + max_relative = std::max(max_relative, relative); + } + } + } + } + if (any_bad) { + // Never had a valid relative value, don't print it. + if (max_relative < 0) { + fprintf(stderr, "c=%" PRIu64 ": max +/- %E exceeds +/- %.2E\n", + static_cast(c), max_l1, threshold_l1); + } else { + fprintf(stderr, + "c=%" PRIu64 ": max +/- %E, x %E exceeds +/- %.2E, x %.2E\n", + static_cast(c), max_l1, max_relative, threshold_l1, + threshold_relative); + } + // Dump the expected image and actual image if the region is small enough. + const intptr_t kMaxTestDumpSize = 16; + if (xsize <= kMaxTestDumpSize + 2 * border && + ysize <= kMaxTestDumpSize + 2 * border) { + fprintf(stderr, "Expected image:\n"); + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + fprintf(stderr, "%10lf ", static_cast(row_expected[x])); + } + fprintf(stderr, "\n"); + } + + fprintf(stderr, "Actual image:\n"); + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + bool bad = l1 > threshold_l1; + if (row_expected[x] > 1E-10) { + const double relative = l1 / std::abs(double(row_expected[x])); + bad &= relative > threshold_relative; + } + if (bad) { + fprintf(stderr, "%10lf ", static_cast(row_actual[x])); + } else { + fprintf(stderr, "%10s ", "=="); + } + } + fprintf(stderr, "\n"); + } + } + + // Find first failing x for further debugging. + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + bool bad = l1 > threshold_l1; + if (row_expected[x] > 1E-10) { + const double relative = l1 / std::abs(double(row_expected[x])); + bad &= relative > threshold_relative; + } + if (bad) { + FAIL() << x << ", " << y << " (" << expected.xsize() << " x " + << expected.ysize() << ") expected " + << static_cast(row_expected[x]) << " actual " + << static_cast(row_actual[x]); + } + } + } + return; // if any_bad, we should have exited. + } +} + +template +void VerifyRelativeError(const Image3& expected, const Image3& actual, + const float threshold_l1, + const float threshold_relative, + const intptr_t border = 0) { + for (size_t c = 0; c < 3; ++c) { + VerifyRelativeError(expected.Plane(c), actual.Plane(c), threshold_l1, + threshold_relative, border, c); + } +} + +template +void GenerateImage(Rng& rng, Plane* image, U begin, U end) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + if (std::is_same::value || std::is_same::value) { + row[x] = rng.UniformF(begin, end); + } else if (std::is_signed::value) { + row[x] = rng.UniformI(begin, end); + } else { + row[x] = rng.UniformU(begin, end); + } + } + } +} + +template +void RandomFillImage(Plane* image, const T begin, const T end, + const int seed = 129) { + Rng rng(seed); + GenerateImage(rng, image, begin, end); +} + +template +typename std::enable_if::value>::type RandomFillImage( + Plane* image) { + Rng rng(129); + GenerateImage(rng, image, int64_t(0), + int64_t(std::numeric_limits::max()) + 1); +} + +JXL_INLINE void RandomFillImage(Plane* image) { + Rng rng(129); + GenerateImage(rng, image, 0.0f, std::numeric_limits::max()); +} + +template +void GenerateImage(Rng& rng, Image3* image, U begin, U end) { + for (size_t c = 0; c < 3; ++c) { + GenerateImage(rng, &image->Plane(c), begin, end); + } +} + +template +typename std::enable_if::value>::type RandomFillImage( + Image3* image) { + Rng rng(129); + GenerateImage(rng, image, int64_t(0), + int64_t(std::numeric_limits::max()) + 1); +} + +JXL_INLINE void RandomFillImage(Image3F* image) { + Rng rng(129); + GenerateImage(rng, image, 0.0f, std::numeric_limits::max()); +} + +template +void RandomFillImage(Image3* image, const U begin, const U end, + const int seed = 129) { + Rng rng(seed); + GenerateImage(rng, image, begin, end); +} + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_TEST_UTILS_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.cc b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.cc new file mode 100644 index 000000000..db49a1c21 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.cc @@ -0,0 +1,145 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/dec_jpeg_data.h" + +#include + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { +namespace jpeg { +Status DecodeJPEGData(Span encoded, JPEGData* jpeg_data) { + Status ret = true; + const uint8_t* in = encoded.data(); + size_t available_in = encoded.size(); + { + BitReader br(encoded); + BitReaderScopedCloser br_closer(&br, &ret); + JXL_RETURN_IF_ERROR(Bundle::Read(&br, jpeg_data)); + JXL_RETURN_IF_ERROR(br.JumpToByteBoundary()); + in += br.TotalBitsConsumed() / 8; + available_in -= br.TotalBitsConsumed() / 8; + } + JXL_RETURN_IF_ERROR(ret); + + BrotliDecoderState* brotli_dec = + BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + + struct BrotliDecDeleter { + BrotliDecoderState* brotli_dec; + ~BrotliDecDeleter() { BrotliDecoderDestroyInstance(brotli_dec); } + } brotli_dec_deleter{brotli_dec}; + + BrotliDecoderResult result = + BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS; + + auto br_read = [&](std::vector& data) -> Status { + size_t available_out = data.size(); + uint8_t* out = data.data(); + while (available_out != 0) { + if (BrotliDecoderIsFinished(brotli_dec)) { + return JXL_FAILURE("Not enough decompressed output"); + } + uint8_t* next_out_before = out; + size_t avail_out_before = available_out; + msan::MemoryIsInitialized(in, available_in); + result = BrotliDecoderDecompressStream(brotli_dec, &available_in, &in, + &available_out, &out, nullptr); + if (result != + BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT && + result != BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_FAILURE( + "Brotli decoding error: %s\n", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(brotli_dec))); + } + msan::UnpoisonMemory(next_out_before, avail_out_before - available_out); + } + return true; + }; + size_t num_icc = 0; + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + auto& marker = jpeg_data->app_data[i]; + if (jpeg_data->app_marker_type[i] != AppMarkerType::kUnknown) { + // Set the size of the marker. + size_t size_minus_1 = marker.size() - 1; + marker[1] = size_minus_1 >> 8; + marker[2] = size_minus_1 & 0xFF; + if (jpeg_data->app_marker_type[i] == AppMarkerType::kICC) { + if (marker.size() < 17) { + return JXL_FAILURE("ICC markers must be at least 17 bytes"); + } + marker[0] = 0xE2; + memcpy(&marker[3], kIccProfileTag, sizeof kIccProfileTag); + marker[15] = ++num_icc; + } + } else { + JXL_RETURN_IF_ERROR(br_read(marker)); + if (marker[1] * 256u + marker[2] + 1u != marker.size()) { + return JXL_FAILURE("Incorrect marker size"); + } + } + } + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + auto& marker = jpeg_data->app_data[i]; + if (jpeg_data->app_marker_type[i] == AppMarkerType::kICC) { + marker[16] = num_icc; + } + if (jpeg_data->app_marker_type[i] == AppMarkerType::kExif) { + marker[0] = 0xE1; + if (marker.size() < 3 + sizeof kExifTag) { + return JXL_FAILURE("Incorrect Exif marker size"); + } + memcpy(&marker[3], kExifTag, sizeof kExifTag); + } + if (jpeg_data->app_marker_type[i] == AppMarkerType::kXMP) { + marker[0] = 0xE1; + if (marker.size() < 3 + sizeof kXMPTag) { + return JXL_FAILURE("Incorrect XMP marker size"); + } + memcpy(&marker[3], kXMPTag, sizeof kXMPTag); + } + } + // TODO(eustas): actually inject ICC profile and check it fits perfectly. + for (size_t i = 0; i < jpeg_data->com_data.size(); i++) { + auto& marker = jpeg_data->com_data[i]; + JXL_RETURN_IF_ERROR(br_read(marker)); + if (marker[1] * 256u + marker[2] + 1u != marker.size()) { + return JXL_FAILURE("Incorrect marker size"); + } + } + for (size_t i = 0; i < jpeg_data->inter_marker_data.size(); i++) { + JXL_RETURN_IF_ERROR(br_read(jpeg_data->inter_marker_data[i])); + } + JXL_RETURN_IF_ERROR(br_read(jpeg_data->tail_data)); + + // Check if there is more decompressed output. + size_t available_out = 1; + uint64_t dummy; + uint8_t* next_out = reinterpret_cast(&dummy); + result = BrotliDecoderDecompressStream(brotli_dec, &available_in, &in, + &available_out, &next_out, nullptr); + if (available_out == 0 || + result == BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + return JXL_FAILURE("Excess data in compressed stream"); + } + if (result == BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT) { + return JXL_FAILURE("Incomplete brotli-stream"); + } + if (!BrotliDecoderIsFinished(brotli_dec) || + result != BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_FAILURE("Corrupted brotli-stream"); + } + if (available_in != 0) { + return JXL_FAILURE("Unused data after brotli stream"); + } + + return true; +} +} // namespace jpeg +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.h b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.h new file mode 100644 index 000000000..b9d50bf9f --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data.h @@ -0,0 +1,19 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_DATA_H_ +#define LIB_JXL_JPEG_DEC_JPEG_DATA_H_ + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { +Status DecodeJPEGData(Span encoded, JPEGData* jpeg_data); +} +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_DATA_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.cc b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.cc new file mode 100644 index 000000000..5336e47fd --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.cc @@ -0,0 +1,995 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" + +#include +#include /* for memset, memcpy */ + +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/common.h" +#include "lib/jxl/jpeg/dec_jpeg_serialization_state.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +namespace { + +enum struct SerializationStatus { + NEEDS_MORE_INPUT, + NEEDS_MORE_OUTPUT, + ERROR, + DONE +}; + +const int kJpegPrecision = 8; + +// JpegBitWriter: buffer size +const size_t kJpegBitWriterChunkSize = 16384; + +// DCTCodingState: maximum number of correction bits to buffer +const int kJPEGMaxCorrectionBits = 1u << 16; + +// Returns non-zero if and only if x has a zero byte, i.e. one of +// x & 0xff, x & 0xff00, ..., x & 0xff00000000000000 is zero. +static JXL_INLINE uint64_t HasZeroByte(uint64_t x) { + return (x - 0x0101010101010101ULL) & ~x & 0x8080808080808080ULL; +} + +void JpegBitWriterInit(JpegBitWriter* bw, + std::deque* output_queue) { + bw->output = output_queue; + bw->chunk = OutputChunk(kJpegBitWriterChunkSize); + bw->pos = 0; + bw->put_buffer = 0; + bw->put_bits = 64; + bw->healthy = true; + bw->data = bw->chunk.buffer->data(); +} + +static JXL_NOINLINE void SwapBuffer(JpegBitWriter* bw) { + bw->chunk.len = bw->pos; + bw->output->emplace_back(std::move(bw->chunk)); + bw->chunk = OutputChunk(kJpegBitWriterChunkSize); + bw->data = bw->chunk.buffer->data(); + bw->pos = 0; +} + +static JXL_INLINE void Reserve(JpegBitWriter* bw, size_t n_bytes) { + if (JXL_UNLIKELY((bw->pos + n_bytes) > kJpegBitWriterChunkSize)) { + SwapBuffer(bw); + } +} + +/** + * Writes the given byte to the output, writes an extra zero if byte is 0xFF. + * + * This method is "careless" - caller must make sure that there is enough + * space in the output buffer. Emits up to 2 bytes to buffer. + */ +static JXL_INLINE void EmitByte(JpegBitWriter* bw, int byte) { + bw->data[bw->pos++] = byte; + if (byte == 0xFF) bw->data[bw->pos++] = 0; +} + +static JXL_INLINE void DischargeBitBuffer(JpegBitWriter* bw) { + // At this point we are ready to emit the most significant 6 bytes of + // put_buffer_ to the output. + // The JPEG format requires that after every 0xff byte in the entropy + // coded section, there is a zero byte, therefore we first check if any of + // the 6 most significant bytes of put_buffer_ is 0xFF. + Reserve(bw, 12); + if (HasZeroByte(~bw->put_buffer | 0xFFFF)) { + // We have a 0xFF byte somewhere, examine each byte and append a zero + // byte if necessary. + EmitByte(bw, (bw->put_buffer >> 56) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 48) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 40) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 32) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 24) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 16) & 0xFF); + } else { + // We don't have any 0xFF bytes, output all 6 bytes without checking. + bw->data[bw->pos] = (bw->put_buffer >> 56) & 0xFF; + bw->data[bw->pos + 1] = (bw->put_buffer >> 48) & 0xFF; + bw->data[bw->pos + 2] = (bw->put_buffer >> 40) & 0xFF; + bw->data[bw->pos + 3] = (bw->put_buffer >> 32) & 0xFF; + bw->data[bw->pos + 4] = (bw->put_buffer >> 24) & 0xFF; + bw->data[bw->pos + 5] = (bw->put_buffer >> 16) & 0xFF; + bw->pos += 6; + } + bw->put_buffer <<= 48; + bw->put_bits += 48; +} + +static JXL_INLINE void WriteBits(JpegBitWriter* bw, int nbits, uint64_t bits) { + // This is an optimization; if everything goes well, + // then |nbits| is positive; if non-existing Huffman symbol is going to be + // encoded, its length should be zero; later encoder could check the + // "health" of JpegBitWriter. + if (nbits == 0) { + bw->healthy = false; + return; + } + bw->put_bits -= nbits; + bw->put_buffer |= (bits << bw->put_bits); + if (bw->put_bits <= 16) DischargeBitBuffer(bw); +} + +void EmitMarker(JpegBitWriter* bw, int marker) { + Reserve(bw, 2); + JXL_DASSERT(marker != 0xFF); + bw->data[bw->pos++] = 0xFF; + bw->data[bw->pos++] = marker; +} + +bool JumpToByteBoundary(JpegBitWriter* bw, const uint8_t** pad_bits, + const uint8_t* pad_bits_end) { + size_t n_bits = bw->put_bits & 7u; + uint8_t pad_pattern; + if (*pad_bits == nullptr) { + pad_pattern = (1u << n_bits) - 1; + } else { + pad_pattern = 0; + const uint8_t* src = *pad_bits; + // TODO(eustas): bitwise reading looks insanely ineffective... + while (n_bits--) { + pad_pattern <<= 1; + if (src >= pad_bits_end) return false; + // TODO(eustas): DCHECK *src == {0, 1} + pad_pattern |= !!*(src++); + } + *pad_bits = src; + } + + Reserve(bw, 16); + + while (bw->put_bits <= 56) { + int c = (bw->put_buffer >> 56) & 0xFF; + EmitByte(bw, c); + bw->put_buffer <<= 8; + bw->put_bits += 8; + } + if (bw->put_bits < 64) { + int pad_mask = 0xFFu >> (64 - bw->put_bits); + int c = ((bw->put_buffer >> 56) & ~pad_mask) | pad_pattern; + EmitByte(bw, c); + } + bw->put_buffer = 0; + bw->put_bits = 64; + + return true; +} + +void JpegBitWriterFinish(JpegBitWriter* bw) { + if (bw->pos == 0) return; + bw->chunk.len = bw->pos; + bw->output->emplace_back(std::move(bw->chunk)); + bw->chunk = OutputChunk(nullptr, 0); + bw->data = nullptr; + bw->pos = 0; +} + +void DCTCodingStateInit(DCTCodingState* s) { + s->eob_run_ = 0; + s->cur_ac_huff_ = nullptr; + s->refinement_bits_.clear(); + s->refinement_bits_.reserve(kJPEGMaxCorrectionBits); +} + +// Emit all buffered data to the bit stream using the given Huffman code and +// bit writer. +static JXL_INLINE void Flush(DCTCodingState* s, JpegBitWriter* bw) { + if (s->eob_run_ > 0) { + int nbits = FloorLog2Nonzero(s->eob_run_); + int symbol = nbits << 4u; + WriteBits(bw, s->cur_ac_huff_->depth[symbol], + s->cur_ac_huff_->code[symbol]); + if (nbits > 0) { + WriteBits(bw, nbits, s->eob_run_ & ((1 << nbits) - 1)); + } + s->eob_run_ = 0; + } + for (size_t i = 0; i < s->refinement_bits_.size(); ++i) { + WriteBits(bw, 1, s->refinement_bits_[i]); + } + s->refinement_bits_.clear(); +} + +// Buffer some more data at the end-of-band (the last non-zero or newly +// non-zero coefficient within the [Ss, Se] spectral band). +static JXL_INLINE void BufferEndOfBand(DCTCodingState* s, + const HuffmanCodeTable* ac_huff, + const std::vector* new_bits, + JpegBitWriter* bw) { + if (s->eob_run_ == 0) { + s->cur_ac_huff_ = ac_huff; + } + ++s->eob_run_; + if (new_bits) { + s->refinement_bits_.insert(s->refinement_bits_.end(), new_bits->begin(), + new_bits->end()); + } + if (s->eob_run_ == 0x7FFF || + s->refinement_bits_.size() > kJPEGMaxCorrectionBits - kDCTBlockSize + 1) { + Flush(s, bw); + } +} + +bool BuildHuffmanCodeTable(const JPEGHuffmanCode& huff, + HuffmanCodeTable* table) { + int huff_code[kJpegHuffmanAlphabetSize]; + // +1 for a sentinel element. + uint32_t huff_size[kJpegHuffmanAlphabetSize + 1]; + int p = 0; + for (size_t l = 1; l <= kJpegHuffmanMaxBitLength; ++l) { + int i = huff.counts[l]; + if (p + i > kJpegHuffmanAlphabetSize + 1) { + return false; + } + while (i--) huff_size[p++] = l; + } + + if (p == 0) { + return true; + } + + // Reuse sentinel element. + int last_p = p - 1; + huff_size[last_p] = 0; + + int code = 0; + uint32_t si = huff_size[0]; + p = 0; + while (huff_size[p]) { + while ((huff_size[p]) == si) { + huff_code[p++] = code; + code++; + } + code <<= 1; + si++; + } + for (p = 0; p < last_p; p++) { + int i = huff.values[p]; + table->depth[i] = huff_size[p]; + table->code[i] = huff_code[p]; + } + return true; +} + +bool EncodeSOI(SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, 0xD8})); + return true; +} + +bool EncodeEOI(const JPEGData& jpg, SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, 0xD9})); + state->output_queue.emplace_back(jpg.tail_data); + return true; +} + +bool EncodeSOF(const JPEGData& jpg, uint8_t marker, SerializationState* state) { + if (marker <= 0xC2) state->is_progressive = (marker == 0xC2); + + const size_t n_comps = jpg.components.size(); + const size_t marker_len = 8 + 3 * n_comps; + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = marker; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + data[pos++] = kJpegPrecision; + data[pos++] = jpg.height >> 8u; + data[pos++] = jpg.height & 0xFFu; + data[pos++] = jpg.width >> 8u; + data[pos++] = jpg.width & 0xFFu; + data[pos++] = n_comps; + for (size_t i = 0; i < n_comps; ++i) { + data[pos++] = jpg.components[i].id; + data[pos++] = ((jpg.components[i].h_samp_factor << 4u) | + (jpg.components[i].v_samp_factor)); + const size_t quant_idx = jpg.components[i].quant_idx; + if (quant_idx >= jpg.quant.size()) return false; + data[pos++] = jpg.quant[quant_idx].index; + } + return true; +} + +bool EncodeSOS(const JPEGData& jpg, const JPEGScanInfo& scan_info, + SerializationState* state) { + const size_t n_scans = scan_info.num_components; + const size_t marker_len = 6 + 2 * n_scans; + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xDA; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + data[pos++] = n_scans; + for (size_t i = 0; i < n_scans; ++i) { + const JPEGComponentScanInfo& si = scan_info.components[i]; + if (si.comp_idx >= jpg.components.size()) return false; + data[pos++] = jpg.components[si.comp_idx].id; + data[pos++] = (si.dc_tbl_idx << 4u) + si.ac_tbl_idx; + } + data[pos++] = scan_info.Ss; + data[pos++] = scan_info.Se; + data[pos++] = ((scan_info.Ah << 4u) | (scan_info.Al)); + return true; +} + +bool EncodeDHT(const JPEGData& jpg, SerializationState* state) { + const std::vector& huffman_code = jpg.huffman_code; + + size_t marker_len = 2; + for (size_t i = state->dht_index; i < huffman_code.size(); ++i) { + const JPEGHuffmanCode& huff = huffman_code[i]; + marker_len += kJpegHuffmanMaxBitLength; + for (size_t j = 0; j < huff.counts.size(); ++j) { + marker_len += huff.counts[j]; + } + if (huff.is_last) break; + } + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xC4; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + while (true) { + const size_t huffman_code_index = state->dht_index++; + if (huffman_code_index >= huffman_code.size()) { + return false; + } + const JPEGHuffmanCode& huff = huffman_code[huffman_code_index]; + size_t index = huff.slot_id; + HuffmanCodeTable* huff_table; + if (index & 0x10) { + index -= 0x10; + huff_table = &state->ac_huff_table[index]; + } else { + huff_table = &state->dc_huff_table[index]; + } + // TODO(eustas): cache + // TODO(eustas): set up non-existing symbols + if (!BuildHuffmanCodeTable(huff, huff_table)) { + return false; + } + size_t total_count = 0; + size_t max_length = 0; + for (size_t i = 0; i < huff.counts.size(); ++i) { + if (huff.counts[i] != 0) { + max_length = i; + } + total_count += huff.counts[i]; + } + --total_count; + data[pos++] = huff.slot_id; + for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) { + data[pos++] = (i == max_length ? huff.counts[i] - 1 : huff.counts[i]); + } + for (size_t i = 0; i < total_count; ++i) { + data[pos++] = huff.values[i]; + } + if (huff.is_last) break; + } + return true; +} + +bool EncodeDQT(const JPEGData& jpg, SerializationState* state) { + int marker_len = 2; + for (size_t i = state->dqt_index; i < jpg.quant.size(); ++i) { + const JPEGQuantTable& table = jpg.quant[i]; + marker_len += 1 + (table.precision ? 2 : 1) * kDCTBlockSize; + if (table.is_last) break; + } + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xDB; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + while (true) { + const size_t idx = state->dqt_index++; + if (idx >= jpg.quant.size()) { + return false; // corrupt input + } + const JPEGQuantTable& table = jpg.quant[idx]; + data[pos++] = (table.precision << 4u) + table.index; + for (size_t i = 0; i < kDCTBlockSize; ++i) { + int val_idx = kJPEGNaturalOrder[i]; + int val = table.values[val_idx]; + if (table.precision) { + data[pos++] = val >> 8u; + } + data[pos++] = val & 0xFFu; + } + if (table.is_last) break; + } + return true; +} + +bool EncodeDRI(const JPEGData& jpg, SerializationState* state) { + state->seen_dri_marker = true; + OutputChunk dri_marker = {0xFF, + 0xDD, + 0, + 4, + static_cast(jpg.restart_interval >> 8), + static_cast(jpg.restart_interval & 0xFF)}; + state->output_queue.push_back(std::move(dri_marker)); + return true; +} + +bool EncodeRestart(uint8_t marker, SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, marker})); + return true; +} + +bool EncodeAPP(const JPEGData& jpg, uint8_t marker, SerializationState* state) { + // TODO(eustas): check that marker corresponds to payload? + (void)marker; + + size_t app_index = state->app_index++; + if (app_index >= jpg.app_data.size()) return false; + state->output_queue.push_back(OutputChunk({0xFF})); + state->output_queue.emplace_back(jpg.app_data[app_index]); + return true; +} + +bool EncodeCOM(const JPEGData& jpg, SerializationState* state) { + size_t com_index = state->com_index++; + if (com_index >= jpg.com_data.size()) return false; + state->output_queue.push_back(OutputChunk({0xFF})); + state->output_queue.emplace_back(jpg.com_data[com_index]); + return true; +} + +bool EncodeInterMarkerData(const JPEGData& jpg, SerializationState* state) { + size_t index = state->data_index++; + if (index >= jpg.inter_marker_data.size()) return false; + state->output_queue.emplace_back(jpg.inter_marker_data[index]); + return true; +} + +bool EncodeDCTBlockSequential(const coeff_t* coeffs, + const HuffmanCodeTable& dc_huff, + const HuffmanCodeTable& ac_huff, + int num_zero_runs, coeff_t* last_dc_coeff, + JpegBitWriter* bw) { + coeff_t temp2; + coeff_t temp; + temp2 = coeffs[0]; + temp = temp2 - *last_dc_coeff; + *last_dc_coeff = temp2; + temp2 = temp; + if (temp < 0) { + temp = -temp; + if (temp < 0) return false; + temp2--; + } + int dc_nbits = (temp == 0) ? 0 : (FloorLog2Nonzero(temp) + 1); + WriteBits(bw, dc_huff.depth[dc_nbits], dc_huff.code[dc_nbits]); + if (dc_nbits >= 12) return false; + if (dc_nbits > 0) { + WriteBits(bw, dc_nbits, temp2 & ((1u << dc_nbits) - 1)); + } + int r = 0; + for (int k = 1; k < 64; ++k) { + if ((temp = coeffs[kJPEGNaturalOrder[k]]) == 0) { + r++; + continue; + } + if (temp < 0) { + temp = -temp; + if (temp < 0) return false; + temp2 = ~temp; + } else { + temp2 = temp; + } + while (r > 15) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + int ac_nbits = FloorLog2Nonzero(temp) + 1; + if (ac_nbits >= 16) return false; + int symbol = (r << 4u) + ac_nbits; + WriteBits(bw, ac_huff.depth[symbol], ac_huff.code[symbol]); + WriteBits(bw, ac_nbits, temp2 & ((1 << ac_nbits) - 1)); + r = 0; + } + for (int i = 0; i < num_zero_runs; ++i) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + if (r > 0) { + WriteBits(bw, ac_huff.depth[0], ac_huff.code[0]); + } + return true; +} + +bool EncodeDCTBlockProgressive(const coeff_t* coeffs, + const HuffmanCodeTable& dc_huff, + const HuffmanCodeTable& ac_huff, int Ss, int Se, + int Al, int num_zero_runs, + DCTCodingState* coding_state, + coeff_t* last_dc_coeff, JpegBitWriter* bw) { + bool eob_run_allowed = Ss > 0; + coeff_t temp2; + coeff_t temp; + if (Ss == 0) { + temp2 = coeffs[0] >> Al; + temp = temp2 - *last_dc_coeff; + *last_dc_coeff = temp2; + temp2 = temp; + if (temp < 0) { + temp = -temp; + if (temp < 0) return false; + temp2--; + } + int nbits = (temp == 0) ? 0 : (FloorLog2Nonzero(temp) + 1); + WriteBits(bw, dc_huff.depth[nbits], dc_huff.code[nbits]); + if (nbits > 0) { + WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1)); + } + ++Ss; + } + if (Ss > Se) { + return true; + } + int r = 0; + for (int k = Ss; k <= Se; ++k) { + if ((temp = coeffs[kJPEGNaturalOrder[k]]) == 0) { + r++; + continue; + } + if (temp < 0) { + temp = -temp; + if (temp < 0) return false; + temp >>= Al; + temp2 = ~temp; + } else { + temp >>= Al; + temp2 = temp; + } + if (temp == 0) { + r++; + continue; + } + Flush(coding_state, bw); + while (r > 15) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + int nbits = FloorLog2Nonzero(temp) + 1; + int symbol = (r << 4u) + nbits; + WriteBits(bw, ac_huff.depth[symbol], ac_huff.code[symbol]); + WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1)); + r = 0; + } + if (num_zero_runs > 0) { + Flush(coding_state, bw); + for (int i = 0; i < num_zero_runs; ++i) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + } + if (r > 0) { + BufferEndOfBand(coding_state, &ac_huff, nullptr, bw); + if (!eob_run_allowed) { + Flush(coding_state, bw); + } + } + return true; +} + +bool EncodeRefinementBits(const coeff_t* coeffs, + const HuffmanCodeTable& ac_huff, int Ss, int Se, + int Al, DCTCodingState* coding_state, + JpegBitWriter* bw) { + bool eob_run_allowed = Ss > 0; + if (Ss == 0) { + // Emit next bit of DC component. + WriteBits(bw, 1, (coeffs[0] >> Al) & 1); + ++Ss; + } + if (Ss > Se) { + return true; + } + int abs_values[kDCTBlockSize]; + int eob = 0; + for (int k = Ss; k <= Se; k++) { + const coeff_t abs_val = std::abs(coeffs[kJPEGNaturalOrder[k]]); + abs_values[k] = abs_val >> Al; + if (abs_values[k] == 1) { + eob = k; + } + } + int r = 0; + std::vector refinement_bits; + refinement_bits.reserve(kDCTBlockSize); + for (int k = Ss; k <= Se; k++) { + if (abs_values[k] == 0) { + r++; + continue; + } + while (r > 15 && k <= eob) { + Flush(coding_state, bw); + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + for (int bit : refinement_bits) { + WriteBits(bw, 1, bit); + } + refinement_bits.clear(); + } + if (abs_values[k] > 1) { + refinement_bits.push_back(abs_values[k] & 1u); + continue; + } + Flush(coding_state, bw); + int symbol = (r << 4u) + 1; + int new_non_zero_bit = (coeffs[kJPEGNaturalOrder[k]] < 0) ? 0 : 1; + WriteBits(bw, ac_huff.depth[symbol], ac_huff.code[symbol]); + WriteBits(bw, 1, new_non_zero_bit); + for (int bit : refinement_bits) { + WriteBits(bw, 1, bit); + } + refinement_bits.clear(); + r = 0; + } + if (r > 0 || !refinement_bits.empty()) { + BufferEndOfBand(coding_state, &ac_huff, &refinement_bits, bw); + if (!eob_run_allowed) { + Flush(coding_state, bw); + } + } + return true; +} + +template +SerializationStatus JXL_NOINLINE DoEncodeScan(const JPEGData& jpg, + SerializationState* state) { + const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index]; + EncodeScanState& ss = state->scan_state; + + const int restart_interval = + state->seen_dri_marker ? jpg.restart_interval : 0; + + const auto get_next_extra_zero_run_index = [&ss, &scan_info]() -> int { + if (ss.extra_zero_runs_pos < scan_info.extra_zero_runs.size()) { + return scan_info.extra_zero_runs[ss.extra_zero_runs_pos].block_idx; + } else { + return -1; + } + }; + + const auto get_next_reset_point = [&ss, &scan_info]() -> int { + if (ss.next_reset_point_pos < scan_info.reset_points.size()) { + return scan_info.reset_points[ss.next_reset_point_pos++]; + } else { + return -1; + } + }; + + if (ss.stage == EncodeScanState::HEAD) { + if (!EncodeSOS(jpg, scan_info, state)) return SerializationStatus::ERROR; + JpegBitWriterInit(&ss.bw, &state->output_queue); + DCTCodingStateInit(&ss.coding_state); + ss.restarts_to_go = restart_interval; + ss.next_restart_marker = 0; + ss.block_scan_index = 0; + ss.extra_zero_runs_pos = 0; + ss.next_extra_zero_run_index = get_next_extra_zero_run_index(); + ss.next_reset_point_pos = 0; + ss.next_reset_point = get_next_reset_point(); + ss.mcu_y = 0; + memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff)); + ss.stage = EncodeScanState::BODY; + } + JpegBitWriter* bw = &ss.bw; + DCTCodingState* coding_state = &ss.coding_state; + + JXL_DASSERT(ss.stage == EncodeScanState::BODY); + + // "Non-interleaved" means color data comes in separate scans, in other words + // each scan can contain only one color component. + const bool is_interleaved = (scan_info.num_components > 1); + int MCUs_per_row = 0; + int MCU_rows = 0; + jpg.CalculateMcuSize(scan_info, &MCUs_per_row, &MCU_rows); + const bool is_progressive = state->is_progressive; + const int Al = is_progressive ? scan_info.Al : 0; + const int Ss = is_progressive ? scan_info.Ss : 0; + const int Se = is_progressive ? scan_info.Se : 63; + + // DC-only is defined by [0..0] spectral range. + const bool want_ac = ((Ss != 0) || (Se != 0)); + // TODO: support streaming decoding again. + const bool complete_ac = true; + const bool has_ac = true; + if (want_ac && !has_ac) return SerializationStatus::NEEDS_MORE_INPUT; + + // |has_ac| implies |complete_dc| but not vice versa; for the sake of + // simplicity we pretend they are equal, because they are separated by just a + // few bytes of input. + const bool complete_dc = has_ac; + const bool complete = want_ac ? complete_ac : complete_dc; + // When "incomplete" |ac_dc| tracks information about current ("incomplete") + // band parsing progress. + + // FIXME: Is this always complete? + // const int last_mcu_y = + // complete ? MCU_rows : parsing_state.internal->ac_dc.next_mcu_y * + // v_group; + (void)complete; + const int last_mcu_y = complete ? MCU_rows : 0; + + for (; ss.mcu_y < last_mcu_y; ++ss.mcu_y) { + for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) { + // Possibly emit a restart marker. + if (restart_interval > 0 && ss.restarts_to_go == 0) { + Flush(coding_state, bw); + if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) { + return SerializationStatus::ERROR; + } + EmitMarker(bw, 0xD0 + ss.next_restart_marker); + ss.next_restart_marker += 1; + ss.next_restart_marker &= 0x7; + ss.restarts_to_go = restart_interval; + memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff)); + } + // Encode one MCU + for (size_t i = 0; i < scan_info.num_components; ++i) { + const JPEGComponentScanInfo& si = scan_info.components[i]; + const JPEGComponent& c = jpg.components[si.comp_idx]; + const HuffmanCodeTable& dc_huff = state->dc_huff_table[si.dc_tbl_idx]; + const HuffmanCodeTable& ac_huff = state->ac_huff_table[si.ac_tbl_idx]; + int n_blocks_y = is_interleaved ? c.v_samp_factor : 1; + int n_blocks_x = is_interleaved ? c.h_samp_factor : 1; + for (int iy = 0; iy < n_blocks_y; ++iy) { + for (int ix = 0; ix < n_blocks_x; ++ix) { + int block_y = ss.mcu_y * n_blocks_y + iy; + int block_x = mcu_x * n_blocks_x + ix; + int block_idx = block_y * c.width_in_blocks + block_x; + if (ss.block_scan_index == ss.next_reset_point) { + Flush(coding_state, bw); + ss.next_reset_point = get_next_reset_point(); + } + int num_zero_runs = 0; + if (ss.block_scan_index == ss.next_extra_zero_run_index) { + num_zero_runs = scan_info.extra_zero_runs[ss.extra_zero_runs_pos] + .num_extra_zero_runs; + ++ss.extra_zero_runs_pos; + ss.next_extra_zero_run_index = get_next_extra_zero_run_index(); + } + const coeff_t* coeffs = &c.coeffs[block_idx << 6]; + bool ok; + if (kMode == 0) { + ok = EncodeDCTBlockSequential(coeffs, dc_huff, ac_huff, + num_zero_runs, + ss.last_dc_coeff + si.comp_idx, bw); + } else if (kMode == 1) { + ok = EncodeDCTBlockProgressive( + coeffs, dc_huff, ac_huff, Ss, Se, Al, num_zero_runs, + coding_state, ss.last_dc_coeff + si.comp_idx, bw); + } else { + ok = EncodeRefinementBits(coeffs, ac_huff, Ss, Se, Al, + coding_state, bw); + } + if (!ok) return SerializationStatus::ERROR; + ++ss.block_scan_index; + } + } + } + --ss.restarts_to_go; + } + } + if (ss.mcu_y < MCU_rows) { + if (!bw->healthy) return SerializationStatus::ERROR; + return SerializationStatus::NEEDS_MORE_INPUT; + } + Flush(coding_state, bw); + if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) { + return SerializationStatus::ERROR; + } + JpegBitWriterFinish(bw); + ss.stage = EncodeScanState::HEAD; + state->scan_index++; + if (!bw->healthy) return SerializationStatus::ERROR; + + return SerializationStatus::DONE; +} + +static SerializationStatus JXL_INLINE EncodeScan(const JPEGData& jpg, + SerializationState* state) { + const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index]; + const bool is_progressive = state->is_progressive; + const int Al = is_progressive ? scan_info.Al : 0; + const int Ah = is_progressive ? scan_info.Ah : 0; + const int Ss = is_progressive ? scan_info.Ss : 0; + const int Se = is_progressive ? scan_info.Se : 63; + const bool need_sequential = + !is_progressive || (Ah == 0 && Al == 0 && Ss == 0 && Se == 63); + if (need_sequential) { + return DoEncodeScan<0>(jpg, state); + } else if (Ah == 0) { + return DoEncodeScan<1>(jpg, state); + } else { + return DoEncodeScan<2>(jpg, state); + } +} + +SerializationStatus SerializeSection(uint8_t marker, SerializationState* state, + const JPEGData& jpg) { + const auto to_status = [](bool result) { + return result ? SerializationStatus::DONE : SerializationStatus::ERROR; + }; + // TODO(eustas): add and use marker enum + switch (marker) { + case 0xC0: + case 0xC1: + case 0xC2: + case 0xC9: + case 0xCA: + return to_status(EncodeSOF(jpg, marker, state)); + + case 0xC4: + return to_status(EncodeDHT(jpg, state)); + + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD5: + case 0xD6: + case 0xD7: + return to_status(EncodeRestart(marker, state)); + + case 0xD9: + return to_status(EncodeEOI(jpg, state)); + + case 0xDA: + return EncodeScan(jpg, state); + + case 0xDB: + return to_status(EncodeDQT(jpg, state)); + + case 0xDD: + return to_status(EncodeDRI(jpg, state)); + + case 0xE0: + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xED: + case 0xEE: + case 0xEF: + return to_status(EncodeAPP(jpg, marker, state)); + + case 0xFE: + return to_status(EncodeCOM(jpg, state)); + + case 0xFF: + return to_status(EncodeInterMarkerData(jpg, state)); + + default: + return SerializationStatus::ERROR; + } +} + +} // namespace + +// TODO(veluca): add streaming support again. +Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out) { + SerializationState ss; + + size_t written = 0; + const auto maybe_push_output = [&]() -> Status { + if (ss.stage != SerializationState::ERROR) { + while (!ss.output_queue.empty()) { + auto& chunk = ss.output_queue.front(); + size_t num_written = out(chunk.next, chunk.len); + if (num_written == 0 && chunk.len > 0) { + return StatusMessage(Status(StatusCode::kNotEnoughBytes), + "Failed to write output"); + } + chunk.len -= num_written; + written += num_written; + if (chunk.len == 0) { + ss.output_queue.pop_front(); + } + } + } + return true; + }; + + while (true) { + switch (ss.stage) { + case SerializationState::INIT: { + // Valid Brunsli requires, at least, 0xD9 marker. + // This might happen on corrupted stream, or on unconditioned JPEGData. + // TODO(eustas): check D9 in the only one and is the last one. + if (jpg.marker_order.empty()) { + ss.stage = SerializationState::ERROR; + break; + } + + ss.dc_huff_table.resize(kMaxHuffmanTables); + ss.ac_huff_table.resize(kMaxHuffmanTables); + if (jpg.has_zero_padding_bit) { + ss.pad_bits = jpg.padding_bits.data(); + ss.pad_bits_end = ss.pad_bits + jpg.padding_bits.size(); + } + + EncodeSOI(&ss); + JXL_QUIET_RETURN_IF_ERROR(maybe_push_output()); + ss.stage = SerializationState::SERIALIZE_SECTION; + break; + } + + case SerializationState::SERIALIZE_SECTION: { + if (ss.section_index >= jpg.marker_order.size()) { + ss.stage = SerializationState::DONE; + break; + } + uint8_t marker = jpg.marker_order[ss.section_index]; + SerializationStatus status = SerializeSection(marker, &ss, jpg); + if (status == SerializationStatus::ERROR) { + JXL_WARNING("Failed to encode marker 0x%.2x", marker); + ss.stage = SerializationState::ERROR; + break; + } + JXL_QUIET_RETURN_IF_ERROR(maybe_push_output()); + if (status == SerializationStatus::NEEDS_MORE_INPUT) { + return JXL_FAILURE("Incomplete serialization data"); + } else if (status != SerializationStatus::DONE) { + JXL_DASSERT(false); + ss.stage = SerializationState::ERROR; + break; + } + ++ss.section_index; + break; + } + + case SerializationState::DONE: + JXL_ASSERT(ss.output_queue.empty()); + return true; + + case SerializationState::ERROR: + return JXL_FAILURE("JPEG serialization error"); + } + } +} + +Status EncodeImageJPGCoefficients(const CodecInOut* io, PaddedBytes* bytes) { + auto write = [&bytes](const uint8_t* buf, size_t len) { + bytes->append(buf, buf + len); + return len; + }; + return WriteJpeg(*io->Main().jpeg_data, write); +} + +} // namespace jpeg +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.h b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.h new file mode 100644 index 000000000..f272ae7df --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_data_writer.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Functions for writing a JPEGData object into a jpeg byte stream. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ +#define LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ + +#include +#include + +#include + +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +// Function type used to write len bytes into buf. Returns the number of bytes +// written. +using JPEGOutput = std::function; + +Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out); + +// Reconstructs the JPEG from the coefficients and metadata in CodecInOut. +Status EncodeImageJPGCoefficients(const CodecInOut* io, PaddedBytes* bytes); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_output_chunk.h b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_output_chunk.h new file mode 100644 index 000000000..e003c0495 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_output_chunk.h @@ -0,0 +1,72 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ +#define LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ + +#include +#include + +#include +#include +#include + +namespace jxl { +namespace jpeg { + +/** + * A chunk of output data. + * + * Data producer creates OutputChunks and adds them to the end output queue. + * Once control flow leaves the producer code, it is considered that chunk of + * data is final and can not be changed; to underline this fact |next| is a + * const-pointer. + * + * Data consumer removes OutputChunks from the beginning of the output queue. + * It is possible to consume OutputChunks partially, by updating |next| and + * |len|. + * + * There are 2 types of output chunks: + * - owning: actual data is stored in |buffer| field; producer fills data after + * the instance it created; it is legal to reduce |len| to show that not all + * the capacity of |buffer| is used + * - non-owning: represents the data stored (owned) somewhere else + */ +struct OutputChunk { + // Non-owning + template + explicit OutputChunk(Bytes& bytes) : len(bytes.size()) { + // Deal both with const qualifier and data type. + const void* src = bytes.data(); + next = reinterpret_cast(src); + } + + // Non-owning + OutputChunk(const uint8_t* data, size_t size) : next(data), len(size) {} + + // Owning + explicit OutputChunk(size_t size = 0) { + buffer.reset(new std::vector(size)); + next = buffer->data(); + len = size; + } + + // Owning + OutputChunk(std::initializer_list bytes) { + buffer.reset(new std::vector(bytes)); + next = buffer->data(); + len = bytes.size(); + } + + const uint8_t* next; + size_t len; + // TODO(veluca): consider removing the unique_ptr. + std::unique_ptr> buffer; +}; + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_serialization_state.h b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_serialization_state.h new file mode 100644 index 000000000..a25c335b5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/dec_jpeg_serialization_state.h @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ +#define LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ + +#include +#include + +#include "lib/jxl/jpeg/dec_jpeg_output_chunk.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +struct HuffmanCodeTable { + int depth[256]; + int code[256]; +}; + +// Handles the packing of bits into output bytes. +struct JpegBitWriter { + bool healthy; + std::deque* output; + OutputChunk chunk; + uint8_t* data; + size_t pos; + uint64_t put_buffer; + int put_bits; +}; + +// Holds data that is buffered between 8x8 blocks in progressive mode. +struct DCTCodingState { + // The run length of end-of-band symbols in a progressive scan. + int eob_run_; + // The huffman table to be used when flushing the state. + const HuffmanCodeTable* cur_ac_huff_; + // The sequence of currently buffered refinement bits for a successive + // approximation scan (one where Ah > 0). + std::vector refinement_bits_; +}; + +struct EncodeScanState { + enum Stage { HEAD, BODY }; + + Stage stage = HEAD; + + int mcu_y; + JpegBitWriter bw; + coeff_t last_dc_coeff[kMaxComponents] = {0}; + int restarts_to_go; + int next_restart_marker; + int block_scan_index; + DCTCodingState coding_state; + size_t extra_zero_runs_pos; + int next_extra_zero_run_index; + size_t next_reset_point_pos; + int next_reset_point; +}; + +struct SerializationState { + enum Stage { + INIT, + SERIALIZE_SECTION, + DONE, + ERROR, + }; + + Stage stage = INIT; + + std::deque output_queue; + + size_t section_index = 0; + int dht_index = 0; + int dqt_index = 0; + int app_index = 0; + int com_index = 0; + int data_index = 0; + int scan_index = 0; + std::vector dc_huff_table; + std::vector ac_huff_table; + const uint8_t* pad_bits = nullptr; + const uint8_t* pad_bits_end = nullptr; + bool seen_dri_marker = false; + bool is_progressive = false; + + EncodeScanState scan_state; +}; + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data.cc b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data.cc new file mode 100644 index 000000000..0f625d8e4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data.cc @@ -0,0 +1,381 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/enc_jpeg_data.h" + +#include +#include + +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { +namespace jpeg { + +namespace { + +constexpr int BITS_IN_JSAMPLE = 8; +using ByteSpan = Span; + +// TODO(eustas): move to jpeg_data, to use from codec_jpg as well. +// See if there is a canonically chunked ICC profile and mark corresponding +// app-tags with AppMarkerType::kICC. +Status DetectIccProfile(JPEGData& jpeg_data) { + JXL_DASSERT(jpeg_data.app_data.size() == jpeg_data.app_marker_type.size()); + size_t num_icc = 0; + size_t num_icc_jpeg = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + const auto& app = jpeg_data.app_data[i]; + size_t pos = 0; + if (app[pos++] != 0xE2) continue; + // At least APPn + size; otherwise it should be intermarker-data. + JXL_DASSERT(app.size() >= 3); + size_t tag_length = (app[pos] << 8) + app[pos + 1]; + pos += 2; + JXL_DASSERT(app.size() == tag_length + 1); + // Empty payload is 2 bytes for tag length itself + signature + if (tag_length < 2 + sizeof kIccProfileTag) continue; + + if (memcmp(&app[pos], kIccProfileTag, sizeof kIccProfileTag) != 0) continue; + pos += sizeof kIccProfileTag; + uint8_t chunk_id = app[pos++]; + uint8_t num_chunks = app[pos++]; + if (chunk_id != num_icc + 1) continue; + if (num_icc_jpeg == 0) num_icc_jpeg = num_chunks; + if (num_icc_jpeg != num_chunks) continue; + num_icc++; + jpeg_data.app_marker_type[i] = AppMarkerType::kICC; + } + if (num_icc != num_icc_jpeg) { + return JXL_FAILURE("Invalid ICC chunks"); + } + return true; +} + +bool GetMarkerPayload(const uint8_t* data, size_t size, ByteSpan* payload) { + if (size < 3) { + return false; + } + size_t hi = data[1]; + size_t lo = data[2]; + size_t internal_size = (hi << 8u) | lo; + // Second byte of marker is not counted towards size. + if (internal_size != size - 1) { + return false; + } + // cut second marker byte and "length" from payload. + *payload = ByteSpan(data, size); + payload->remove_prefix(3); + return true; +} + +Status DetectBlobs(jpeg::JPEGData& jpeg_data) { + JXL_DASSERT(jpeg_data.app_data.size() == jpeg_data.app_marker_type.size()); + bool have_exif = false, have_xmp = false; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + auto& marker = jpeg_data.app_data[i]; + if (marker.empty() || marker[0] != kApp1) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if (!have_exif && payload.size() >= sizeof kExifTag && + !memcmp(payload.data(), kExifTag, sizeof kExifTag)) { + jpeg_data.app_marker_type[i] = AppMarkerType::kExif; + have_exif = true; + } + if (!have_xmp && payload.size() >= sizeof kXMPTag && + !memcmp(payload.data(), kXMPTag, sizeof kXMPTag)) { + jpeg_data.app_marker_type[i] = AppMarkerType::kXMP; + have_xmp = true; + } + } + return true; +} + +Status ParseChunkedMarker(const jpeg::JPEGData& src, uint8_t marker_type, + const ByteSpan& tag, PaddedBytes* output, + bool allow_permutations = false) { + output->clear(); + + std::vector chunks; + std::vector presence; + size_t expected_number_of_parts = 0; + bool is_first_chunk = true; + size_t ordinal = 0; + for (const auto& marker : src.app_data) { + if (marker.empty() || marker[0] != marker_type) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if ((payload.size() < tag.size()) || + memcmp(payload.data(), tag.data(), tag.size()) != 0) { + continue; + } + payload.remove_prefix(tag.size()); + if (payload.size() < 2) { + return JXL_FAILURE("Chunk is too small."); + } + uint8_t index = payload[0]; + uint8_t total = payload[1]; + ordinal++; + if (!allow_permutations) { + if (index != ordinal) return JXL_FAILURE("Invalid chunk order."); + } + + payload.remove_prefix(2); + + JXL_RETURN_IF_ERROR(total != 0); + if (is_first_chunk) { + is_first_chunk = false; + expected_number_of_parts = total; + // 1-based indices; 0-th element is added for convenience. + chunks.resize(total + 1); + presence.resize(total + 1); + } else { + JXL_RETURN_IF_ERROR(expected_number_of_parts == total); + } + + if (index == 0 || index > total) { + return JXL_FAILURE("Invalid chunk index."); + } + + if (presence[index]) { + return JXL_FAILURE("Duplicate chunk."); + } + presence[index] = true; + chunks[index] = payload; + } + + for (size_t i = 0; i < expected_number_of_parts; ++i) { + // 0-th element is not used. + size_t index = i + 1; + if (!presence[index]) { + return JXL_FAILURE("Missing chunk."); + } + output->append(chunks[index]); + } + + return true; +} + +Status SetBlobsFromJpegData(const jpeg::JPEGData& jpeg_data, Blobs* blobs) { + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + auto& marker = jpeg_data.app_data[i]; + if (marker.empty() || marker[0] != kApp1) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if (payload.size() >= sizeof kExifTag && + !memcmp(payload.data(), kExifTag, sizeof kExifTag)) { + if (blobs->exif.empty()) { + blobs->exif.resize(payload.size() - sizeof kExifTag); + memcpy(blobs->exif.data(), payload.data() + sizeof kExifTag, + payload.size() - sizeof kExifTag); + } else { + JXL_WARNING( + "ReJPEG: multiple Exif blobs, storing only first one in the JPEG " + "XL container\n"); + } + } + if (payload.size() >= sizeof kXMPTag && + !memcmp(payload.data(), kXMPTag, sizeof kXMPTag)) { + if (blobs->xmp.empty()) { + blobs->xmp.resize(payload.size() - sizeof kXMPTag); + memcpy(blobs->xmp.data(), payload.data() + sizeof kXMPTag, + payload.size() - sizeof kXMPTag); + } else { + JXL_WARNING( + "ReJPEG: multiple XMP blobs, storing only first one in the JPEG " + "XL container\n"); + } + } + } + return true; +} + +static inline bool IsJPG(const Span bytes) { + return bytes.size() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xD8; +} + +} // namespace + +Status SetColorEncodingFromJpegData(const jpeg::JPEGData& jpg, + ColorEncoding* color_encoding) { + PaddedBytes icc_profile; + if (!ParseChunkedMarker(jpg, kApp2, ByteSpan(kIccProfileTag), &icc_profile)) { + JXL_WARNING("ReJPEG: corrupted ICC profile\n"); + icc_profile.clear(); + } + + if (icc_profile.empty()) { + bool is_gray = (jpg.components.size() == 1); + *color_encoding = ColorEncoding::SRGB(is_gray); + return true; + } + + return color_encoding->SetICC(std::move(icc_profile)); +} + +Status EncodeJPEGData(JPEGData& jpeg_data, PaddedBytes* bytes, + const CompressParams& cparams) { + jpeg_data.app_marker_type.resize(jpeg_data.app_data.size(), + AppMarkerType::kUnknown); + JXL_RETURN_IF_ERROR(DetectIccProfile(jpeg_data)); + JXL_RETURN_IF_ERROR(DetectBlobs(jpeg_data)); + BitWriter writer; + JXL_RETURN_IF_ERROR(Bundle::Write(jpeg_data, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + *bytes = std::move(writer).TakeBytes(); + BrotliEncoderState* brotli_enc = + BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + int effort = cparams.brotli_effort; + if (effort < 0) effort = 11 - static_cast(cparams.speed_tier); + BrotliEncoderSetParameter(brotli_enc, BROTLI_PARAM_QUALITY, effort); + size_t total_data = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + if (jpeg_data.app_marker_type[i] != AppMarkerType::kUnknown) { + continue; + } + total_data += jpeg_data.app_data[i].size(); + } + for (size_t i = 0; i < jpeg_data.com_data.size(); i++) { + total_data += jpeg_data.com_data[i].size(); + } + for (size_t i = 0; i < jpeg_data.inter_marker_data.size(); i++) { + total_data += jpeg_data.inter_marker_data[i].size(); + } + total_data += jpeg_data.tail_data.size(); + size_t initial_size = bytes->size(); + size_t brotli_capacity = BrotliEncoderMaxCompressedSize(total_data); + BrotliEncoderSetParameter(brotli_enc, BROTLI_PARAM_SIZE_HINT, total_data); + bytes->resize(bytes->size() + brotli_capacity); + size_t enc_size = 0; + auto br_append = [&](const std::vector& data, bool last) { + size_t available_in = data.size(); + const uint8_t* in = data.data(); + uint8_t* out = &(*bytes)[initial_size + enc_size]; + do { + uint8_t* out_before = out; + msan::MemoryIsInitialized(in, available_in); + JXL_CHECK(BrotliEncoderCompressStream( + brotli_enc, last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS, + &available_in, &in, &brotli_capacity, &out, &enc_size)); + msan::UnpoisonMemory(out_before, out - out_before); + } while (BrotliEncoderHasMoreOutput(brotli_enc) || available_in > 0); + }; + + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + if (jpeg_data.app_marker_type[i] != AppMarkerType::kUnknown) { + continue; + } + br_append(jpeg_data.app_data[i], /*last=*/false); + } + for (size_t i = 0; i < jpeg_data.com_data.size(); i++) { + br_append(jpeg_data.com_data[i], /*last=*/false); + } + for (size_t i = 0; i < jpeg_data.inter_marker_data.size(); i++) { + br_append(jpeg_data.inter_marker_data[i], /*last=*/false); + } + br_append(jpeg_data.tail_data, /*last=*/true); + BrotliEncoderDestroyInstance(brotli_enc); + bytes->resize(initial_size + enc_size); + return true; +} + +Status DecodeImageJPG(const Span bytes, CodecInOut* io) { + if (!IsJPG(bytes)) return false; + io->frames.clear(); + io->frames.reserve(1); + io->frames.emplace_back(&io->metadata.m); + io->Main().jpeg_data = make_unique(); + jpeg::JPEGData* jpeg_data = io->Main().jpeg_data.get(); + if (!jpeg::ReadJpeg(bytes.data(), bytes.size(), jpeg::JpegReadMode::kReadAll, + jpeg_data)) { + return JXL_FAILURE("Error reading JPEG"); + } + JXL_RETURN_IF_ERROR( + SetColorEncodingFromJpegData(*jpeg_data, &io->metadata.m.color_encoding)); + JXL_RETURN_IF_ERROR(SetBlobsFromJpegData(*jpeg_data, &io->blobs)); + size_t nbcomp = jpeg_data->components.size(); + if (nbcomp != 1 && nbcomp != 3) { + return JXL_FAILURE("Cannot recompress JPEGs with neither 1 nor 3 channels"); + } + YCbCrChromaSubsampling cs; + if (nbcomp == 3) { + uint8_t hsample[3], vsample[3]; + for (size_t i = 0; i < nbcomp; i++) { + hsample[i] = jpeg_data->components[i].h_samp_factor; + vsample[i] = jpeg_data->components[i].v_samp_factor; + } + JXL_RETURN_IF_ERROR(cs.Set(hsample, vsample)); + } else if (nbcomp == 1) { + uint8_t hsample[3], vsample[3]; + for (size_t i = 0; i < 3; i++) { + hsample[i] = jpeg_data->components[0].h_samp_factor; + vsample[i] = jpeg_data->components[0].v_samp_factor; + } + JXL_RETURN_IF_ERROR(cs.Set(hsample, vsample)); + } + bool is_rgb = false; + { + const auto& markers = jpeg_data->marker_order; + // If there is a JFIF marker, this is YCbCr. Otherwise... + if (std::find(markers.begin(), markers.end(), 0xE0) == markers.end()) { + // Try to find an 'Adobe' marker. + size_t app_markers = 0; + size_t i = 0; + for (; i < markers.size(); i++) { + // This is an APP marker. + if ((markers[i] & 0xF0) == 0xE0) { + JXL_CHECK(app_markers < jpeg_data->app_data.size()); + // APP14 marker + if (markers[i] == 0xEE) { + const auto& data = jpeg_data->app_data[app_markers]; + if (data.size() == 15 && data[3] == 'A' && data[4] == 'd' && + data[5] == 'o' && data[6] == 'b' && data[7] == 'e') { + // 'Adobe' marker. + is_rgb = data[14] == 0; + break; + } + } + app_markers++; + } + } + + if (i == markers.size()) { + // No 'Adobe' marker, guess from component IDs. + is_rgb = nbcomp == 3 && jpeg_data->components[0].id == 'R' && + jpeg_data->components[1].id == 'G' && + jpeg_data->components[2].id == 'B'; + } + } + } + + io->Main().chroma_subsampling = cs; + io->Main().color_transform = + (!is_rgb || nbcomp == 1) ? ColorTransform::kYCbCr : ColorTransform::kNone; + + io->metadata.m.SetIntensityTarget(kDefaultIntensityTarget); + io->metadata.m.SetUintSamples(BITS_IN_JSAMPLE); + io->SetFromImage(Image3F(jpeg_data->width, jpeg_data->height), + io->metadata.m.color_encoding); + SetIntensityTarget(io); + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data.h b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data.h new file mode 100644 index 000000000..806128c46 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_H_ +#define LIB_JXL_JPEG_ENC_JPEG_DATA_H_ + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { +Status EncodeJPEGData(JPEGData& jpeg_data, PaddedBytes* bytes, + const CompressParams& cparams); + +Status SetColorEncodingFromJpegData(const jpeg::JPEGData& jpg, + ColorEncoding* color_encoding); + +/** + * Decodes bytes containing JPEG codestream into a CodecInOut as coefficients + * only, for lossless JPEG transcoding. + */ +Status DecodeImageJPG(Span bytes, CodecInOut* io); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_DATA_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data_reader.cc b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data_reader.cc new file mode 100644 index 000000000..4a6c1de47 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data_reader.cc @@ -0,0 +1,1147 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/jpeg/enc_jpeg_huffman_decode.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +// By default only print debug messages when JXL_DEBUG_ON_ERROR is enabled. +#ifndef JXL_DEBUG_JPEG_DATA_READER +#define JXL_DEBUG_JPEG_DATA_READER JXL_DEBUG_ON_ERROR +#endif // JXL_DEBUG_JPEG_DATA_READER + +#define JXL_JPEG_DEBUG(format, ...) \ + JXL_DEBUG(JXL_DEBUG_JPEG_DATA_READER, format, ##__VA_ARGS__) + +namespace jxl { +namespace jpeg { + +namespace { +static const int kBrunsliMaxSampling = 15; +static const size_t kBrunsliMaxNumBlocks = 1ull << 24; + +// Macros for commonly used error conditions. + +#define JXL_JPEG_VERIFY_LEN(n) \ + if (*pos + (n) > len) { \ + JXL_JPEG_DEBUG("Unexpected end of input: pos=%" PRIuS \ + " need=%d len=%" PRIuS, \ + *pos, static_cast(n), len); \ + jpg->error = JPEGReadError::UNEXPECTED_EOF; \ + return false; \ + } + +#define JXL_JPEG_VERIFY_INPUT(var, low, high, code) \ + if ((var) < (low) || (var) > (high)) { \ + JXL_JPEG_DEBUG("Invalid " #var ": %d", static_cast(var)); \ + jpg->error = JPEGReadError::INVALID_##code; \ + return false; \ + } + +#define JXL_JPEG_VERIFY_MARKER_END() \ + if (start_pos + marker_len != *pos) { \ + JXL_JPEG_DEBUG("Invalid marker length: declared=%" PRIuS \ + " actual=%" PRIuS, \ + marker_len, (*pos - start_pos)); \ + jpg->error = JPEGReadError::WRONG_MARKER_SIZE; \ + return false; \ + } + +#define JXL_JPEG_EXPECT_MARKER() \ + if (pos + 2 > len || data[pos] != 0xff) { \ + JXL_JPEG_DEBUG("Marker byte (0xff) expected, found: 0x%.2x pos=%" PRIuS \ + " len=%" PRIuS, \ + (pos < len ? data[pos] : 0), pos, len); \ + jpg->error = JPEGReadError::MARKER_BYTE_NOT_FOUND; \ + return false; \ + } + +inline int ReadUint8(const uint8_t* data, size_t* pos) { + return data[(*pos)++]; +} + +inline int ReadUint16(const uint8_t* data, size_t* pos) { + int v = (data[*pos] << 8) + data[*pos + 1]; + *pos += 2; + return v; +} + +// Reads the Start of Frame (SOF) marker segment and fills in *jpg with the +// parsed data. +bool ProcessSOF(const uint8_t* data, const size_t len, JpegReadMode mode, + size_t* pos, JPEGData* jpg) { + if (jpg->width != 0) { + JXL_JPEG_DEBUG("Duplicate SOF marker."); + jpg->error = JPEGReadError::DUPLICATE_SOF; + return false; + } + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(8); + size_t marker_len = ReadUint16(data, pos); + int precision = ReadUint8(data, pos); + int height = ReadUint16(data, pos); + int width = ReadUint16(data, pos); + int num_components = ReadUint8(data, pos); + // 'jbrd' is hardcoded for 8bits: + JXL_JPEG_VERIFY_INPUT(precision, 8, 8, PRECISION); + JXL_JPEG_VERIFY_INPUT(height, 1, kMaxDimPixels, HEIGHT); + JXL_JPEG_VERIFY_INPUT(width, 1, kMaxDimPixels, WIDTH); + JXL_JPEG_VERIFY_INPUT(num_components, 1, kMaxComponents, NUMCOMP); + JXL_JPEG_VERIFY_LEN(3 * num_components); + jpg->height = height; + jpg->width = width; + jpg->components.resize(num_components); + + // Read sampling factors and quant table index for each component. + std::vector ids_seen(256, false); + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (size_t i = 0; i < jpg->components.size(); ++i) { + const int id = ReadUint8(data, pos); + if (ids_seen[id]) { // (cf. section B.2.2, syntax of Ci) + JXL_JPEG_DEBUG("Duplicate ID %d in SOF.", id); + jpg->error = JPEGReadError::DUPLICATE_COMPONENT_ID; + return false; + } + ids_seen[id] = true; + jpg->components[i].id = id; + int factor = ReadUint8(data, pos); + int h_samp_factor = factor >> 4; + int v_samp_factor = factor & 0xf; + JXL_JPEG_VERIFY_INPUT(h_samp_factor, 1, kBrunsliMaxSampling, SAMP_FACTOR); + JXL_JPEG_VERIFY_INPUT(v_samp_factor, 1, kBrunsliMaxSampling, SAMP_FACTOR); + jpg->components[i].h_samp_factor = h_samp_factor; + jpg->components[i].v_samp_factor = v_samp_factor; + jpg->components[i].quant_idx = ReadUint8(data, pos); + max_h_samp_factor = std::max(max_h_samp_factor, h_samp_factor); + max_v_samp_factor = std::max(max_v_samp_factor, v_samp_factor); + } + + // We have checked above that none of the sampling factors are 0, so the max + // sampling factors can not be 0. + int MCU_rows = DivCeil(jpg->height, max_v_samp_factor * 8); + int MCU_cols = DivCeil(jpg->width, max_h_samp_factor * 8); + // Compute the block dimensions for each component. + for (size_t i = 0; i < jpg->components.size(); ++i) { + JPEGComponent* c = &jpg->components[i]; + if (max_h_samp_factor % c->h_samp_factor != 0 || + max_v_samp_factor % c->v_samp_factor != 0) { + JXL_JPEG_DEBUG("Non-integral subsampling ratios."); + jpg->error = JPEGReadError::INVALID_SAMPLING_FACTORS; + return false; + } + c->width_in_blocks = MCU_cols * c->h_samp_factor; + c->height_in_blocks = MCU_rows * c->v_samp_factor; + const uint64_t num_blocks = + static_cast(c->width_in_blocks) * c->height_in_blocks; + if (num_blocks > kBrunsliMaxNumBlocks) { + JXL_JPEG_DEBUG("Image too large."); + jpg->error = JPEGReadError::IMAGE_TOO_LARGE; + return false; + } + if (mode == JpegReadMode::kReadAll) { + c->coeffs.resize(num_blocks * kDCTBlockSize); + } + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Start of Scan (SOS) marker segment and fills in *scan_info with the +// parsed data. +bool ProcessSOS(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(3); + size_t marker_len = ReadUint16(data, pos); + size_t comps_in_scan = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(comps_in_scan, 1, jpg->components.size(), + COMPS_IN_SCAN); + + JPEGScanInfo scan_info; + scan_info.num_components = comps_in_scan; + JXL_JPEG_VERIFY_LEN(2 * comps_in_scan); + std::vector ids_seen(256, false); + for (size_t i = 0; i < comps_in_scan; ++i) { + uint32_t id = ReadUint8(data, pos); + if (ids_seen[id]) { // (cf. section B.2.3, regarding CSj) + JXL_JPEG_DEBUG("Duplicate ID %d in SOS.", id); + jpg->error = JPEGReadError::DUPLICATE_COMPONENT_ID; + return false; + } + ids_seen[id] = true; + bool found_index = false; + for (size_t j = 0; j < jpg->components.size(); ++j) { + if (jpg->components[j].id == id) { + scan_info.components[i].comp_idx = j; + found_index = true; + } + } + if (!found_index) { + JXL_JPEG_DEBUG("SOS marker: Could not find component with id %d", id); + jpg->error = JPEGReadError::COMPONENT_NOT_FOUND; + return false; + } + int c = ReadUint8(data, pos); + int dc_tbl_idx = c >> 4; + int ac_tbl_idx = c & 0xf; + JXL_JPEG_VERIFY_INPUT(dc_tbl_idx, 0, 3, HUFFMAN_INDEX); + JXL_JPEG_VERIFY_INPUT(ac_tbl_idx, 0, 3, HUFFMAN_INDEX); + scan_info.components[i].dc_tbl_idx = dc_tbl_idx; + scan_info.components[i].ac_tbl_idx = ac_tbl_idx; + } + JXL_JPEG_VERIFY_LEN(3); + scan_info.Ss = ReadUint8(data, pos); + scan_info.Se = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(static_cast(scan_info.Ss), 0, 63, START_OF_SCAN); + JXL_JPEG_VERIFY_INPUT(scan_info.Se, scan_info.Ss, 63, END_OF_SCAN); + int c = ReadUint8(data, pos); + scan_info.Ah = c >> 4; + scan_info.Al = c & 0xf; + if (scan_info.Ah != 0 && scan_info.Al != scan_info.Ah - 1) { + // section G.1.1.1.2 : Successive approximation control only improves + // by one bit at a time. But it's not always respected, so we just issue + // a warning. + JXL_WARNING("Invalid progressive parameters: Al=%d Ah=%d", scan_info.Al, + scan_info.Ah); + } + // Check that all the Huffman tables needed for this scan are defined. + for (size_t i = 0; i < comps_in_scan; ++i) { + bool found_dc_table = false; + bool found_ac_table = false; + for (size_t j = 0; j < jpg->huffman_code.size(); ++j) { + uint32_t slot_id = jpg->huffman_code[j].slot_id; + if (slot_id == scan_info.components[i].dc_tbl_idx) { + found_dc_table = true; + } else if (slot_id == scan_info.components[i].ac_tbl_idx + 16) { + found_ac_table = true; + } + } + if (scan_info.Ss == 0 && !found_dc_table) { + JXL_JPEG_DEBUG( + "SOS marker: Could not find DC Huffman table with index %d", + scan_info.components[i].dc_tbl_idx); + jpg->error = JPEGReadError::HUFFMAN_TABLE_NOT_FOUND; + return false; + } + if (scan_info.Se > 0 && !found_ac_table) { + JXL_JPEG_DEBUG( + "SOS marker: Could not find AC Huffman table with index %d", + scan_info.components[i].ac_tbl_idx); + jpg->error = JPEGReadError::HUFFMAN_TABLE_NOT_FOUND; + return false; + } + } + jpg->scan_info.push_back(scan_info); + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Define Huffman Table (DHT) marker segment and fills in *jpg with +// the parsed data. Builds the Huffman decoding table in either dc_huff_lut or +// ac_huff_lut, depending on the type and solt_id of Huffman code being read. +bool ProcessDHT(const uint8_t* data, const size_t len, JpegReadMode mode, + std::vector* dc_huff_lut, + std::vector* ac_huff_lut, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + if (marker_len == 2) { + JXL_JPEG_DEBUG("DHT marker: no Huffman table found"); + jpg->error = JPEGReadError::EMPTY_DHT; + return false; + } + while (*pos < start_pos + marker_len) { + JXL_JPEG_VERIFY_LEN(1 + kJpegHuffmanMaxBitLength); + JPEGHuffmanCode huff; + huff.slot_id = ReadUint8(data, pos); + int huffman_index = huff.slot_id; + int is_ac_table = (huff.slot_id & 0x10) != 0; + HuffmanTableEntry* huff_lut; + if (is_ac_table) { + huffman_index -= 0x10; + JXL_JPEG_VERIFY_INPUT(huffman_index, 0, 3, HUFFMAN_INDEX); + huff_lut = &(*ac_huff_lut)[huffman_index * kJpegHuffmanLutSize]; + } else { + JXL_JPEG_VERIFY_INPUT(huffman_index, 0, 3, HUFFMAN_INDEX); + huff_lut = &(*dc_huff_lut)[huffman_index * kJpegHuffmanLutSize]; + } + huff.counts[0] = 0; + int total_count = 0; + int space = 1 << kJpegHuffmanMaxBitLength; + int max_depth = 1; + for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) { + int count = ReadUint8(data, pos); + if (count != 0) { + max_depth = i; + } + huff.counts[i] = count; + total_count += count; + space -= count * (1 << (kJpegHuffmanMaxBitLength - i)); + } + if (is_ac_table) { + JXL_JPEG_VERIFY_INPUT(total_count, 0, kJpegHuffmanAlphabetSize, + HUFFMAN_CODE); + } else { + JXL_JPEG_VERIFY_INPUT(total_count, 0, kJpegDCAlphabetSize, HUFFMAN_CODE); + } + JXL_JPEG_VERIFY_LEN(total_count); + std::vector values_seen(256, false); + for (int i = 0; i < total_count; ++i) { + int value = ReadUint8(data, pos); + if (!is_ac_table) { + JXL_JPEG_VERIFY_INPUT(value, 0, kJpegDCAlphabetSize - 1, HUFFMAN_CODE); + } + if (values_seen[value]) { + JXL_JPEG_DEBUG("Duplicate Huffman code value %d", value); + jpg->error = JPEGReadError::INVALID_HUFFMAN_CODE; + return false; + } + values_seen[value] = true; + huff.values[i] = value; + } + // Add an invalid symbol that will have the all 1 code. + ++huff.counts[max_depth]; + huff.values[total_count] = kJpegHuffmanAlphabetSize; + space -= (1 << (kJpegHuffmanMaxBitLength - max_depth)); + if (space < 0) { + JXL_JPEG_DEBUG("Invalid Huffman code lengths."); + jpg->error = JPEGReadError::INVALID_HUFFMAN_CODE; + return false; + } else if (space > 0 && huff_lut[0].value != 0xffff) { + // Re-initialize the values to an invalid symbol so that we can recognize + // it when reading the bit stream using a Huffman code with space > 0. + for (int i = 0; i < kJpegHuffmanLutSize; ++i) { + huff_lut[i].bits = 0; + huff_lut[i].value = 0xffff; + } + } + huff.is_last = (*pos == start_pos + marker_len); + if (mode == JpegReadMode::kReadAll) { + BuildJpegHuffmanTable(&huff.counts[0], &huff.values[0], huff_lut); + } + jpg->huffman_code.push_back(huff); + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Define Quantization Table (DQT) marker segment and fills in *jpg +// with the parsed data. +bool ProcessDQT(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + if (marker_len == 2) { + JXL_JPEG_DEBUG("DQT marker: no quantization table found"); + jpg->error = JPEGReadError::EMPTY_DQT; + return false; + } + while (*pos < start_pos + marker_len && jpg->quant.size() < kMaxQuantTables) { + JXL_JPEG_VERIFY_LEN(1); + int quant_table_index = ReadUint8(data, pos); + int quant_table_precision = quant_table_index >> 4; + JXL_JPEG_VERIFY_INPUT(quant_table_precision, 0, 1, QUANT_TBL_PRECISION); + quant_table_index &= 0xf; + JXL_JPEG_VERIFY_INPUT(quant_table_index, 0, 3, QUANT_TBL_INDEX); + JXL_JPEG_VERIFY_LEN((quant_table_precision + 1) * kDCTBlockSize); + JPEGQuantTable table; + table.index = quant_table_index; + table.precision = quant_table_precision; + for (size_t i = 0; i < kDCTBlockSize; ++i) { + int quant_val = + quant_table_precision ? ReadUint16(data, pos) : ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(quant_val, 1, 65535, QUANT_VAL); + table.values[kJPEGNaturalOrder[i]] = quant_val; + } + table.is_last = (*pos == start_pos + marker_len); + jpg->quant.push_back(table); + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the DRI marker and saves the restart interval into *jpg. +bool ProcessDRI(const uint8_t* data, const size_t len, size_t* pos, + bool* found_dri, JPEGData* jpg) { + if (*found_dri) { + JXL_JPEG_DEBUG("Duplicate DRI marker."); + jpg->error = JPEGReadError::DUPLICATE_DRI; + return false; + } + *found_dri = true; + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(4); + size_t marker_len = ReadUint16(data, pos); + int restart_interval = ReadUint16(data, pos); + jpg->restart_interval = restart_interval; + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Saves the APP marker segment as a string to *jpg. +bool ProcessAPP(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + JXL_JPEG_VERIFY_INPUT(marker_len, 2, 65535, MARKER_LEN); + JXL_JPEG_VERIFY_LEN(marker_len - 2); + JXL_DASSERT(*pos >= 3); + // Save the marker type together with the app data. + const uint8_t* app_str_start = data + *pos - 3; + std::vector app_str(app_str_start, app_str_start + marker_len + 1); + *pos += marker_len - 2; + jpg->app_data.push_back(app_str); + return true; +} + +// Saves the COM marker segment as a string to *jpg. +bool ProcessCOM(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + JXL_JPEG_VERIFY_INPUT(marker_len, 2, 65535, MARKER_LEN); + JXL_JPEG_VERIFY_LEN(marker_len - 2); + const uint8_t* com_str_start = data + *pos - 3; + std::vector com_str(com_str_start, com_str_start + marker_len + 1); + *pos += marker_len - 2; + jpg->com_data.push_back(com_str); + return true; +} + +// Helper structure to read bits from the entropy coded data segment. +struct BitReaderState { + BitReaderState(const uint8_t* data, const size_t len, size_t pos) + : data_(data), len_(len) { + Reset(pos); + } + + void Reset(size_t pos) { + pos_ = pos; + val_ = 0; + bits_left_ = 0; + next_marker_pos_ = len_ - 2; + FillBitWindow(); + } + + // Returns the next byte and skips the 0xff/0x00 escape sequences. + uint8_t GetNextByte() { + if (pos_ >= next_marker_pos_) { + ++pos_; + return 0; + } + uint8_t c = data_[pos_++]; + if (c == 0xff) { + uint8_t escape = data_[pos_]; + if (escape == 0) { + ++pos_; + } else { + // 0xff was followed by a non-zero byte, which means that we found the + // start of the next marker segment. + next_marker_pos_ = pos_ - 1; + } + } + return c; + } + + void FillBitWindow() { + if (bits_left_ <= 16) { + while (bits_left_ <= 56) { + val_ <<= 8; + val_ |= (uint64_t)GetNextByte(); + bits_left_ += 8; + } + } + } + + int ReadBits(int nbits) { + FillBitWindow(); + uint64_t val = (val_ >> (bits_left_ - nbits)) & ((1ULL << nbits) - 1); + bits_left_ -= nbits; + return val; + } + + // Sets *pos to the next stream position where parsing should continue. + // Enqueue the padding bits seen (0 or 1). + // Returns false if there is inconsistent or invalid padding or the stream + // ended too early. + bool FinishStream(JPEGData* jpg, size_t* pos) { + int npadbits = bits_left_ & 7; + if (npadbits > 0) { + uint64_t padmask = (1ULL << npadbits) - 1; + uint64_t padbits = (val_ >> (bits_left_ - npadbits)) & padmask; + if (padbits != padmask) { + jpg->has_zero_padding_bit = true; + } + for (int i = npadbits - 1; i >= 0; --i) { + jpg->padding_bits.push_back((padbits >> i) & 1); + } + } + // Give back some bytes that we did not use. + int unused_bytes_left = bits_left_ >> 3; + while (unused_bytes_left-- > 0) { + --pos_; + // If we give back a 0 byte, we need to check if it was a 0xff/0x00 escape + // sequence, and if yes, we need to give back one more byte. + if (pos_ < next_marker_pos_ && data_[pos_] == 0 && + data_[pos_ - 1] == 0xff) { + --pos_; + } + } + if (pos_ > next_marker_pos_) { + // Data ran out before the scan was complete. + JXL_JPEG_DEBUG("Unexpected end of scan."); + return false; + } + *pos = pos_; + return true; + } + + const uint8_t* data_; + const size_t len_; + size_t pos_; + uint64_t val_; + int bits_left_; + size_t next_marker_pos_; +}; + +// Returns the next Huffman-coded symbol. +int ReadSymbol(const HuffmanTableEntry* table, BitReaderState* br) { + int nbits; + br->FillBitWindow(); + int val = (br->val_ >> (br->bits_left_ - 8)) & 0xff; + table += val; + nbits = table->bits - 8; + if (nbits > 0) { + br->bits_left_ -= 8; + table += table->value; + val = (br->val_ >> (br->bits_left_ - nbits)) & ((1 << nbits) - 1); + table += val; + } + br->bits_left_ -= table->bits; + return table->value; +} + +/** + * Returns the DC diff or AC value for extra bits value x and prefix code s. + * + * CCITT Rec. T.81 (1992 E) + * Table F.1 – Difference magnitude categories for DC coding + * SSSS | DIFF values + * ------+-------------------------- + * 0 | 0 + * 1 | –1, 1 + * 2 | –3, –2, 2, 3 + * 3 | –7..–4, 4..7 + * ......|.......................... + * 11 | –2047..–1024, 1024..2047 + * + * CCITT Rec. T.81 (1992 E) + * Table F.2 – Categories assigned to coefficient values + * [ Same as Table F.1, but does not include SSSS equal to 0 and 11] + * + * + * CCITT Rec. T.81 (1992 E) + * F.1.2.1.1 Structure of DC code table + * For each category,... additional bits... appended... to uniquely identify + * which difference... occurred... When DIFF is positive... SSSS... bits of DIFF + * are appended. When DIFF is negative... SSSS... bits of (DIFF – 1) are + * appended... Most significant bit... is 0 for negative differences and 1 for + * positive differences. + * + * In other words the upper half of extra bits range represents DIFF as is. + * The lower half represents the negative DIFFs with an offset. + */ +int HuffExtend(int x, int s) { + JXL_DASSERT(s >= 1); + int half = 1 << (s - 1); + if (x >= half) { + JXL_DASSERT(x < (1 << s)); + return x; + } else { + return x - (1 << s) + 1; + } +} + +// Decodes one 8x8 block of DCT coefficients from the bit stream. +bool DecodeDCTBlock(const HuffmanTableEntry* dc_huff, + const HuffmanTableEntry* ac_huff, int Ss, int Se, int Al, + int* eobrun, bool* reset_state, int* num_zero_runs, + BitReaderState* br, JPEGData* jpg, coeff_t* last_dc_coeff, + coeff_t* coeffs) { + // Nowadays multiplication is even faster than variable shift. + int Am = 1 << Al; + bool eobrun_allowed = Ss > 0; + if (Ss == 0) { + int s = ReadSymbol(dc_huff, br); + if (s >= kJpegDCAlphabetSize) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for DC coefficient.", s); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + int diff = 0; + if (s > 0) { + int bits = br->ReadBits(s); + diff = HuffExtend(bits, s); + } + int coeff = diff + *last_dc_coeff; + const int dc_coeff = coeff * Am; + coeffs[0] = dc_coeff; + // TODO(eustas): is there a more elegant / explicit way to check this? + if (dc_coeff != coeffs[0]) { + JXL_JPEG_DEBUG("Invalid DC coefficient %d", dc_coeff); + jpg->error = JPEGReadError::NON_REPRESENTABLE_DC_COEFF; + return false; + } + *last_dc_coeff = coeff; + ++Ss; + } + if (Ss > Se) { + return true; + } + if (*eobrun > 0) { + --(*eobrun); + return true; + } + *num_zero_runs = 0; + for (int k = Ss; k <= Se; k++) { + int sr = ReadSymbol(ac_huff, br); + if (sr >= kJpegHuffmanAlphabetSize) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for AC coefficient %d", sr, k); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + int r = sr >> 4; + int s = sr & 15; + if (s > 0) { + k += r; + if (k > Se) { + JXL_JPEG_DEBUG("Out-of-band coefficient %d band was %d-%d", k, Ss, Se); + jpg->error = JPEGReadError::OUT_OF_BAND_COEFF; + return false; + } + if (s + Al >= kJpegDCAlphabetSize) { + JXL_JPEG_DEBUG( + "Out of range AC coefficient value: s = %d Al = %d k = %d", s, Al, + k); + jpg->error = JPEGReadError::NON_REPRESENTABLE_AC_COEFF; + return false; + } + int bits = br->ReadBits(s); + int coeff = HuffExtend(bits, s); + coeffs[kJPEGNaturalOrder[k]] = coeff * Am; + *num_zero_runs = 0; + } else if (r == 15) { + k += 15; + ++(*num_zero_runs); + } else { + if (eobrun_allowed && k == Ss && *eobrun == 0) { + // We have two end-of-block runs right after each other, so we signal + // the jpeg encoder to force a state reset at this point. + *reset_state = true; + } + *eobrun = 1 << r; + if (r > 0) { + if (!eobrun_allowed) { + JXL_JPEG_DEBUG("End-of-block run crossing DC coeff."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + *eobrun += br->ReadBits(r); + } + break; + } + } + --(*eobrun); + return true; +} + +bool RefineDCTBlock(const HuffmanTableEntry* ac_huff, int Ss, int Se, int Al, + int* eobrun, bool* reset_state, BitReaderState* br, + JPEGData* jpg, coeff_t* coeffs) { + // Nowadays multiplication is even faster than variable shift. + int Am = 1 << Al; + bool eobrun_allowed = Ss > 0; + if (Ss == 0) { + int s = br->ReadBits(1); + coeff_t dc_coeff = coeffs[0]; + dc_coeff |= s * Am; + coeffs[0] = dc_coeff; + ++Ss; + } + if (Ss > Se) { + return true; + } + int p1 = Am; + int m1 = -Am; + int k = Ss; + int r; + int s; + bool in_zero_run = false; + if (*eobrun <= 0) { + for (; k <= Se; k++) { + s = ReadSymbol(ac_huff, br); + if (s >= kJpegHuffmanAlphabetSize) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for AC coefficient %d", s, k); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + r = s >> 4; + s &= 15; + if (s) { + if (s != 1) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for AC coefficient %d", s, + k); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + s = br->ReadBits(1) ? p1 : m1; + in_zero_run = false; + } else { + if (r != 15) { + if (eobrun_allowed && k == Ss && *eobrun == 0) { + // We have two end-of-block runs right after each other, so we + // signal the jpeg encoder to force a state reset at this point. + *reset_state = true; + } + *eobrun = 1 << r; + if (r > 0) { + if (!eobrun_allowed) { + JXL_JPEG_DEBUG("End-of-block run crossing DC coeff."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + *eobrun += br->ReadBits(r); + } + break; + } + in_zero_run = true; + } + do { + coeff_t thiscoef = coeffs[kJPEGNaturalOrder[k]]; + if (thiscoef != 0) { + if (br->ReadBits(1)) { + if ((thiscoef & p1) == 0) { + if (thiscoef >= 0) { + thiscoef += p1; + } else { + thiscoef += m1; + } + } + } + coeffs[kJPEGNaturalOrder[k]] = thiscoef; + } else { + if (--r < 0) { + break; + } + } + k++; + } while (k <= Se); + if (s) { + if (k > Se) { + JXL_JPEG_DEBUG("Out-of-band coefficient %d band was %d-%d", k, Ss, + Se); + jpg->error = JPEGReadError::OUT_OF_BAND_COEFF; + return false; + } + coeffs[kJPEGNaturalOrder[k]] = s; + } + } + } + if (in_zero_run) { + JXL_JPEG_DEBUG("Extra zero run before end-of-block."); + jpg->error = JPEGReadError::EXTRA_ZERO_RUN; + return false; + } + if (*eobrun > 0) { + for (; k <= Se; k++) { + coeff_t thiscoef = coeffs[kJPEGNaturalOrder[k]]; + if (thiscoef != 0) { + if (br->ReadBits(1)) { + if ((thiscoef & p1) == 0) { + if (thiscoef >= 0) { + thiscoef += p1; + } else { + thiscoef += m1; + } + } + } + coeffs[kJPEGNaturalOrder[k]] = thiscoef; + } + } + } + --(*eobrun); + return true; +} + +bool ProcessRestart(const uint8_t* data, const size_t len, + int* next_restart_marker, BitReaderState* br, + JPEGData* jpg) { + size_t pos = 0; + if (!br->FinishStream(jpg, &pos)) { + jpg->error = JPEGReadError::INVALID_SCAN; + return false; + } + int expected_marker = 0xd0 + *next_restart_marker; + JXL_JPEG_EXPECT_MARKER(); + int marker = data[pos + 1]; + if (marker != expected_marker) { + JXL_JPEG_DEBUG("Did not find expected restart marker %d actual %d", + expected_marker, marker); + jpg->error = JPEGReadError::WRONG_RESTART_MARKER; + return false; + } + br->Reset(pos + 2); + *next_restart_marker += 1; + *next_restart_marker &= 0x7; + return true; +} + +bool ProcessScan(const uint8_t* data, const size_t len, + const std::vector& dc_huff_lut, + const std::vector& ac_huff_lut, + uint16_t scan_progression[kMaxComponents][kDCTBlockSize], + bool is_progressive, size_t* pos, JPEGData* jpg) { + if (!ProcessSOS(data, len, pos, jpg)) { + return false; + } + JPEGScanInfo* scan_info = &jpg->scan_info.back(); + bool is_interleaved = (scan_info->num_components > 1); + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (size_t i = 0; i < jpg->components.size(); ++i) { + max_h_samp_factor = + std::max(max_h_samp_factor, jpg->components[i].h_samp_factor); + max_v_samp_factor = + std::max(max_v_samp_factor, jpg->components[i].v_samp_factor); + } + + int MCU_rows = DivCeil(jpg->height, max_v_samp_factor * 8); + int MCUs_per_row = DivCeil(jpg->width, max_h_samp_factor * 8); + if (!is_interleaved) { + const JPEGComponent& c = jpg->components[scan_info->components[0].comp_idx]; + MCUs_per_row = DivCeil(jpg->width * c.h_samp_factor, 8 * max_h_samp_factor); + MCU_rows = DivCeil(jpg->height * c.v_samp_factor, 8 * max_v_samp_factor); + } + coeff_t last_dc_coeff[kMaxComponents] = {0}; + BitReaderState br(data, len, *pos); + int restarts_to_go = jpg->restart_interval; + int next_restart_marker = 0; + int eobrun = -1; + int block_scan_index = 0; + const int Al = is_progressive ? scan_info->Al : 0; + const int Ah = is_progressive ? scan_info->Ah : 0; + const int Ss = is_progressive ? scan_info->Ss : 0; + const int Se = is_progressive ? scan_info->Se : 63; + const uint16_t scan_bitmask = Ah == 0 ? (0xffff << Al) : (1u << Al); + const uint16_t refinement_bitmask = (1 << Al) - 1; + for (size_t i = 0; i < scan_info->num_components; ++i) { + int comp_idx = scan_info->components[i].comp_idx; + for (int k = Ss; k <= Se; ++k) { + if (scan_progression[comp_idx][k] & scan_bitmask) { + JXL_JPEG_DEBUG( + "Overlapping scans: component=%d k=%d prev_mask: %u cur_mask %u", + comp_idx, k, scan_progression[i][k], scan_bitmask); + jpg->error = JPEGReadError::OVERLAPPING_SCANS; + return false; + } + if (scan_progression[comp_idx][k] & refinement_bitmask) { + JXL_JPEG_DEBUG( + "Invalid scan order, a more refined scan was already done: " + "component=%d k=%d prev_mask=%u cur_mask=%u", + comp_idx, k, scan_progression[i][k], scan_bitmask); + jpg->error = JPEGReadError::INVALID_SCAN_ORDER; + return false; + } + scan_progression[comp_idx][k] |= scan_bitmask; + } + } + if (Al > 10) { + JXL_JPEG_DEBUG("Scan parameter Al=%d is not supported.", Al); + jpg->error = JPEGReadError::NON_REPRESENTABLE_AC_COEFF; + return false; + } + for (int mcu_y = 0; mcu_y < MCU_rows; ++mcu_y) { + for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) { + // Handle the restart intervals. + if (jpg->restart_interval > 0) { + if (restarts_to_go == 0) { + if (ProcessRestart(data, len, &next_restart_marker, &br, jpg)) { + restarts_to_go = jpg->restart_interval; + memset(static_cast(last_dc_coeff), 0, sizeof(last_dc_coeff)); + if (eobrun > 0) { + JXL_JPEG_DEBUG("End-of-block run too long."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + eobrun = -1; // fresh start + } else { + return false; + } + } + --restarts_to_go; + } + // Decode one MCU. + for (size_t i = 0; i < scan_info->num_components; ++i) { + JPEGComponentScanInfo* si = &scan_info->components[i]; + JPEGComponent* c = &jpg->components[si->comp_idx]; + const HuffmanTableEntry* dc_lut = + &dc_huff_lut[si->dc_tbl_idx * kJpegHuffmanLutSize]; + const HuffmanTableEntry* ac_lut = + &ac_huff_lut[si->ac_tbl_idx * kJpegHuffmanLutSize]; + int nblocks_y = is_interleaved ? c->v_samp_factor : 1; + int nblocks_x = is_interleaved ? c->h_samp_factor : 1; + for (int iy = 0; iy < nblocks_y; ++iy) { + for (int ix = 0; ix < nblocks_x; ++ix) { + int block_y = mcu_y * nblocks_y + iy; + int block_x = mcu_x * nblocks_x + ix; + int block_idx = block_y * c->width_in_blocks + block_x; + bool reset_state = false; + int num_zero_runs = 0; + coeff_t* coeffs = &c->coeffs[block_idx * kDCTBlockSize]; + if (Ah == 0) { + if (!DecodeDCTBlock(dc_lut, ac_lut, Ss, Se, Al, &eobrun, + &reset_state, &num_zero_runs, &br, jpg, + &last_dc_coeff[si->comp_idx], coeffs)) { + return false; + } + } else { + if (!RefineDCTBlock(ac_lut, Ss, Se, Al, &eobrun, &reset_state, + &br, jpg, coeffs)) { + return false; + } + } + if (reset_state) { + scan_info->reset_points.emplace_back(block_scan_index); + } + if (num_zero_runs > 0) { + JPEGScanInfo::ExtraZeroRunInfo info; + info.block_idx = block_scan_index; + info.num_extra_zero_runs = num_zero_runs; + scan_info->extra_zero_runs.push_back(info); + } + ++block_scan_index; + } + } + } + } + } + if (eobrun > 0) { + JXL_JPEG_DEBUG("End-of-block run too long."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + if (!br.FinishStream(jpg, pos)) { + jpg->error = JPEGReadError::INVALID_SCAN; + return false; + } + if (*pos > len) { + JXL_JPEG_DEBUG("Unexpected end of file during scan. pos=%" PRIuS + " len=%" PRIuS, + *pos, len); + jpg->error = JPEGReadError::UNEXPECTED_EOF; + return false; + } + return true; +} + +// Changes the quant_idx field of the components to refer to the index of the +// quant table in the jpg->quant array. +bool FixupIndexes(JPEGData* jpg) { + for (size_t i = 0; i < jpg->components.size(); ++i) { + JPEGComponent* c = &jpg->components[i]; + bool found_index = false; + for (size_t j = 0; j < jpg->quant.size(); ++j) { + if (jpg->quant[j].index == c->quant_idx) { + c->quant_idx = j; + found_index = true; + break; + } + } + if (!found_index) { + JXL_JPEG_DEBUG("Quantization table with index %u not found", + c->quant_idx); + jpg->error = JPEGReadError::QUANT_TABLE_NOT_FOUND; + return false; + } + } + return true; +} + +size_t FindNextMarker(const uint8_t* data, const size_t len, size_t pos) { + // kIsValidMarker[i] == 1 means (0xc0 + i) is a valid marker. + static const uint8_t kIsValidMarker[] = { + 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + }; + size_t num_skipped = 0; + while (pos + 1 < len && (data[pos] != 0xff || data[pos + 1] < 0xc0 || + !kIsValidMarker[data[pos + 1] - 0xc0])) { + ++pos; + ++num_skipped; + } + return num_skipped; +} + +} // namespace + +bool ReadJpeg(const uint8_t* data, const size_t len, JpegReadMode mode, + JPEGData* jpg) { + size_t pos = 0; + // Check SOI marker. + JXL_JPEG_EXPECT_MARKER(); + int marker = data[pos + 1]; + pos += 2; + if (marker != 0xd8) { + JXL_JPEG_DEBUG("Did not find expected SOI marker, actual=%d", marker); + jpg->error = JPEGReadError::SOI_NOT_FOUND; + return false; + } + int lut_size = kMaxHuffmanTables * kJpegHuffmanLutSize; + std::vector dc_huff_lut(lut_size); + std::vector ac_huff_lut(lut_size); + bool found_sof = false; + bool found_dri = false; + uint16_t scan_progression[kMaxComponents][kDCTBlockSize] = {{0}}; + + jpg->padding_bits.resize(0); + bool is_progressive = false; // default + do { + // Read next marker. + size_t num_skipped = FindNextMarker(data, len, pos); + if (num_skipped > 0) { + // Add a fake marker to indicate arbitrary in-between-markers data. + jpg->marker_order.push_back(0xff); + jpg->inter_marker_data.emplace_back(data + pos, data + pos + num_skipped); + pos += num_skipped; + } + JXL_JPEG_EXPECT_MARKER(); + marker = data[pos + 1]; + pos += 2; + bool ok = true; + switch (marker) { + case 0xc0: + case 0xc1: + case 0xc2: + is_progressive = (marker == 0xc2); + ok = ProcessSOF(data, len, mode, &pos, jpg); + found_sof = true; + break; + case 0xc4: + ok = ProcessDHT(data, len, mode, &dc_huff_lut, &ac_huff_lut, &pos, jpg); + break; + case 0xd0: + case 0xd1: + case 0xd2: + case 0xd3: + case 0xd4: + case 0xd5: + case 0xd6: + case 0xd7: + // RST markers do not have any data. + break; + case 0xd9: + // Found end marker. + break; + case 0xda: + if (mode == JpegReadMode::kReadAll) { + ok = ProcessScan(data, len, dc_huff_lut, ac_huff_lut, + scan_progression, is_progressive, &pos, jpg); + } + break; + case 0xdb: + ok = ProcessDQT(data, len, &pos, jpg); + break; + case 0xdd: + ok = ProcessDRI(data, len, &pos, &found_dri, jpg); + break; + case 0xe0: + case 0xe1: + case 0xe2: + case 0xe3: + case 0xe4: + case 0xe5: + case 0xe6: + case 0xe7: + case 0xe8: + case 0xe9: + case 0xea: + case 0xeb: + case 0xec: + case 0xed: + case 0xee: + case 0xef: + if (mode != JpegReadMode::kReadTables) { + ok = ProcessAPP(data, len, &pos, jpg); + } + break; + case 0xfe: + if (mode != JpegReadMode::kReadTables) { + ok = ProcessCOM(data, len, &pos, jpg); + } + break; + default: + JXL_JPEG_DEBUG("Unsupported marker: %d pos=%" PRIuS " len=%" PRIuS, + marker, pos, len); + jpg->error = JPEGReadError::UNSUPPORTED_MARKER; + ok = false; + break; + } + if (!ok) { + return false; + } + jpg->marker_order.push_back(marker); + if (mode == JpegReadMode::kReadHeader && found_sof) { + break; + } + } while (marker != 0xd9); + + if (!found_sof) { + JXL_JPEG_DEBUG("Missing SOF marker."); + jpg->error = JPEGReadError::SOF_NOT_FOUND; + return false; + } + + // Supplemental checks. + if (mode == JpegReadMode::kReadAll) { + if (pos < len) { + jpg->tail_data = std::vector(data + pos, data + len); + } + if (!FixupIndexes(jpg)) { + return false; + } + if (jpg->huffman_code.empty()) { + // Section B.2.4.2: "If a table has never been defined for a particular + // destination, then when this destination is specified in a scan header, + // the results are unpredictable." + JXL_JPEG_DEBUG("Need at least one Huffman code table."); + jpg->error = JPEGReadError::HUFFMAN_TABLE_ERROR; + return false; + } + if (jpg->huffman_code.size() >= kMaxDHTMarkers) { + JXL_JPEG_DEBUG("Too many Huffman tables."); + jpg->error = JPEGReadError::HUFFMAN_TABLE_ERROR; + return false; + } + } + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data_reader.h b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data_reader.h new file mode 100644 index 000000000..3fad820e9 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_data_reader.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Functions for reading a jpeg byte stream into a JPEGData object. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ +#define LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ + +#include +#include + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +enum class JpegReadMode { + kReadHeader, // only basic headers + kReadTables, // headers and tables (quant, Huffman, ...) + kReadAll, // everything +}; + +// Parses the JPEG stream contained in data[*pos ... len) and fills in *jpg with +// the parsed information. +// If mode is kReadHeader, it fills in only the image dimensions in *jpg. +// Returns false if the data is not valid JPEG, or if it contains an unsupported +// JPEG feature. +bool ReadJpeg(const uint8_t* data, const size_t len, JpegReadMode mode, + JPEGData* jpg); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc new file mode 100644 index 000000000..38282e640 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/enc_jpeg_huffman_decode.h" + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +// Returns the table width of the next 2nd level table, count is the histogram +// of bit lengths for the remaining symbols, len is the code length of the next +// processed symbol. +static inline int NextTableBitSize(const int* count, int len) { + int left = 1 << (len - kJpegHuffmanRootTableBits); + while (len < static_cast(kJpegHuffmanMaxBitLength)) { + left -= count[len]; + if (left <= 0) break; + ++len; + left <<= 1; + } + return len - kJpegHuffmanRootTableBits; +} + +void BuildJpegHuffmanTable(const uint32_t* count, const uint32_t* symbols, + HuffmanTableEntry* lut) { + HuffmanTableEntry code; // current table entry + HuffmanTableEntry* table; // next available space in table + int len; // current code length + int idx; // symbol index + int key; // prefix code + int reps; // number of replicate key values in current table + int low; // low bits for current root entry + int table_bits; // key length of current table + int table_size; // size of current table + + // Make a local copy of the input bit length histogram. + int tmp_count[kJpegHuffmanMaxBitLength + 1] = {0}; + int total_count = 0; + for (len = 1; len <= static_cast(kJpegHuffmanMaxBitLength); ++len) { + tmp_count[len] = count[len]; + total_count += tmp_count[len]; + } + + table = lut; + table_bits = kJpegHuffmanRootTableBits; + table_size = 1 << table_bits; + + // Special case code with only one value. + if (total_count == 1) { + code.bits = 0; + code.value = symbols[0]; + for (key = 0; key < table_size; ++key) { + table[key] = code; + } + return; + } + + // Fill in root table. + key = 0; + idx = 0; + for (len = 1; len <= kJpegHuffmanRootTableBits; ++len) { + for (; tmp_count[len] > 0; --tmp_count[len]) { + code.bits = len; + code.value = symbols[idx++]; + reps = 1 << (kJpegHuffmanRootTableBits - len); + while (reps--) { + table[key++] = code; + } + } + } + + // Fill in 2nd level tables and add pointers to root table. + table += table_size; + table_size = 0; + low = 0; + for (len = kJpegHuffmanRootTableBits + 1; + len <= static_cast(kJpegHuffmanMaxBitLength); ++len) { + for (; tmp_count[len] > 0; --tmp_count[len]) { + // Start a new sub-table if the previous one is full. + if (low >= table_size) { + table += table_size; + table_bits = NextTableBitSize(tmp_count, len); + table_size = 1 << table_bits; + low = 0; + lut[key].bits = table_bits + kJpegHuffmanRootTableBits; + lut[key].value = (table - lut) - key; + ++key; + } + code.bits = len - kJpegHuffmanRootTableBits; + code.value = symbols[idx++]; + reps = 1 << (table_bits - code.bits); + while (reps--) { + table[low++] = code; + } + } + } +} + +} // namespace jpeg +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_huffman_decode.h b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_huffman_decode.h new file mode 100644 index 000000000..b8a60e410 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/enc_jpeg_huffman_decode.h @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Utility function for building a Huffman lookup table for the jpeg decoder. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ +#define LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ + +#include + +namespace jxl { +namespace jpeg { + +constexpr int kJpegHuffmanRootTableBits = 8; +// Maximum huffman lookup table size. +// According to zlib/examples/enough.c, 758 entries are always enough for +// an alphabet of 257 symbols (256 + 1 special symbol for the all 1s code) and +// max bit length 16 if the root table has 8 bits. +constexpr int kJpegHuffmanLutSize = 758; + +struct HuffmanTableEntry { + // Initialize the value to an invalid symbol so that we can recognize it + // when reading the bit stream using a Huffman code with space > 0. + HuffmanTableEntry() : bits(0), value(0xffff) {} + + uint8_t bits; // number of bits used for this symbol + uint16_t value; // symbol value or table offset +}; + +// Builds jpeg-style Huffman lookup table from the given symbols. +// The symbols are in order of increasing bit lengths. The number of symbols +// with bit length n is given in counts[n] for each n >= 1. +void BuildJpegHuffmanTable(const uint32_t* counts, const uint32_t* symbols, + HuffmanTableEntry* lut); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ diff --git a/media/libjxl/src/lib/jxl/jpeg/jpeg_data.cc b/media/libjxl/src/lib/jxl/jpeg/jpeg_data.cc new file mode 100644 index 000000000..a78d77c96 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/jpeg_data.cc @@ -0,0 +1,460 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/jpeg_data.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace jpeg { + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +namespace { +enum JPEGComponentType : uint32_t { + kGray = 0, + kYCbCr = 1, + kRGB = 2, + kCustom = 3, +}; + +struct JPEGInfo { + size_t num_app_markers = 0; + size_t num_com_markers = 0; + size_t num_scans = 0; + size_t num_intermarker = 0; + bool has_dri = false; +}; + +Status VisitMarker(uint8_t* marker, Visitor* visitor, JPEGInfo* info) { + uint32_t marker32 = *marker - 0xc0; + JXL_RETURN_IF_ERROR(visitor->Bits(6, 0x00, &marker32)); + *marker = marker32 + 0xc0; + if ((*marker & 0xf0) == 0xe0) { + info->num_app_markers++; + } + if (*marker == 0xfe) { + info->num_com_markers++; + } + if (*marker == 0xda) { + info->num_scans++; + } + // We use a fake 0xff marker to signal intermarker data. + if (*marker == 0xff) { + info->num_intermarker++; + } + if (*marker == 0xdd) { + info->has_dri = true; + } + return true; +} + +} // namespace + +Status JPEGData::VisitFields(Visitor* visitor) { + bool is_gray = components.size() == 1; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &is_gray)); + if (visitor->IsReading()) { + components.resize(is_gray ? 1 : 3); + } + JPEGInfo info; + if (visitor->IsReading()) { + uint8_t marker = 0xc0; + do { + JXL_RETURN_IF_ERROR(VisitMarker(&marker, visitor, &info)); + marker_order.push_back(marker); + if (marker_order.size() > 16384) { + return JXL_FAILURE("Too many markers: %" PRIuS "\n", + marker_order.size()); + } + } while (marker != 0xd9); + } else { + if (marker_order.size() > 16384) { + return JXL_FAILURE("Too many markers: %" PRIuS "\n", marker_order.size()); + } + for (size_t i = 0; i < marker_order.size(); i++) { + JXL_RETURN_IF_ERROR(VisitMarker(&marker_order[i], visitor, &info)); + } + if (!marker_order.empty()) { + // Last marker should always be EOI marker. + JXL_CHECK(marker_order.back() == 0xd9); + } + } + + // Size of the APP and COM markers. + if (visitor->IsReading()) { + app_data.resize(info.num_app_markers); + app_marker_type.resize(info.num_app_markers); + com_data.resize(info.num_com_markers); + scan_info.resize(info.num_scans); + } + JXL_ASSERT(app_data.size() == info.num_app_markers); + JXL_ASSERT(app_marker_type.size() == info.num_app_markers); + JXL_ASSERT(com_data.size() == info.num_com_markers); + JXL_ASSERT(scan_info.size() == info.num_scans); + for (size_t i = 0; i < app_data.size(); i++) { + auto& app = app_data[i]; + // Encodes up to 8 different values. + JXL_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), BitsOffset(1, 2), BitsOffset(2, 4), 0, + reinterpret_cast(&app_marker_type[i]))); + if (app_marker_type[i] != AppMarkerType::kUnknown && + app_marker_type[i] != AppMarkerType::kICC && + app_marker_type[i] != AppMarkerType::kExif && + app_marker_type[i] != AppMarkerType::kXMP) { + return JXL_FAILURE("Unknown app marker type %u", + static_cast(app_marker_type[i])); + } + uint32_t len = app.size() - 1; + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) app.resize(len + 1); + if (app.size() < 3) { + return JXL_FAILURE("Invalid marker size: %" PRIuS "\n", app.size()); + } + } + for (auto& com : com_data) { + uint32_t len = com.size() - 1; + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) com.resize(len + 1); + if (com.size() < 3) { + return JXL_FAILURE("Invalid marker size: %" PRIuS "\n", com.size()); + } + } + + uint32_t num_quant_tables = quant.size(); + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 2, &num_quant_tables)); + if (num_quant_tables == 4) { + return JXL_FAILURE("Invalid number of quant tables"); + } + if (visitor->IsReading()) { + quant.resize(num_quant_tables); + } + for (size_t i = 0; i < num_quant_tables; i++) { + if (quant[i].precision > 1) { + return JXL_FAILURE( + "Quant tables with more than 16 bits are not supported"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(1, 0, &quant[i].precision)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, i, &quant[i].index)); + JXL_RETURN_IF_ERROR(visitor->Bool(true, &quant[i].is_last)); + } + + JPEGComponentType component_type = + components.size() == 1 && components[0].id == 1 ? JPEGComponentType::kGray + : components.size() == 3 && components[0].id == 1 && + components[1].id == 2 && components[2].id == 3 + ? JPEGComponentType::kYCbCr + : components.size() == 3 && components[0].id == 'R' && + components[1].id == 'G' && components[2].id == 'B' + ? JPEGComponentType::kRGB + : JPEGComponentType::kCustom; + JXL_RETURN_IF_ERROR( + visitor->Bits(2, JPEGComponentType::kYCbCr, + reinterpret_cast(&component_type))); + uint32_t num_components; + if (component_type == JPEGComponentType::kGray) { + num_components = 1; + } else if (component_type != JPEGComponentType::kCustom) { + num_components = 3; + } else { + num_components = components.size(); + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 3, &num_components)); + if (num_components != 1 && num_components != 3) { + return JXL_FAILURE("Invalid number of components: %u", num_components); + } + } + if (visitor->IsReading()) { + components.resize(num_components); + } + if (component_type == JPEGComponentType::kCustom) { + for (size_t i = 0; i < components.size(); i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(8, 0, &components[i].id)); + } + } else if (component_type == JPEGComponentType::kGray) { + components[0].id = 1; + } else if (component_type == JPEGComponentType::kRGB) { + components[0].id = 'R'; + components[1].id = 'G'; + components[2].id = 'B'; + } else { + components[0].id = 1; + components[1].id = 2; + components[2].id = 3; + } + size_t used_tables = 0; + for (size_t i = 0; i < components.size(); i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &components[i].quant_idx)); + if (components[i].quant_idx >= quant.size()) { + return JXL_FAILURE("Invalid quant table for component %" PRIuS ": %u\n", + i, components[i].quant_idx); + } + used_tables |= 1U << components[i].quant_idx; + } + for (size_t i = 0; i < quant.size(); i++) { + if (used_tables & (1 << i)) continue; + if (i == 0) return JXL_FAILURE("First quant table unused."); + // Unused quant table has to be set to copy of previous quant table + for (size_t j = 0; j < 64; j++) { + if (quant[i].values[j] != quant[i - 1].values[j]) { + return JXL_FAILURE("Non-trivial unused quant table"); + } + } + } + + uint32_t num_huff = huffman_code.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(4), BitsOffset(3, 2), BitsOffset(4, 10), + BitsOffset(6, 26), 4, &num_huff)); + if (visitor->IsReading()) { + huffman_code.resize(num_huff); + } + for (JPEGHuffmanCode& hc : huffman_code) { + bool is_ac = hc.slot_id >> 4; + uint32_t id = hc.slot_id & 0xF; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &is_ac)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &id)); + hc.slot_id = (static_cast(is_ac) << 4) | id; + JXL_RETURN_IF_ERROR(visitor->Bool(true, &hc.is_last)); + size_t num_symbols = 0; + for (size_t i = 0; i <= 16; i++) { + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(3, 2), + Bits(8), 0, &hc.counts[i])); + num_symbols += hc.counts[i]; + } + if (num_symbols < 1) { + // Actually, at least 2 symbols are required, since one of them is EOI. + return JXL_FAILURE("Empty Huffman table"); + } + if (num_symbols > hc.values.size()) { + return JXL_FAILURE("Huffman code too large (%" PRIuS ")", num_symbols); + } + // Presence flags for 4 * 64 + 1 values. + uint64_t value_slots[5] = {}; + for (size_t i = 0; i < num_symbols; i++) { + // Goes up to 256, included. Might have the same symbol appear twice... + JXL_RETURN_IF_ERROR(visitor->U32(Bits(2), BitsOffset(2, 4), + BitsOffset(4, 8), BitsOffset(8, 1), 0, + &hc.values[i])); + value_slots[hc.values[i] >> 6] |= (uint64_t)1 << (hc.values[i] & 0x3F); + } + if (hc.values[num_symbols - 1] != kJpegHuffmanAlphabetSize) { + return JXL_FAILURE("Missing EOI symbol"); + } + // Last element, denoting EOI, have to be 1 after the loop. + JXL_ASSERT(value_slots[4] == 1); + size_t num_values = 1; + for (size_t i = 0; i < 4; ++i) num_values += hwy::PopCount(value_slots[i]); + if (num_values != num_symbols) { + return JXL_FAILURE("Duplicate Huffman symbols"); + } + if (!is_ac) { + bool only_dc = ((value_slots[0] >> kJpegDCAlphabetSize) | value_slots[1] | + value_slots[2] | value_slots[3]) == 0; + if (!only_dc) return JXL_FAILURE("Huffman symbols out of DC range"); + } + } + + for (auto& scan : scan_info) { + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 1, &scan.num_components)); + if (scan.num_components >= 4) { + return JXL_FAILURE("Invalid number of components in SOS marker"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(6, 0, &scan.Ss)); + JXL_RETURN_IF_ERROR(visitor->Bits(6, 63, &scan.Se)); + JXL_RETURN_IF_ERROR(visitor->Bits(4, 0, &scan.Al)); + JXL_RETURN_IF_ERROR(visitor->Bits(4, 0, &scan.Ah)); + for (size_t i = 0; i < scan.num_components; i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].comp_idx)); + if (scan.components[i].comp_idx >= components.size()) { + return JXL_FAILURE("Invalid component idx in SOS marker"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].ac_tbl_idx)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].dc_tbl_idx)); + } + // TODO(veluca): actually set and use this value. + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), Val(2), BitsOffset(3, 3), + kMaxNumPasses - 1, + &scan.last_needed_pass)); + } + + // From here on, this is data that is not strictly necessary to get a valid + // JPEG, but necessary for bit-exact JPEG reconstruction. + if (info.has_dri) { + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &restart_interval)); + } + + uint64_t padding_spot_limit = scan_info.size(); + + for (auto& scan : scan_info) { + uint32_t num_reset_points = scan.reset_points.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(2, 1), BitsOffset(4, 4), + BitsOffset(16, 20), 0, &num_reset_points)); + if (visitor->IsReading()) { + scan.reset_points.resize(num_reset_points); + } + int last_block_idx = -1; + for (auto& block_idx : scan.reset_points) { + block_idx -= last_block_idx + 1; + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(3, 1), + BitsOffset(5, 9), BitsOffset(28, 41), 0, + &block_idx)); + block_idx += last_block_idx + 1; + if (static_cast(block_idx) < last_block_idx + 1) { + return JXL_FAILURE("Invalid block ID: %u, last block was %d", block_idx, + last_block_idx); + } + // TODO(eustas): better upper boundary could be given at this point; also + // it could be applied during reset_points reading. + if (block_idx > (1u << 30)) { + // At most 8K x 8K x num_channels blocks are expected. That is, + // typically, 1.5 * 2^27. 2^30 should be sufficient for any sane + // image. + return JXL_FAILURE("Invalid block ID: %u", block_idx); + } + last_block_idx = block_idx; + } + + uint32_t num_extra_zero_runs = scan.extra_zero_runs.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(2, 1), BitsOffset(4, 4), + BitsOffset(16, 20), 0, + &num_extra_zero_runs)); + if (visitor->IsReading()) { + scan.extra_zero_runs.resize(num_extra_zero_runs); + } + last_block_idx = -1; + for (size_t i = 0; i < scan.extra_zero_runs.size(); ++i) { + uint32_t& block_idx = scan.extra_zero_runs[i].block_idx; + JXL_RETURN_IF_ERROR(visitor->U32( + Val(1), BitsOffset(2, 2), BitsOffset(4, 5), BitsOffset(8, 20), 1, + &scan.extra_zero_runs[i].num_extra_zero_runs)); + block_idx -= last_block_idx + 1; + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(3, 1), + BitsOffset(5, 9), BitsOffset(28, 41), 0, + &block_idx)); + block_idx += last_block_idx + 1; + if (static_cast(block_idx) < last_block_idx + 1) { + return JXL_FAILURE("Invalid block ID: %u, last block was %d", block_idx, + last_block_idx); + } + if (block_idx > (1u << 30)) { + // At most 8K x 8K x num_channels blocks are expected. That is, + // typically, 1.5 * 2^27. 2^30 should be sufficient for any sane + // image. + return JXL_FAILURE("Invalid block ID: %u", block_idx); + } + last_block_idx = block_idx; + } + + if (restart_interval > 0) { + int MCUs_per_row = 0; + int MCU_rows = 0; + CalculateMcuSize(scan, &MCUs_per_row, &MCU_rows); + padding_spot_limit += DivCeil(MCU_rows * MCUs_per_row, restart_interval); + } + } + std::vector inter_marker_data_sizes; + inter_marker_data_sizes.reserve(info.num_intermarker); + for (size_t i = 0; i < info.num_intermarker; ++i) { + uint32_t len = visitor->IsReading() ? 0 : inter_marker_data[i].size(); + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) inter_marker_data_sizes.emplace_back(len); + } + uint32_t tail_data_len = tail_data.size(); + if (!visitor->IsReading() && tail_data_len > 4260096) { + error = JPEGReadError::TAIL_DATA_TOO_LARGE; + return JXL_FAILURE("Tail data too large (max size = 4260096, size = %u).", + tail_data_len); + } + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(8, 1), + BitsOffset(16, 257), BitsOffset(22, 65793), + 0, &tail_data_len)); + + JXL_RETURN_IF_ERROR(visitor->Bool(false, &has_zero_padding_bit)); + if (has_zero_padding_bit) { + uint32_t nbit = padding_bits.size(); + JXL_RETURN_IF_ERROR(visitor->Bits(24, 0, &nbit)); + if (nbit > 7 * padding_spot_limit) { + return JXL_FAILURE("Number of padding bits does not correspond to image"); + } + // TODO(eustas): check that that much bits of input are available. + if (visitor->IsReading()) { + padding_bits.resize(nbit); + } + // TODO(eustas): read in (8-64?) bit groups to reduce overhead. + for (uint8_t& bit : padding_bits) { + bool bbit = bit; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &bbit)); + bit = bbit; + } + } + + // Apply postponed actions. + if (visitor->IsReading()) { + tail_data.resize(tail_data_len); + JXL_ASSERT(inter_marker_data_sizes.size() == info.num_intermarker); + inter_marker_data.reserve(info.num_intermarker); + for (size_t i = 0; i < info.num_intermarker; ++i) { + inter_marker_data.emplace_back(inter_marker_data_sizes[i]); + } + } + + return true; +} + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +void JPEGData::CalculateMcuSize(const JPEGScanInfo& scan, int* MCUs_per_row, + int* MCU_rows) const { + const bool is_interleaved = (scan.num_components > 1); + const JPEGComponent& base_component = components[scan.components[0].comp_idx]; + // h_group / v_group act as numerators for converting number of blocks to + // number of MCU. In interleaved mode it is 1, so MCU is represented with + // max_*_samp_factor blocks. In non-interleaved mode we choose numerator to + // be the samping factor, consequently MCU is always represented with single + // block. + const int h_group = is_interleaved ? 1 : base_component.h_samp_factor; + const int v_group = is_interleaved ? 1 : base_component.v_samp_factor; + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (const auto& c : components) { + max_h_samp_factor = std::max(c.h_samp_factor, max_h_samp_factor); + max_v_samp_factor = std::max(c.v_samp_factor, max_v_samp_factor); + } + *MCUs_per_row = DivCeil(width * h_group, 8 * max_h_samp_factor); + *MCU_rows = DivCeil(height * v_group, 8 * max_v_samp_factor); +} + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +Status SetJPEGDataFromICC(const PaddedBytes& icc, jpeg::JPEGData* jpeg_data) { + size_t icc_pos = 0; + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + if (jpeg_data->app_marker_type[i] != jpeg::AppMarkerType::kICC) { + continue; + } + size_t len = jpeg_data->app_data[i].size() - 17; + if (icc_pos + len > icc.size()) { + return JXL_FAILURE( + "ICC length is less than APP markers: requested %" PRIuS + " more bytes, " + "%" PRIuS " available", + len, icc.size() - icc_pos); + } + memcpy(&jpeg_data->app_data[i][17], icc.data() + icc_pos, len); + icc_pos += len; + } + if (icc_pos != icc.size() && icc_pos != 0) { + return JXL_FAILURE("ICC length is more than APP markers"); + } + return true; +} + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jpeg +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/jpeg/jpeg_data.h b/media/libjxl/src/lib/jxl/jpeg/jpeg_data.h new file mode 100644 index 000000000..8fbc8696e --- /dev/null +++ b/media/libjxl/src/lib/jxl/jpeg/jpeg_data.h @@ -0,0 +1,268 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Data structures that represent the non-pixel contents of a jpeg file. + +#ifndef LIB_JXL_JPEG_JPEG_DATA_H_ +#define LIB_JXL_JPEG_JPEG_DATA_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/common.h" // JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/fields.h" + +namespace jxl { +namespace jpeg { + +constexpr int kMaxComponents = 4; +constexpr int kMaxQuantTables = 4; +constexpr int kMaxHuffmanTables = 4; +constexpr size_t kJpegHuffmanMaxBitLength = 16; +constexpr int kJpegHuffmanAlphabetSize = 256; +constexpr int kJpegDCAlphabetSize = 12; +constexpr int kMaxDHTMarkers = 512; +constexpr int kMaxDimPixels = 65535; +constexpr uint8_t kApp1 = 0xE1; +constexpr uint8_t kApp2 = 0xE2; +const uint8_t kIccProfileTag[12] = "ICC_PROFILE"; +const uint8_t kExifTag[6] = "Exif\0"; +const uint8_t kXMPTag[29] = "http://ns.adobe.com/xap/1.0/"; + +/* clang-format off */ +constexpr uint32_t kJPEGNaturalOrder[80] = { + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // extra entries for safety in decoder + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63, 63 +}; + +constexpr uint32_t kJPEGZigZagOrder[64] = { + 0, 1, 5, 6, 14, 15, 27, 28, + 2, 4, 7, 13, 16, 26, 29, 42, + 3, 8, 12, 17, 25, 30, 41, 43, + 9, 11, 18, 24, 31, 40, 44, 53, + 10, 19, 23, 32, 39, 45, 52, 54, + 20, 22, 33, 38, 46, 51, 55, 60, + 21, 34, 37, 47, 50, 56, 59, 61, + 35, 36, 48, 49, 57, 58, 62, 63 +}; +/* clang-format on */ + +enum struct JPEGReadError { + OK = 0, + SOI_NOT_FOUND, + SOF_NOT_FOUND, + UNEXPECTED_EOF, + MARKER_BYTE_NOT_FOUND, + UNSUPPORTED_MARKER, + WRONG_MARKER_SIZE, + INVALID_PRECISION, + INVALID_WIDTH, + INVALID_HEIGHT, + INVALID_NUMCOMP, + INVALID_SAMP_FACTOR, + INVALID_START_OF_SCAN, + INVALID_END_OF_SCAN, + INVALID_SCAN_BIT_POSITION, + INVALID_COMPS_IN_SCAN, + INVALID_HUFFMAN_INDEX, + INVALID_QUANT_TBL_INDEX, + INVALID_QUANT_VAL, + INVALID_MARKER_LEN, + INVALID_SAMPLING_FACTORS, + INVALID_HUFFMAN_CODE, + INVALID_SYMBOL, + NON_REPRESENTABLE_DC_COEFF, + NON_REPRESENTABLE_AC_COEFF, + INVALID_SCAN, + OVERLAPPING_SCANS, + INVALID_SCAN_ORDER, + EXTRA_ZERO_RUN, + DUPLICATE_DRI, + DUPLICATE_SOF, + WRONG_RESTART_MARKER, + DUPLICATE_COMPONENT_ID, + COMPONENT_NOT_FOUND, + HUFFMAN_TABLE_NOT_FOUND, + HUFFMAN_TABLE_ERROR, + QUANT_TABLE_NOT_FOUND, + EMPTY_DHT, + EMPTY_DQT, + OUT_OF_BAND_COEFF, + EOB_RUN_TOO_LONG, + IMAGE_TOO_LARGE, + INVALID_QUANT_TBL_PRECISION, + TAIL_DATA_TOO_LARGE +}; + +// Quantization values for an 8x8 pixel block. +struct JPEGQuantTable { + std::array values; + uint32_t precision = 0; + // The index of this quantization table as it was parsed from the input JPEG. + // Each DQT marker segment contains an 'index' field, and we save this index + // here. Valid values are 0 to 3. + uint32_t index = 0; + // Set to true if this table is the last one within its marker segment. + bool is_last = true; +}; + +// Huffman code and decoding lookup table used for DC and AC coefficients. +struct JPEGHuffmanCode { + // Bit length histogram. + std::array counts = {}; + // Symbol values sorted by increasing bit lengths. + std::array values = {}; + // The index of the Huffman code in the current set of Huffman codes. For AC + // component Huffman codes, 0x10 is added to the index. + int slot_id = 0; + // Set to true if this Huffman code is the last one within its marker segment. + bool is_last = true; +}; + +// Huffman table indexes used for one component of one scan. +struct JPEGComponentScanInfo { + uint32_t comp_idx; + uint32_t dc_tbl_idx; + uint32_t ac_tbl_idx; +}; + +// Contains information that is used in one scan. +struct JPEGScanInfo { + // Parameters used for progressive scans (named the same way as in the spec): + // Ss : Start of spectral band in zig-zag sequence. + // Se : End of spectral band in zig-zag sequence. + // Ah : Successive approximation bit position, high. + // Al : Successive approximation bit position, low. + uint32_t Ss; + uint32_t Se; + uint32_t Ah; + uint32_t Al; + uint32_t num_components = 0; + std::array components; + // Last codestream pass that is needed to write this scan. + uint32_t last_needed_pass = 0; + + // Extra information required for bit-precise JPEG file reconstruction. + + // Set of block indexes where the JPEG encoder has to flush the end-of-block + // runs and refinement bits. + std::vector reset_points; + // The number of extra zero runs (Huffman symbol 0xf0) before the end of + // block (if nonzero), indexed by block index. + // All of these symbols can be omitted without changing the pixel values, but + // some jpeg encoders put these at the end of blocks. + typedef struct { + uint32_t block_idx; + uint32_t num_extra_zero_runs; + } ExtraZeroRunInfo; + std::vector extra_zero_runs; +}; + +typedef int16_t coeff_t; + +// Represents one component of a jpeg file. +struct JPEGComponent { + JPEGComponent() + : id(0), + h_samp_factor(1), + v_samp_factor(1), + quant_idx(0), + width_in_blocks(0), + height_in_blocks(0) {} + + // One-byte id of the component. + uint32_t id; + // Horizontal and vertical sampling factors. + // In interleaved mode, each minimal coded unit (MCU) has + // h_samp_factor x v_samp_factor DCT blocks from this component. + int h_samp_factor; + int v_samp_factor; + // The index of the quantization table used for this component. + uint32_t quant_idx; + // The dimensions of the component measured in 8x8 blocks. + uint32_t width_in_blocks; + uint32_t height_in_blocks; + // The DCT coefficients of this component, laid out block-by-block, divided + // through the quantization matrix values. + std::vector coeffs; +}; + +enum class AppMarkerType : uint32_t { + kUnknown = 0, + kICC = 1, + kExif = 2, + kXMP = 3, +}; + +// Represents a parsed jpeg file. +struct JPEGData : public Fields { + JPEGData() + : width(0), + height(0), + restart_interval(0), + error(JPEGReadError::OK), + has_zero_padding_bit(false) {} + + JXL_FIELDS_NAME(JPEGData) +#if JPEGXL_ENABLE_TRANSCODE_JPEG + // Doesn't serialize everything - skips brotli-encoded data and what is + // already encoded in the codestream. + Status VisitFields(Visitor* visitor) override; +#else + Status VisitFields(Visitor* /* visitor */) override { + JXL_ABORT("JPEG transcoding support not enabled"); + } +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + + void CalculateMcuSize(const JPEGScanInfo& scan, int* MCUs_per_row, + int* MCU_rows) const; + + int width; + int height; + uint32_t restart_interval; + std::vector> app_data; + std::vector app_marker_type; + std::vector> com_data; + std::vector quant; + std::vector huffman_code; + std::vector components; + std::vector scan_info; + std::vector marker_order; + std::vector> inter_marker_data; + std::vector tail_data; + JPEGReadError error; + + // Extra information required for bit-precise JPEG file reconstruction. + + bool has_zero_padding_bit; + std::vector padding_bits; +}; + +#if JPEGXL_ENABLE_TRANSCODE_JPEG +// Set ICC profile in jpeg_data. +Status SetJPEGDataFromICC(const PaddedBytes& icc, jpeg::JPEGData* jpeg_data); +#else +static JXL_INLINE Status SetJPEGDataFromICC(const PaddedBytes& /* icc */, + jpeg::JPEGData* /* jpeg_data */) { + JXL_ABORT("JPEG transcoding support not enabled"); +} +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_JPEG_DATA_H_ diff --git a/media/libjxl/src/lib/jxl/jxl.syms b/media/libjxl/src/lib/jxl/jxl.syms new file mode 100644 index 000000000..0f398d715 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jxl.syms @@ -0,0 +1,5 @@ +{ + extern "C" { + jpegxl_*; + }; +}; diff --git a/media/libjxl/src/lib/jxl/jxl.version b/media/libjxl/src/lib/jxl/jxl.version new file mode 100644 index 000000000..26b0e9e54 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jxl.version @@ -0,0 +1,17 @@ +JXL_0 { + global: + Jxl*; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/media/libjxl/src/lib/jxl/jxl_inspection.h b/media/libjxl/src/lib/jxl/jxl_inspection.h new file mode 100644 index 000000000..0b70a5852 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jxl_inspection.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JXL_INSPECTION_H_ +#define LIB_JXL_JXL_INSPECTION_H_ + +#include + +#include "lib/jxl/image.h" + +namespace jxl { +// Type of the inspection-callback which, if enabled, will be called on various +// intermediate data during image processing, allowing inspection access. +// +// Returns false if processing can be stopped at that point, true otherwise. +// This is only advisory - it is always OK to just continue processing. +using InspectorImage3F = std::function; +} // namespace jxl + +#endif // LIB_JXL_JXL_INSPECTION_H_ diff --git a/media/libjxl/src/lib/jxl/jxl_osx.syms b/media/libjxl/src/lib/jxl/jxl_osx.syms new file mode 100644 index 000000000..96bc56802 --- /dev/null +++ b/media/libjxl/src/lib/jxl/jxl_osx.syms @@ -0,0 +1 @@ +_Jxl* diff --git a/media/libjxl/src/lib/jxl/jxl_test.cc b/media/libjxl/src/lib/jxl/jxl_test.cc new file mode 100644 index 000000000..63ce6125f --- /dev/null +++ b/media/libjxl/src/lib/jxl/jxl_test.cc @@ -0,0 +1,1744 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/jxl.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/codec_y4m_testonly.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/fake_parallel_runner_testonly.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" +#include "tools/box/box.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +#define JXL_TEST_NL 0 // Disabled in code + +void CreateImage1x1(CodecInOut* io) { + Image3F image(1, 1); + ZeroFillImage(&image); + io->metadata.m.SetUintSamples(8); + io->metadata.m.color_encoding = ColorEncoding::SRGB(); + io->SetFromImage(std::move(image), io->metadata.m.color_encoding); +} + +TEST(JxlTest, HeaderSize) { + CodecInOut io; + CreateImage1x1(&io); + + CompressParams cparams; + cparams.butteraugli_distance = 1.5; + ThreadPool* pool = nullptr; + + { + CodecInOut io2; + AuxOut aux_out; + Roundtrip(&io, cparams, {}, pool, &io2, &aux_out); + EXPECT_LE(aux_out.layers[kLayerHeader].total_bits, 41u); + } + + { + CodecInOut io2; + io.metadata.m.SetAlphaBits(8); + ImageF alpha(1, 1); + alpha.Row(0)[0] = 1; + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + AuxOut aux_out; + Roundtrip(&io, cparams, {}, pool, &io2, &aux_out); + EXPECT_LE(aux_out.layers[kLayerHeader].total_bits, 49u); + } +} + +TEST(JxlTest, RoundtripSinglePixel) { + CodecInOut io; + CreateImage1x1(&io); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + ThreadPool* pool = nullptr; + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); +} + +// Changing serialized signature causes Decode to fail. +#ifndef JXL_CRASH_ON_ERROR +TEST(JxlTest, RoundtripMarker) { + CodecInOut io; + CreateImage1x1(&io); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + AuxOut* aux_out = nullptr; + ThreadPool* pool = nullptr; + + PassesEncoderState enc_state; + for (size_t i = 0; i < 2; ++i) { + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + compressed[i] ^= 0xFF; + CodecInOut io2; + EXPECT_FALSE(test::DecodeFile({}, compressed, &io2, pool)); + } +} +#endif + +TEST(JxlTest, RoundtripTinyFast) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(32, 32); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 4.0f; + + CodecInOut io2; + const size_t enc_bytes = Roundtrip(&io, cparams, {}, pool, &io2); + printf("32x32 image size %" PRIuS " bytes\n", enc_bytes); +} + +TEST(JxlTest, RoundtripSmallD1) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + CodecInOut io_out; + size_t compressed_size; + + { + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 1000u); + EXPECT_THAT(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.0)); + } + + { + // And then, with a lower intensity target than the default, the bitrate + // should be smaller. + CodecInOut io_dim; + ASSERT_TRUE(SetFromBytes(Span(orig), &io_dim, pool)); + io_dim.metadata.m.SetIntensityTarget(100); + io_dim.ShrinkTo(io_dim.xsize() / 8, io_dim.ysize() / 8); + EXPECT_LT(Roundtrip(&io_dim, cparams, {}, pool, &io_out), compressed_size); + EXPECT_THAT( + ButteraugliDistance(io_dim, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.1)); + EXPECT_EQ(io_dim.metadata.m.IntensityTarget(), + io_out.metadata.m.IntensityTarget()); + } +} + +TEST(JxlTest, RoundtripOtherTransforms) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/64px/a2d1un_nkitzmiller_srgb8.png"); + std::unique_ptr io = jxl::make_unique(); + ASSERT_TRUE(SetFromBytes(Span(orig), io.get(), pool)); + + CompressParams cparams; + // Slow modes access linear image for adaptive quant search + cparams.speed_tier = SpeedTier::kKitten; + cparams.color_transform = ColorTransform::kNone; + cparams.butteraugli_distance = 5.0f; + + std::unique_ptr io2 = jxl::make_unique(); + const size_t compressed_size = + Roundtrip(io.get(), cparams, {}, pool, io2.get()); + EXPECT_LE(compressed_size, 23000u); + EXPECT_THAT(ButteraugliDistance(*io, *io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(3.0)); + + // Check the consistency when performing another roundtrip. + std::unique_ptr io3 = jxl::make_unique(); + const size_t compressed_size2 = + Roundtrip(io.get(), cparams, {}, pool, io3.get()); + EXPECT_LE(compressed_size2, 23000u); + EXPECT_THAT(ButteraugliDistance(*io, *io3, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(3.0)); +} + +TEST(JxlTest, RoundtripResample2) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 2; + cparams.speed_tier = SpeedTier::kFalcon; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 17000u); + EXPECT_THAT(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), + IsSlightlyBelow(90)); +} + +TEST(JxlTest, RoundtripResample2Slow) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 2; + cparams.butteraugli_distance = 10; + cparams.speed_tier = SpeedTier::kTortoise; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 5000u); + EXPECT_THAT(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), + IsSlightlyBelow(250)); +} + +TEST(JxlTest, RoundtripResample2MT) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + // image has to be large enough to have multiple groups after downsampling + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + CompressParams cparams; + cparams.resampling = 2; + cparams.speed_tier = SpeedTier::kFalcon; + CodecInOut io2; + // TODO(veluca): Figure out why msan and release produce different + // file size. + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 200000u); + EXPECT_THAT(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), + IsSlightlyBelow(340)); +} + +// Roundtrip the image using a parallel runner that executes single-threaded but +// in random order. +TEST(JxlTest, RoundtripOutOfOrderProcessing) { + FakeParallelRunner fake_pool(/*order_seed=*/123, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + // Image size is selected so that the block border needed is larger than the + // amount of pixels available on the next block. + io.ShrinkTo(513, 515); + + CompressParams cparams; + // Force epf so we end up needing a lot of border. + cparams.epf = 3; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, &pool, &io2); + + EXPECT_GE(1.5, ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripOutOfOrderProcessingBorder) { + FakeParallelRunner fake_pool(/*order_seed=*/47, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + // Image size is selected so that the block border needed is larger than the + // amount of pixels available on the next block. + io.ShrinkTo(513, 515); + + CompressParams cparams; + // Force epf so we end up needing a lot of border. + cparams.epf = 3; + cparams.resampling = 2; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, &pool, &io2); + + EXPECT_GE(2.8, ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripResample4) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 4; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 6000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(22)); +} + +TEST(JxlTest, RoundtripResample8) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 8; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 2100u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(50)); +} + +TEST(JxlTest, RoundtripUnalignedD2) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 12, io.ysize() / 7); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 700u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.7)); +} + +#if JXL_TEST_NL + +TEST(JxlTest, RoundtripMultiGroupNL) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); // partial X, full Y group + + CompressParams cparams; + + cparams.fast_mode = true; + cparams.butteraugli_distance = 1.0f; + CodecInOut io2; + Roundtrip(&io, cparams, {}, &pool, &io2); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(0.9f)); + + cparams.butteraugli_distance = 2.0f; + CodecInOut io3; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io3), 80000u); + EXPECT_THAT(ButteraugliDistance(io, io3, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(1.5f)); +} + +#endif + +TEST(JxlTest, RoundtripMultiGroup) { + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + { + ThreadPoolInternal pool(4); + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + } + io.ShrinkTo(600, 1024); + + auto test = [&](jxl::SpeedTier speed_tier, float target_distance, + size_t expected_size, float expected_distance) { + ThreadPoolInternal pool(4); + CompressParams cparams; + cparams.butteraugli_distance = target_distance; + cparams.speed_tier = speed_tier; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), expected_size); + EXPECT_THAT(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), + IsSlightlyBelow(expected_distance)); + }; + + auto run_kitten = std::async(std::launch::async, test, SpeedTier::kKitten, + 1.0f, 55000u, 11); + auto run_wombat = std::async(std::launch::async, test, SpeedTier::kWombat, + 2.0f, 34000u, 18); +} + +TEST(JxlTest, RoundtripRGBToGrayscale) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0f; + cparams.speed_tier = SpeedTier::kFalcon; + + extras::JXLDecompressParams dparams; + dparams.color_space = "Gra_D65_Rel_SRG"; + + CodecInOut io2; + EXPECT_FALSE(io.Main().IsGray()); + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 55000u); + EXPECT_TRUE(io2.Main().IsGray()); + + // Convert original to grayscale here, because TransformTo refuses to + // convert between grayscale and RGB. + ColorEncoding srgb_lin = ColorEncoding::LinearSRGB(/*is_gray=*/false); + ASSERT_TRUE(io.TransformTo(srgb_lin, GetJxlCms(), &pool)); + Image3F* color = io.Main().color(); + for (size_t y = 0; y < color->ysize(); ++y) { + float* row_r = color->PlaneRow(0, y); + float* row_g = color->PlaneRow(1, y); + float* row_b = color->PlaneRow(2, y); + for (size_t x = 0; x < color->xsize(); ++x) { + float luma = 0.2126 * row_r[x] + 0.7152 * row_g[x] + 0.0722 * row_b[x]; + row_r[x] = row_g[x] = row_b[x] = luma; + } + } + ColorEncoding srgb_gamma = ColorEncoding::SRGB(/*is_gray=*/false); + ASSERT_TRUE(io.TransformTo(srgb_gamma, GetJxlCms(), &pool)); + io.metadata.m.color_encoding = io2.Main().c_current(); + io.Main().OverrideProfile(io2.Main().c_current()); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(1.7)); +} + +TEST(JxlTest, RoundtripLargeFast) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 450800u); +} + +TEST(JxlTest, RoundtripDotsForceEpf) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.epf = 2; + cparams.dots = Override::kOn; + cparams.speed_tier = SpeedTier::kSquirrel; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 450000u); +} + +// Checks for differing size/distance in two consecutive runs of distance 2, +// which involves additional processing including adaptive reconstruction. +// Failing this may be a sign of race conditions or invalid memory accesses. +TEST(JxlTest, RoundtripD2Consistent) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 2.0; + + // Try each xsize mod kBlockDim to verify right border handling. + for (size_t xsize = 48; xsize > 40; --xsize) { + io.ShrinkTo(xsize, 15); + + CodecInOut io2; + const size_t size2 = Roundtrip(&io, cparams, {}, &pool, &io2); + + CodecInOut io3; + const size_t size3 = Roundtrip(&io, cparams, {}, &pool, &io3); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()); + const float dist3 = ComputeDistance2(io.Main(), io3.Main(), GetJxlCms()); + EXPECT_EQ(dist2, dist3); + } +} + +// Same as above, but for full image, testing multiple groups. +TEST(JxlTest, RoundtripLargeConsistent) { + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + { + ThreadPoolInternal pool(8); + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + } + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 2.0; + + auto roundtrip_and_compare = [&]() { + ThreadPoolInternal pool(8); + CodecInOut io2; + size_t size = Roundtrip(&io, cparams, {}, &pool, &io2); + double dist = ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()); + return std::tuple(size, dist); + }; + + // Try each xsize mod kBlockDim to verify right border handling. + auto future2 = std::async(std::launch::async, roundtrip_and_compare); + auto future3 = std::async(std::launch::async, roundtrip_and_compare); + + const auto result2 = future2.get(); + const auto result3 = future3.get(); + + // Exact same compressed size. + EXPECT_EQ(std::get<0>(result2), std::get<0>(result3)); + + // Exact same distance. + EXPECT_EQ(std::get<1>(result2), std::get<1>(result3)); +} + +#if JXL_TEST_NL + +TEST(JxlTest, RoundtripSmallNL) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 1500u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.7)); +} + +#endif + +TEST(JxlTest, RoundtripNoGaborishNoAR) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.butteraugli_distance = 1.0; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 40000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(2.0)); +} + +TEST(JxlTest, RoundtripSmallNoGaborish) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.butteraugli_distance = 1.0; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 900u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.2)); +} + +TEST(JxlTest, RoundtripSmallPatchesAlpha) { + ThreadPool* pool = nullptr; + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + Image3F black_with_small_lines(256, 256); + ImageF alpha(black_with_small_lines.xsize(), black_with_small_lines.ysize()); + ZeroFillImage(&black_with_small_lines); + // This pattern should be picked up by the patch detection heuristics. + for (size_t y = 0; y < black_with_small_lines.ysize(); y++) { + float* JXL_RESTRICT row = black_with_small_lines.PlaneRow(1, y); + for (size_t x = 0; x < black_with_small_lines.xsize(); x++) { + if (x % 4 == 0 && (y / 32) % 4 == 0) row[x] = 127.0f; + } + } + io.metadata.m.SetAlphaBits(8); + io.SetFromImage(std::move(black_with_small_lines), + ColorEncoding::LinearSRGB()); + FillImage(1.0f, &alpha); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 0.1f; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 2000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(0.04f)); +} + +TEST(JxlTest, RoundtripSmallPatches) { + ThreadPool* pool = nullptr; + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + Image3F black_with_small_lines(256, 256); + ZeroFillImage(&black_with_small_lines); + // This pattern should be picked up by the patch detection heuristics. + for (size_t y = 0; y < black_with_small_lines.ysize(); y++) { + float* JXL_RESTRICT row = black_with_small_lines.PlaneRow(1, y); + for (size_t x = 0; x < black_with_small_lines.xsize(); x++) { + if (x % 4 == 0 && (y / 32) % 4 == 0) row[x] = 127.0f; + } + } + io.SetFromImage(std::move(black_with_small_lines), + ColorEncoding::LinearSRGB()); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 0.1f; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 2000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(0.04f)); +} + +// Test header encoding of original bits per sample +TEST(JxlTest, RoundtripImageBundleOriginalBits) { + ThreadPool* pool = nullptr; + + // Image does not matter, only io.metadata.m and io2.metadata.m are tested. + Image3F image(1, 1); + ZeroFillImage(&image); + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(std::move(image), ColorEncoding::LinearSRGB()); + + CompressParams cparams; + + // Test unsigned integers from 1 to 32 bits + for (uint32_t bit_depth = 1; bit_depth <= 32; bit_depth++) { + if (bit_depth == 32) { + // TODO(lode): allow testing 32, however the code below ends up in + // enc_modular which does not support 32. We only want to test the header + // encoding though, so try without modular. + break; + } + + io.metadata.m.SetUintSamples(bit_depth); + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); + + EXPECT_EQ(bit_depth, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io2.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_EQ(0u, io2.metadata.m.GetAlphaBits()); + } + + // Test various existing and non-existing floating point formats + for (uint32_t bit_depth = 8; bit_depth <= 32; bit_depth++) { + if (bit_depth != 32) { + // TODO: test other float types once they work + break; + } + + uint32_t exponent_bit_depth; + if (bit_depth < 10) { + exponent_bit_depth = 2; + } else if (bit_depth < 12) { + exponent_bit_depth = 3; + } else if (bit_depth < 16) { + exponent_bit_depth = 4; + } else if (bit_depth < 20) { + exponent_bit_depth = 5; + } else if (bit_depth < 24) { + exponent_bit_depth = 6; + } else if (bit_depth < 28) { + exponent_bit_depth = 7; + } else { + exponent_bit_depth = 8; + } + + io.metadata.m.bit_depth.bits_per_sample = bit_depth; + io.metadata.m.bit_depth.floating_point_sample = true; + io.metadata.m.bit_depth.exponent_bits_per_sample = exponent_bit_depth; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); + + EXPECT_EQ(bit_depth, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_TRUE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(exponent_bit_depth, + io2.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_EQ(0u, io2.metadata.m.GetAlphaBits()); + } +} + +TEST(JxlTest, RoundtripGrayscale) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData( + "external/wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_NE(io.xsize(), 0u); + io.ShrinkTo(128, 128); + EXPECT_TRUE(io.Main().IsGray()); + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + + { + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + EXPECT_TRUE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 7000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.6)); + } + + // Test with larger butteraugli distance and other settings enabled so + // different jxl codepaths trigger. + { + CompressParams cparams; + cparams.butteraugli_distance = 8.0; + + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + EXPECT_TRUE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 1300u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(6.0)); + } + + { + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + + CodecInOut io2; + extras::JXLDecompressParams dparams; + dparams.color_space = "RGB_D65_SRG_Rel_SRG"; + EXPECT_TRUE(test::DecodeFile(dparams, compressed, &io2, pool)); + EXPECT_FALSE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 7000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.6)); + } +} + +TEST(JxlTest, RoundtripAlpha) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(300, 300); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + + for (bool use_image_callback : {false, true}) { + for (bool unpremul_alpha : {false, true}) { + CodecInOut io2; + extras::JXLDecompressParams dparams; + dparams.use_image_callback = use_image_callback; + dparams.unpremultiply_alpha = unpremul_alpha; + EXPECT_TRUE(test::DecodeFile(dparams, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 10077u); + + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.2)); + } + } +} + +TEST(JxlTest, RoundtripAlphaPremultiplied) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io, io_nopremul; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_TRUE(SetFromBytes(Span(orig), &io_nopremul, pool)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(300, 300); + io_nopremul.ShrinkTo(300, 300); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + EXPECT_FALSE(io.Main().AlphaIsPremultiplied()); + EXPECT_TRUE(io.PremultiplyAlpha()); + EXPECT_TRUE(io.Main().AlphaIsPremultiplied()); + + EXPECT_FALSE(io_nopremul.Main().AlphaIsPremultiplied()); + + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + + for (bool use_image_callback : {false, true}) { + for (bool unpremul_alpha : {false, true}) { + for (bool use_uint8 : {false, true}) { + printf( + "Testing premultiplied alpha using %s %s requesting " + "%spremultiplied output.\n", + use_uint8 ? "uint8" : "float", + use_image_callback ? "image callback" : "image_buffer", + unpremul_alpha ? "un" : ""); + CodecInOut io2; + extras::JXLDecompressParams dparams; + dparams.use_image_callback = use_image_callback; + dparams.unpremultiply_alpha = unpremul_alpha; + if (use_uint8) { + dparams.accepted_formats = { + {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}}; + } + EXPECT_TRUE(test::DecodeFile(dparams, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 10000u); + EXPECT_EQ(unpremul_alpha, !io2.Main().AlphaIsPremultiplied()); + if (!unpremul_alpha) { + EXPECT_THAT( + ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.25)); + EXPECT_TRUE(io2.UnpremultiplyAlpha()); + EXPECT_FALSE(io2.Main().AlphaIsPremultiplied()); + } + EXPECT_THAT(ButteraugliDistance(io_nopremul, io2, cparams.ba_params, + GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.35)); + } + } + } +} + +TEST(JxlTest, RoundtripAlphaResampling) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + + CompressParams cparams; + cparams.resampling = 2; + cparams.ec_resampling = 2; + cparams.butteraugli_distance = 1.0; + cparams.speed_tier = SpeedTier::kHare; + + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 15000u); + + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(4.7)); +} + +TEST(JxlTest, RoundtripAlphaResamplingOnlyAlpha) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + + CompressParams cparams; + cparams.ec_resampling = 2; + cparams.butteraugli_distance = 1.0; + cparams.speed_tier = SpeedTier::kFalcon; + + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 34200u); + + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.85)); +} + +TEST(JxlTest, RoundtripAlphaNonMultipleOf8) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(12, 12); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 180u); + + // TODO(robryk): Fix the following line in presence of different alpha_bits in + // the two contexts. + // EXPECT_TRUE(SamePixels(io.Main().alpha(), io2.Main().alpha())); + // TODO(robryk): Fix the distance estimate used in the encoder. + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(0.9)); +} + +TEST(JxlTest, RoundtripAlpha16) { + ThreadPoolInternal pool(4); + + size_t xsize = 1200, ysize = 160; + Image3F color(xsize, ysize); + ImageF alpha(xsize, ysize); + // Generate 16-bit pattern that uses various colors and alpha values. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + color.PlaneRow(0, y)[x] = (y * 65535 / ysize) * (1.0f / 65535); + color.PlaneRow(1, y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + color.PlaneRow(2, y)[x] = + ((y + x) * 65535 / (xsize + ysize)) * (1.0f / 65535); + alpha.Row(y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + } + } + const bool is_gray = false; + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + io.SetFromImage(std::move(color), io.metadata.m.color_encoding); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + // The image is wider than 512 pixels to ensure multiple groups are tested. + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + + CompressParams cparams; + cparams.butteraugli_distance = 0.5; + cparams.speed_tier = SpeedTier::kWombat; + + io.metadata.m.SetUintSamples(16); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, &pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, &pool)); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(0.8)); +} + +namespace { +CompressParams CParamsForLossless() { + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + cparams.options.predictor = {Predictor::Weighted}; + return cparams; +} +} // namespace + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 3500000u); + // If this test fails with a very close to 0.0 but not exactly 0.0 butteraugli + // distance, then there is likely a floating point issue, that could be + // happening either in io or io2. The values of io are generated by + // external_image.cc, and those in io2 by the jxl decoder. If they use + // slightly different floating point operations (say, one casts int to float + // while other divides the int through 255.0f and later multiplies it by + // 255 again) they will get slightly different values. To fix, ensure both + // sides do the following formula for converting integer range 0-255 to + // floating point range 0.0f-255.0f: static_cast(i) + // without any further intermediate operations. + // Note that this precision issue is not a problem in practice if the values + // are equal when rounded to 8-bit int, but currently full exact precision is + // tested. + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLosslessNoEncoderFastPathWP)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + cparams.speed_tier = SpeedTier::kFalcon; + cparams.options.skip_encoder_fast_path = true; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 3500000u); + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLosslessNoEncoderFastPathGradient)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + cparams.speed_tier = SpeedTier::kThunder; + cparams.options.skip_encoder_fast_path = true; + cparams.options.predictor = {Predictor::Gradient}; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 3500000u); + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLosslessNoEncoderVeryFastPathGradient)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + cparams.speed_tier = SpeedTier::kLightning; + cparams.options.skip_encoder_fast_path = true; + cparams.options.predictor = {Predictor::Gradient}; + + CodecInOut io2, io3; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 3500000u); + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); + cparams.options.skip_encoder_fast_path = false; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io3), 3500000u); + EXPECT_EQ(ComputeDistance2(io.Main(), io3.Main(), GetJxlCms()), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8Falcon)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + cparams.speed_tier = SpeedTier::kFalcon; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 3500000u); + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripLossless8Alpha) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + EXPECT_EQ(8u, io.metadata.m.GetAlphaBits()); + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + + CompressParams cparams = CParamsForLossless(); + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 350000u); + // If fails, see note about floating point in RoundtripLossless8. + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); + EXPECT_EQ(8u, io2.metadata.m.GetAlphaBits()); + EXPECT_EQ(8u, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io2.metadata.m.bit_depth.exponent_bits_per_sample); +} + +TEST(JxlTest, RoundtripLossless16Alpha) { + ThreadPool* pool = nullptr; + + size_t xsize = 1200, ysize = 160; + Image3F color(xsize, ysize); + ImageF alpha(xsize, ysize); + // Generate 16-bit pattern that uses various colors and alpha values. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + color.PlaneRow(0, y)[x] = (y * 65535 / ysize) * (1.0f / 65535); + color.PlaneRow(1, y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + color.PlaneRow(2, y)[x] = + ((y + x) * 65535 / (xsize + ysize)) * (1.0f / 65535); + alpha.Row(y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + } + } + const bool is_gray = false; + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + io.SetFromImage(std::move(color), io.metadata.m.color_encoding); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + EXPECT_EQ(16u, io.metadata.m.GetAlphaBits()); + EXPECT_EQ(16u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + + CompressParams cparams = CParamsForLossless(); + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 7100u); + // If this test fails with a very close to 0.0 but not exactly 0.0 butteraugli + // distance, then there is likely a floating point issue, that could be + // happening either in io or io2. The values of io are generated by + // external_image.cc, and those in io2 by the jxl decoder. If they use + // slightly different floating point operations (say, one does "i / 257.0f" + // while the other does "i * (1.0f / 257)" they will get slightly different + // values. To fix, ensure both sides do the following formula for converting + // integer range 0-65535 to Image3F floating point range 0.0f-255.0f: + // "i * (1.0f / 257)". + // Note that this precision issue is not a problem in practice if the values + // are equal when rounded to 16-bit int, but currently full exact precision is + // tested. + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); + EXPECT_EQ(16u, io2.metadata.m.GetAlphaBits()); + EXPECT_EQ(16u, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io2.metadata.m.bit_depth.exponent_bits_per_sample); +} + +TEST(JxlTest, RoundtripLossless16AlphaNotMisdetectedAs8Bit) { + ThreadPool* pool = nullptr; + + size_t xsize = 128, ysize = 128; + Image3F color(xsize, ysize); + ImageF alpha(xsize, ysize); + // All 16-bit values, both color and alpha, of this image are below 64. + // This allows testing if a code path wrongly concludes it's an 8-bit instead + // of 16-bit image (or even 6-bit). + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + color.PlaneRow(0, y)[x] = (y * 64 / ysize) * (1.0f / 65535); + color.PlaneRow(1, y)[x] = (x * 64 / xsize) * (1.0f / 65535); + color.PlaneRow(2, y)[x] = + ((y + x) * 64 / (xsize + ysize)) * (1.0f / 65535); + alpha.Row(y)[x] = (64 * x / xsize) * (1.0f / 65535); + } + } + const bool is_gray = false; + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + io.SetFromImage(std::move(color), io.metadata.m.color_encoding); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + EXPECT_EQ(16u, io.metadata.m.GetAlphaBits()); + EXPECT_EQ(16u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + + CompressParams cparams = CParamsForLossless(); + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 3100u); + EXPECT_EQ(16u, io2.metadata.m.GetAlphaBits()); + EXPECT_EQ(16u, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io2.metadata.m.bit_depth.exponent_bits_per_sample); + // If fails, see note about floating point in RoundtripLossless8. + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0.0); + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); +} + +TEST(JxlTest, RoundtripYCbCr420) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + const PaddedBytes yuv420 = ReadTestData("jxl/flower/flower.png.ffmpeg.y4m"); + CodecInOut io2; + ASSERT_TRUE(test::DecodeImageY4M(Span(yuv420), &io2)); + + CompressParams cparams = CParamsForLossless(); + cparams.speed_tier = SpeedTier::kThunder; + + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io2, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io3; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io3, pool)); + + EXPECT_LE(compressed.size(), 2000000u); + + // we're comparing an original PNG with a YCbCr 4:2:0 version + EXPECT_THAT(ComputeDistance2(io.Main(), io3.Main(), GetJxlCms()), + IsSlightlyBelow(4.3)); +} + +TEST(JxlTest, RoundtripDots) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0u); + + CompressParams cparams; + cparams.dots = Override::kOn; + cparams.butteraugli_distance = 0.04; + cparams.speed_tier = SpeedTier::kSquirrel; + + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 400000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(0.3)); +} + +TEST(JxlTest, RoundtripNoise) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0u); + + CompressParams cparams; + cparams.noise = Override::kOn; + cparams.speed_tier = SpeedTier::kSquirrel; + + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 40000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.6)); +} + +TEST(JxlTest, RoundtripLossless8Gray) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData( + "external/wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams = CParamsForLossless(); + + EXPECT_TRUE(io.Main().IsGray()); + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 130000u); + // If fails, see note about floating point in RoundtripLossless8. + EXPECT_EQ(ComputeDistance2(io.Main(), io2.Main(), GetJxlCms()), 0); + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool)); + EXPECT_TRUE(io2.Main().IsGray()); + EXPECT_EQ(8u, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io2.metadata.m.bit_depth.exponent_bits_per_sample); +} + +#if JPEGXL_ENABLE_GIF + +TEST(JxlTest, RoundtripAnimation) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/traffic_light.gif"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_EQ(4u, io.frames.size()); + + CompressParams cparams; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 3000u); + + EXPECT_EQ(io2.frames.size(), io.frames.size()); + test::CoalesceGIFAnimationWithAlpha(&io); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), +#if JXL_HIGH_PRECISION + 1.55); +#else + 1.75); +#endif +} + +TEST(JxlTest, RoundtripLosslessAnimation) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/traffic_light.gif"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_EQ(4u, io.frames.size()); + + CompressParams cparams = CParamsForLossless(); + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 1200u); + + EXPECT_EQ(io2.frames.size(), io.frames.size()); + test::CoalesceGIFAnimationWithAlpha(&io); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + 5e-4); +} + +TEST(JxlTest, RoundtripAnimationPatches) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/animation_patches.gif"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_EQ(2u, io.frames.size()); + + CompressParams cparams; + cparams.patches = Override::kOn; + CodecInOut io2; + // 40k with no patches, 27k with patch frames encoded multiple times. + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 24000u); + + EXPECT_EQ(io2.frames.size(), io.frames.size()); + // >10 with broken patches + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.2)); +} + +#endif // JPEGXL_ENABLE_GIF + +size_t RoundtripJpeg(const PaddedBytes& jpeg_in, ThreadPool* pool) { + CodecInOut io; + EXPECT_TRUE(jpeg::DecodeImageJPG(Span(jpeg_in), &io)); + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + PassesEncoderState passes_enc_state; + PaddedBytes compressed, codestream; + + EXPECT_TRUE(EncodeFile(cparams, &io, &passes_enc_state, &codestream, + GetJxlCms(), + /*aux_out=*/nullptr, pool)); + jpegxl::tools::JpegXlContainer enc_container; + enc_container.codestream = std::move(codestream); + jpeg::JPEGData data_in = *io.Main().jpeg_data; + jxl::PaddedBytes jpeg_data; + EXPECT_TRUE(EncodeJPEGData(data_in, &jpeg_data, cparams)); + enc_container.jpeg_reconstruction = jpeg_data.data(); + enc_container.jpeg_reconstruction_size = jpeg_data.size(); + EXPECT_TRUE(EncodeJpegXlContainerOneShot(enc_container, &compressed)); + + jxl::extras::JXLDecompressParams dparams; + dparams.runner = pool->runner(); + dparams.runner_opaque = pool->runner_opaque(); + std::vector out; + jxl::extras::PackedPixelFile ppf; + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, &ppf, &out)); + EXPECT_EQ(out.size(), jpeg_in.size()); + size_t failures = 0; + for (size_t i = 0; i < std::min(out.size(), jpeg_in.size()); i++) { + if (out[i] != jpeg_in[i]) { + EXPECT_EQ(out[i], jpeg_in[i]) + << "byte mismatch " << i << " " << out[i] << " != " << jpeg_in[i]; + if (++failures > 4) { + return compressed.size(); + } + } + } + return compressed.size(); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression444)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_444.jpg"); + // JPEG size is 696,659 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 570000u); +} + +#if JPEGXL_ENABLE_JPEG + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_444.jpg"); + CodecInOut io; + ASSERT_TRUE(jpeg::DecodeImageJPG(Span(orig), &io)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + CodecInOut io3; + Roundtrip(&io, cparams, {}, &pool, &io3); + + // TODO(eustas): investigate, why SJPEG and JpegRecompression pixels are + // different. + EXPECT_THAT(ComputeDistance2(io2.Main(), io3.Main(), GetJxlCms()), + IsSlightlyBelow(12)); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels420)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_420.jpg"); + CodecInOut io; + ASSERT_TRUE(jpeg::DecodeImageJPG(Span(orig), &io)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + CodecInOut io3; + Roundtrip(&io, cparams, {}, &pool, &io3); + + EXPECT_THAT(ComputeDistance2(io2.Main(), io3.Main(), GetJxlCms()), + IsSlightlyBelow(11)); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels420EarlyFlush)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_420.jpg"); + CodecInOut io; + ASSERT_TRUE(jpeg::DecodeImageJPG(Span(orig), &io)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 8; + + CodecInOut io3; + Roundtrip(&io, cparams, dparams, &pool, &io3); + + EXPECT_THAT(ComputeDistance2(io2.Main(), io3.Main(), GetJxlCms()), + IsSlightlyBelow(4410)); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels420Mul16)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower_cropped.jpg"); + CodecInOut io; + ASSERT_TRUE(jpeg::DecodeImageJPG(Span(orig), &io)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + CodecInOut io3; + Roundtrip(&io, cparams, {}, &pool, &io3); + + EXPECT_THAT(ComputeDistance2(io2.Main(), io3.Main(), GetJxlCms()), + IsSlightlyBelow(4)); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels_asymmetric)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("jxl/flower/flower.png.im_q85_asymmetric.jpg"); + CodecInOut io; + ASSERT_TRUE(jpeg::DecodeImageJPG(Span(orig), &io)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + CodecInOut io3; + Roundtrip(&io, cparams, {}, &pool, &io3); + + EXPECT_THAT(ComputeDistance2(io2.Main(), io3.Main(), GetJxlCms()), + IsSlightlyBelow(10)); +} + +#endif + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionGray)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("jxl/flower/flower.png.im_q85_gray.jpg"); + // JPEG size is 456,528 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 390000u); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression420)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_420.jpg"); + // JPEG size is 546,797 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 460000u); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression_luma_subsample)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("jxl/flower/flower.png.im_q85_luma_subsample.jpg"); + // JPEG size is 400,724 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 330000u); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression444_12)) { + // 444 JPEG that has an interesting sampling-factor (1x2, 1x2, 1x2). + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("jxl/flower/flower.png.im_q85_444_1x2.jpg"); + // JPEG size is 703,874 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 570000u); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression422)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_422.jpg"); + // JPEG size is 522,057 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 500000u); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression440)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png.im_q85_440.jpg"); + // JPEG size is 603,623 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 510000u); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression_asymmetric)) { + // 2x vertical downsample of one chroma channel, 2x horizontal downsample of + // the other. + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("jxl/flower/flower.png.im_q85_asymmetric.jpg"); + // JPEG size is 604,601 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 510000u); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression420Progr)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("jxl/flower/flower.png.im_q85_420_progr.jpg"); + // JPEG size is 522,057 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 460000u); +} + +TEST(JxlTest, RoundtripProgressive) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); + + CompressParams cparams; + + cparams.butteraugli_distance = 1.0f; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.progressive_mode = true; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 61700u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(1.17f)); +} + +TEST(JxlTest, RoundtripProgressiveLevel2Slow) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); + + CompressParams cparams; + + cparams.butteraugli_distance = 1.0f; + cparams.progressive_dc = 2; + cparams.speed_tier = SpeedTier::kTortoise; + cparams.responsive = true; + cparams.progressive_mode = true; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, &pool, &io2), 71000u); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(1.2f)); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/lehmer_code.h b/media/libjxl/src/lib/jxl/lehmer_code.h new file mode 100644 index 000000000..dd1d21c6f --- /dev/null +++ b/media/libjxl/src/lib/jxl/lehmer_code.h @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LEHMER_CODE_H_ +#define LIB_JXL_LEHMER_CODE_H_ + +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Permutation <=> factorial base representation (Lehmer code). + +using LehmerT = uint32_t; + +template +constexpr T ValueOfLowest1Bit(T t) { + return t & -t; +} + +// Computes the Lehmer (factorial basis) code of permutation, an array of n +// unique indices in [0..n), and stores it in code[0..len). N*logN time. +// temp must have n + 1 elements but need not be initialized. +template +void ComputeLehmerCode(const PermutationT* JXL_RESTRICT permutation, + uint32_t* JXL_RESTRICT temp, const size_t n, + LehmerT* JXL_RESTRICT code) { + for (size_t idx = 0; idx < n + 1; ++idx) temp[idx] = 0; + + for (size_t idx = 0; idx < n; ++idx) { + const PermutationT s = permutation[idx]; + + // Compute sum in Fenwick tree + uint32_t penalty = 0; + uint32_t i = s + 1; + while (i != 0) { + penalty += temp[i]; + i &= i - 1; // clear lowest bit + } + JXL_DASSERT(s >= penalty); + code[idx] = s - penalty; + i = s + 1; + // Add operation in Fenwick tree + while (i < n + 1) { + temp[i] += 1; + i += ValueOfLowest1Bit(i); + } + } +} + +// Decodes the Lehmer code in code[0..n) into permutation[0..n). +// temp must have 1 << CeilLog2(n) elements but need not be initialized. +template +void DecodeLehmerCode(const LehmerT* JXL_RESTRICT code, + uint32_t* JXL_RESTRICT temp, size_t n, + PermutationT* JXL_RESTRICT permutation) { + JXL_DASSERT(n != 0); + const size_t log2n = CeilLog2Nonzero(n); + const size_t padded_n = 1ull << log2n; + + for (size_t i = 0; i < padded_n; i++) { + const int32_t i1 = static_cast(i + 1); + temp[i] = static_cast(ValueOfLowest1Bit(i1)); + } + + for (size_t i = 0; i < n; i++) { + JXL_DASSERT(code[i] + i < n); + uint32_t rank = code[i] + 1; + + // Extract i-th unused element via implicit order-statistics tree. + size_t bit = padded_n; + size_t next = 0; + for (size_t i = 0; i <= log2n; i++) { + const size_t cand = next + bit; + JXL_DASSERT(cand >= 1); + bit >>= 1; + if (temp[cand - 1] < rank) { + next = cand; + rank -= temp[cand - 1]; + } + } + + permutation[i] = next; + + // Mark as used + next += 1; + while (next <= padded_n) { + temp[next - 1] -= 1; + next += ValueOfLowest1Bit(next); + } + } +} + +} // namespace jxl + +#endif // LIB_JXL_LEHMER_CODE_H_ diff --git a/media/libjxl/src/lib/jxl/lehmer_code_test.cc b/media/libjxl/src/lib/jxl/lehmer_code_test.cc new file mode 100644 index 000000000..74109c85b --- /dev/null +++ b/media/libjxl/src/lib/jxl/lehmer_code_test.cc @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/lehmer_code.h" + +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/thread_pool_internal.h" + +namespace jxl { +namespace { + +template +struct WorkingSet { + explicit WorkingSet(size_t max_n) + : padded_n(1ull << CeilLog2Nonzero(max_n + 1)), + permutation(max_n), + temp(padded_n), + lehmer(max_n), + decoded(max_n) {} + + size_t padded_n; + std::vector permutation; + std::vector temp; + std::vector lehmer; + std::vector decoded; +}; + +template +void Roundtrip(size_t n, WorkingSet* ws) { + JXL_ASSERT(n != 0); + const size_t padded_n = 1ull << CeilLog2Nonzero(n); + + Rng rng(n * 65537 + 13); + + // Ensure indices fit into PermutationT + EXPECT_LE(n, 1ULL << (sizeof(PermutationT) * 8)); + + std::iota(ws->permutation.begin(), ws->permutation.begin() + n, 0); + + // For various random permutations: + for (size_t rep = 0; rep < 3; ++rep) { + rng.Shuffle(ws->permutation.data(), n); + + // Must decode to the same permutation + ComputeLehmerCode(ws->permutation.data(), ws->temp.data(), n, + ws->lehmer.data()); + memset(ws->temp.data(), 0, padded_n * 4); + DecodeLehmerCode(ws->lehmer.data(), ws->temp.data(), n, ws->decoded.data()); + + for (size_t i = 0; i < n; ++i) { + EXPECT_EQ(ws->permutation[i], ws->decoded[i]); + } + } +} + +// Preallocates arrays and tests n = [begin, end). +template +void RoundtripSizeRange(ThreadPool* pool, uint32_t begin, uint32_t end) { + ASSERT_NE(0u, begin); // n = 0 not allowed. + std::vector> working_sets; + + JXL_CHECK(RunOnPool( + pool, begin, end, + [&working_sets, end](const size_t num_threads) { + for (size_t i = 0; i < num_threads; i++) { + working_sets.emplace_back(end - 1); + } + return true; + }, + [&working_sets](const uint32_t n, const size_t thread) { + Roundtrip(n, &working_sets[thread]); + }, + "lehmer test")); +} + +TEST(LehmerCodeTest, TestRoundtrips) { + ThreadPoolInternal pool(8); + + RoundtripSizeRange(&pool, 1, 1026); + + // Ensures PermutationT can fit > 16 bit values. + RoundtripSizeRange(&pool, 65536, 65540); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/libjxl.pc.in b/media/libjxl/src/lib/jxl/libjxl.pc.in new file mode 100644 index 000000000..4a7af65b7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/libjxl.pc.in @@ -0,0 +1,13 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@PKGCONFIG_TARGET_LIBS@ +includedir=@PKGCONFIG_TARGET_INCLUDES@ + +Name: libjxl +Description: Loads and saves JPEG XL files +Version: @JPEGXL_LIBRARY_VERSION@ +Requires.private: @JPEGXL_LIBRARY_REQUIRES@ +Libs: -L${libdir} -ljxl +Libs.private: -lm +Cflags: -I${includedir} +Cflags.private: -DJXL_STATIC_DEFINE diff --git a/media/libjxl/src/lib/jxl/linalg.cc b/media/libjxl/src/lib/jxl/linalg.cc new file mode 100644 index 000000000..61d66dd8d --- /dev/null +++ b/media/libjxl/src/lib/jxl/linalg.cc @@ -0,0 +1,235 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/linalg.h" + +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +void AssertSymmetric(const ImageD& A) { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(A.xsize() == A.ysize()); + for (size_t i = 0; i < A.xsize(); ++i) { + for (size_t j = i + 1; j < A.xsize(); ++j) { + JXL_ASSERT(std::abs(A.Row(i)[j] - A.Row(j)[i]) < 1e-15); + } + } +#endif +} + +void Diagonalize2x2(const double a0, const double a1, const double b, double* c, + double* s) { + if (std::abs(b) < 1e-15) { + *c = 1.0; + *s = 0.0; + return; + } + double phi = std::atan2(2 * b, a1 - a0); + double theta = b > 0.0 ? 0.5 * phi : 0.5 * phi + Pi(1.0); + *c = std::cos(theta); + *s = std::sin(theta); +} + +void GivensRotation(const double x, const double y, double* c, double* s) { + if (y == 0.0) { + *c = x < 0.0 ? -1.0 : 1.0; + *s = 0.0; + } else { + const double h = hypot(x, y); + const double d = 1.0 / h; + *c = x * d; + *s = -y * d; + } +} + +void RotateMatrixCols(ImageD* const JXL_RESTRICT U, int i, int j, double c, + double s) { + JXL_ASSERT(U->xsize() == U->ysize()); + const size_t N = U->xsize(); + double* const JXL_RESTRICT u_i = U->Row(i); + double* const JXL_RESTRICT u_j = U->Row(j); + std::vector rot_i, rot_j; + rot_i.reserve(N); + rot_j.reserve(N); + for (size_t k = 0; k < N; ++k) { + rot_i.push_back(u_i[k] * c - u_j[k] * s); + rot_j.push_back(u_i[k] * s + u_j[k] * c); + } + for (size_t k = 0; k < N; ++k) { + u_i[k] = rot_i[k]; + u_j[k] = rot_j[k]; + } +} +void HouseholderReflector(const size_t N, const double* x, double* u) { + const double sigma = x[0] <= 0.0 ? 1.0 : -1.0; + u[0] = x[0] - sigma * std::sqrt(DotProduct(N, x, x)); + for (size_t k = 1; k < N; ++k) { + u[k] = x[k]; + } + double u_norm = 1.0 / std::sqrt(DotProduct(N, u, u)); + for (size_t k = 0; k < N; ++k) { + u[k] *= u_norm; + } +} + +void ConvertToTridiagonal(const ImageD& A, ImageD* const JXL_RESTRICT T, + ImageD* const JXL_RESTRICT U) { + AssertSymmetric(A); + const size_t N = A.xsize(); + *U = Identity(A.xsize()); + *T = CopyImage(A); + std::vector u_stack; + for (size_t k = 0; k + 2 < N; ++k) { + if (DotProduct(N - k - 2, &T->Row(k)[k + 2], &T->Row(k)[k + 2]) > 1e-15) { + ImageD u(N, 1); + ZeroFillImage(&u); + HouseholderReflector(N - k - 1, &T->Row(k)[k + 1], &u.Row(0)[k + 1]); + ImageD v = MatMul(*T, u); + double scale = DotProduct(u, v); + v = LinComb(2.0, v, -2.0 * scale, u); + SubtractFrom(MatMul(u, Transpose(v)), T); + SubtractFrom(MatMul(v, Transpose(u)), T); + u_stack.emplace_back(std::move(u)); + } + } + while (!u_stack.empty()) { + const ImageD& u = u_stack.back(); + ImageD v = MatMul(Transpose(*U), u); + SubtractFrom(ScaleImage(2.0, MatMul(u, Transpose(v))), U); + u_stack.pop_back(); + } +} + +double WilkinsonShift(const double a0, const double a1, const double b) { + const double d = 0.5 * (a0 - a1); + if (d == 0.0) { + return a1 - std::abs(b); + } + const double sign_d = d > 0.0 ? 1.0 : -1.0; + return a1 - b * b / (d + sign_d * hypotf(d, b)); +} + +void ImplicitQRStep(ImageD* const JXL_RESTRICT U, double* const JXL_RESTRICT a, + double* const JXL_RESTRICT b, int m0, int m1) { + JXL_ASSERT(m1 - m0 > 2); + double x = a[m0] - WilkinsonShift(a[m1 - 2], a[m1 - 1], b[m1 - 1]); + double y = b[m0 + 1]; + for (int k = m0; k < m1 - 1; ++k) { + double c, s; + GivensRotation(x, y, &c, &s); + const double w = c * x - s * y; + const double d = a[k] - a[k + 1]; + const double z = (2 * c * b[k + 1] + d * s) * s; + a[k] -= z; + a[k + 1] += z; + b[k + 1] = d * c * s + (c * c - s * s) * b[k + 1]; + x = b[k + 1]; + if (k > m0) { + b[k] = w; + } + if (k < m1 - 2) { + y = -s * b[k + 2]; + b[k + 2] *= c; + } + RotateMatrixCols(U, k, k + 1, c, s); + } +} + +void ScanInterval(const double* const JXL_RESTRICT a, + const double* const JXL_RESTRICT b, int istart, + const int iend, const double eps, + std::deque >* intervals) { + for (int k = istart; k < iend; ++k) { + if ((k + 1 == iend) || + std::abs(b[k + 1]) < eps * (std::abs(a[k]) + std::abs(a[k + 1]))) { + if (k > istart) { + intervals->push_back(std::make_pair(istart, k + 1)); + } + istart = k + 1; + } + } +} + +void ConvertToDiagonal(const ImageD& A, ImageD* const JXL_RESTRICT diag, + ImageD* const JXL_RESTRICT U) { + AssertSymmetric(A); + const size_t N = A.xsize(); + ImageD T; + ConvertToTridiagonal(A, &T, U); + // From now on, the algorithm keeps the transformed matrix tri-diagonal, + // so we only need to keep track of the diagonal and the off-diagonal entries. + std::vector a(N); + std::vector b(N); + for (size_t k = 0; k < N; ++k) { + a[k] = T.Row(k)[k]; + if (k > 0) b[k] = T.Row(k)[k - 1]; + } + // Run the symmetric tri-diagonal QR algorithm with implicit Wilkinson shift. + const double kEpsilon = 1e-14; + std::deque > intervals; + ScanInterval(&a[0], &b[0], 0, N, kEpsilon, &intervals); + while (!intervals.empty()) { + const int istart = intervals[0].first; + const int iend = intervals[0].second; + intervals.pop_front(); + if (iend == istart + 2) { + double& a0 = a[istart]; + double& a1 = a[istart + 1]; + double& b1 = b[istart + 1]; + double c, s; + Diagonalize2x2(a0, a1, b1, &c, &s); + const double d = a0 - a1; + const double z = (2 * c * b1 + d * s) * s; + a0 -= z; + a1 += z; + b1 = 0.0; + RotateMatrixCols(U, istart, istart + 1, c, s); + } else { + ImplicitQRStep(U, &a[0], &b[0], istart, iend); + ScanInterval(&a[0], &b[0], istart, iend, kEpsilon, &intervals); + } + } + *diag = ImageD(N, 1); + double* const JXL_RESTRICT diag_row = diag->Row(0); + for (size_t k = 0; k < N; ++k) { + diag_row[k] = a[k]; + } +} + +void ComputeQRFactorization(const ImageD& A, ImageD* const JXL_RESTRICT Q, + ImageD* const JXL_RESTRICT R) { + JXL_ASSERT(A.xsize() == A.ysize()); + const size_t N = A.xsize(); + *Q = Identity(N); + *R = CopyImage(A); + std::vector u_stack; + for (size_t k = 0; k + 1 < N; ++k) { + if (DotProduct(N - k - 1, &R->Row(k)[k + 1], &R->Row(k)[k + 1]) > 1e-15) { + ImageD u(N, 1); + FillImage(0.0, &u); + HouseholderReflector(N - k, &R->Row(k)[k], &u.Row(0)[k]); + ImageD v = MatMul(Transpose(u), *R); + SubtractFrom(ScaleImage(2.0, MatMul(u, v)), R); + u_stack.emplace_back(std::move(u)); + } + } + while (!u_stack.empty()) { + const ImageD& u = u_stack.back(); + ImageD v = MatMul(Transpose(u), *Q); + SubtractFrom(ScaleImage(2.0, MatMul(u, v)), Q); + u_stack.pop_back(); + } +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/linalg.h b/media/libjxl/src/lib/jxl/linalg.h new file mode 100644 index 000000000..e44dd8535 --- /dev/null +++ b/media/libjxl/src/lib/jxl/linalg.h @@ -0,0 +1,295 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LINALG_H_ +#define LIB_JXL_LINALG_H_ + +// Linear algebra. + +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +using ImageD = Plane; + +template +inline T DotProduct(const size_t N, const T* const JXL_RESTRICT a, + const T* const JXL_RESTRICT b) { + T sum = 0.0; + for (size_t k = 0; k < N; ++k) { + sum += a[k] * b[k]; + } + return sum; +} + +template +inline T L2NormSquared(const size_t N, const T* const JXL_RESTRICT a) { + return DotProduct(N, a, a); +} + +template +inline T L1Norm(const size_t N, const T* const JXL_RESTRICT a) { + T sum = 0; + for (size_t k = 0; k < N; ++k) { + sum += a[k] >= 0 ? a[k] : -a[k]; + } + return sum; +} + +inline double DotProduct(const ImageD& a, const ImageD& b) { + JXL_ASSERT(a.ysize() == 1); + JXL_ASSERT(b.ysize() == 1); + JXL_ASSERT(a.xsize() == b.xsize()); + const double* const JXL_RESTRICT row_a = a.Row(0); + const double* const JXL_RESTRICT row_b = b.Row(0); + return DotProduct(a.xsize(), row_a, row_b); +} + +inline ImageD Transpose(const ImageD& A) { + ImageD out(A.ysize(), A.xsize()); + for (size_t x = 0; x < A.xsize(); ++x) { + double* const JXL_RESTRICT row_out = out.Row(x); + for (size_t y = 0; y < A.ysize(); ++y) { + row_out[y] = A.Row(y)[x]; + } + } + return out; +} + +template +Plane MatMul(const Plane& A, const Plane& B) { + JXL_ASSERT(A.ysize() == B.xsize()); + Plane out(A.xsize(), B.ysize()); + for (size_t y = 0; y < B.ysize(); ++y) { + const Tin2* const JXL_RESTRICT row_b = B.Row(y); + Tout* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < A.xsize(); ++x) { + row_out[x] = 0.0; + for (size_t k = 0; k < B.xsize(); ++k) { + row_out[x] += A.Row(k)[x] * row_b[k]; + } + } + } + return out; +} + +template +ImageD MatMul(const Plane& A, const Plane& B) { + return MatMul(A, B); +} + +template +ImageI MatMulI(const Plane& A, const Plane& B) { + return MatMul(A, B); +} + +// Computes A = B * C, with sizes rows*cols: A=ha*wa, B=wa*wb, C=ha*wb +template +void MatMul(const T* a, const T* b, int ha, int wa, int wb, T* c) { + std::vector temp(wa); // Make better use of cache lines + for (int x = 0; x < wb; x++) { + for (int z = 0; z < wa; z++) { + temp[z] = b[z * wb + x]; + } + for (int y = 0; y < ha; y++) { + double e = 0; + for (int z = 0; z < wa; z++) { + e += a[y * wa + z] * temp[z]; + } + c[y * wb + x] = e; + } + } +} + +// Computes C = A + factor * B +template +void MatAdd(const T* a, const T* b, F factor, int h, int w, T* c) { + for (int i = 0; i < w * h; i++) { + c[i] = a[i] + b[i] * factor; + } +} + +template +inline Plane Identity(const size_t N) { + Plane out(N, N); + for (size_t i = 0; i < N; ++i) { + T* JXL_RESTRICT row = out.Row(i); + std::fill(row, row + N, 0); + row[i] = static_cast(1.0); + } + return out; +} + +inline ImageD Diagonal(const ImageD& d) { + JXL_ASSERT(d.ysize() == 1); + ImageD out(d.xsize(), d.xsize()); + const double* JXL_RESTRICT row_diag = d.Row(0); + for (size_t k = 0; k < d.xsize(); ++k) { + double* JXL_RESTRICT row_out = out.Row(k); + std::fill(row_out, row_out + d.xsize(), 0.0); + row_out[k] = row_diag[k]; + } + return out; +} + +// Computes c, s such that c^2 + s^2 = 1 and +// [c -s] [x] = [ * ] +// [s c] [y] [ 0 ] +void GivensRotation(double x, double y, double* c, double* s); + +// U = U * Givens(i, j, c, s) +void RotateMatrixCols(ImageD* JXL_RESTRICT U, int i, int j, double c, double s); + +// A is symmetric, U is orthogonal, T is tri-diagonal and +// A = U * T * Transpose(U). +void ConvertToTridiagonal(const ImageD& A, ImageD* JXL_RESTRICT T, + ImageD* JXL_RESTRICT U); + +// A is symmetric, U is orthogonal, and A = U * Diagonal(diag) * Transpose(U). +void ConvertToDiagonal(const ImageD& A, ImageD* JXL_RESTRICT diag, + ImageD* JXL_RESTRICT U); + +// A is square matrix, Q is orthogonal, R is upper triangular and A = Q * R; +void ComputeQRFactorization(const ImageD& A, ImageD* JXL_RESTRICT Q, + ImageD* JXL_RESTRICT R); + +// Inverts a 3x3 matrix in place +template +Status Inv3x3Matrix(T* matrix) { + // Intermediate computation is done in double precision. + double temp[9]; + temp[0] = static_cast(matrix[4]) * matrix[8] - + static_cast(matrix[5]) * matrix[7]; + temp[1] = static_cast(matrix[2]) * matrix[7] - + static_cast(matrix[1]) * matrix[8]; + temp[2] = static_cast(matrix[1]) * matrix[5] - + static_cast(matrix[2]) * matrix[4]; + temp[3] = static_cast(matrix[5]) * matrix[6] - + static_cast(matrix[3]) * matrix[8]; + temp[4] = static_cast(matrix[0]) * matrix[8] - + static_cast(matrix[2]) * matrix[6]; + temp[5] = static_cast(matrix[2]) * matrix[3] - + static_cast(matrix[0]) * matrix[5]; + temp[6] = static_cast(matrix[3]) * matrix[7] - + static_cast(matrix[4]) * matrix[6]; + temp[7] = static_cast(matrix[1]) * matrix[6] - + static_cast(matrix[0]) * matrix[7]; + temp[8] = static_cast(matrix[0]) * matrix[4] - + static_cast(matrix[1]) * matrix[3]; + double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6]; + if (std::abs(det) < 1e-10) { + return JXL_FAILURE("Matrix determinant is too close to 0"); + } + double idet = 1.0 / det; + for (int i = 0; i < 9; i++) { + matrix[i] = temp[i] * idet; + } + return true; +} + +// Solves system of linear equations A * X = B using the conjugate gradient +// method. Matrix a must be a n*n, symmetric and positive definite. +// Vectors b and x must have n elements +template +void ConjugateGradient(const T* a, int n, const T* b, T* x) { + std::vector r(n); + MatMul(a, x, n, n, 1, r.data()); + MatAdd(b, r.data(), -1, n, 1, r.data()); + std::vector p = r; + T rr; + MatMul(r.data(), r.data(), 1, n, 1, &rr); // inner product + + if (rr == 0) return; // The initial values were already optimal + + for (int i = 0; i < n; i++) { + std::vector ap(n); + MatMul(a, p.data(), n, n, 1, ap.data()); + T alpha; + MatMul(r.data(), ap.data(), 1, n, 1, &alpha); + // Normally alpha couldn't be zero here but if numerical issues caused it, + // return assuming the solution is close. + if (alpha == 0) return; + alpha = rr / alpha; + MatAdd(x, p.data(), alpha, n, 1, x); + MatAdd(r.data(), ap.data(), -alpha, n, 1, r.data()); + + T rr2; + MatMul(r.data(), r.data(), 1, n, 1, &rr2); // inner product + if (rr2 < 1e-20) break; + + T beta = rr2 / rr; + MatAdd(r.data(), p.data(), beta, 1, n, p.data()); + rr = rr2; + } +} + +// Computes optimal coefficients r to approximate points p with linear +// combination of functions f. The matrix f has h rows and w columns, r has h +// values, p has w values. h is the amount of functions, w the amount of points. +// Uses the finite element method and minimizes mean square error. +template +void FEM(const T* f, int h, int w, const T* p, T* r) { + // Compute "Gramian" matrix G = F * F^T + // Speed up multiplication by using non-zero intervals in sparse F. + std::vector start(h); + std::vector end(h); + for (int y = 0; y < h; y++) { + start[y] = end[y] = 0; + for (int x = 0; x < w; x++) { + if (f[y * w + x] != 0) { + start[y] = x; + break; + } + } + for (int x = w - 1; x >= 0; x--) { + if (f[y * w + x] != 0) { + end[y] = x + 1; + break; + } + } + } + + std::vector g(h * h); + for (int y = 0; y < h; y++) { + for (int x = 0; x <= y; x++) { + T v = 0; + // Intersection of the two sparse intervals. + int s = std::max(start[x], start[y]); + int e = std::min(end[x], end[y]); + for (int z = s; z < e; z++) { + v += f[x * w + z] * f[y * w + z]; + } + // Symmetric, so two values output at once + g[y * h + x] = v; + g[x * h + y] = v; + } + } + + // B vector: sum of each column of F multiplied by corresponding p + std::vector b(h, 0); + for (int y = 0; y < h; y++) { + T v = 0; + for (int x = 0; x < w; x++) { + v += f[y * w + x] * p[x]; + } + b[y] = v; + } + + ConjugateGradient(g.data(), h, b.data(), r); +} + +} // namespace jxl + +#endif // LIB_JXL_LINALG_H_ diff --git a/media/libjxl/src/lib/jxl/linalg_test.cc b/media/libjxl/src/lib/jxl/linalg_test.cc new file mode 100644 index 000000000..292b984b5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/linalg_test.cc @@ -0,0 +1,146 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/linalg.h" + +#include "gtest/gtest.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +template +Plane RandomMatrix(const size_t xsize, const size_t ysize, Rng& rng, + const T vmin, const T vmax) { + Plane A(xsize, ysize); + GenerateImage(rng, &A, vmin, vmax); + return A; +} + +template +Plane RandomSymmetricMatrix(const size_t N, Rng& rng, const T vmin, + const T vmax) { + Plane A = RandomMatrix(N, N, rng, vmin, vmax); + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < i; ++j) { + A.Row(j)[i] = A.Row(i)[j]; + } + } + return A; +} +void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) { + ASSERT_EQ(A.xsize(), B.xsize()); + ASSERT_EQ(A.ysize(), B.ysize()); + for (size_t y = 0; y < A.ysize(); ++y) { + for (size_t x = 0; x < A.xsize(); ++x) { + ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps); + } + } +} + +void VerifyOrthogonal(const ImageD& A, const double eps) { + VerifyMatrixEqual(Identity(A.xsize()), MatMul(Transpose(A), A), eps); +} + +void VerifyTridiagonal(const ImageD& T, const double eps) { + ASSERT_EQ(T.xsize(), T.ysize()); + for (size_t i = 0; i < T.xsize(); ++i) { + for (size_t j = i + 2; j < T.xsize(); ++j) { + ASSERT_NEAR(T.Row(i)[j], 0.0, eps); + ASSERT_NEAR(T.Row(j)[i], 0.0, eps); + } + } +} + +void VerifyUpperTriangular(const ImageD& R, const double eps) { + ASSERT_EQ(R.xsize(), R.ysize()); + for (size_t i = 0; i < R.xsize(); ++i) { + for (size_t j = i + 1; j < R.xsize(); ++j) { + ASSERT_NEAR(R.Row(i)[j], 0.0, eps); + } + } +} + +TEST(LinAlgTest, ConvertToTridiagonal) { + { + ImageD I = Identity(5); + ImageD T, U; + ConvertToTridiagonal(I, &T, &U); + VerifyMatrixEqual(I, T, 1e-15); + VerifyMatrixEqual(I, U, 1e-15); + } + { + ImageD A = Identity(5); + A.Row(0)[1] = A.Row(1)[0] = 2.0; + A.Row(0)[4] = A.Row(4)[0] = 3.0; + A.Row(2)[3] = A.Row(3)[2] = 2.0; + A.Row(3)[4] = A.Row(4)[3] = 2.0; + ImageD U, d; + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } + Rng rng(0); + for (int N = 2; N < 100; ++N) { + ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0); + ImageD T, U; + ConvertToTridiagonal(A, &T, &U); + VerifyOrthogonal(U, 1e-12); + VerifyTridiagonal(T, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(T, Transpose(U))), 1e-12); + } +} + +TEST(LinAlgTest, ConvertToDiagonal) { + { + ImageD I = Identity(5); + ImageD U, d; + ConvertToDiagonal(I, &d, &U); + VerifyMatrixEqual(I, U, 1e-15); + for (int k = 0; k < 5; ++k) { + ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15); + } + } + { + ImageD A = Identity(5); + A.Row(0)[1] = A.Row(1)[0] = 2.0; + A.Row(2)[3] = A.Row(3)[2] = 2.0; + A.Row(3)[4] = A.Row(4)[3] = 2.0; + ImageD U, d; + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } + Rng rng(0); + for (int N = 2; N < 100; ++N) { + ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0); + ImageD U, d; + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } +} + +TEST(LinAlgTest, ComputeQRFactorization) { + { + ImageD I = Identity(5); + ImageD Q, R; + ComputeQRFactorization(I, &Q, &R); + VerifyMatrixEqual(I, Q, 1e-15); + VerifyMatrixEqual(I, R, 1e-15); + } + Rng rng(0); + for (int N = 2; N < 100; ++N) { + ImageD A = RandomMatrix(N, N, rng, -1.0, 1.0); + ImageD Q, R; + ComputeQRFactorization(A, &Q, &R); + VerifyOrthogonal(Q, 1e-12); + VerifyUpperTriangular(R, 1e-12); + VerifyMatrixEqual(A, MatMul(Q, R), 1e-12); + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/loop_filter.cc b/media/libjxl/src/lib/jxl/loop_filter.cc new file mode 100644 index 000000000..1aec0f75d --- /dev/null +++ b/media/libjxl/src/lib/jxl/loop_filter.cc @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/loop_filter.h" + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +LoopFilter::LoopFilter() { Bundle::Init(this); } +Status LoopFilter::VisitFields(Visitor* JXL_RESTRICT visitor) { + // Must come before AllDefault. + + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &gab)); + if (visitor->Conditional(gab)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &gab_custom)); + if (visitor->Conditional(gab_custom)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_x_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_x_weight2)); + if (std::abs(1.0f + (gab_x_weight1 + gab_x_weight2) * 4) < 1e-8) { + return JXL_FAILURE( + "Gaborish x weights lead to near 0 unnormalized kernel"); + } + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_y_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_y_weight2)); + if (std::abs(1.0f + (gab_y_weight1 + gab_y_weight2) * 4) < 1e-8) { + return JXL_FAILURE( + "Gaborish y weights lead to near 0 unnormalized kernel"); + } + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_b_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_b_weight2)); + if (std::abs(1.0f + (gab_b_weight1 + gab_b_weight2) * 4) < 1e-8) { + return JXL_FAILURE( + "Gaborish b weights lead to near 0 unnormalized kernel"); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 2, &epf_iters)); + if (visitor->Conditional(epf_iters > 0)) { + if (visitor->Conditional(!nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_sharp_custom)); + if (visitor->Conditional(epf_sharp_custom)) { + for (size_t i = 0; i < kEpfSharpEntries; ++i) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16( + float(i) / float(kEpfSharpEntries - 1), &epf_sharp_lut[i])); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_weight_custom)); + if (visitor->Conditional(epf_weight_custom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(40.0f, &epf_channel_scale[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(5.0f, &epf_channel_scale[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(3.5f, &epf_channel_scale[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.45f, &epf_pass1_zeroflush)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.6f, &epf_pass2_zeroflush)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_sigma_custom)); + if (visitor->Conditional(epf_sigma_custom)) { + if (visitor->Conditional(!nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.46f, &epf_quant_mul)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.9f, &epf_pass0_sigma_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(6.5f, &epf_pass2_sigma_scale)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(0.6666666666666666f, &epf_border_sad_mul)); + } + if (visitor->Conditional(nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.0f, &epf_sigma_for_modular)); + if (epf_sigma_for_modular < 1e-8) { + return JXL_FAILURE("EPF: sigma for modular is too small"); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/loop_filter.h b/media/libjxl/src/lib/jxl/loop_filter.h new file mode 100644 index 000000000..62501670f --- /dev/null +++ b/media/libjxl/src/lib/jxl/loop_filter.h @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LOOP_FILTER_H_ +#define LIB_JXL_LOOP_FILTER_H_ + +// Parameters for loop filter(s), stored in each frame. + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +struct LoopFilter : public Fields { + LoopFilter(); + JXL_FIELDS_NAME(LoopFilter) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + size_t Padding() const { + static const size_t padding_per_epf_iter[4] = {0, 2, 3, 6}; + return padding_per_epf_iter[epf_iters] + (gab ? 1 : 0); + } + + mutable bool all_default; + + // --- Gaborish convolution + bool gab; + + bool gab_custom; + float gab_x_weight1; + float gab_x_weight2; + float gab_y_weight1; + float gab_y_weight2; + float gab_b_weight1; + float gab_b_weight2; + + // --- Edge-preserving filter + + // Number of EPF stages to apply. 0 means EPF disabled. 1 applies only the + // first stage, 2 applies both stages and 3 applies the first stage twice and + // the second stage once. + uint32_t epf_iters; + + bool epf_sharp_custom; + enum { kEpfSharpEntries = 8 }; + float epf_sharp_lut[kEpfSharpEntries]; + + bool epf_weight_custom; // Custom weight params + float epf_channel_scale[3]; // Relative weight of each channel + float epf_pass1_zeroflush; // Minimum weight for first pass + float epf_pass2_zeroflush; // Minimum weight for second pass + + bool epf_sigma_custom; // Custom sigma parameters + float epf_quant_mul; // Sigma is ~ this * quant + float epf_pass0_sigma_scale; // Multiplier for sigma in pass 0 + float epf_pass2_sigma_scale; // Multiplier for sigma in the second pass + float epf_border_sad_mul; // (inverse) multiplier for sigma on borders + + float epf_sigma_for_modular; + + uint64_t extensions; + + bool nonserialized_is_modular = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_LOOP_FILTER_H_ diff --git a/media/libjxl/src/lib/jxl/luminance.cc b/media/libjxl/src/lib/jxl/luminance.cc new file mode 100644 index 000000000..d5ce75a1b --- /dev/null +++ b/media/libjxl/src/lib/jxl/luminance.cc @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/luminance.h" + +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +void SetIntensityTarget(CodecInOut* io) { SetIntensityTarget(&io->metadata.m); } + +void SetIntensityTarget(ImageMetadata* m) { + if (m->color_encoding.tf.IsPQ()) { + // Peak luminance of PQ as defined by SMPTE ST 2084:2014. + m->SetIntensityTarget(10000); + } else if (m->color_encoding.tf.IsHLG()) { + // Nominal display peak luminance used as a reference by + // Rec. ITU-R BT.2100-2. + m->SetIntensityTarget(1000); + } else { + // SDR + m->SetIntensityTarget(kDefaultIntensityTarget); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/luminance.h b/media/libjxl/src/lib/jxl/luminance.h new file mode 100644 index 000000000..92f889a92 --- /dev/null +++ b/media/libjxl/src/lib/jxl/luminance.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LUMINANCE_H_ +#define LIB_JXL_LUMINANCE_H_ + +namespace jxl { + +// Chooses a default intensity target based on the transfer function of the +// image, if known. For SDR images or images not known to be HDR, returns +// kDefaultIntensityTarget, for images known to have PQ or HLG transfer function +// returns a higher value. +class CodecInOut; +void SetIntensityTarget(CodecInOut* io); + +struct ImageMetadata; +void SetIntensityTarget(ImageMetadata* m); + +} // namespace jxl + +#endif // LIB_JXL_LUMINANCE_H_ diff --git a/media/libjxl/src/lib/jxl/memory_manager_internal.cc b/media/libjxl/src/lib/jxl/memory_manager_internal.cc new file mode 100644 index 000000000..87727e75c --- /dev/null +++ b/media/libjxl/src/lib/jxl/memory_manager_internal.cc @@ -0,0 +1,18 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/memory_manager_internal.h" + +#include + +namespace jxl { + +void* MemoryManagerDefaultAlloc(void* opaque, size_t size) { + return malloc(size); +} + +void MemoryManagerDefaultFree(void* opaque, void* address) { free(address); } + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/memory_manager_internal.h b/media/libjxl/src/lib/jxl/memory_manager_internal.h new file mode 100644 index 000000000..b4a78903f --- /dev/null +++ b/media/libjxl/src/lib/jxl/memory_manager_internal.h @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ +#define LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ + +// Memory allocator with support for alignment + misalignment. + +#include +#include +#include +#include // memcpy + +#include +#include + +#include "jxl/memory_manager.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Default alloc and free functions. +void* MemoryManagerDefaultAlloc(void* opaque, size_t size); +void MemoryManagerDefaultFree(void* opaque, void* address); + +// Initializes the memory manager instance with the passed one. The +// MemoryManager passed in |memory_manager| may be NULL or contain NULL +// functions which will be initialized with the default ones. If either alloc +// or free are NULL, then both must be NULL, otherwise this function returns an +// error. +static JXL_INLINE Status MemoryManagerInit( + JxlMemoryManager* self, const JxlMemoryManager* memory_manager) { + if (memory_manager) { + *self = *memory_manager; + } else { + memset(self, 0, sizeof(*self)); + } + if (!self->alloc != !self->free) { + return false; + } + if (!self->alloc) self->alloc = jxl::MemoryManagerDefaultAlloc; + if (!self->free) self->free = jxl::MemoryManagerDefaultFree; + + return true; +} + +static JXL_INLINE void* MemoryManagerAlloc( + const JxlMemoryManager* memory_manager, size_t size) { + return memory_manager->alloc(memory_manager->opaque, size); +} + +static JXL_INLINE void MemoryManagerFree(const JxlMemoryManager* memory_manager, + void* address) { + return memory_manager->free(memory_manager->opaque, address); +} + +// Helper class to be used as a deleter in a unique_ptr call. +class MemoryManagerDeleteHelper { + public: + explicit MemoryManagerDeleteHelper(const JxlMemoryManager* memory_manager) + : memory_manager_(memory_manager) {} + + // Delete and free the passed pointer using the memory_manager. + template + void operator()(T* address) const { + if (!address) { + return; + } + address->~T(); + return memory_manager_->free(memory_manager_->opaque, address); + } + + private: + const JxlMemoryManager* memory_manager_; +}; + +template +using MemoryManagerUniquePtr = std::unique_ptr; + +// Creates a new object T allocating it with the memory allocator into a +// unique_ptr. +template +JXL_INLINE MemoryManagerUniquePtr MemoryManagerMakeUnique( + const JxlMemoryManager* memory_manager, Args&&... args) { + T* mem = + static_cast(memory_manager->alloc(memory_manager->opaque, sizeof(T))); + if (!mem) { + // Allocation error case. + return MemoryManagerUniquePtr(nullptr, + MemoryManagerDeleteHelper(memory_manager)); + } + return MemoryManagerUniquePtr(new (mem) T(std::forward(args)...), + MemoryManagerDeleteHelper(memory_manager)); +} + +} // namespace jxl + +#endif // LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/context_predict.h b/media/libjxl/src/lib/jxl/modular/encoding/context_predict.h new file mode 100644 index 000000000..914cd6a4e --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/context_predict.h @@ -0,0 +1,626 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ +#define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ + +#include +#include + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace weighted { +constexpr static size_t kNumPredictors = 4; +constexpr static int64_t kPredExtraBits = 3; +constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; +constexpr static size_t kNumProperties = 1; + +struct Header : public Fields { + JXL_FIELDS_NAME(WeightedPredictorHeader) + // TODO(janwas): move to cc file, avoid including fields.h. + Header() { Bundle::Init(this); } + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + auto visit_p = [visitor](pixel_type val, pixel_type *p) { + uint32_t up = *p; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); + *p = up; + return Status(true); + }; + JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); + return true; + } + + bool all_default; + pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; + uint32_t w[kNumPredictors] = {}; +}; + +struct State { + pixel_type_w prediction[kNumPredictors] = {}; + pixel_type_w pred = 0; // *before* removing the added bits. + std::vector pred_errors[kNumPredictors]; + std::vector error; + const Header header; + + // Allows to approximate division by a number from 1 to 64. + uint32_t divlookup[64]; + + constexpr static pixel_type_w AddBits(pixel_type_w x) { + return uint64_t(x) << kPredExtraBits; + } + + State(Header header, size_t xsize, size_t ysize) : header(header) { + // Extra margin to avoid out-of-bounds writes. + // All have space for two rows of data. + for (size_t i = 0; i < 4; i++) { + pred_errors[i].resize((xsize + 2) * 2); + } + error.resize((xsize + 2) * 2); + // Initialize division lookup table. + for (int i = 0; i < 64; i++) { + divlookup[i] = (1 << 24) / (i + 1); + } + } + + // Approximates 4+(maxweight<<24)/(x+1), avoiding division + JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { + int shift = static_cast(FloorLog2Nonzero(x + 1)) - 5; + if (shift < 0) shift = 0; + return 4 + ((maxweight * divlookup[x >> shift]) >> shift); + } + + // Approximates the weighted average of the input values with the given + // weights, avoiding division. Weights must sum to at least 16. + JXL_INLINE pixel_type_w + WeightedAverage(const pixel_type_w *JXL_RESTRICT p, + std::array w) const { + uint32_t weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + weight_sum += w[i]; + } + JXL_DASSERT(weight_sum > 15); + uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. + weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + w[i] >>= log_weight - 4; + weight_sum += w[i]; + } + // for rounding. + pixel_type_w sum = (weight_sum >> 1) - 1; + for (size_t i = 0; i < kNumPredictors; i++) { + sum += p[i] * w[i]; + } + return (sum * divlookup[weight_sum - 1]) >> 24; + } + + template + JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, + pixel_type_w N, pixel_type_w W, + pixel_type_w NE, pixel_type_w NW, + pixel_type_w NN, Properties *properties, + size_t offset) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + size_t pos_N = prev_row + x; + size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; + size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; + std::array weights; + for (size_t i = 0; i < kNumPredictors; i++) { + // pred_errors[pos_N] also contains the error of pixel W. + // pred_errors[pos_NW] also contains the error of pixel WW. + weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + + pred_errors[i][pos_NW]; + weights[i] = ErrorWeight(weights[i], header.w[i]); + } + + N = AddBits(N); + W = AddBits(W); + NE = AddBits(NE); + NW = AddBits(NW); + NN = AddBits(NN); + + pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; + pixel_type_w teN = error[pos_N]; + pixel_type_w teNW = error[pos_NW]; + pixel_type_w sumWN = teN + teW; + pixel_type_w teNE = error[pos_NE]; + + if (compute_properties) { + pixel_type_w p = teW; + if (std::abs(teN) > std::abs(p)) p = teN; + if (std::abs(teNW) > std::abs(p)) p = teNW; + if (std::abs(teNE) > std::abs(p)) p = teNE; + (*properties)[offset++] = p; + } + + prediction[0] = W + NE - N; + prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); + prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); + prediction[3] = + N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + + (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> + 5); + + pred = WeightedAverage(prediction, weights); + + // If all three have the same sign, skip clamping. + if (((teN ^ teW) | (teN ^ teNW)) > 0) { + return (pred + kPredictionRound) >> kPredExtraBits; + } + + // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). + pixel_type_w mx = std::max(W, std::max(NE, N)); + pixel_type_w mn = std::min(W, std::min(NE, N)); + pred = std::max(mn, std::min(mx, pred)); + return (pred + kPredictionRound) >> kPredExtraBits; + } + + JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, + size_t xsize) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + val = AddBits(val); + error[cur_row + x] = pred - val; + for (size_t i = 0; i < kNumPredictors; i++) { + pixel_type_w err = + (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; + // For predicting in the next row. + pred_errors[i][cur_row + x] = err; + // Add the error on this pixel to the error on the NE pixel. This has the + // effect of adding the error on this pixel to the E and EE pixels. + pred_errors[i][prev_row + x + 1] += err; + } + } +}; + +// Encoder helper function to set the parameters to some presets. +inline void PredictorMode(int i, Header *header) { + switch (i) { + case 0: + // ~ lossless16 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 10; + header->p3Ca = 7; + header->p3Cb = 7; + header->p3Cc = 7; + header->p3Cd = 0; + header->p3Ce = 0; + break; + case 1: + // ~ default lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xb; + header->p1C = 8; + header->p2C = 8; + header->p3Ca = 4; + header->p3Cb = 0; + header->p3Cc = 3; + header->p3Cd = 23; + header->p3Ce = 2; + break; + case 2: + // ~ west lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xd; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 9; + header->p3Ca = 7; + header->p3Cb = 0; + header->p3Cc = 0; + header->p3Cd = 16; + header->p3Ce = 9; + break; + case 3: + // ~ north lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xd; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 8; + header->p3Ca = 0; + header->p3Cb = 16; + header->p3Cc = 0; + header->p3Cd = 23; + header->p3Ce = 0; + break; + case 4: + default: + // something else, because why not + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 10; + header->p3Ca = 5; + header->p3Cb = 5; + header->p3Cc = 5; + header->p3Cd = 12; + header->p3Ce = 4; + break; + } +} +} // namespace weighted + +// Stores a node and its two children at the same time. This significantly +// reduces the number of branches needed during decoding. +struct FlatDecisionNode { + // Property + splitval of the top node. + int32_t property0; // -1 if leaf. + union { + PropertyVal splitval0; + Predictor predictor; + }; + uint32_t childID; // childID is ctx id if leaf. + // Property+splitval of the two child nodes. + union { + PropertyVal splitvals[2]; + int32_t multiplier; + }; + union { + int32_t properties[2]; + int64_t predictor_offset; + }; +}; +using FlatTree = std::vector; + +class MATreeLookup { + public: + explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} + struct LookupResult { + uint32_t context; + Predictor predictor; + int64_t offset; + int32_t multiplier; + }; + JXL_INLINE LookupResult Lookup(const Properties &properties) const { + uint32_t pos = 0; + while (true) { + const FlatDecisionNode &node = nodes_[pos]; + if (node.property0 < 0) { + return {node.childID, node.predictor, node.predictor_offset, + node.multiplier}; + } + bool p0 = properties[node.property0] <= node.splitval0; + uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; + uint32_t off1 = + 2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0); + pos = node.childID + (p0 ? off1 : off0); + } + } + + private: + const FlatTree &nodes_; +}; + +static constexpr size_t kExtraPropsPerChannel = 4; +static constexpr size_t kNumNonrefProperties = + kNumStaticProperties + 13 + weighted::kNumProperties; + +constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; +constexpr size_t kGradientProp = 9; + +// Clamps gradient to the min/max of n, w (and l, implicitly). +static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, + const int32_t l) { + const int32_t m = std::min(n, w); + const int32_t M = std::max(n, w); + // The end result of this operation doesn't overflow or underflow if the + // result is between m and M, but the intermediate value may overflow, so we + // do the intermediate operations in uint32_t and check later if we had an + // overflow or underflow condition comparing m, M and l directly. + // grad = M + m - l = n + w - l + const int32_t grad = + static_cast(static_cast(n) + static_cast(w) - + static_cast(l)); + // We use two sets of ternary operators to force the evaluation of them in + // any case, allowing the compiler to avoid branches and use cmovl/cmovg in + // x86. + const int32_t grad_clamp_M = (l < m) ? M : grad; + return (l > M) ? m : grad_clamp_M; +} + +inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { + pixel_type_w p = a + b - c; + pixel_type_w pa = std::abs(p - a); + pixel_type_w pb = std::abs(p - b); + return pa < pb ? a : b; +} + +inline void PrecomputeReferences(const Channel &ch, size_t y, + const Image &image, uint32_t i, + Channel *references) { + ZeroFillImage(&references->plane); + uint32_t offset = 0; + size_t num_extra_props = references->w; + intptr_t onerow = references->plane.PixelsPerRow(); + for (int32_t j = static_cast(i) - 1; + j >= 0 && offset < num_extra_props; j--) { + if (image.channel[j].w != image.channel[i].w || + image.channel[j].h != image.channel[i].h) { + continue; + } + if (image.channel[j].hshift != image.channel[i].hshift) continue; + if (image.channel[j].vshift != image.channel[i].vshift) continue; + pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; + const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); + const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); + for (size_t x = 0; x < ch.w; x++, rp += onerow) { + pixel_type_w v = rpp[x]; + rp[0] = std::abs(v); + rp[1] = v; + pixel_type_w vleft = (x ? rpp[x - 1] : 0); + pixel_type_w vtop = (y ? rpprev[x] : vleft); + pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); + pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); + rp[2] = std::abs(v - vpredicted); + rp[3] = v - vpredicted; + } + + offset += kExtraPropsPerChannel; + } +} + +struct PredictionResult { + int context = 0; + pixel_type_w guess = 0; + Predictor predictor; + int32_t multiplier; +}; + +inline void InitPropsRow( + Properties *p, + const std::array &static_props, + const int y) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + (*p)[i] = static_props[i]; + } + (*p)[2] = y; + (*p)[9] = 0; // local gradient. +} + +namespace detail { +enum PredictorMode { + kUseTree = 1, + kUseWP = 2, + kForceComputeProperties = 4, + kAllPredictions = 8, + kNoEdgeCases = 16 +}; + +JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, + pixel_type_w top, pixel_type_w toptop, + pixel_type_w topleft, pixel_type_w topright, + pixel_type_w leftleft, + pixel_type_w toprightright, + pixel_type_w wp_pred) { + switch (p) { + case Predictor::Zero: + return pixel_type_w{0}; + case Predictor::Left: + return left; + case Predictor::Top: + return top; + case Predictor::Select: + return Select(left, top, topleft); + case Predictor::Weighted: + return wp_pred; + case Predictor::Gradient: + return pixel_type_w{ClampedGradient(left, top, topleft)}; + case Predictor::TopLeft: + return topleft; + case Predictor::TopRight: + return topright; + case Predictor::LeftLeft: + return leftleft; + case Predictor::Average0: + return (left + top) / 2; + case Predictor::Average1: + return (left + topleft) / 2; + case Predictor::Average2: + return (topleft + top) / 2; + case Predictor::Average3: + return (top + topright) / 2; + case Predictor::Average4: + return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + + 1 * toprightright + 3 * topright + 8) / + 16; + default: + return pixel_type_w{0}; + } +} + +template +JXL_INLINE PredictionResult Predict( + Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, + const MATreeLookup *lookup, const Channel *references, + weighted::State *wp_state, pixel_type_w *predictions) { + // We start in position 3 because of 2 static properties + y. + size_t offset = 3; + constexpr bool compute_properties = + mode & kUseTree || mode & kForceComputeProperties; + constexpr bool nec = mode & kNoEdgeCases; + pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); + pixel_type_w top = (nec || y ? pp[-onerow] : left); + pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); + pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); + pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); + pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); + pixel_type_w toprightright = + (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); + + if (compute_properties) { + // location + (*p)[offset++] = x; + // neighbors + (*p)[offset++] = std::abs(top); + (*p)[offset++] = std::abs(left); + (*p)[offset++] = top; + (*p)[offset++] = left; + + // local gradient + (*p)[offset] = left - (*p)[offset + 1]; + offset++; + // local gradient + (*p)[offset++] = left + top - topleft; + + // FFV1 context properties + (*p)[offset++] = left - topleft; + (*p)[offset++] = topleft - top; + (*p)[offset++] = top - topright; + (*p)[offset++] = top - toptop; + (*p)[offset++] = left - leftleft; + } + + pixel_type_w wp_pred = 0; + if (mode & kUseWP) { + wp_pred = wp_state->Predict( + x, y, w, top, left, topright, topleft, toptop, p, offset); + } + if (!nec && compute_properties) { + offset += weighted::kNumProperties; + // Extra properties. + const pixel_type *JXL_RESTRICT rp = references->Row(x); + for (size_t i = 0; i < references->w; i++) { + (*p)[offset++] = rp[i]; + } + } + PredictionResult result; + if (mode & kUseTree) { + MATreeLookup::LookupResult lr = lookup->Lookup(*p); + result.context = lr.context; + result.guess = lr.offset; + result.multiplier = lr.multiplier; + predictor = lr.predictor; + } + if (mode & kAllPredictions) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft, + topright, leftleft, toprightright, wp_pred); + } + } + result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, + leftleft, toprightright, wp_pred); + result.predictor = predictor; + + return result; +} +} // namespace detail + +inline PredictionResult PredictNoTreeNoWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor) { + return detail::Predict( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictNoTreeWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + weighted::State *wp_state) { + return detail::Predict( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references) { + return detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} +// Only use for y > 1, x > 1, x < w-2, and empty references +JXL_INLINE PredictionResult +PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const MATreeLookup &tree_lookup, const Channel &references) { + return detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictLearn(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict( + p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, + wp_state, /*predictions=*/nullptr); +} + +inline void PredictLearnAll(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const Channel &references, + weighted::State *wp_state, + pixel_type_w *predictions) { + detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, &references, wp_state, predictions); +} + +inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + pixel_type_w *predictions) { + detail::Predict( + /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, predictions); +} +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.cc b/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.cc new file mode 100644 index 000000000..66562f7df --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.cc @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/dec_ma.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +namespace { + +Status ValidateTree( + const Tree &tree, + const std::vector> &prop_bounds, + size_t root) { + if (tree[root].property == -1) return true; + size_t p = tree[root].property; + int val = tree[root].splitval; + if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree"); + // Splitting at max value makes no sense: left range will be exactly same + // as parent, right range will be invalid (min > max). + if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree"); + auto new_bounds = prop_bounds; + new_bounds[p].first = val + 1; + JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild)); + new_bounds[p] = prop_bounds[p]; + new_bounds[p].second = val; + return ValidateTree(tree, new_bounds, tree[root].rchild); +} + +Status DecodeTree(BitReader *br, ANSSymbolReader *reader, + const std::vector &context_map, Tree *tree, + size_t tree_size_limit) { + size_t leaf_id = 0; + size_t to_decode = 1; + tree->clear(); + while (to_decode > 0) { + JXL_RETURN_IF_ERROR(br->AllReadsWithinBounds()); + if (tree->size() > tree_size_limit) { + return JXL_FAILURE("Tree is too large: %" PRIuS " nodes vs %" PRIuS + " max nodes", + tree->size(), tree_size_limit); + } + to_decode--; + uint32_t prop1 = reader->ReadHybridUint(kPropertyContext, br, context_map); + if (prop1 > 256) return JXL_FAILURE("Invalid tree property value"); + int property = prop1 - 1; + if (property == -1) { + size_t predictor = + reader->ReadHybridUint(kPredictorContext, br, context_map); + if (predictor >= kNumModularPredictors) { + return JXL_FAILURE("Invalid predictor"); + } + int64_t predictor_offset = + UnpackSigned(reader->ReadHybridUint(kOffsetContext, br, context_map)); + uint32_t mul_log = + reader->ReadHybridUint(kMultiplierLogContext, br, context_map); + if (mul_log >= 31) { + return JXL_FAILURE("Invalid multiplier logarithm"); + } + uint32_t mul_bits = + reader->ReadHybridUint(kMultiplierBitsContext, br, context_map); + if (mul_bits + 1 >= 1u << (31u - mul_log)) { + return JXL_FAILURE("Invalid multiplier"); + } + uint32_t multiplier = (mul_bits + 1U) << mul_log; + tree->emplace_back(-1, 0, leaf_id++, 0, static_cast(predictor), + predictor_offset, multiplier); + continue; + } + int splitval = + UnpackSigned(reader->ReadHybridUint(kSplitValContext, br, context_map)); + tree->emplace_back(property, splitval, tree->size() + to_decode + 1, + tree->size() + to_decode + 2, Predictor::Zero, 0, 1); + to_decode += 2; + } + std::vector> prop_bounds; + prop_bounds.resize(256, {std::numeric_limits::min(), + std::numeric_limits::max()}); + return ValidateTree(*tree, prop_bounds, 0); +} +} // namespace + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) { + std::vector tree_context_map; + ANSCode tree_code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map)); + // TODO(eustas): investigate more infinite tree cases. + if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) { + return JXL_FAILURE("Infinite tree"); + } + ANSSymbolReader reader(&tree_code, br); + JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree, + std::min(tree_size_limit, kMaxTreeSize))); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.h b/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.h new file mode 100644 index 000000000..a910c4deb --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/dec_ma.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ + +#include +#include + +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// inner nodes +struct PropertyDecisionNode { + PropertyVal splitval; + int16_t property; // -1: leaf node, lchild points to leaf node + uint32_t lchild; + uint32_t rchild; + Predictor predictor; + int64_t predictor_offset; + uint32_t multiplier; + + PropertyDecisionNode(int p, int split_val, int lchild, int rchild, + Predictor predictor, int64_t predictor_offset, + uint32_t multiplier) + : splitval(split_val), + property(p), + lchild(lchild), + rchild(rchild), + predictor(predictor), + predictor_offset(predictor_offset), + multiplier(multiplier) {} + PropertyDecisionNode() + : splitval(0), + property(-1), + lchild(0), + rchild(0), + predictor(Predictor::Zero), + predictor_offset(0), + multiplier(1) {} + static PropertyDecisionNode Leaf(Predictor predictor, int64_t offset = 0, + uint32_t multiplier = 1) { + return PropertyDecisionNode(-1, 0, 0, 0, predictor, offset, multiplier); + } + static PropertyDecisionNode Split(int p, int split_val, int lchild, + int rchild = -1) { + if (rchild == -1) rchild = lchild + 1; + return PropertyDecisionNode(p, split_val, lchild, rchild, Predictor::Zero, + 0, 1); + } +}; + +using Tree = std::vector; + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/enc_debug_tree.cc b/media/libjxl/src/lib/jxl/modular/encoding/enc_debug_tree.cc new file mode 100644 index 000000000..f2a1705e4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/enc_debug_tree.cc @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_debug_tree.h" + +#include +#include + +#include "lib/jxl/base/os_macros.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +#if JXL_OS_IOS +#define JXL_ENABLE_DOT 0 +#else +#define JXL_ENABLE_DOT 1 // iOS lacks C89 system() +#endif + +namespace jxl { + +const char *PredictorName(Predictor p) { + switch (p) { + case Predictor::Zero: + return "Zero"; + case Predictor::Left: + return "Left"; + case Predictor::Top: + return "Top"; + case Predictor::Average0: + return "Avg0"; + case Predictor::Average1: + return "Avg1"; + case Predictor::Average2: + return "Avg2"; + case Predictor::Average3: + return "Avg3"; + case Predictor::Average4: + return "Avg4"; + case Predictor::Select: + return "Sel"; + case Predictor::Gradient: + return "Grd"; + case Predictor::Weighted: + return "Wgh"; + case Predictor::TopLeft: + return "TopL"; + case Predictor::TopRight: + return "TopR"; + case Predictor::LeftLeft: + return "LL"; + default: + return "INVALID"; + }; +} + +std::string PropertyName(size_t i) { + static_assert(kNumNonrefProperties == 16, "Update this function"); + switch (i) { + case 0: + return "c"; + case 1: + return "g"; + case 2: + return "y"; + case 3: + return "x"; + case 4: + return "|N|"; + case 5: + return "|W|"; + case 6: + return "N"; + case 7: + return "W"; + case 8: + return "W-WW-NW+NWW"; + case 9: + return "W+N-NW"; + case 10: + return "W-NW"; + case 11: + return "NW-N"; + case 12: + return "N-NE"; + case 13: + return "N-NN"; + case 14: + return "W-WW"; + case 15: + return "WGH"; + default: + return "ch[" + ToString(15 - (int)i) + "]"; + } +} + +void PrintTree(const Tree &tree, const std::string &path) { + FILE *f = fopen((path + ".dot").c_str(), "w"); + fprintf(f, "graph{\n"); + for (size_t cur = 0; cur < tree.size(); cur++) { + if (tree[cur].property < 0) { + fprintf(f, "n%05" PRIuS " [label=\"%s%+" PRId64 " (x%u)\"];\n", cur, + PredictorName(tree[cur].predictor), tree[cur].predictor_offset, + tree[cur].multiplier); + } else { + fprintf(f, "n%05" PRIuS " [label=\"%s>%d\"];\n", cur, + PropertyName(tree[cur].property).c_str(), tree[cur].splitval); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].lchild); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].rchild); + } + } + fprintf(f, "}\n"); + fclose(f); +#if JXL_ENABLE_DOT + JXL_ASSERT( + system(("dot " + path + ".dot -T svg -o " + path + ".svg").c_str()) == 0); +#endif +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/encoding/enc_debug_tree.h b/media/libjxl/src/lib/jxl/modular/encoding/enc_debug_tree.h new file mode 100644 index 000000000..78deaab1b --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/enc_debug_tree.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +const char *PredictorName(Predictor p); +std::string PropertyName(size_t i); + +void PrintTree(const Tree &tree, const std::string &path); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/enc_encoding.cc b/media/libjxl/src/lib/jxl/modular/encoding/enc_encoding.cc new file mode 100644 index 000000000..eeed2aee5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/enc_encoding.cc @@ -0,0 +1,560 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_debug_tree.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Plot tree (if enabled) and predictor usage map. +constexpr bool kWantDebug = false; +constexpr bool kPrintTree = false; + +inline std::array PredictorColor(Predictor p) { + switch (p) { + case Predictor::Zero: + return {{0, 0, 0}}; + case Predictor::Left: + return {{255, 0, 0}}; + case Predictor::Top: + return {{0, 255, 0}}; + case Predictor::Average0: + return {{0, 0, 255}}; + case Predictor::Average4: + return {{192, 128, 128}}; + case Predictor::Select: + return {{255, 255, 0}}; + case Predictor::Gradient: + return {{255, 0, 255}}; + case Predictor::Weighted: + return {{0, 255, 255}}; + // TODO + default: + return {{255, 255, 255}}; + }; +} + +} // namespace + +void GatherTreeData(const Image &image, pixel_type chan, size_t group_id, + const weighted::Header &wp_header, + const ModularOptions &options, TreeSamples &tree_samples, + size_t *total_pixels) { + const Channel &channel = image.channel[chan]; + + JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w, + channel.h, chan); + + std::array static_props = { + {chan, (int)group_id}}; + Properties properties(kNumNonrefProperties + + kExtraPropsPerChannel * options.max_properties); + double pixel_fraction = std::min(1.0f, options.nb_repeats); + // a fraction of 0 is used to disable learning entirely. + if (pixel_fraction > 0) { + pixel_fraction = std::max(pixel_fraction, + std::min(1.0, 1024.0 / (channel.w * channel.h))); + } + uint64_t threshold = + (std::numeric_limits::max() >> 32) * pixel_fraction; + uint64_t s[2] = {static_cast(0x94D049BB133111EBull), + static_cast(0xBF58476D1CE4E5B9ull)}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + InitPropsRow(&properties, static_props, y); + // TODO(veluca): avoid computing WP if we don't use its property or + // predictions. + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w pred[kNumModularPredictors]; + if (tree_samples.NumPredictors() != 1) { + PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references, + &wp_state, pred); + } else { + pred[static_cast(tree_samples.PredictorFromIndex(0))] = + PredictLearn(&properties, channel.w, p + x, onerow, x, y, + tree_samples.PredictorFromIndex(0), references, + &wp_state) + .guess; + } + (*total_pixels)++; + if (use_sample()) { + tree_samples.AddSample(p[x], properties, pred); + } + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } +} + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector &multiplier_info = {}, + StaticPropRange static_prop_range = {}) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (static_prop_range[i][1] == 0) { + static_prop_range[i][1] = std::numeric_limits::max(); + } + } + if (!tree_samples.HasSamples()) { + Tree tree; + tree.emplace_back(); + tree.back().predictor = tree_samples.PredictorFromIndex(0); + tree.back().property = -1; + tree.back().predictor_offset = 0; + tree.back().multiplier = 1; + return tree; + } + float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels; + float required_cost = pixel_fraction * 0.9 + 0.1; + tree_samples.AllSamplesDone(); + Tree tree; + ComputeBestTree(tree_samples, + options.splitting_heuristics_node_threshold * required_cost, + multiplier_info, static_prop_range, + options.fast_decode_multiplier, &tree); + return tree; +} + +Status EncodeModularChannelMAANS(const Image &image, pixel_type chan, + const weighted::Header &wp_header, + const Tree &global_tree, Token **tokenpp, + AuxOut *aux_out, size_t group_id, + bool skip_encoder_fast_path) { + const Channel &channel = image.channel[chan]; + Token *tokenp = *tokenpp; + JXL_ASSERT(channel.w != 0 && channel.h != 0); + + Image3F predictor_img; + if (kWantDebug) predictor_img = Image3F(channel.w, channel.h); + + JXL_DEBUG_V(6, + "Encoding %" PRIuS "x%" PRIuS + " channel %d, " + "(shift=%i,%i)", + channel.w, channel.h, chan, channel.hshift, channel.vshift); + + std::array static_props = { + {chan, (int)group_id}}; + bool use_wp, is_wp_only; + bool is_gradient_only; + size_t num_props; + FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp, + &is_wp_only, &is_gradient_only); + Properties properties(num_props); + MATreeLookup tree_lookup(tree); + JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size()); + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint16_t context_lookup[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets); + } + if (is_gradient_only) { + is_gradient_only = TreeToLookupTable(tree, context_lookup, offsets); + } + + if (is_wp_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(Predictor::Weighted)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + int32_t residual = r[x] - guess; + *tokenp++ = Token(tree[0].childID, PackSigned(residual)); + } + } + } else if (is_gradient_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min( + std::max(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(Predictor::Zero)[c]), + &predictor_img.Plane(c)); + } + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + *tokenp++ = Token(tree[0].childID, PackSigned(p[x])); + } + } + } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted && + (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 && + tree[0].predictor_offset == 0 && !skip_encoder_fast_path) { + // multiplier is a power of 2. + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(tree[0].predictor)[c]), + &predictor_img.Plane(c)); + } + uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x, + y, tree[0].predictor); + pixel_type_w residual = r[x] - pred.guess; + JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual); + *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift)); + } + } + + } else if (!use_wp && !skip_encoder_fast_path) { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + } + } + } else { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + if (kWantDebug && WantDebugOutput(aux_out)) { + aux_out->DumpImage( + ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(), + predictor_img); + } + *tokenpp = tokenp; + return true; +} + +Status ModularEncode(const Image &image, const ModularOptions &options, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector *tokens, + size_t *width) { + if (image.error) return JXL_FAILURE("Invalid image"); + size_t nb_channels = image.channel.size(); + JXL_DEBUG_V( + 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.", + nb_channels, image.bitdepth, image.w, image.h); + + if (nb_channels < 1) { + return true; // is there any use for a zero-channel image? + } + + // encode transforms + GroupHeader header_storage; + if (header == nullptr) header = &header_storage; + Bundle::Init(header); + if (options.predictor == Predictor::Weighted) { + weighted::PredictorMode(options.wp_mode, &header->wp_header); + } + header->transforms = image.transform; + // This doesn't actually work + if (tree != nullptr) { + header->use_global_tree = true; + } + if (tree_samples == nullptr && tree == nullptr) { + JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out)); + } + + TreeSamples tree_samples_storage; + size_t total_pixels_storage = 0; + if (!total_pixels) total_pixels = &total_pixels_storage; + // If there's no tree, compute one (or gather data to). + if (tree == nullptr) { + bool gather_data = tree_samples != nullptr; + if (tree_samples == nullptr) { + JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor( + options.predictor, options.wp_tree_mode)); + JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties( + options.splitting_heuristics_properties, options.wp_tree_mode)); + std::vector pixel_samples; + std::vector diff_samples; + std::vector group_pixel_count; + std::vector channel_pixel_count; + CollectPixelSamples(image, options, 0, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples); + std::vector dummy_multiplier_info; + StaticPropRange range; + tree_samples_storage.PreQuantizeProperties( + range, dummy_multiplier_info, group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples, options.max_property_values); + } + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + GatherTreeData(image, i, group_id, header->wp_header, options, + gather_data ? *tree_samples : tree_samples_storage, + total_pixels); + } + if (gather_data) return true; + } + + JXL_ASSERT((tree == nullptr) == (tokens == nullptr)); + + Tree tree_storage; + std::vector> tokens_storage(1); + // Compute tree. + if (tree == nullptr) { + EntropyEncodingData code; + std::vector context_map; + + std::vector> tree_tokens(1); + tree_storage = + LearnTree(std::move(tree_samples_storage), *total_pixels, options); + tree = &tree_storage; + tokens = &tokens_storage[0]; + + Tree decoded_tree; + TokenizeTree(*tree, &tree_tokens[0], &decoded_tree); + JXL_ASSERT(tree->size() == decoded_tree.size()); + tree_storage = std::move(decoded_tree); + + if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) { + PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id)); + } + // Write tree + BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens, + &code, &context_map, writer, kLayerModularTree, + aux_out); + WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree, + aux_out); + } + + size_t image_width = 0; + size_t total_tokens = 0; + for (size_t i = 0; i < nb_channels; i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + if (image.channel[i].w > image_width) image_width = image.channel[i].w; + total_tokens += image.channel[i].w * image.channel[i].h; + } + if (options.zero_tokens) { + tokens->resize(tokens->size() + total_tokens, {0, 0}); + } else { + // Do one big allocation for all the tokens we'll need, + // to avoid reallocs that might require copying. + size_t pos = tokens->size(); + tokens->resize(pos + total_tokens); + Token *tokenp = tokens->data() + pos; + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS( + image, i, header->wp_header, *tree, &tokenp, aux_out, group_id, + options.skip_encoder_fast_path)); + } + // Make sure we actually wrote all tokens + JXL_CHECK(tokenp == tokens->data() + tokens->size()); + } + + // Write data if not using a global tree/ANS stream. + if (!header->use_global_tree) { + EntropyEncodingData code; + std::vector context_map; + HistogramParams histo_params; + histo_params.image_widths.push_back(image_width); + BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2, + tokens_storage, &code, &context_map, writer, layer, + aux_out); + WriteTokens(tokens_storage[0], code, context_map, writer, layer, aux_out); + } else { + *width = image_width; + } + return true; +} + +Status ModularGenericCompress(Image &image, const ModularOptions &opts, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector *tokens, + size_t *width) { + if (image.w == 0 || image.h == 0) return true; + ModularOptions options = opts; // Make a copy to modify it. + + if (options.predictor == static_cast(-1)) { + options.predictor = Predictor::Gradient; + } + + size_t bits = writer ? writer->BitsWritten() : 0; + JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer, + group_id, tree_samples, total_pixels, tree, + header, tokens, width)); + bits = writer ? writer->BitsWritten() - bits : 0; + if (writer) { + JXL_DEBUG_V(4, + "Modular-encoded a %" PRIuS "x%" PRIuS + " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes", + image.w, image.h, image.bitdepth, image.channel.size(), + bits / 8); + } + (void)bits; + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/encoding/enc_encoding.h b/media/libjxl/src/lib/jxl/modular/encoding/enc_encoding.h new file mode 100644 index 000000000..8491c932c --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/enc_encoding.h @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector &multiplier_info = {}, + StaticPropRange static_prop_range = {}); + +// TODO(veluca): make cleaner interfaces. + +Status ModularGenericCompress( + Image &image, const ModularOptions &opts, BitWriter *writer, + AuxOut *aux_out = nullptr, size_t layer = 0, size_t group_id = 0, + // For gathering data for producing a global tree. + TreeSamples *tree_samples = nullptr, size_t *total_pixels = nullptr, + // For encoding with global tree. + const Tree *tree = nullptr, GroupHeader *header = nullptr, + std::vector *tokens = nullptr, size_t *widths = nullptr); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/enc_ma.cc b/media/libjxl/src/lib/jxl/modular/encoding/enc_ma.cc new file mode 100644 index 000000000..90b11baa0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/enc_ma.cc @@ -0,0 +1,1023 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_ma.h" + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/modular/encoding/ma_common.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" +#include +#include + +#include "lib/jxl/base/random.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; + +const HWY_FULL(float) df; +const HWY_FULL(int32_t) di; +size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } + +float EstimateBits(const int32_t *counts, int32_t *rounded_counts, + size_t num_symbols) { + // Try to approximate the effect of rounding up nonzero probabilities. + int32_t total = std::accumulate(counts, counts + num_symbols, 0); + const auto min = Set(di, (total + ANS_TAB_SIZE - 1) >> ANS_LOG_TAB_SIZE); + const auto zero_i = Zero(di); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + auto counts_v = LoadU(di, &counts[i]); + counts_v = IfThenElse(Eq(counts_v, zero_i), zero_i, + IfThenElse(Lt(counts_v, min), min, counts_v)); + StoreU(counts_v, di, &rounded_counts[i]); + } + // Compute entropy of the "rounded" probabilities. + const auto zero = Zero(df); + const size_t total_scalar = + std::accumulate(rounded_counts, rounded_counts + num_symbols, 0); + const auto inv_total = Set(df, 1.0f / total_scalar); + auto bits_lanes = Zero(df); + auto total_v = Set(di, total_scalar); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + const auto counts_v = ConvertTo(df, LoadU(di, &counts[i])); + const auto round_counts_v = LoadU(di, &rounded_counts[i]); + const auto probs = Mul(ConvertTo(df, round_counts_v), inv_total); + const auto nbps = IfThenElse(Eq(round_counts_v, total_v), BitCast(di, zero), + BitCast(di, FastLog2f(df, probs))); + bits_lanes = Sub(bits_lanes, IfThenElse(Eq(counts_v, zero), zero, + Mul(counts_v, BitCast(df, nbps)))); + } + return GetLane(SumOfLanes(df, bits_lanes)); +} + +void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, + int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { + // Note that the tree splits on *strictly greater*. + (*tree)[pos].lchild = tree->size(); + (*tree)[pos].rchild = tree->size() + 1; + (*tree)[pos].splitval = splitval; + (*tree)[pos].property = property; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = rpred; + tree->back().predictor_offset = roff; + tree->back().multiplier = 1; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = lpred; + tree->back().predictor_offset = loff; + tree->back().multiplier = 1; +} + +enum class IntersectionType { kNone, kPartial, kInside }; +IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, + uint32_t &partial_axis, uint32_t &partial_val) { + bool partial = false; + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (haystack[i][0] >= needle[i][1]) { + return IntersectionType::kNone; + } + if (haystack[i][1] <= needle[i][0]) { + return IntersectionType::kNone; + } + if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { + continue; + } + partial = true; + partial_axis = i; + if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { + partial_val = haystack[i][0] - 1; + } else { + JXL_DASSERT(haystack[i][1] > needle[i][0] && + haystack[i][1] < needle[i][1]); + partial_val = haystack[i][1] - 1; + } + } + return partial ? IntersectionType::kPartial : IntersectionType::kInside; +} + +void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, + size_t end, size_t prop) { + auto cmp = [&](size_t a, size_t b) { + return int32_t(tree_samples.Property(prop, a)) - + int32_t(tree_samples.Property(prop, b)); + }; + Rng rng(0); + while (end > begin + 1) { + { + size_t pivot = rng.UniformU(begin, end); + tree_samples.Swap(begin, pivot); + } + size_t pivot_begin = begin; + size_t pivot_end = pivot_begin + 1; + for (size_t i = begin + 1; i < end; i++) { + JXL_DASSERT(i >= pivot_end); + JXL_DASSERT(pivot_end > pivot_begin); + int32_t cmp_result = cmp(i, pivot_begin); + if (cmp_result < 0) { // i < pivot, move pivot forward and put i before + // the pivot. + tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); + pivot_begin++; + pivot_end++; + } else if (cmp_result == 0) { + tree_samples.Swap(pivot_end, i); + pivot_end++; + } + } + JXL_DASSERT(pivot_begin >= begin); + JXL_DASSERT(pivot_end > pivot_begin); + JXL_DASSERT(pivot_end <= end); + for (size_t i = begin; i < pivot_begin; i++) { + JXL_DASSERT(cmp(i, pivot_begin) < 0); + } + for (size_t i = pivot_end; i < end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) > 0); + } + for (size_t i = pivot_begin; i < pivot_end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) == 0); + } + // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, + // pivot_end) is = pivot, and [pivot_end, end) is > pivot. + // If pos falls in the first or the last interval, we continue in that + // interval; otherwise, we are done. + if (pivot_begin > pos) { + end = pivot_begin; + } else if (pivot_end < pos) { + begin = pivot_end; + } else { + break; + } + } +} + +void FindBestSplit(TreeSamples &tree_samples, float threshold, + const std::vector &mul_info, + StaticPropRange initial_static_prop_range, + float fast_decode_multiplier, Tree *tree) { + struct NodeInfo { + size_t pos; + size_t begin; + size_t end; + uint64_t used_properties; + StaticPropRange static_prop_range; + }; + std::vector nodes; + nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, + initial_static_prop_range}); + + size_t num_predictors = tree_samples.NumPredictors(); + size_t num_properties = tree_samples.NumProperties(); + + // TODO(veluca): consider parallelizing the search (processing multiple nodes + // at a time). + while (!nodes.empty()) { + size_t pos = nodes.back().pos; + size_t begin = nodes.back().begin; + size_t end = nodes.back().end; + uint64_t used_properties = nodes.back().used_properties; + StaticPropRange static_prop_range = nodes.back().static_prop_range; + nodes.pop_back(); + if (begin == end) continue; + + struct SplitInfo { + size_t prop = 0; + uint32_t val = 0; + size_t pos = 0; + float lcost = std::numeric_limits::max(); + float rcost = std::numeric_limits::max(); + Predictor lpred = Predictor::Zero; + Predictor rpred = Predictor::Zero; + float Cost() { return lcost + rcost; } + }; + + SplitInfo best_split_static_constant; + SplitInfo best_split_static; + SplitInfo best_split_nonstatic; + SplitInfo best_split_nowp; + + JXL_DASSERT(begin <= end); + JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); + + // Compute the maximum token in the range. + size_t max_symbols = 0; + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + uint32_t tok = tree_samples.Token(pred, i); + max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; + } + } + max_symbols = Padded(max_symbols); + std::vector rounded_counts(max_symbols); + std::vector counts(max_symbols * num_predictors); + std::vector tot_extra_bits(num_predictors); + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + counts[pred * max_symbols + tree_samples.Token(pred, i)] += + tree_samples.Count(i); + tot_extra_bits[pred] += + tree_samples.NBits(pred, i) * tree_samples.Count(i); + } + } + + float base_bits; + { + size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); + base_bits = EstimateBits(counts.data() + pred * max_symbols, + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred]; + } + + SplitInfo *best = &best_split_nonstatic; + + SplitInfo forced_split; + // The multiplier ranges cut halfway through the current ranges of static + // properties. We do this even if the current node is not a leaf, to + // minimize the number of nodes in the resulting tree. + for (size_t i = 0; i < mul_info.size(); i++) { + uint32_t axis, val; + IntersectionType t = + BoxIntersects(static_prop_range, mul_info[i].range, axis, val); + if (t == IntersectionType::kNone) continue; + if (t == IntersectionType::kInside) { + (*tree)[pos].multiplier = mul_info[i].multiplier; + break; + } + if (t == IntersectionType::kPartial) { + forced_split.val = tree_samples.QuantizeProperty(axis, val); + forced_split.prop = axis; + forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; + forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; + best = &forced_split; + best->pos = begin; + JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); + for (size_t x = begin; x < end; x++) { + if (tree_samples.Property(best->prop, x) <= best->val) { + best->pos++; + } + } + break; + } + } + + if (best != &forced_split) { + std::vector prop_value_used_count; + std::vector count_increase; + std::vector extra_bits_increase; + // For each property, compute which of its values are used, and what + // tokens correspond to those usages. Then, iterate through the values, + // and compute the entropy of each side of the split (of the form `prop > + // threshold`). Finally, find the split that minimizes the cost. + struct CostInfo { + float cost = std::numeric_limits::max(); + float extra_cost = 0; + float Cost() const { return cost + extra_cost; } + Predictor pred; // will be uninitialized in some cases, but never used. + }; + std::vector costs_l; + std::vector costs_r; + + std::vector counts_above(max_symbols); + std::vector counts_below(max_symbols); + + // The lower the threshold, the higher the expected noisiness of the + // estimate. Thus, discourage changing predictors. + float change_pred_penalty = 800.0f / (100.0f + threshold); + for (size_t prop = 0; prop < num_properties && base_bits > threshold; + prop++) { + costs_l.clear(); + costs_r.clear(); + size_t prop_size = tree_samples.NumPropertyValues(prop); + if (extra_bits_increase.size() < prop_size) { + count_increase.resize(prop_size * max_symbols); + extra_bits_increase.resize(prop_size); + } + // Clear prop_value_used_count (which cannot be cleared "on the go") + prop_value_used_count.clear(); + prop_value_used_count.resize(prop_size); + + size_t first_used = prop_size; + size_t last_used = 0; + + // TODO(veluca): consider finding multiple splits along a single + // property at the same time, possibly with a bottom-up approach. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + prop_value_used_count[p]++; + last_used = std::max(last_used, p); + first_used = std::min(first_used, p); + } + costs_l.resize(last_used - first_used); + costs_r.resize(last_used - first_used); + // For all predictors, compute the right and left costs of each split. + for (size_t pred = 0; pred < num_predictors; pred++) { + // Compute cost and histogram increments for each property value. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + size_t cnt = tree_samples.Count(i); + size_t sym = tree_samples.Token(pred, i); + count_increase[p * max_symbols + sym] += cnt; + extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; + } + memcpy(counts_above.data(), counts.data() + pred * max_symbols, + max_symbols * sizeof counts_above[0]); + memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); + size_t extra_bits_below = 0; + // Exclude last used: this ensures neither counts_above nor + // counts_below is empty. + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + extra_bits_below += extra_bits_increase[i]; + // The increase for this property value has been used, and will not + // be used again: clear it. Also below. + extra_bits_increase[i] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + counts_above[sym] -= count_increase[i * max_symbols + sym]; + counts_below[sym] += count_increase[i * max_symbols + sym]; + count_increase[i * max_symbols + sym] = 0; + } + float rcost = EstimateBits(counts_above.data(), + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred] - extra_bits_below; + float lcost = EstimateBits(counts_below.data(), + rounded_counts.data(), max_symbols) + + extra_bits_below; + JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); + float penalty = 0; + // Never discourage moving away from the Weighted predictor. + if (tree_samples.PredictorFromIndex(pred) != + (*tree)[pos].predictor && + (*tree)[pos].predictor != Predictor::Weighted) { + penalty = change_pred_penalty; + } + // If everything else is equal, disfavour Weighted (slower) and + // favour Zero (faster if it's the only predictor used in a + // group+channel combination) + if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { + penalty += 1e-8; + } + if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { + penalty -= 1e-8; + } + if (rcost + penalty < costs_r[i - first_used].Cost()) { + costs_r[i - first_used].cost = rcost; + costs_r[i - first_used].extra_cost = penalty; + costs_r[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + if (lcost + penalty < costs_l[i - first_used].Cost()) { + costs_l[i - first_used].cost = lcost; + costs_l[i - first_used].extra_cost = penalty; + costs_l[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + } + } + // Iterate through the possible splits and find the one with minimum sum + // of costs of the two sides. + size_t split = begin; + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + split += prop_value_used_count[i]; + float rcost = costs_r[i - first_used].cost; + float lcost = costs_l[i - first_used].cost; + // WP was not used + we would use the WP property or predictor + bool adds_wp = + (tree_samples.PropertyFromIndex(prop) == kWPProp && + (used_properties & (1LU << prop)) == 0) || + ((costs_l[i - first_used].pred == Predictor::Weighted || + costs_r[i - first_used].pred == Predictor::Weighted) && + (*tree)[pos].predictor != Predictor::Weighted); + bool zero_entropy_side = rcost == 0 || lcost == 0; + + SplitInfo &best = + prop < kNumStaticProperties + ? (zero_entropy_side ? best_split_static_constant + : best_split_static) + : (adds_wp ? best_split_nonstatic : best_split_nowp); + if (lcost + rcost < best.Cost()) { + best.prop = prop; + best.val = i; + best.pos = split; + best.lcost = lcost; + best.lpred = costs_l[i - first_used].pred; + best.rcost = rcost; + best.rpred = costs_r[i - first_used].pred; + } + } + // Clear extra_bits_increase and cost_increase for last_used. + extra_bits_increase[last_used] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + count_increase[last_used * max_symbols + sym] = 0; + } + } + + // Try to avoid introducing WP. + if (best_split_nowp.Cost() + threshold < base_bits && + best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_nowp; + } + // Split along static props if possible and not significantly more + // expensive. + if (best_split_static.Cost() + threshold < base_bits && + best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_static; + } + // Split along static props to create constant nodes if possible. + if (best_split_static_constant.Cost() + threshold < base_bits) { + best = &best_split_static_constant; + } + } + + if (best->Cost() + threshold < base_bits) { + uint32_t p = tree_samples.PropertyFromIndex(best->prop); + pixel_type dequant = + tree_samples.UnquantizeProperty(best->prop, best->val); + // Split node and try to split children. + MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); + // "Sort" according to winning property + SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); + if (p >= kNumStaticProperties) { + used_properties |= 1 << best->prop; + } + auto new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(static_cast(dequant + 1) <= new_sp_range[p][1]); + new_sp_range[p][1] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, + used_properties, new_sp_range}); + new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(new_sp_range[p][0] <= static_cast(dequant + 1)); + new_sp_range[p][0] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, + used_properties, new_sp_range}); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FindBestSplit); // Local function. + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree) { + // TODO(veluca): take into account that different contexts can have different + // uint configs. + // + // Initialize tree. + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = tree_samples.PredictorFromIndex(0); + tree->back().predictor_offset = 0; + tree->back().multiplier = 1; + JXL_ASSERT(tree_samples.NumProperties() < 64); + + JXL_ASSERT(tree_samples.NumDistinctSamples() <= + std::numeric_limits::max()); + HWY_DYNAMIC_DISPATCH(FindBestSplit) + (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, + tree); +} + +constexpr int32_t TreeSamples::kPropertyRange; +constexpr uint32_t TreeSamples::kDedupEntryUnused; + +Status TreeSamples::SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode) { + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + predictors = {Predictor::Weighted}; + residuals.resize(1); + return true; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && + predictor == Predictor::Weighted) { + return JXL_FAILURE("Invalid predictor settings"); + } + if (predictor == Predictor::Variable) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictors.push_back(static_cast(i)); + } + std::swap(predictors[0], predictors[static_cast(Predictor::Weighted)]); + std::swap(predictors[1], predictors[static_cast(Predictor::Gradient)]); + } else if (predictor == Predictor::Best) { + predictors = {Predictor::Weighted, Predictor::Gradient}; + } else { + predictors = {predictor}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto wp_it = + std::find(predictors.begin(), predictors.end(), Predictor::Weighted); + if (wp_it != predictors.end()) { + predictors.erase(wp_it); + } + } + residuals.resize(predictors.size()); + return true; +} + +Status TreeSamples::SetProperties(const std::vector &properties, + ModularOptions::TreeMode wp_tree_mode) { + props_to_use = properties; + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + props_to_use = {static_cast(kWPProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { + props_to_use = {static_cast(kGradientProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); + if (it != props_to_use.end()) { + props_to_use.erase(it); + } + } + if (props_to_use.empty()) { + return JXL_FAILURE("Invalid property set configuration"); + } + props.resize(props_to_use.size()); + return true; +} + +void TreeSamples::InitTable(size_t size) { + JXL_DASSERT((size & (size - 1)) == 0); + if (dedup_table_.size() == size) return; + dedup_table_.resize(size, kDedupEntryUnused); + for (size_t i = 0; i < NumDistinctSamples(); i++) { + if (sample_counts[i] != std::numeric_limits::max()) { + AddToTable(i); + } + } +} + +bool TreeSamples::AddToTableAndMerge(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos1])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos1]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos1]] == + std::numeric_limits::max()) { + dedup_table_[pos1] = kDedupEntryUnused; + } + return true; + } + if (dedup_table_[pos2] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos2])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos2]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos2]] == + std::numeric_limits::max()) { + dedup_table_[pos2] = kDedupEntryUnused; + } + return true; + } + AddToTable(a); + return false; +} + +void TreeSamples::AddToTable(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] == kDedupEntryUnused) { + dedup_table_[pos1] = a; + } else if (dedup_table_[pos2] == kDedupEntryUnused) { + dedup_table_[pos2] = a; + } +} + +void TreeSamples::PrepareForSamples(size_t num_samples) { + for (auto &res : residuals) { + res.reserve(res.size() + num_samples); + } + for (auto &p : props) { + p.reserve(p.size() + num_samples); + } + size_t total_num_samples = num_samples + sample_counts.size(); + size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2); + InitTable(next_pow2); +} + +size_t TreeSamples::Hash1(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd; + uint64_t h = constant; + for (const auto &r : residuals) { + h = h * constant + r[a].tok; + h = h * constant + r[a].nbits; + } + for (const auto &p : props) { + h = h * constant + p[a]; + } + return (h >> 16) & (dedup_table_.size() - 1); +} +size_t TreeSamples::Hash2(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; + uint64_t h = constant; + for (const auto &p : props) { + h = h * constant ^ p[a]; + } + for (const auto &r : residuals) { + h = h * constant ^ r[a].tok; + h = h * constant ^ r[a].nbits; + } + return (h >> 16) & (dedup_table_.size() - 1); +} + +bool TreeSamples::IsSameSample(size_t a, size_t b) const { + bool ret = true; + for (const auto &r : residuals) { + if (r[a].tok != r[b].tok) { + ret = false; + } + if (r[a].nbits != r[b].nbits) { + ret = false; + } + } + for (const auto &p : props) { + if (p[a] != p[b]) { + ret = false; + } + } + return ret; +} + +void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions) { + for (size_t i = 0; i < predictors.size(); i++) { + pixel_type v = pixel - predictions[static_cast(predictors[i])]; + uint32_t tok, nbits, bits; + HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); + JXL_DASSERT(tok < 256); + JXL_DASSERT(nbits < 256); + residuals[i].emplace_back( + ResidualToken{static_cast(tok), static_cast(nbits)}); + } + for (size_t i = 0; i < props_to_use.size(); i++) { + props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); + } + sample_counts.push_back(1); + num_samples++; + if (AddToTableAndMerge(sample_counts.size() - 1)) { + for (auto &r : residuals) r.pop_back(); + for (auto &p : props) p.pop_back(); + sample_counts.pop_back(); + } +} + +void TreeSamples::Swap(size_t a, size_t b) { + if (a == b) return; + for (auto &r : residuals) { + std::swap(r[a], r[b]); + } + for (auto &p : props) { + std::swap(p[a], p[b]); + } + std::swap(sample_counts[a], sample_counts[b]); +} + +void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { + if (b == c) return Swap(a, b); + for (auto &r : residuals) { + auto tmp = r[a]; + r[a] = r[c]; + r[c] = r[b]; + r[b] = tmp; + } + for (auto &p : props) { + auto tmp = p[a]; + p[a] = p[c]; + p[c] = p[b]; + p[b] = tmp; + } + auto tmp = sample_counts[a]; + sample_counts[a] = sample_counts[c]; + sample_counts[c] = sample_counts[b]; + sample_counts[b] = tmp; +} + +namespace { +std::vector QuantizeHistogram(const std::vector &histogram, + size_t num_chunks) { + if (histogram.empty()) return {}; + // TODO(veluca): selecting distinct quantiles is likely not the best + // way to go about this. + std::vector thresholds; + size_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); + size_t cumsum = 0; + size_t threshold = 0; + for (size_t i = 0; i + 1 < histogram.size(); i++) { + cumsum += histogram[i]; + if (cumsum > (threshold + 1) * sum / num_chunks) { + thresholds.push_back(i); + while (cumsum >= (threshold + 1) * sum / num_chunks) threshold++; + } + } + return thresholds; +} + +std::vector QuantizeSamples(const std::vector &samples, + size_t num_chunks) { + if (samples.empty()) return {}; + int min = *std::min_element(samples.begin(), samples.end()); + constexpr int kRange = 512; + min = std::min(std::max(min, -kRange), kRange); + std::vector counts(2 * kRange + 1); + for (int s : samples) { + uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; + counts[sample_offset]++; + } + std::vector thresholds = QuantizeHistogram(counts, num_chunks); + for (auto &v : thresholds) v += min; + return thresholds; +} +} // namespace + +void TreeSamples::PreQuantizeProperties( + const StaticPropRange &range, + const std::vector &multiplier_info, + const std::vector &group_pixel_count, + const std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples, size_t max_property_values) { + // If we have forced splits because of multipliers, choose channel and group + // thresholds accordingly. + std::vector group_multiplier_thresholds; + std::vector channel_multiplier_thresholds; + for (const auto &v : multiplier_info) { + if (v.range[0][0] != range[0][0]) { + channel_multiplier_thresholds.push_back(v.range[0][0] - 1); + } + if (v.range[0][1] != range[0][1]) { + channel_multiplier_thresholds.push_back(v.range[0][1] - 1); + } + if (v.range[1][0] != range[1][0]) { + group_multiplier_thresholds.push_back(v.range[1][0] - 1); + } + if (v.range[1][1] != range[1][1]) { + group_multiplier_thresholds.push_back(v.range[1][1] - 1); + } + } + std::sort(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()); + channel_multiplier_thresholds.resize( + std::unique(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()) - + channel_multiplier_thresholds.begin()); + std::sort(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()); + group_multiplier_thresholds.resize( + std::unique(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()) - + group_multiplier_thresholds.begin()); + + compact_properties.resize(props_to_use.size()); + auto quantize_channel = [&]() { + if (!channel_multiplier_thresholds.empty()) { + return channel_multiplier_thresholds; + } + return QuantizeHistogram(channel_pixel_count, max_property_values); + }; + auto quantize_group_id = [&]() { + if (!group_multiplier_thresholds.empty()) { + return group_multiplier_thresholds; + } + return QuantizeHistogram(group_pixel_count, max_property_values); + }; + auto quantize_coordinate = [&]() { + std::vector quantized; + quantized.reserve(max_property_values - 1); + for (size_t i = 0; i + 1 < max_property_values; i++) { + quantized.push_back((i + 1) * 256 / max_property_values - 1); + } + return quantized; + }; + std::vector abs_pixel_thr; + std::vector pixel_thr; + auto quantize_pixel_property = [&]() { + if (pixel_thr.empty()) { + pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return pixel_thr; + }; + auto quantize_abs_pixel_property = [&]() { + if (abs_pixel_thr.empty()) { + quantize_pixel_property(); // Compute the non-abs thresholds. + for (auto &v : pixel_samples) v = std::abs(v); + abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return abs_pixel_thr; + }; + std::vector abs_diff_thr; + std::vector diff_thr; + auto quantize_diff_property = [&]() { + if (diff_thr.empty()) { + diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return diff_thr; + }; + auto quantize_abs_diff_property = [&]() { + if (abs_diff_thr.empty()) { + quantize_diff_property(); // Compute the non-abs thresholds. + for (auto &v : diff_samples) v = std::abs(v); + abs_diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return abs_diff_thr; + }; + auto quantize_wp = [&]() { + if (max_property_values < 32) { + return std::vector{-127, -63, -31, -15, -7, -3, -1, 0, + 1, 3, 7, 15, 31, 63, 127}; + } + if (max_property_values < 64) { + return std::vector{-255, -191, -127, -95, -63, -47, -31, -23, + -15, -11, -7, -5, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, + 63, 95, 127, 191, 255}; + } + return std::vector{ + -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, + -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, + -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, + 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, + 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; + }; + + property_mapping.resize(props_to_use.size()); + for (size_t i = 0; i < props_to_use.size(); i++) { + if (props_to_use[i] == 0) { + compact_properties[i] = quantize_channel(); + } else if (props_to_use[i] == 1) { + compact_properties[i] = quantize_group_id(); + } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { + compact_properties[i] = quantize_coordinate(); + } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || + props_to_use[i] == 8 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { + compact_properties[i] = quantize_pixel_property(); + } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { + compact_properties[i] = quantize_abs_pixel_property(); + } else if (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { + compact_properties[i] = quantize_abs_diff_property(); + } else if (props_to_use[i] == kWPProp) { + compact_properties[i] = quantize_wp(); + } else { + compact_properties[i] = quantize_diff_property(); + } + property_mapping[i].resize(kPropertyRange * 2 + 1); + size_t mapped = 0; + for (size_t j = 0; j < property_mapping[i].size(); j++) { + while (mapped < compact_properties[i].size() && + static_cast(j) - kPropertyRange > + compact_properties[i][mapped]) { + mapped++; + } + // property_mapping[i] of a value V is `mapped` if + // compact_properties[i][mapped] <= j and + // compact_properties[i][mapped-1] > j + // This is because the decision node in the tree splits on (property) > j, + // hence everything that is not > of a threshold should be clustered + // together. + property_mapping[i][j] = mapped; + } + } +} + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector &group_pixel_count, + std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples) { + if (options.nb_repeats == 0) return; + if (group_pixel_count.size() <= group_id) { + group_pixel_count.resize(group_id + 1); + } + if (channel_pixel_count.size() < image.channel.size()) { + channel_pixel_count.resize(image.channel.size()); + } + Rng rng(group_id); + // Sample 10% of the final number of samples for property quantization. + float fraction = std::min(options.nb_repeats * 0.1, 0.99); + Rng::GeometricDistribution dist(fraction); + size_t total_pixels = 0; + std::vector channel_ids; + for (size_t i = 0; i < image.channel.size(); i++) { + if (image.channel[i].w <= 1 || image.channel[i].h == 0) { + continue; // skip empty or width-1 channels. + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + channel_ids.push_back(i); + group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; + channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; + total_pixels += image.channel[i].w * image.channel[i].h; + } + if (channel_ids.empty()) return; + pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); + diff_samples.reserve(diff_samples.size() + fraction * total_pixels); + size_t i = 0; + size_t y = 0; + size_t x = 0; + auto advance = [&](size_t amount) { + x += amount; + // Detect row overflow (rare). + while (x >= image.channel[channel_ids[i]].w) { + x -= image.channel[channel_ids[i]].w; + y++; + // Detect end-of-channel (even rarer). + if (y == image.channel[channel_ids[i]].h) { + i++; + y = 0; + if (i >= channel_ids.size()) { + return; + } + } + } + }; + advance(rng.Geometric(dist)); + for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { + const pixel_type *row = image.channel[channel_ids[i]].Row(y); + pixel_samples.push_back(row[x]); + size_t xp = x == 0 ? 1 : x - 1; + diff_samples.push_back(row[x] - row[xp]); + } +} + +// TODO(veluca): very simple encoding scheme. This should be improved. +void TokenizeTree(const Tree &tree, std::vector *tokens, + Tree *decoder_tree) { + JXL_ASSERT(tree.size() <= kMaxTreeSize); + std::queue q; + q.push(0); + size_t leaf_id = 0; + decoder_tree->clear(); + while (!q.empty()) { + int cur = q.front(); + q.pop(); + JXL_ASSERT(tree[cur].property >= -1); + tokens->emplace_back(kPropertyContext, tree[cur].property + 1); + if (tree[cur].property == -1) { + tokens->emplace_back(kPredictorContext, + static_cast(tree[cur].predictor)); + tokens->emplace_back(kOffsetContext, + PackSigned(tree[cur].predictor_offset)); + uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); + uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; + tokens->emplace_back(kMultiplierLogContext, mul_log); + tokens->emplace_back(kMultiplierBitsContext, mul_bits); + JXL_ASSERT(tree[cur].predictor < Predictor::Best); + decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, + tree[cur].predictor_offset, + tree[cur].multiplier); + continue; + } + decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, + decoder_tree->size() + q.size() + 1, + decoder_tree->size() + q.size() + 2, + Predictor::Zero, 0, 1); + q.push(tree[cur].lchild); + q.push(tree[cur].rchild); + tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/modular/encoding/enc_ma.h b/media/libjxl/src/lib/jxl/modular/encoding/enc_ma.h new file mode 100644 index 000000000..ede37c802 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/enc_ma.h @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ + +#include + +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Struct to collect all the data needed to build a tree. +struct TreeSamples { + bool HasSamples() const { + return !residuals.empty() && !residuals[0].empty(); + } + size_t NumDistinctSamples() const { return sample_counts.size(); } + size_t NumSamples() const { return num_samples; } + // Set the predictor to use. Must be called before adding any samples. + Status SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode); + // Set the properties to use. Must be called before adding any samples. + Status SetProperties(const std::vector &properties, + ModularOptions::TreeMode wp_tree_mode); + + size_t Token(size_t pred, size_t i) const { return residuals[pred][i].tok; } + size_t NBits(size_t pred, size_t i) const { return residuals[pred][i].nbits; } + size_t Count(size_t i) const { return sample_counts[i]; } + size_t PredictorIndex(Predictor predictor) const { + const auto predictor_elem = + std::find(predictors.begin(), predictors.end(), predictor); + JXL_DASSERT(predictor_elem != predictors.end()); + return predictor_elem - predictors.begin(); + } + size_t PropertyIndex(size_t property) const { + const auto property_elem = + std::find(props_to_use.begin(), props_to_use.end(), property); + JXL_DASSERT(property_elem != props_to_use.end()); + return property_elem - props_to_use.begin(); + } + size_t NumPropertyValues(size_t property_index) const { + return compact_properties[property_index].size() + 1; + } + // Returns the *quantized* property value. + size_t Property(size_t property_index, size_t i) const { + return props[property_index][i]; + } + int UnquantizeProperty(size_t property_index, uint32_t quant) const { + JXL_ASSERT(quant < compact_properties[property_index].size()); + return compact_properties[property_index][quant]; + } + + Predictor PredictorFromIndex(size_t index) const { + JXL_DASSERT(index < predictors.size()); + return predictors[index]; + } + size_t PropertyFromIndex(size_t index) const { + JXL_DASSERT(index < props_to_use.size()); + return props_to_use[index]; + } + size_t NumPredictors() const { return predictors.size(); } + size_t NumProperties() const { return props_to_use.size(); } + + // Preallocate data for a given number of samples. MUST be called before + // adding any sample. + void PrepareForSamples(size_t num_samples); + // Add a sample. + void AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions); + // Pre-cluster property values. + void PreQuantizeProperties( + const StaticPropRange &range, + const std::vector &multiplier_info, + const std::vector &group_pixel_count, + const std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples, size_t max_property_values); + + void AllSamplesDone() { dedup_table_ = std::vector(); } + + uint32_t QuantizeProperty(uint32_t prop, pixel_type v) const { + v = std::min(std::max(v, -kPropertyRange), kPropertyRange) + kPropertyRange; + return property_mapping[prop][v]; + } + + // Swaps samples in position a and b. Does nothing if a == b. + void Swap(size_t a, size_t b); + + // Cycles samples: a -> b -> c -> a. We assume a <= b <= c, so that we can + // just call Swap(a, b) if b==c. + void ThreeShuffle(size_t a, size_t b, size_t c); + + private: + // TODO(veluca): as the total number of properties and predictors are known + // before adding any samples, it might be better to interleave predictors, + // properties and counts in a single vector to improve locality. + // A first attempt at doing this actually results in much slower encoding, + // possibly because of the more complex addressing. + struct ResidualToken { + uint8_t tok; + uint8_t nbits; + }; + // Residual information: token and number of extra bits, per predictor. + std::vector> residuals; + // Number of occurrences of each sample. + std::vector sample_counts; + // Property values, quantized to at most 256 distinct values. + std::vector> props; + // Decompactification info for `props`. + std::vector> compact_properties; + // List of properties to use. + std::vector props_to_use; + // List of predictors to use. + std::vector predictors; + // Mapping property value -> quantized property value. + static constexpr int32_t kPropertyRange = 511; + std::vector> property_mapping; + // Number of samples seen. + size_t num_samples = 0; + // Table for deduplication. + static constexpr uint32_t kDedupEntryUnused{static_cast(-1)}; + std::vector dedup_table_; + + // Functions for sample deduplication. + bool IsSameSample(size_t a, size_t b) const; + size_t Hash1(size_t a) const; + size_t Hash2(size_t a) const; + void InitTable(size_t size); + // Returns true if `a` was already present in the table. + bool AddToTableAndMerge(size_t a); + void AddToTable(size_t a); +}; + +void TokenizeTree(const Tree &tree, std::vector *tokens, + Tree *decoder_tree); + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector &group_pixel_count, + std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples); + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree); + +} // namespace jxl +#endif // LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/encoding.cc b/media/libjxl/src/lib/jxl/modular/encoding/encoding.cc new file mode 100644 index 000000000..9d2c3e5cf --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/encoding.cc @@ -0,0 +1,622 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/encoding.h" + +#include +#include + +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/scope_guard.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Removes all nodes that use a static property (i.e. channel or group ID) from +// the tree and collapses each node on even levels with its two children to +// produce a flatter tree. Also computes whether the resulting tree requires +// using the weighted predictor. +FlatTree FilterTree(const Tree &global_tree, + std::array &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only) { + *num_props = 0; + bool has_wp = false; + bool has_non_wp = false; + *gradient_only = true; + const auto mark_property = [&](int32_t p) { + if (p == kWPProp) { + has_wp = true; + } else if (p >= kNumStaticProperties) { + has_non_wp = true; + } + if (p >= kNumStaticProperties && p != kGradientProp) { + *gradient_only = false; + } + }; + FlatTree output; + std::queue nodes; + nodes.push(0); + // Produces a trimmed and flattened tree by doing a BFS visit of the original + // tree, ignoring branches that are known to be false and proceeding two + // levels at a time to collapse nodes in a flatter tree; if an inner parent + // node has a leaf as a child, the leaf is duplicated and an implicit fake + // node is added. This allows to reduce the number of branches when traversing + // the resulting flat tree. + while (!nodes.empty()) { + size_t cur = nodes.front(); + nodes.pop(); + // Skip nodes that we can decide now, by jumping directly to their children. + while (global_tree[cur].property < kNumStaticProperties && + global_tree[cur].property != -1) { + if (static_props[global_tree[cur].property] > global_tree[cur].splitval) { + cur = global_tree[cur].lchild; + } else { + cur = global_tree[cur].rchild; + } + } + FlatDecisionNode flat; + if (global_tree[cur].property == -1) { + flat.property0 = -1; + flat.childID = global_tree[cur].lchild; + flat.predictor = global_tree[cur].predictor; + flat.predictor_offset = global_tree[cur].predictor_offset; + flat.multiplier = global_tree[cur].multiplier; + *gradient_only &= flat.predictor == Predictor::Gradient; + has_wp |= flat.predictor == Predictor::Weighted; + has_non_wp |= flat.predictor != Predictor::Weighted; + output.push_back(flat); + continue; + } + flat.childID = output.size() + nodes.size() + 1; + + flat.property0 = global_tree[cur].property; + *num_props = std::max(flat.property0 + 1, *num_props); + flat.splitval0 = global_tree[cur].splitval; + + for (size_t i = 0; i < 2; i++) { + size_t cur_child = + i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild; + // Skip nodes that we can decide now. + while (global_tree[cur_child].property < kNumStaticProperties && + global_tree[cur_child].property != -1) { + if (static_props[global_tree[cur_child].property] > + global_tree[cur_child].splitval) { + cur_child = global_tree[cur_child].lchild; + } else { + cur_child = global_tree[cur_child].rchild; + } + } + // We ended up in a leaf, add a dummy decision and two copies of the leaf. + if (global_tree[cur_child].property == -1) { + flat.properties[i] = 0; + flat.splitvals[i] = 0; + nodes.push(cur_child); + nodes.push(cur_child); + } else { + flat.properties[i] = global_tree[cur_child].property; + flat.splitvals[i] = global_tree[cur_child].splitval; + nodes.push(global_tree[cur_child].lchild); + nodes.push(global_tree[cur_child].rchild); + *num_props = std::max(flat.properties[i] + 1, *num_props); + } + } + + for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]); + mark_property(flat.property0); + output.push_back(flat); + } + if (*num_props > kNumNonrefProperties) { + *num_props = + DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) * + kExtraPropsPerChannel + + kNumNonrefProperties; + } else { + *num_props = kNumNonrefProperties; + } + *use_wp = has_wp; + *wp_only = has_wp && !has_non_wp; + + return output; +} + +Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader, + const std::vector &context_map, + const Tree &global_tree, + const weighted::Header &wp_header, + pixel_type chan, size_t group_id, + Image *image) { + Channel &channel = image->channel[chan]; + + std::array static_props = { + {chan, (int)group_id}}; + // TODO(veluca): filter the tree according to static_props. + + // zero pixel channel? could happen + if (channel.w == 0 || channel.h == 0) return true; + + bool tree_has_wp_prop_or_pred = false; + bool is_wp_only = false; + bool is_gradient_only = false; + size_t num_props; + FlatTree tree = + FilterTree(global_tree, static_props, &num_props, + &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only); + + // From here on, tree lookup returns a *clustered* context ID. + // This avoids an extra memory lookup after tree traversal. + for (size_t i = 0; i < tree.size(); i++) { + if (tree[i].property0 == -1) { + tree[i].childID = context_map[tree[i].childID]; + } + } + + JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size()); + + // MAANS decode + const auto make_pixel = [](uint64_t v, pixel_type multiplier, + pixel_type_w offset) -> pixel_type { + JXL_DASSERT((v & 0xFFFFFFFF) == v); + pixel_type_w val = UnpackSigned(v); + // if it overflows, it overflows, and we have a problem anyway + return val * multiplier + offset; + }; + + if (tree.size() == 1) { + // special optimized case: no meta-adaptation, so no need + // to compute properties. + Predictor predictor = tree[0].predictor; + int64_t offset = tree[0].predictor_offset; + int32_t multiplier = tree[0].multiplier; + size_t ctx_id = tree[0].childID; + if (predictor == Predictor::Zero) { + uint32_t value; + if (reader->IsSingleValueAndAdvance(ctx_id, &value, + channel.w * channel.h)) { + // Special-case: histogram has a single symbol, with no extra bits, and + // we use ANS mode. + JXL_DEBUG_V(8, "Fastest track."); + pixel_type v = make_pixel(value, multiplier, offset); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + std::fill(r, r + channel.w, v); + } + } else { + JXL_DEBUG_V(8, "Fast track."); + if (multiplier == 1 && offset == 0) { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = UnpackSigned(v); + } + } + } else { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, offset); + } + } + } + } + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1 && reader->HuffRleOnly()) { + JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track."); + uint32_t run = 0; + uint32_t v = 0; + pixel_type_w sv = 0; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1); + const pixel_type *JXL_RESTRICT rtopleft = + (y ? channel.Row(y - 1) - 1 : r - 1); + pixel_type_w guess = (y ? rtop[0] : 0); + if (run == 0) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[0] = sv + guess; + for (size_t x = 1; x < channel.w; x++) { + pixel_type left = r[x - 1]; + pixel_type top = rtop[x]; + pixel_type topleft = rtopleft[x]; + pixel_type_w guess = ClampedGradient(top, left, topleft); + if (!run) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[x] = sv + guess; + } + } + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1) { + JXL_DEBUG_V(8, "Gradient very fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type top = (y ? *(r + x - onerow) : left); + pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type guess = ClampedGradient(top, left, topleft); + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, 1, guess); + } + } + } else if (predictor != Predictor::Weighted) { + // special optimized case: no wp + JXL_DEBUG_V(8, "Quite fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = + PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor); + pixel_type_w g = pred.guess + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + // NOTE: pred.multiplier is unset. + r[x] = make_pixel(v, multiplier, g); + } + } + } else { + JXL_DEBUG_V(8, "Somewhat fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y, + predictor, &wp_state) + .guess + + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, g); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } + return true; + } + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint8_t context_lookup[2 * kPropRangeFast] = {}; + int8_t multipliers[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + if (is_gradient_only) { + is_gradient_only = + TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + + if (is_gradient_only) { + JXL_DEBUG_V(8, "Gradient fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min( + std::max(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast(offsets[pos]) + guess); + } + } + } else if (is_wp_only) { + JXL_DEBUG_V(8, "WP fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast(offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (!tree_has_wp_prop_or_pred) { + // special optimized case: the weighted predictor and its properties are not + // used, so no need to compute weights and properties. + JXL_DEBUG_V(8, "Slow track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, *image, chan, &references); + InitPropsRow(&properties, static_props, y); + if (y > 1 && channel.w > 8 && references.w == 0) { + for (size_t x = 0; x < 2; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = 2; x < channel.w - 2; x++) { + PredictionResult res = + PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = channel.w - 2; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } else { + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } + } + } else { + JXL_DEBUG_V(8, "Slowest track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + InitPropsRow(&properties, static_props, y); + PrecomputeReferences(channel, y, *image, chan, &references); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + return true; +} + +GroupHeader::GroupHeader() { Bundle::Init(this); } + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options) { + size_t nb_channels = image.channel.size(); + for (bool is_dc : {true, false}) { + size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1); + size_t c = image.nb_meta_channels; + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w > options.group_dim || ch.h > options.group_dim) break; + } + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w == 0 || ch.h == 0) continue; // skip empty + bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3; + if (is_dc_channel != is_dc) continue; + size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift); + if (tile_dim == 0) { + return JXL_FAILURE("Inconsistent transforms"); + } + } + } + return true; +} + +Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, + size_t group_id, ModularOptions *options, + const Tree *global_tree, const ANSCode *global_code, + const std::vector *global_ctx_map, + bool allow_truncated_group) { + if (image.channel.empty()) return true; + + // decode transforms + Status status = Bundle::Read(br, &header); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; + if (!br->AllReadsWithinBounds()) { + // Don't do/undo transforms if header is incomplete. + header.transforms.clear(); + image.transform = header.transforms; + for (size_t c = 0; c < image.channel.size(); c++) { + ZeroFillImage(&image.channel[c].plane); + } + return Status(StatusCode::kNotEnoughBytes); + } + + JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ", + header.transforms.size()); + image.transform = header.transforms; + for (Transform &transform : image.transform) { + JXL_RETURN_IF_ERROR(transform.MetaApply(image)); + } + if (image.error) { + return JXL_FAILURE("Corrupt file. Aborting."); + } + JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options)); + + size_t nb_channels = image.channel.size(); + + size_t num_chans = 0; + size_t distance_multiplier = 0; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + if (channel.w > distance_multiplier) { + distance_multiplier = channel.w; + } + num_chans++; + } + if (num_chans == 0) return true; + + size_t next_channel = 0; + auto scope_guard = MakeScopeGuard([&]() { + // Do not do anything if truncated groups are not allowed. + if (!allow_truncated_group) return; + for (size_t c = next_channel; c < nb_channels; c++) { + ZeroFillImage(&image.channel[c].plane); + } + }); + + // Read tree. + Tree tree_storage; + std::vector context_map_storage; + ANSCode code_storage; + const Tree *tree = &tree_storage; + const ANSCode *code = &code_storage; + const std::vector *context_map = &context_map_storage; + if (!header.use_global_tree) { + size_t max_tree_size = 1024; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + size_t pixels = channel.w * channel.h; + if (pixels / channel.w != channel.h) { + return JXL_FAILURE("Tree size overflow"); + } + max_tree_size += pixels; + if (max_tree_size < pixels) return JXL_FAILURE("Tree size overflow"); + } + max_tree_size = std::min(static_cast(1 << 20), max_tree_size); + JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size)); + JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2, + &code_storage, &context_map_storage)); + } else { + if (!global_tree || !global_code || !global_ctx_map || + global_tree->empty()) { + return JXL_FAILURE("No global tree available but one was requested"); + } + tree = global_tree; + code = global_code; + context_map = global_ctx_map; + } + + // Read channels + ANSSymbolReader reader(code, br, distance_multiplier); + for (; next_channel < nb_channels; next_channel++) { + Channel &channel = image.channel[next_channel]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (next_channel >= image.nb_meta_channels && + (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS( + br, &reader, *context_map, *tree, header.wp_header, next_channel, + group_id, &image)); + // Truncated group. + if (!br->AllReadsWithinBounds()) { + if (!allow_truncated_group) return JXL_FAILURE("Truncated input"); + return Status(StatusCode::kNotEnoughBytes); + } + } + + // Make sure no zero-filling happens even if next_channel < nb_channels. + scope_guard.Disarm(); + + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, bool undo_transforms, + const Tree *tree, const ANSCode *code, + const std::vector *ctx_map, + bool allow_truncated_group) { +#ifdef JXL_ENABLE_ASSERT + std::vector> req_sizes(image.channel.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + req_sizes[c] = {image.channel[c].w, image.channel[c].h}; + } +#endif + GroupHeader local_header; + if (header == nullptr) header = &local_header; + size_t bit_pos = br->TotalBitsConsumed(); + auto dec_status = ModularDecode(br, image, *header, group_id, options, tree, + code, ctx_map, allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) return dec_status; + if (undo_transforms) image.undo_transforms(header->wp_header); + if (image.error) return JXL_FAILURE("Corrupt file. Aborting."); + JXL_DEBUG_V(4, + "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS + " image from %" PRIuS " bytes", + image.w, image.h, image.channel.size(), + (br->TotalBitsConsumed() - bit_pos) / 8); + JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str()); + (void)bit_pos; +#ifdef JXL_ENABLE_ASSERT + // Check that after applying all transforms we are back to the requested image + // sizes, otherwise there's a programming error with the transformations. + if (undo_transforms) { + JXL_ASSERT(image.channel.size() == req_sizes.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + JXL_ASSERT(req_sizes[c].first == image.channel[c].w); + JXL_ASSERT(req_sizes[c].second == image.channel[c].h); + } + } +#endif + return dec_status; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/encoding/encoding.h b/media/libjxl/src/lib/jxl/modular/encoding/encoding.h new file mode 100644 index 000000000..89697bce8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/encoding.h @@ -0,0 +1,135 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENCODING_H_ + +#include +#include + +#include + +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +// Valid range of properties for using lookup tables instead of trees. +constexpr int32_t kPropRangeFast = 512; + +struct GroupHeader : public Fields { + GroupHeader(); + + JXL_FIELDS_NAME(GroupHeader) + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &use_global_tree)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&wp_header)); + uint32_t num_transforms = static_cast(transforms.size()); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(8, 18), 0, + &num_transforms)); + if (visitor->IsReading()) transforms.resize(num_transforms); + for (size_t i = 0; i < num_transforms; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&transforms[i])); + } + return true; + } + + bool use_global_tree; + weighted::Header wp_header; + + std::vector transforms; +}; + +FlatTree FilterTree(const Tree &global_tree, + std::array &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only); + +template +bool TreeToLookupTable(const FlatTree &tree, + T context_lookup[2 * kPropRangeFast], + int8_t offsets[2 * kPropRangeFast], + int8_t multipliers[2 * kPropRangeFast] = nullptr) { + struct TreeRange { + // Begin *excluded*, end *included*. This works best with > vs <= decision + // nodes. + int begin, end; + size_t pos; + }; + std::vector ranges; + ranges.push_back(TreeRange{-kPropRangeFast - 1, kPropRangeFast - 1, 0}); + while (!ranges.empty()) { + TreeRange cur = ranges.back(); + ranges.pop_back(); + if (cur.begin < -kPropRangeFast - 1 || cur.begin >= kPropRangeFast - 1 || + cur.end > kPropRangeFast - 1) { + // Tree is outside the allowed range, exit. + return false; + } + auto &node = tree[cur.pos]; + // Leaf. + if (node.property0 == -1) { + if (node.predictor_offset < std::numeric_limits::min() || + node.predictor_offset > std::numeric_limits::max()) { + return false; + } + if (node.multiplier < std::numeric_limits::min() || + node.multiplier > std::numeric_limits::max()) { + return false; + } + if (multipliers == nullptr && node.multiplier != 1) { + return false; + } + for (int i = cur.begin + 1; i < cur.end + 1; i++) { + context_lookup[i + kPropRangeFast] = node.childID; + if (multipliers) multipliers[i + kPropRangeFast] = node.multiplier; + offsets[i + kPropRangeFast] = node.predictor_offset; + } + continue; + } + // > side of top node. + if (node.properties[0] >= kNumStaticProperties) { + ranges.push_back(TreeRange({node.splitvals[0], cur.end, node.childID})); + ranges.push_back( + TreeRange({node.splitval0, node.splitvals[0], node.childID + 1})); + } else { + ranges.push_back(TreeRange({node.splitval0, cur.end, node.childID})); + } + // <= side + if (node.properties[1] >= kNumStaticProperties) { + ranges.push_back( + TreeRange({node.splitvals[1], node.splitval0, node.childID + 2})); + ranges.push_back( + TreeRange({cur.begin, node.splitvals[1], node.childID + 3})); + } else { + ranges.push_back( + TreeRange({cur.begin, node.splitval0, node.childID + 2})); + } + } + return true; +} +// TODO(veluca): make cleaner interfaces. + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options); + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, + bool undo_transforms = true, + const Tree *tree = nullptr, + const ANSCode *code = nullptr, + const std::vector *ctx_map = nullptr, + bool allow_truncated_group = false); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENCODING_H_ diff --git a/media/libjxl/src/lib/jxl/modular/encoding/ma_common.h b/media/libjxl/src/lib/jxl/modular/encoding/ma_common.h new file mode 100644 index 000000000..71b784732 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/encoding/ma_common.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ +#define LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ + +#include + +namespace jxl { + +enum MATreeContext : size_t { + kSplitValContext = 0, + kPropertyContext = 1, + kPredictorContext = 2, + kOffsetContext = 3, + kMultiplierLogContext = 4, + kMultiplierBitsContext = 5, + + kNumTreeContexts = 6, +}; + +static constexpr size_t kMaxTreeSize = 1 << 22; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ diff --git a/media/libjxl/src/lib/jxl/modular/modular_image.cc b/media/libjxl/src/lib/jxl/modular/modular_image.cc new file mode 100644 index 000000000..785d0c544 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/modular_image.cc @@ -0,0 +1,77 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/modular_image.h" + +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void Image::undo_transforms(const weighted::Header &wp_header, + jxl::ThreadPool *pool) { + while (!transform.empty()) { + Transform t = transform.back(); + JXL_DEBUG_V(4, "Undoing transform"); + Status result = t.Inverse(*this, wp_header, pool); + if (result == false) { + JXL_NOTIFY_ERROR("Error while undoing transform."); + error = true; + return; + } + JXL_DEBUG_V(8, "Undoing transform: done"); + transform.pop_back(); + } +} + +Image::Image(size_t iw, size_t ih, int bitdepth, int nb_chans) + : w(iw), h(ih), bitdepth(bitdepth), nb_meta_channels(0), error(false) { + for (int i = 0; i < nb_chans; i++) channel.emplace_back(Channel(iw, ih)); +} + +Image::Image() : w(0), h(0), bitdepth(8), nb_meta_channels(0), error(true) {} + +Image &Image::operator=(Image &&other) noexcept { + w = other.w; + h = other.h; + bitdepth = other.bitdepth; + nb_meta_channels = other.nb_meta_channels; + error = other.error; + channel = std::move(other.channel); + transform = std::move(other.transform); + return *this; +} + +Image Image::clone() { + Image c(w, h, bitdepth, 0); + c.nb_meta_channels = nb_meta_channels; + c.error = error; + c.transform = transform; + for (Channel &ch : channel) { + Channel a(ch.w, ch.h, ch.hshift, ch.vshift); + CopyImageTo(ch.plane, &a.plane); + c.channel.push_back(std::move(a)); + } + return c; +} + +std::string Image::DebugString() const { + std::ostringstream os; + os << w << "x" << h << ", depth: " << bitdepth; + if (!channel.empty()) { + os << ", channels:"; + for (size_t i = 0; i < channel.size(); ++i) { + os << " " << channel[i].w << "x" << channel[i].h + << "(shift: " << channel[i].hshift << "," << channel[i].vshift << ")"; + if (i < nb_meta_channels) os << "*"; + } + } + return os.str(); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/modular_image.h b/media/libjxl/src/lib/jxl/modular/modular_image.h new file mode 100644 index 000000000..3e9b5a8a0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/modular_image.h @@ -0,0 +1,118 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_MODULAR_IMAGE_H_ +#define LIB_JXL_MODULAR_MODULAR_IMAGE_H_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +typedef int32_t pixel_type; // can use int16_t if it's only for 8-bit images. + // Need some wiggle room for YCoCg / Squeeze etc + +typedef int64_t pixel_type_w; + +namespace weighted { +struct Header; +} + +class Channel { + public: + jxl::Plane plane; + size_t w, h; + int hshift, vshift; // w ~= image.w >> hshift; h ~= image.h >> vshift + Channel(size_t iw, size_t ih, int hsh = 0, int vsh = 0) + : plane(iw, ih), w(iw), h(ih), hshift(hsh), vshift(vsh) {} + + Channel(const Channel& other) = delete; + Channel& operator=(const Channel& other) = delete; + + // Move assignment + Channel& operator=(Channel&& other) noexcept { + w = other.w; + h = other.h; + hshift = other.hshift; + vshift = other.vshift; + plane = std::move(other.plane); + return *this; + } + + // Move constructor + Channel(Channel&& other) noexcept = default; + + void shrink() { + if (plane.xsize() == w && plane.ysize() == h) return; + jxl::Plane resizedplane(w, h); + plane = std::move(resizedplane); + } + void shrink(int nw, int nh) { + w = nw; + h = nh; + shrink(); + } + + JXL_INLINE pixel_type* Row(const size_t y) { return plane.Row(y); } + JXL_INLINE const pixel_type* Row(const size_t y) const { + return plane.Row(y); + } +}; + +class Transform; + +class Image { + public: + // image data, transforms can dramatically change the number of channels and + // their semantics + std::vector channel; + // transforms that have been applied (and that have to be undone) + std::vector transform; + + // image dimensions (channels may have different dimensions due to transforms) + size_t w, h; + int bitdepth; + size_t nb_meta_channels; // first few channels might contain palette(s) + bool error; // true if a fatal error occurred, false otherwise + + Image(size_t iw, size_t ih, int bitdepth, int nb_chans); + Image(); + + Image(const Image& other) = delete; + Image& operator=(const Image& other) = delete; + + Image& operator=(Image&& other) noexcept; + Image(Image&& other) noexcept = default; + + bool empty() const { + for (const auto& ch : channel) { + if (ch.w && ch.h) return false; + } + return true; + } + + Image clone(); + + void undo_transforms(const weighted::Header& wp_header, + jxl::ThreadPool* pool = nullptr); + + std::string DebugString() const; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_MODULAR_IMAGE_H_ diff --git a/media/libjxl/src/lib/jxl/modular/options.h b/media/libjxl/src/lib/jxl/modular/options.h new file mode 100644 index 000000000..ce6596b91 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/options.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_OPTIONS_H_ +#define LIB_JXL_MODULAR_OPTIONS_H_ + +#include + +#include +#include + +namespace jxl { + +using PropertyVal = int32_t; +using Properties = std::vector; + +enum class Predictor : uint32_t { + Zero = 0, + Left = 1, + Top = 2, + Average0 = 3, + Select = 4, + Gradient = 5, + Weighted = 6, + TopRight = 7, + TopLeft = 8, + LeftLeft = 9, + Average1 = 10, + Average2 = 11, + Average3 = 12, + Average4 = 13, + // The following predictors are encoder-only. + Best = 14, // Best of Gradient and Weighted + Variable = + 15, // Find the best decision tree for predictors/predictor per row +}; + +constexpr size_t kNumModularPredictors = + static_cast(Predictor::Average4) + 1; +constexpr size_t kNumModularEncoderPredictors = + static_cast(Predictor::Variable) + 1; + +static constexpr ssize_t kNumStaticProperties = 2; // channel, group_id. + +using StaticPropRange = + std::array, kNumStaticProperties>; + +struct ModularMultiplierInfo { + StaticPropRange range; + uint32_t multiplier; +}; + +struct ModularOptions { + /// Used in both encode and decode: + + // Stop encoding/decoding when reaching a (non-meta) channel that has a + // dimension bigger than max_chan_size. + size_t max_chan_size = 0xFFFFFF; + + // Used during decoding for validation of transforms (sqeeezing) scheme. + size_t group_dim = 0x1FFFFFFF; + + /// Encode options: + // Fraction of pixels to look at to learn a MA tree + // Number of iterations to do to learn a MA tree + // (if zero there is no MA context model) + float nb_repeats = .5f; + + // Maximum number of (previous channel) properties to use in the MA trees + int max_properties = 0; // no previous channels + + // Alternative heuristic tweaks. + // Properties default to channel, group, weighted, gradient residual, W-NW, + // NW-N, N-NE, N-NN + std::vector splitting_heuristics_properties = {0, 1, 15, 9, + 10, 11, 12, 13}; + float splitting_heuristics_node_threshold = 96; + size_t max_property_values = 32; + + // Predictor to use for each channel. + Predictor predictor = static_cast(-1); + + int wp_mode = 0; + + float fast_decode_multiplier = 1.01f; + + // Forces the encoder to produce a tree that is compatible with the WP-only + // decode path (or with the no-wp path, or the gradient-only path). + enum class TreeMode { kGradientOnly, kWPOnly, kNoWP, kDefault }; + TreeMode wp_tree_mode = TreeMode::kDefault; + + // Skip fast paths in the encoder. + bool skip_encoder_fast_path = false; + + // Kind of tree to use. + // TODO(veluca): add tree kinds for JPEG recompression with CfL enabled, + // general AC metadata, different DC qualities, and others. + enum class TreeKind { + kTrivialTreeNoPredictor, + kLearn, + kJpegTranscodeACMeta, + kFalconACMeta, + kACMeta, + kWPFixedDC, + kGradientFixedDC, + }; + TreeKind tree_kind = TreeKind::kLearn; + + // Ignore the image and just pretend all tokens are zeroes + bool zero_tokens = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_OPTIONS_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_palette.cc b/media/libjxl/src/lib/jxl/modular/transform/enc_palette.cc new file mode 100644 index 000000000..7065f8081 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_palette.cc @@ -0,0 +1,603 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_palette.h" + +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/enc_transform.h" +#include "lib/jxl/modular/transform/palette.h" + +namespace jxl { + +namespace palette_internal { + +static constexpr bool kEncodeToHighQualityImplicitPalette = true; + +// Inclusive. +static constexpr int kMinImplicitPaletteIndex = -(2 * 72 - 1); + +float ColorDistance(const std::vector &JXL_RESTRICT a, + const std::vector &JXL_RESTRICT b) { + JXL_ASSERT(a.size() == b.size()); + float distance = 0; + float ave3 = 0; + if (a.size() >= 3) { + ave3 = (a[0] + b[0] + a[1] + b[1] + a[2] + b[2]) * (1.21f / 3.0f); + } + float sum_a = 0, sum_b = 0; + for (size_t c = 0; c < a.size(); ++c) { + const float difference = + static_cast(a[c]) - static_cast(b[c]); + float weight = c == 0 ? 3 : c == 1 ? 5 : 2; + if (c < 3 && (a[c] + b[c] >= ave3)) { + const float add_w[3] = { + 1.15, + 1.15, + 1.12, + }; + weight += add_w[c]; + if (c == 2 && ((a[2] + b[2]) < 1.22 * ave3)) { + weight -= 0.5; + } + } + distance += difference * difference * weight * weight; + const int sum_weight = c == 0 ? 3 : c == 1 ? 5 : 1; + sum_a += a[c] * sum_weight; + sum_b += b[c] * sum_weight; + } + distance *= 4; + float sum_difference = sum_a - sum_b; + distance += sum_difference * sum_difference; + return distance; +} + +static int QuantizeColorToImplicitPaletteIndex( + const std::vector &color, const int palette_size, + const int bit_depth, bool high_quality) { + int index = 0; + if (high_quality) { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int quantized = ((kLargeCube - 1) * color[c] + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + index += quantized * multiplier; + multiplier *= kLargeCube; + } + return index + palette_size + kLargeCubeOffset; + } else { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int value = color[c]; + value -= 1 << (std::max(0, bit_depth - 3)); + value = std::max(0, value); + int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + if (quantized > kSmallCube - 1) { + quantized = kSmallCube - 1; + } + index += quantized * multiplier; + multiplier *= kSmallCube; + } + return index + palette_size; + } +} + +} // namespace palette_internal + +int RoundInt(int value, int div) { // symmetric rounding around 0 + if (value < 0) return -RoundInt(-value, div); + return (value + div / 2) / div; +} + +struct PaletteIterationData { + static constexpr int kMaxDeltas = 128; + bool final_run = false; + std::vector deltas[3]; + std::vector delta_distances; + std::vector frequent_deltas[3]; + + // Populates `frequent_deltas` with items from `deltas` based on frequencies + // and color distances. + void FindFrequentColorDeltas(int num_pixels, int bitdepth) { + using pixel_type_3d = std::array; + std::map delta_frequency_map; + pixel_type bucket_size = 3 << std::max(0, bitdepth - 8); + // Store frequency weighted by delta distance from quantized value. + for (size_t i = 0; i < deltas[0].size(); ++i) { + pixel_type_3d delta = { + {RoundInt(deltas[0][i], bucket_size), + RoundInt(deltas[1][i], bucket_size), + RoundInt(deltas[2][i], bucket_size)}}; // a basic form of clustering + if (delta[0] == 0 && delta[1] == 0 && delta[2] == 0) continue; + delta_frequency_map[delta] += sqrt(sqrt(delta_distances[i])); + } + + const float delta_distance_multiplier = 1.0f / num_pixels; + + // Weigh frequencies by magnitude and normalize. + for (auto &delta_frequency : delta_frequency_map) { + std::vector current_delta = {delta_frequency.first[0], + delta_frequency.first[1], + delta_frequency.first[2]}; + float delta_distance = + sqrt(palette_internal::ColorDistance({0, 0, 0}, current_delta)) + 1; + delta_frequency.second *= delta_distance * delta_distance_multiplier; + } + + // Sort by weighted frequency. + using pixel_type_3d_frequency = std::pair; + std::vector sorted_delta_frequency_map( + delta_frequency_map.begin(), delta_frequency_map.end()); + std::sort( + sorted_delta_frequency_map.begin(), sorted_delta_frequency_map.end(), + [](const pixel_type_3d_frequency &a, const pixel_type_3d_frequency &b) { + return a.second > b.second; + }); + + // Store the top deltas. + for (auto &delta_frequency : sorted_delta_frequency_map) { + if (frequent_deltas[0].size() >= kMaxDeltas) break; + // Number obtained by optimizing on jyrki31 corpus: + if (delta_frequency.second < 17) break; + for (int c = 0; c < 3; ++c) { + frequent_deltas[c].push_back(delta_frequency.first[c] * bucket_size); + } + } + } +}; + +Status FwdPaletteIteration(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, + bool ordered, bool lossy, Predictor &predictor, + const weighted::Header &wp_header, + PaletteIterationData &palette_iteration_data) { + JXL_QUIET_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); + JXL_ASSERT(begin_c >= input.nb_meta_channels); + uint32_t nb = end_c - begin_c + 1; + + size_t w = input.channel[begin_c].w; + size_t h = input.channel[begin_c].h; + + if (!lossy && nb == 1) { + // Channel palette special case + if (nb_colors == 0) return false; + std::vector lookup; + pixel_type minval, maxval; + compute_minmax(input.channel[begin_c], &minval, &maxval); + size_t lookup_table_size = + static_cast(maxval) - static_cast(minval) + 1; + if (lookup_table_size > palette_internal::kMaxPaletteLookupTableSize) { + // a lookup table would use too much memory, instead use a slower approach + // with std::set + std::set chpalette; + pixel_type idx = 0; + for (size_t y = 0; y < h; y++) { + const pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + const bool new_color = chpalette.insert(p[x]).second; + if (new_color) { + idx++; + if (idx > (int)nb_colors) return false; + } + } + } + JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx); + Channel pch(idx, 1); + pch.hshift = -1; + nb_colors = idx; + idx = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + for (pixel_type p : chpalette) { + p_palette[idx++] = p; + } + for (size_t y = 0; y < h; y++) { + pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + for (idx = 0; p[x] != p_palette[idx] && idx < (int)nb_colors; idx++) { + } + JXL_DASSERT(idx < (int)nb_colors); + p[x] = idx; + } + } + predictor = Predictor::Zero; + input.nb_meta_channels++; + input.channel.insert(input.channel.begin(), std::move(pch)); + + return true; + } + lookup.resize(lookup_table_size, 0); + pixel_type idx = 0; + for (size_t y = 0; y < h; y++) { + const pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + if (lookup[p[x] - minval] == 0) { + lookup[p[x] - minval] = 1; + idx++; + if (idx > (int)nb_colors) return false; + } + } + } + JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx); + Channel pch(idx, 1); + pch.hshift = -1; + nb_colors = idx; + idx = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + for (size_t i = 0; i < lookup_table_size; i++) { + if (lookup[i]) { + p_palette[idx] = i + minval; + lookup[i] = idx; + idx++; + } + } + for (size_t y = 0; y < h; y++) { + pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) p[x] = lookup[p[x] - minval]; + } + predictor = Predictor::Zero; + input.nb_meta_channels++; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; + } + + Image quantized_input; + if (lossy) { + quantized_input = Image(w, h, input.bitdepth, nb); + for (size_t c = 0; c < nb; c++) { + CopyImageTo(input.channel[begin_c + c].plane, + &quantized_input.channel[c].plane); + } + } + + JXL_DEBUG_V( + 7, "Trying to represent channels %i-%i using at most a %i-color palette.", + begin_c, end_c, nb_colors); + nb_deltas = 0; + bool delta_used = false; + std::set> + candidate_palette; // ordered lexicographically + std::vector> candidate_palette_imageorder; + std::vector color(nb); + std::vector color_with_error(nb); + std::vector p_in(nb); + + if (lossy) { + palette_iteration_data.FindFrequentColorDeltas(w * h, input.bitdepth); + nb_deltas = palette_iteration_data.frequent_deltas[0].size(); + + // Count color frequency for colors that make a cross. + std::map, size_t> color_freq_map; + for (size_t y = 1; y + 1 < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 1; x + 1 < w; x++) { + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + int offsets[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}}; + bool makes_cross = true; + for (int i = 0; i < 4 && makes_cross; ++i) { + int dx = offsets[i][0]; + int dy = offsets[i][1]; + for (uint32_t c = 0; c < nb && makes_cross; c++) { + if (input.channel[begin_c + c].Row(y + dy)[x + dx] != color[c]) { + makes_cross = false; + } + } + } + if (makes_cross) color_freq_map[color] += 1; + } + } + // Add colors satisfying frequency condition to the palette. + constexpr float kImageFraction = 0.01f; + size_t color_frequency_lower_bound = 5 + input.h * input.w * kImageFraction; + for (const auto &color_freq : color_freq_map) { + if (color_freq.second > color_frequency_lower_bound) { + candidate_palette.insert(color_freq.first); + candidate_palette_imageorder.push_back(color_freq.first); + } + } + } + + for (size_t y = 0; y < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 0; x < w; x++) { + if (lossy && candidate_palette.size() >= nb_colors) break; + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + const bool new_color = candidate_palette.insert(color).second; + if (new_color) { + candidate_palette_imageorder.push_back(color); + } + if (candidate_palette.size() > nb_colors) { + return false; // too many colors + } + } + } + + nb_colors = nb_deltas + candidate_palette.size(); + JXL_DEBUG_V(6, "Channels %i-%i can be represented using a %i-color palette.", + begin_c, end_c, nb_colors); + + Channel pch(nb_colors, nb); + pch.hshift = -1; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + intptr_t onerow = pch.plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[begin_c].plane.PixelsPerRow(); + const int bit_depth = std::min(input.bitdepth, 24); + + if (lossy) { + for (uint32_t i = 0; i < nb_deltas; i++) { + for (size_t c = 0; c < 3; c++) { + p_palette[c * onerow + i] = + palette_iteration_data.frequent_deltas[c][i]; + } + } + } + + int x = 0; + if (ordered) { + JXL_DEBUG_V(7, "Palette of %i colors, using lexicographic order", + nb_colors); + for (auto pcol : candidate_palette) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) { + p_palette[nb_deltas + i * onerow + x] = pcol[i]; + } + for (size_t i = 0; i < nb; i++) { + JXL_DEBUG_V(9, "%i ", pcol[i]); + } + x++; + } + } else { + JXL_DEBUG_V(7, "Palette of %i colors, using image order", nb_colors); + for (auto pcol : candidate_palette_imageorder) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) + p_palette[nb_deltas + i * onerow + x] = pcol[i]; + for (size_t i = 0; i < nb; i++) JXL_DEBUG_V(9, "%i ", pcol[i]); + x++; + } + } + std::vector wp_states; + for (size_t c = 0; c < nb; c++) { + wp_states.emplace_back(wp_header, w, h); + } + std::vector p_quant(nb); + // Three rows of error for dithering: y to y + 2. + // Each row has two pixels of padding in the ends, which is + // beneficial for both precision and encoding speed. + std::vector> error_row[3]; + if (lossy) { + for (int i = 0; i < 3; ++i) { + error_row[i].resize(nb); + for (size_t c = 0; c < nb; ++c) { + error_row[i][c].resize(w + 4); + } + } + } + for (size_t y = 0; y < h; y++) { + for (size_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + if (lossy) p_quant[c] = quantized_input.channel[c].Row(y); + } + pixel_type *JXL_RESTRICT p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + int index; + if (!lossy) { + for (size_t c = 0; c < nb; c++) color[c] = p_in[c][x]; + // Exact search. + for (index = 0; static_cast(index) < nb_colors; index++) { + bool found = true; + for (size_t c = 0; c < nb; c++) { + if (color[c] != p_palette[c * onerow + index]) { + found = false; + break; + } + } + if (found) break; + } + if (index < static_cast(nb_deltas)) { + delta_used = true; + } + } else { + int best_index = 0; + bool best_is_delta = false; + float best_distance = std::numeric_limits::infinity(); + std::vector best_val(nb, 0); + std::vector ideal_residual(nb, 0); + std::vector quantized_val(nb); + std::vector predictions(nb); + static const double kDiffusionMultiplier[] = {0.55, 0.75}; + for (int diffusion_index = 0; diffusion_index < 2; ++diffusion_index) { + for (size_t c = 0; c < nb; c++) { + color_with_error[c] = + p_in[c][x] + palette_iteration_data.final_run * + kDiffusionMultiplier[diffusion_index] * + error_row[0][c][x + 2]; + color[c] = Clamp1(lroundf(color_with_error[c]), 0l, + (1l << input.bitdepth) - 1); + } + + for (size_t c = 0; c < nb; ++c) { + predictions[c] = PredictNoTreeWP(w, p_quant[c] + x, onerow_image, x, + y, predictor, &wp_states[c]) + .guess; + } + const auto TryIndex = [&](const int index) { + for (size_t c = 0; c < nb; c++) { + quantized_val[c] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/nb_colors, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast(nb_deltas)) { + quantized_val[c] += predictions[c]; + } + } + const float color_distance = + 32.0 / (1LL << std::max(0, 2 * (bit_depth - 8))) * + palette_internal::ColorDistance(color_with_error, + quantized_val); + float index_penalty = 0; + if (index == -1) { + index_penalty = -124; + } else if (index < 0) { + index_penalty = -2 * index; + } else if (index < static_cast(nb_deltas)) { + index_penalty = 250; + } else if (index < static_cast(nb_colors)) { + index_penalty = 150; + } else if (index < static_cast(nb_colors) + + palette_internal::kLargeCubeOffset) { + index_penalty = 70; + } else { + index_penalty = 256; + } + const float distance = color_distance + index_penalty; + if (distance < best_distance) { + best_distance = distance; + best_index = index; + best_is_delta = index < static_cast(nb_deltas); + best_val.swap(quantized_val); + for (size_t c = 0; c < nb; ++c) { + ideal_residual[c] = color_with_error[c] - predictions[c]; + } + } + }; + for (index = palette_internal::kMinImplicitPaletteIndex; + index < static_cast(nb_colors); index++) { + TryIndex(index); + } + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/false)); + if (palette_internal::kEncodeToHighQualityImplicitPalette) { + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/true)); + } + } + index = best_index; + delta_used |= best_is_delta; + if (!palette_iteration_data.final_run) { + for (size_t c = 0; c < 3; ++c) { + palette_iteration_data.deltas[c].push_back(ideal_residual[c]); + } + palette_iteration_data.delta_distances.push_back(best_distance); + } + + for (size_t c = 0; c < nb; ++c) { + wp_states[c].UpdateErrors(best_val[c], x, y, w); + p_quant[c][x] = best_val[c]; + } + float len_error = 0; + for (size_t c = 0; c < nb; ++c) { + float local_error = color_with_error[c] - best_val[c]; + len_error += local_error * local_error; + } + len_error = sqrt(len_error); + float modulate = 1.0; + int len_limit = 38 << std::max(0, bit_depth - 8); + if (len_error > len_limit) { + modulate *= len_limit / len_error; + } + for (size_t c = 0; c < nb; ++c) { + float total_error = (color_with_error[c] - best_val[c]); + + // If the neighboring pixels have some error in the opposite + // direction of total_error, cancel some or all of it out before + // spreading among them. + constexpr int offsets[12][2] = {{1, 2}, {0, 3}, {0, 4}, {1, 1}, + {1, 3}, {2, 2}, {1, 0}, {1, 4}, + {2, 1}, {2, 3}, {2, 0}, {2, 4}}; + float total_available = 0; + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_available += error_row[row][c][x + col]; + } + } + float weight = + std::abs(total_error) / (std::abs(total_available) + 1e-3); + weight = std::min(weight, 1.0f); + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_error += weight * error_row[row][c][x + col]; + error_row[row][c][x + col] *= (1 - weight); + } + } + total_error *= modulate; + const float remaining_error = (1.0f / 14.) * total_error; + error_row[0][c][x + 3] += 2 * remaining_error; + error_row[0][c][x + 4] += remaining_error; + error_row[1][c][x + 0] += remaining_error; + for (int i = 0; i < 5; ++i) { + error_row[1][c][x + i] += remaining_error; + error_row[2][c][x + i] += remaining_error; + } + } + } + if (palette_iteration_data.final_run) p[x] = index; + } + if (lossy) { + for (size_t c = 0; c < nb; ++c) { + error_row[0][c].swap(error_row[1][c]); + error_row[1][c].swap(error_row[2][c]); + std::fill(error_row[2][c].begin(), error_row[2][c].end(), 0.f); + } + } + } + if (!delta_used) { + predictor = Predictor::Zero; + } + if (palette_iteration_data.final_run) { + input.nb_meta_channels++; + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + input.channel.insert(input.channel.begin(), std::move(pch)); + } + nb_colors -= nb_deltas; + return true; +} + +Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered, + bool lossy, Predictor &predictor, + const weighted::Header &wp_header) { + PaletteIterationData palette_iteration_data; + uint32_t nb_colors_orig = nb_colors; + uint32_t nb_deltas_orig = nb_deltas; + // preprocessing pass in case of lossy palette + if (lossy && input.bitdepth >= 8) { + JXL_RETURN_IF_ERROR(FwdPaletteIteration( + input, begin_c, end_c, nb_colors_orig, nb_deltas_orig, ordered, lossy, + predictor, wp_header, palette_iteration_data)); + } + palette_iteration_data.final_run = true; + return FwdPaletteIteration(input, begin_c, end_c, nb_colors, nb_deltas, + ordered, lossy, predictor, wp_header, + palette_iteration_data); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_palette.h b/media/libjxl/src/lib/jxl/modular/transform/enc_palette.h new file mode 100644 index 000000000..0f3d66825 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_palette.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered, + bool lossy, Predictor &predictor, + const weighted::Header &wp_header); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_rct.cc b/media/libjxl/src/lib/jxl/modular/transform/enc_rct.cc new file mode 100644 index 000000000..050563a3c --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_rct.cc @@ -0,0 +1,73 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_rct.h" + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +Status FwdRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2)); + if (rct_type == 0) { // noop + return false; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + size_t m = begin_c; + size_t w = input.channel[m + 0].w; + size_t h = input.channel[m + 0].h; + int second = (custom % 7) >> 1; + int third = (custom % 7) & 1; + const auto do_rct = [&](const int y, const int thread) { + const pixel_type* in0 = input.channel[m + (permutation % 3)].Row(y); + const pixel_type* in1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + const pixel_type* in2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + pixel_type* out0 = input.channel[m].Row(y); + pixel_type* out1 = input.channel[m + 1].Row(y); + pixel_type* out2 = input.channel[m + 2].Row(y); + if (custom == 6) { + for (size_t x = 0; x < w; x++) { + pixel_type R = in0[x]; + pixel_type G = in1[x]; + pixel_type B = in2[x]; + out1[x] = R - B; + pixel_type tmp = B + (out1[x] >> 1); + out2[x] = G - tmp; + out0[x] = tmp + (out2[x] >> 1); + } + } else { + for (size_t x = 0; x < w; x++) { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (second == 1) { + Second = Second - First; + } else if (second == 2) { + Second = Second - ((First + Third) >> 1); + } + if (third) Third = Third - First; + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } + }; + return RunOnPool(pool, 0, h, ThreadPool::NoInit, do_rct, "FwdRCT"); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_rct.h b/media/libjxl/src/lib/jxl/modular/transform/enc_rct.h new file mode 100644 index 000000000..cb5a193c8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_rct.h @@ -0,0 +1,17 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ + +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +Status FwdRCT(Image &input, size_t begin_c, size_t rct_type, ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_squeeze.cc b/media/libjxl/src/lib/jxl/modular/transform/enc_squeeze.cc new file mode 100644 index 000000000..dfd90cde6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_squeeze.cc @@ -0,0 +1,141 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_squeeze.h" + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/squeeze.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void FwdHSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing horizontal squeeze of channel %i to new channel %i", c, + rc); + + Channel chout((chin.w + 1) / 2, chin.h, chin.hshift + 1, chin.vshift); + Channel chout_residual(chin.w - chout.w, chout.h, chin.hshift + 1, + chin.vshift); + + for (size_t y = 0; y < chout.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout_residual.w; x++) { + pixel_type A = p_in[x * 2]; + pixel_type B = p_in[x * 2 + 1]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (x + 1 < chout_residual.w) { + next_avg = (p_in[x * 2 + 2] + p_in[x * 2 + 3] + + (p_in[x * 2 + 2] > p_in[x * 2 + 3])) >> + 1; // which will be chout.value(y,x+1) + } else if (chin.w & 1) + next_avg = p_in[x * 2 + 2]; + pixel_type left = (x > 0 ? p_in[x * 2 - 1] : avg); + pixel_type tendency = SmoothTendency(left, avg, next_avg); + + p_res[x] = diff - tendency; + } + if (chin.w & 1) { + int x = chout.w - 1; + p_out[x] = p_in[x * 2]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +void FwdVSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing vertical squeeze of channel %i to new channel %i", c, + rc); + + Channel chout(chin.w, (chin.h + 1) / 2, chin.hshift, chin.vshift + 1); + Channel chout_residual(chin.w, chin.h - chout.h, chin.hshift, + chin.vshift + 1); + intptr_t onerow_in = chin.plane.PixelsPerRow(); + for (size_t y = 0; y < chout_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y * 2); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout.w; x++) { + pixel_type A = p_in[x]; + pixel_type B = p_in[x + onerow_in]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (y + 1 < chout_residual.h) { + next_avg = (p_in[x + 2 * onerow_in] + p_in[x + 3 * onerow_in] + + (p_in[x + 2 * onerow_in] > p_in[x + 3 * onerow_in])) >> + 1; // which will be chout.value(y+1,x) + } else if (chin.h & 1) { + next_avg = p_in[x + 2 * onerow_in]; + } + pixel_type top = + (y > 0 ? p_in[static_cast(x) - onerow_in] : avg); + pixel_type tendency = SmoothTendency(top, avg, next_avg); + + p_res[x] = diff - tendency; + } + } + if (chin.h & 1) { + size_t y = chout.h - 1; + const pixel_type *p_in = chin.Row(y * 2); + pixel_type *p_out = chout.Row(y); + for (size_t x = 0; x < chout.w; x++) { + p_out[x] = p_in[x]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +Status FwdSqueeze(Image &input, std::vector parameters, + ThreadPool *pool) { + if (parameters.empty()) { + DefaultSqueezeParameters(¶meters, input); + } + // if nothing to do, don't do squeeze + if (parameters.empty()) return false; + for (size_t i = 0; i < parameters.size(); i++) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(parameters[i], input.channel.size())); + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (horizontal) { + FwdHSqueeze(input, c, offset + c - beginc); + } else { + FwdVSqueeze(input, c, offset + c - beginc); + } + } + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_squeeze.h b/media/libjxl/src/lib/jxl/modular/transform/enc_squeeze.h new file mode 100644 index 000000000..39b001017 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_squeeze.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Status FwdSqueeze(Image &input, std::vector parameters, + ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_transform.cc b/media/libjxl/src/lib/jxl/modular/transform/enc_transform.cc new file mode 100644 index 000000000..bdaaf9f87 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_transform.cc @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_transform.h" + +#include "lib/jxl/modular/transform/enc_palette.h" +#include "lib/jxl/modular/transform/enc_rct.h" +#include "lib/jxl/modular/transform/enc_squeeze.h" + +namespace jxl { + +Status TransformForward(Transform &t, Image &input, + const weighted::Header &wp_header, ThreadPool *pool) { + switch (t.id) { + case TransformId::kRCT: + return FwdRCT(input, t.begin_c, t.rct_type, pool); + case TransformId::kSqueeze: + return FwdSqueeze(input, t.squeezes, pool); + case TransformId::kPalette: + return FwdPalette(input, t.begin_c, t.begin_c + t.num_c - 1, t.nb_colors, + t.nb_deltas, t.ordered_palette, t.lossy_palette, + t.predictor, wp_header); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast(t.id)); + } +} + +void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max) { + pixel_type realmin = std::numeric_limits::max(); + pixel_type realmax = std::numeric_limits::min(); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type *JXL_RESTRICT p = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + if (p[x] < realmin) realmin = p[x]; + if (p[x] > realmax) realmax = p[x]; + } + } + + if (min) *min = realmin; + if (max) *max = realmax; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/transform/enc_transform.h b/media/libjxl/src/lib/jxl/modular/transform/enc_transform.h new file mode 100644 index 000000000..07659e1b0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/enc_transform.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Status TransformForward(Transform &t, Image &input, + const weighted::Header &wp_header, ThreadPool *pool); + +void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/palette.h b/media/libjxl/src/lib/jxl/modular/transform/palette.h new file mode 100644 index 000000000..ed2d33bed --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/palette.h @@ -0,0 +1,287 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +namespace palette_internal { + +static constexpr int kMaxPaletteLookupTableSize = 1 << 16; + +static constexpr int kRgbChannels = 3; + +// 5x5x5 color cube for the larger cube. +static constexpr int kLargeCube = 5; + +// Smaller interleaved color cube to fill the holes of the larger cube. +static constexpr int kSmallCube = 4; +static constexpr int kSmallCubeBits = 2; +// kSmallCube ** 3 +static constexpr int kLargeCubeOffset = kSmallCube * kSmallCube * kSmallCube; + +static inline pixel_type Scale(uint64_t value, uint64_t bit_depth, + uint64_t denom) { + // return (value * ((static_cast(1) << bit_depth) - 1)) / denom; + // We only call this function with kSmallCube or kLargeCube - 1 as denom, + // allowing us to avoid a division here. + JXL_ASSERT(denom == 4); + return (value * ((static_cast(1) << bit_depth) - 1)) >> 2; +} + +// The purpose of this function is solely to extend the interpretation of +// palette indices to implicit values. If index < nb_deltas, indicating that the +// result is a delta palette entry, it is the responsibility of the caller to +// treat it as such. +static pixel_type GetPaletteValue(const pixel_type *const palette, int index, + const size_t c, const int palette_size, + const int onerow, const int bit_depth) { + if (index < 0) { + static constexpr std::array, 72> kDeltaPalette = { + { + {{0, 0, 0}}, {{4, 4, 4}}, {{11, 0, 0}}, + {{0, 0, -13}}, {{0, -12, 0}}, {{-10, -10, -10}}, + {{-18, -18, -18}}, {{-27, -27, -27}}, {{-18, -18, 0}}, + {{0, 0, -32}}, {{-32, 0, 0}}, {{-37, -37, -37}}, + {{0, -32, -32}}, {{24, 24, 45}}, {{50, 50, 50}}, + {{-45, -24, -24}}, {{-24, -45, -45}}, {{0, -24, -24}}, + {{-34, -34, 0}}, {{-24, 0, -24}}, {{-45, -45, -24}}, + {{64, 64, 64}}, {{-32, 0, -32}}, {{0, -32, 0}}, + {{-32, 0, 32}}, {{-24, -45, -24}}, {{45, 24, 45}}, + {{24, -24, -45}}, {{-45, -24, 24}}, {{80, 80, 80}}, + {{64, 0, 0}}, {{0, 0, -64}}, {{0, -64, -64}}, + {{-24, -24, 45}}, {{96, 96, 96}}, {{64, 64, 0}}, + {{45, -24, -24}}, {{34, -34, 0}}, {{112, 112, 112}}, + {{24, -45, -45}}, {{45, 45, -24}}, {{0, -32, 32}}, + {{24, -24, 45}}, {{0, 96, 96}}, {{45, -24, 24}}, + {{24, -45, -24}}, {{-24, -45, 24}}, {{0, -64, 0}}, + {{96, 0, 0}}, {{128, 128, 128}}, {{64, 0, 64}}, + {{144, 144, 144}}, {{96, 96, 0}}, {{-36, -36, 36}}, + {{45, -24, -45}}, {{45, -45, -24}}, {{0, 0, -96}}, + {{0, 128, 128}}, {{0, 96, 0}}, {{45, 24, -45}}, + {{-128, 0, 0}}, {{24, -45, 24}}, {{-45, 24, -45}}, + {{64, 0, -64}}, {{64, -64, -64}}, {{96, 0, 96}}, + {{45, -45, 24}}, {{24, 45, -45}}, {{64, 64, -64}}, + {{128, 128, 0}}, {{0, 0, -128}}, {{-24, 45, -45}}, + }}; + if (c >= kRgbChannels) { + return 0; + } + // Do not open the brackets, otherwise INT32_MIN negation could overflow. + index = -(index + 1); + index %= 1 + 2 * (kDeltaPalette.size() - 1); + static constexpr int kMultiplier[] = {-1, 1}; + pixel_type result = + kDeltaPalette[((index + 1) >> 1)][c] * kMultiplier[index & 1]; + if (bit_depth > 8) { + result *= static_cast(1) << (bit_depth - 8); + } + return result; + } else if (palette_size <= index && index < palette_size + kLargeCubeOffset) { + if (c >= kRgbChannels) return 0; + index -= palette_size; + index >>= c * kSmallCubeBits; + return Scale(index % kSmallCube, bit_depth, kSmallCube) + + (1 << (std::max(0, bit_depth - 3))); + } else if (palette_size + kLargeCubeOffset <= index) { + if (c >= kRgbChannels) return 0; + index -= palette_size + kLargeCubeOffset; + // TODO(eustas): should we take care of ambiguity created by + // index >= kLargeCube ** 3 ? + switch (c) { + case 0: + break; + case 1: + index /= kLargeCube; + break; + case 2: + index /= kLargeCube * kLargeCube; + break; + } + return Scale(index % kLargeCube, bit_depth, kLargeCube - 1); + } + return palette[c * onerow + static_cast(index)]; +} + +} // namespace palette_internal + +static Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, + uint32_t nb_deltas, Predictor predictor, + const weighted::Header &wp_header, ThreadPool *pool) { + if (input.nb_meta_channels < 1) { + return JXL_FAILURE("Error: Palette transform without palette."); + } + std::atomic num_errors{0}; + int nb = input.channel[0].h; + uint32_t c0 = begin_c + 1; + if (c0 >= input.channel.size()) { + return JXL_FAILURE("Channel is out of range."); + } + size_t w = input.channel[c0].w; + size_t h = input.channel[c0].h; + if (nb < 1) return JXL_FAILURE("Corrupted transforms"); + for (int i = 1; i < nb; i++) { + input.channel.insert( + input.channel.begin() + c0 + 1, + Channel(w, h, input.channel[c0].hshift, input.channel[c0].vshift)); + } + const Channel &palette = input.channel[0]; + const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0); + intptr_t onerow = input.channel[0].plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow(); + const int bit_depth = std::min(input.bitdepth, 24); + + if (w == 0) { + // Nothing to do. + // Avoid touching "empty" channels with non-zero height. + } else if (nb_deltas == 0 && predictor == Predictor::Zero) { + if (nb == 1) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + pixel_type *p = input.channel[c0].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = Clamp1(p[x], 0, (pixel_type)palette.w - 1); + p[x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/0, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + }, + "UndoChannelPalette")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + std::vector p_out(nb); + const pixel_type *p_index = input.channel[c0].Row(y); + for (int c = 0; c < nb; c++) + p_out[c] = input.channel[c0 + c].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = p_index[x]; + for (int c = 0; c < nb; c++) { + p_out[c][x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + } + }, + "UndoPalette")); + } + } else { + // Parallelized per channel. + ImageI indices = CopyImage(input.channel[c0].plane); + if (predictor == Predictor::Weighted) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, nb, ThreadPool::NoInit, + [&](const uint32_t c, size_t /* thread */) { + Channel &channel = input.channel[c0 + c]; + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, /*onerow=*/onerow, + /*bit_depth=*/bit_depth); + if (index < static_cast(nb_deltas)) { + PredictionResult pred = + PredictNoTreeWP(channel.w, p + x, onerow_image, x, y, + predictor, &wp_state); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + }, + "UndoDeltaPaletteWP")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, nb, ThreadPool::NoInit, + [&](const uint32_t c, size_t /* thread */) { + Channel &channel = input.channel[c0 + c]; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast(nb_deltas)) { + PredictionResult pred = PredictNoTreeNoWP( + channel.w, p + x, onerow_image, x, y, predictor); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + } + } + }, + "UndoDeltaPaletteNoWP")); + } + } + if (c0 >= input.nb_meta_channels) { + // Palette was done on normal channels + input.nb_meta_channels--; + } else { + // Palette was done on metachannels + JXL_ASSERT(static_cast(input.nb_meta_channels) >= 2 - nb); + input.nb_meta_channels -= 2 - nb; + JXL_ASSERT(begin_c + nb - 1 < input.nb_meta_channels); + } + input.channel.erase(input.channel.begin(), input.channel.begin() + 1); + return num_errors.load(std::memory_order_relaxed) == 0; +} + +static Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t nb_colors, uint32_t nb_deltas, bool lossy) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); + + size_t nb = end_c - begin_c + 1; + if (begin_c >= input.nb_meta_channels) { + // Palette was done on normal channels + input.nb_meta_channels++; + } else { + // Palette was done on metachannels + JXL_ASSERT(end_c < input.nb_meta_channels); + // we remove nb-1 metachannels and add one + input.nb_meta_channels += 2 - nb; + } + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + Channel pch(nb_colors + nb_deltas, nb); + pch.hshift = -1; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/rct.cc b/media/libjxl/src/lib/jxl/modular/transform/rct.cc new file mode 100644 index 000000000..f3002a5ac --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/rct.cc @@ -0,0 +1,153 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/rct.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/rct.cc" +#include +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +template +void InvRCTRow(const pixel_type* in0, const pixel_type* in1, + const pixel_type* in2, pixel_type* out0, pixel_type* out1, + pixel_type* out2, size_t w) { + static_assert(transform_type >= 0 && transform_type < 7, + "Invalid transform type"); + int second = transform_type >> 1; + int third = transform_type & 1; + + size_t x = 0; + const HWY_FULL(pixel_type) d; + const size_t N = Lanes(d); + for (; x + N - 1 < w; x += N) { + if (transform_type == 6) { + auto Y = Load(d, in0 + x); + auto Co = Load(d, in1 + x); + auto Cg = Load(d, in2 + x); + Y = Sub(Y, ShiftRight<1>(Cg)); + auto G = Add(Cg, Y); + Y = Sub(Y, ShiftRight<1>(Co)); + auto R = Add(Y, Co); + Store(R, d, out0 + x); + Store(G, d, out1 + x); + Store(Y, d, out2 + x); + } else { + auto First = Load(d, in0 + x); + auto Second = Load(d, in1 + x); + auto Third = Load(d, in2 + x); + if (third) Third = Add(Third, First); + if (second == 1) { + Second = Add(Second, First); + } else if (second == 2) { + Second = Add(Second, ShiftRight<1>(Add(First, Third))); + } + Store(First, d, out0 + x); + Store(Second, d, out1 + x); + Store(Third, d, out2 + x); + } + } + for (; x < w; x++) { + if (transform_type == 6) { + pixel_type Y = in0[x]; + pixel_type Co = in1[x]; + pixel_type Cg = in2[x]; + pixel_type tmp = PixelAdd(Y, -(Cg >> 1)); + pixel_type G = PixelAdd(Cg, tmp); + pixel_type B = PixelAdd(tmp, -(Co >> 1)); + pixel_type R = PixelAdd(B, Co); + out0[x] = R; + out1[x] = G; + out2[x] = B; + } else { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (third) Third = PixelAdd(Third, First); + if (second == 1) { + Second = PixelAdd(Second, First); + } else if (second == 2) { + Second = PixelAdd(Second, (PixelAdd(First, Third) >> 1)); + } + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } +} + +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2)); + size_t m = begin_c; + Channel& c0 = input.channel[m + 0]; + size_t w = c0.w; + size_t h = c0.h; + if (rct_type == 0) { // noop + return true; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + JXL_CHECK(permutation < 6); + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + // Special case: permute-only. Swap channels around. + if (custom == 0) { + Channel ch0 = std::move(input.channel[m]); + Channel ch1 = std::move(input.channel[m + 1]); + Channel ch2 = std::move(input.channel[m + 2]); + input.channel[m + (permutation % 3)] = std::move(ch0); + input.channel[m + ((permutation + 1 + permutation / 3) % 3)] = + std::move(ch1); + input.channel[m + ((permutation + 2 - permutation / 3) % 3)] = + std::move(ch2); + return true; + } + constexpr decltype(&InvRCTRow<0>) inv_rct_row[] = { + InvRCTRow<0>, InvRCTRow<1>, InvRCTRow<2>, InvRCTRow<3>, + InvRCTRow<4>, InvRCTRow<5>, InvRCTRow<6>}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* in0 = input.channel[m].Row(y); + const pixel_type* in1 = input.channel[m + 1].Row(y); + const pixel_type* in2 = input.channel[m + 2].Row(y); + pixel_type* out0 = input.channel[m + (permutation % 3)].Row(y); + pixel_type* out1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + pixel_type* out2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + inv_rct_row[custom](in0, in1, in2, out0, out1, out2, w); + }, + "InvRCT")); + return true; +} + +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(InvRCT); +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(InvRCT)(input, begin_c, rct_type, pool); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/modular/transform/rct.h b/media/libjxl/src/lib/jxl/modular/transform/rct.h new file mode 100644 index 000000000..aef65621d --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/rct.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_RCT_H_ +#define LIB_JXL_MODULAR_TRANSFORM_RCT_H_ + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_RCT_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/squeeze.cc b/media/libjxl/src/lib/jxl/modular/transform/squeeze.cc new file mode 100644 index 000000000..34311895d --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/squeeze.cc @@ -0,0 +1,477 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/squeeze.h" + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc" +#include +#include + +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::MulEven; +using hwy::HWY_NAMESPACE::Ne; +using hwy::HWY_NAMESPACE::Neg; +using hwy::HWY_NAMESPACE::OddEven; +using hwy::HWY_NAMESPACE::RebindToUnsigned; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Xor; + +#if HWY_TARGET != HWY_SCALAR + +JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual, + const pixel_type *JXL_RESTRICT p_avg, + const pixel_type *JXL_RESTRICT p_navg, + const pixel_type *p_pout, + pixel_type *JXL_RESTRICT p_out, + pixel_type *p_nout) { + const HWY_CAPPED(pixel_type, 8) d; + const RebindToUnsigned du; + const size_t N = Lanes(d); + auto onethird = Set(d, 0x55555556); + for (size_t x = 0; x < 8; x += N) { + auto avg = Load(d, p_avg + x); + auto next_avg = Load(d, p_navg + x); + auto top = Load(d, p_pout + x); + // Equivalent to SmoothTendency(top,avg,next_avg), but without branches + auto Ba = Sub(top, avg); + auto an = Sub(avg, next_avg); + auto nonmono = Xor(Ba, an); + auto absBa = Abs(Ba); + auto absan = Abs(an); + auto absBn = Abs(Sub(top, next_avg)); + // Compute a3 = absBa / 3 + auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird))); + auto a3oi = MulEven(Reverse(d, absBa), onethird); + auto a3o = BitCast( + d, Reverse(hwy::HWY_NAMESPACE::Repartition(), + a3oi)); + auto a3 = OddEven(a3o, a3e); + a3 = Add(a3, Add(absBn, Set(d, 2))); + auto absdiff = ShiftRight<2>(a3); + auto skipdiff = Ne(Ba, Zero(d)); + skipdiff = And(skipdiff, Ne(an, Zero(d))); + skipdiff = And(skipdiff, Lt(nonmono, Zero(d))); + auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1))); + absdiff = IfThenElse(Gt(absdiff, absBa2), + Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff); + auto absan2 = ShiftLeft<1>(absan); + absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2), + absan2, absdiff); + auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff); + auto tendency = IfThenZeroElse(skipdiff, diff1); + + auto diff_minus_tendency = Load(d, p_residual + x); + auto diff = Add(diff_minus_tendency, tendency); + auto out = + Add(avg, ShiftRight<1>( + Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff)))))); + Store(out, d, p_out + x); + Store(Sub(out, diff), d, p_nout + x); + } +} + +#endif + +Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { + JXL_ASSERT(c < input.channel.size()); + JXL_ASSERT(rc < input.channel.size()); + Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2)); + JXL_ASSERT(chin.h == chin_residual.h); + + if (chin_residual.w == 0) { + // Short-circuit: output channel has same dimensions as input. + input.channel[c].hshift--; + return true; + } + + // Note: chin.w >= chin_residual.w and at most 1 different. + Channel chout(chin.w + chin_residual.w, chin.h, chin.hshift - 1, chin.vshift); + JXL_DEBUG_V(4, + "Undoing horizontal squeeze of channel %i using residuals in " + "channel %i (going from width %" PRIuS " to %" PRIuS ")", + c, rc, chin.w, chout.w); + + if (chin_residual.h == 0) { + // Short-circuit: channel with no pixels. + input.channel[c] = std::move(chout); + return true; + } + auto unsqueeze_row = [&](size_t y, size_t x0) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + for (size_t x = x0; x < chin_residual.w; x++) { + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg); + pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg); + pixel_type_w tendency = SmoothTendency(left, avg, next_avg); + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w A = avg + (diff / 2); + p_out[(x << 1)] = A; + pixel_type_w B = A - diff; + p_out[(x << 1) + 1] = B; + } + if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1]; + }; + + // somewhat complicated trickery just to be able to SIMD this. + // Horizontal unsqueeze has horizontal data dependencies, so we do + // 8 rows at a time and treat it as a vertical unsqueeze of a + // transposed 8x8 block (or 9x8 for one input). + static constexpr const size_t kRowsPerThread = 8; + const auto unsqueeze_span = [&](const uint32_t task, size_t /* thread */) { + const size_t y0 = task * kRowsPerThread; + const size_t rows = std::min(kRowsPerThread, chin.h - y0); + size_t x = 0; + +#if HWY_TARGET != HWY_SCALAR + intptr_t onerow_in = chin.plane.PixelsPerRow(); + intptr_t onerow_inr = chin_residual.plane.PixelsPerRow(); + intptr_t onerow_out = chout.plane.PixelsPerRow(); + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0); + pixel_type *JXL_RESTRICT p_out = chout.Row(y0); + HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread]; + const HWY_CAPPED(pixel_type, 8) d; + const size_t N = Lanes(d); + if (chin_residual.w > 16 && rows == kRowsPerThread) { + for (; x < chin_residual.w - 9; x += 8) { + Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr); + Transpose8x8Block(p_avg + x, b_p_avg, onerow_in); + for (size_t y = 0; y < kRowsPerThread; y++) { + b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y]; + } + for (size_t i = 0; i < 8; i++) { + FastUnsqueeze( + b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1), + (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i), + b_p_out_even + 8 * i, b_p_out_odd + 8 * i); + } + + Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8); + Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8); + for (size_t y = 0; y < kRowsPerThread; y++) { + for (size_t i = 0; i < kRowsPerThread; i += N) { + auto even = Load(d, b_p_out_evenT + 8 * y + i); + auto odd = Load(d, b_p_out_oddT + 8 * y + i); + StoreInterleaved(d, even, odd, + p_out + ((x + i) << 1) + onerow_out * y); + } + } + } + } +#endif + for (size_t y = 0; y < rows; y++) { + unsqueeze_row(y0 + y, x); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread), + ThreadPool::NoInit, unsqueeze_span, + "InvHorizontalSqueeze")); + input.channel[c] = std::move(chout); + return true; +} + +Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { + JXL_ASSERT(c < input.channel.size()); + JXL_ASSERT(rc < input.channel.size()); + const Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2)); + JXL_ASSERT(chin.w == chin_residual.w); + + if (chin_residual.h == 0) { + // Short-circuit: output channel has same dimensions as input. + input.channel[c].vshift--; + return true; + } + + // Note: chin.h >= chin_residual.h and at most 1 different. + Channel chout(chin.w, chin.h + chin_residual.h, chin.hshift, chin.vshift - 1); + JXL_DEBUG_V( + 4, + "Undoing vertical squeeze of channel %i using residuals in channel " + "%i (going from height %" PRIuS " to %" PRIuS ")", + c, rc, chin.h, chout.h); + + if (chin_residual.w == 0) { + // Short-circuit: channel with no pixels. + input.channel[c] = std::move(chout); + return true; + } + + static constexpr const int kColsPerThread = 64; + const auto unsqueeze_slice = [&](const uint32_t task, size_t /* thread */) { + const size_t x0 = task * kColsPerThread; + const size_t x1 = std::min((size_t)(task + 1) * kColsPerThread, chin.w); + const size_t w = x1 - x0; + // We only iterate up to std::min(chin_residual.h, chin.h) which is + // always chin_residual.h. + for (size_t y = 0; y < chin_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0; + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0; + const pixel_type *JXL_RESTRICT p_navg = + chin.Row(y + 1 < chin.h ? y + 1 : y) + x0; + pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0; + pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0; + const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg; + size_t x = 0; +#if HWY_TARGET != HWY_SCALAR + for (; x + 7 < w; x += 8) { + FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x, + p_out + x, p_nout + x); + } +#endif + for (; x < w; x++) { + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = p_navg[x]; + pixel_type_w top = p_pout[x]; + pixel_type_w tendency = SmoothTendency(top, avg, next_avg); + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w out = avg + (diff / 2); + p_out[x] = out; + // If the chin_residual.h == chin.h, the output has an even number + // of rows so the next line is fine. Otherwise, this loop won't + // write to the last output row which is handled separately. + p_nout[x] = out - diff; + } + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread), + ThreadPool::NoInit, unsqueeze_slice, + "InvVertSqueeze")); + + if (chout.h & 1) { + size_t y = chin.h - 1; + const pixel_type *p_avg = chin.Row(y); + pixel_type *p_out = chout.Row(y << 1); + for (size_t x = 0; x < chin.w; x++) { + p_out[x] = p_avg[x]; + } + } + input.channel[c] = std::move(chout); + return true; +} + +Status InvSqueeze(Image &input, std::vector parameters, + ThreadPool *pool) { + for (int i = parameters.size() - 1; i >= 0; i--) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(parameters[i], input.channel.size())); + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size() + beginc - endc - 1; + } + if (beginc < input.nb_meta_channels) { + // This is checked in MetaSqueeze. + JXL_ASSERT(input.nb_meta_channels > parameters[i].num_c); + input.nb_meta_channels -= parameters[i].num_c; + } + + for (uint32_t c = beginc; c <= endc; c++) { + uint32_t rc = offset + c - beginc; + // MetaApply should imply that `rc` is within range, otherwise there's a + // programming bug. + JXL_ASSERT(rc < input.channel.size()); + if ((input.channel[c].w < input.channel[rc].w) || + (input.channel[c].h < input.channel[rc].h)) { + return JXL_FAILURE("Corrupted squeeze transform"); + } + if (horizontal) { + JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool)); + } else { + JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool)); + } + } + input.channel.erase(input.channel.begin() + offset, + input.channel.begin() + offset + (endc - beginc + 1)); + } + return true; +} + +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { + +HWY_EXPORT(InvSqueeze); +Status InvSqueeze(Image &input, std::vector parameters, + ThreadPool *pool) { + return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool); +} + +void DefaultSqueezeParameters(std::vector *parameters, + const Image &image) { + int nb_channels = image.channel.size() - image.nb_meta_channels; + + parameters->clear(); + size_t w = image.channel[image.nb_meta_channels].w; + size_t h = image.channel[image.nb_meta_channels].h; + JXL_DEBUG_V( + 7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h); + + // do horizontal first on wide images; vertical first on tall images + bool wide = (w > h); + + if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w && + image.channel[image.nb_meta_channels + 1].h == h) { + // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0 + // previews + JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h); + SqueezeParams params; + // horizontal chroma squeeze + params.horizontal = true; + params.in_place = false; + params.begin_c = image.nb_meta_channels + 1; + params.num_c = 2; + parameters->push_back(params); + params.horizontal = false; + // vertical chroma squeeze + parameters->push_back(params); + } + SqueezeParams params; + params.begin_c = image.nb_meta_channels; + params.num_c = nb_channels; + params.in_place = true; + + if (!wide) { + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); + } + } + while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) { + if (w > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = true; + parameters->push_back(params); + w = (w + 1) / 2; + JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h); + } + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); + } + } + JXL_DEBUG_V(7, "that's it"); +} + +Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, + int num_channels) { + int c1 = parameter.begin_c; + int c2 = parameter.begin_c + parameter.num_c - 1; + if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) { + return JXL_FAILURE("Invalid channel range"); + } + return true; +} + +Status MetaSqueeze(Image &image, std::vector *parameters) { + if (parameters->empty()) { + DefaultSqueezeParameters(parameters, image); + } + + for (size_t i = 0; i < parameters->size(); i++) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams((*parameters)[i], image.channel.size())); + bool horizontal = (*parameters)[i].horizontal; + bool in_place = (*parameters)[i].in_place; + uint32_t beginc = (*parameters)[i].begin_c; + uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1; + + uint32_t offset; + if (beginc < image.nb_meta_channels) { + if (endc >= image.nb_meta_channels) { + return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels"); + } + if (!in_place) { + return JXL_FAILURE( + "Invalid squeeze: meta channels require in-place residuals"); + } + image.nb_meta_channels += (*parameters)[i].num_c; + } + if (in_place) { + offset = endc + 1; + } else { + offset = image.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) { + return JXL_FAILURE("Too many squeezes: shift > 30"); + } + size_t w = image.channel[c].w; + size_t h = image.channel[c].h; + if (horizontal) { + image.channel[c].w = (w + 1) / 2; + image.channel[c].hshift++; + w = w - (w + 1) / 2; + } else { + image.channel[c].h = (h + 1) / 2; + image.channel[c].vshift++; + h = h - (h + 1) / 2; + } + image.channel[c].shrink(); + Channel dummy(w, h); + dummy.hshift = image.channel[c].hshift; + dummy.vshift = image.channel[c].vshift; + + image.channel.insert(image.channel.begin() + offset + (c - beginc), + std::move(dummy)); + JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s", + image.DebugString().c_str()); + } + } + return true; +} + +} // namespace jxl + +#endif diff --git a/media/libjxl/src/lib/jxl/modular/transform/squeeze.h b/media/libjxl/src/lib/jxl/modular/transform/squeeze.h new file mode 100644 index 000000000..fb18710a6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/squeeze.h @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ + +// Haar-like transform: halves the resolution in one direction +// A B -> (A+B)>>1 in one channel (average) -> same range as +// original channel +// A-B - tendency in a new channel ('residual' needed to make +// the transform reversible) +// -> theoretically range could be 2.5 +// times larger (2 times without the +// 'tendency'), but there should be lots +// of zeroes +// Repeated application (alternating horizontal and vertical squeezes) results +// in downscaling +// +// The default coefficient ordering is low-frequency to high-frequency, as in +// M. Antonini, M. Barlaud, P. Mathieu and I. Daubechies, "Image coding using +// wavelet transform", IEEE Transactions on Image Processing, vol. 1, no. 2, pp. +// 205-220, April 1992, doi: 10.1109/83.136597. + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +#define JXL_MAX_FIRST_PREVIEW_SIZE 8 + +namespace jxl { + +/* + int avg=(A+B)>>1; + int diff=(A-B); + int rA=(diff+(avg<<1)+(diff&1))>>1; + int rB=rA-diff; + +*/ +// |A B|C D|E F| +// p a n p=avg(A,B), a=avg(C,D), n=avg(E,F) +// +// Goal: estimate C-D (avoiding ringing artifacts) +// (ensuring that in smooth areas, a zero residual corresponds to a smooth +// gradient) + +// best estimate for C: (B + 2*a)/3 +// best estimate for D: (n + 3*a)/4 +// best estimate for C-D: 4*B - 3*n - a /12 + +// avoid ringing by 1) only doing this if B <= a <= n or B >= a >= n +// (otherwise, this is not a smooth area and we cannot really estimate C-D) +// 2) making sure that B <= C <= D <= n or B >= C >= D >= n + +inline pixel_type_w SmoothTendency(pixel_type_w B, pixel_type_w a, + pixel_type_w n) { + pixel_type_w diff = 0; + if (B >= a && a >= n) { + diff = (4 * B - 3 * n - a + 6) / 12; + // 2C = a<<1 + diff - diff&1 <= 2B so diff - diff&1 <= 2B - 2a + // 2D = a<<1 - diff - diff&1 >= 2n so diff + diff&1 <= 2a - 2n + if (diff - (diff & 1) > 2 * (B - a)) diff = 2 * (B - a) + 1; + if (diff + (diff & 1) > 2 * (a - n)) diff = 2 * (a - n); + } else if (B <= a && a <= n) { + diff = (4 * B - 3 * n - a - 6) / 12; + // 2C = a<<1 + diff + diff&1 >= 2B so diff + diff&1 >= 2B - 2a + // 2D = a<<1 - diff + diff&1 <= 2n so diff - diff&1 >= 2a - 2n + if (diff + (diff & 1) < 2 * (B - a)) diff = 2 * (B - a) - 1; + if (diff - (diff & 1) < 2 * (a - n)) diff = 2 * (a - n); + } + return diff; +} + +void DefaultSqueezeParameters(std::vector *parameters, + const Image &image); + +Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, int num_channels); + +Status MetaSqueeze(Image &image, std::vector *parameters); + +Status InvSqueeze(Image &input, std::vector parameters, + ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ diff --git a/media/libjxl/src/lib/jxl/modular/transform/transform.cc b/media/libjxl/src/lib/jxl/modular/transform/transform.cc new file mode 100644 index 000000000..d9f2b435b --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/transform.cc @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/transform.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/palette.h" +#include "lib/jxl/modular/transform/rct.h" +#include "lib/jxl/modular/transform/squeeze.h" + +namespace jxl { + +SqueezeParams::SqueezeParams() { Bundle::Init(this); } +Transform::Transform(TransformId id) { + Bundle::Init(this); + this->id = id; +} + +Status Transform::Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool) { + JXL_DEBUG_V(6, "Input channels (%" PRIuS ", %" PRIuS " meta): ", + input.channel.size(), input.nb_meta_channels); + switch (id) { + case TransformId::kRCT: + return InvRCT(input, begin_c, rct_type, pool); + case TransformId::kSqueeze: + return InvSqueeze(input, squeezes, pool); + case TransformId::kPalette: + return InvPalette(input, begin_c, nb_colors, nb_deltas, predictor, + wp_header, pool); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast(id)); + } +} + +Status Transform::MetaApply(Image &input) { + JXL_DEBUG_V(6, "MetaApply input: %s", input.DebugString().c_str()); + switch (id) { + case TransformId::kRCT: + JXL_DEBUG_V(2, "Transform: kRCT, rct_type=%" PRIu32, rct_type); + return CheckEqualChannels(input, begin_c, begin_c + 2); + case TransformId::kSqueeze: + JXL_DEBUG_V(2, "Transform: kSqueeze:"); +#if JXL_DEBUG_V_LEVEL >= 2 + { + auto squeezes_copy = squeezes; + if (squeezes_copy.empty()) { + DefaultSqueezeParameters(&squeezes_copy, input); + } + for (const auto ¶ms : squeezes_copy) { + JXL_DEBUG_V( + 2, + " squeeze params: horizontal=%d, in_place=%d, begin_c=%" PRIu32 + ", num_c=%" PRIu32, + params.horizontal, params.in_place, params.begin_c, params.num_c); + } + } +#endif + return MetaSqueeze(input, &squeezes); + case TransformId::kPalette: + JXL_DEBUG_V(2, + "Transform: kPalette, begin_c=%" PRIu32 ", num_c=%" PRIu32 + ", nb_colors=%" PRIu32 ", nb_deltas=%" PRIu32, + begin_c, num_c, nb_colors, nb_deltas); + return MetaPalette(input, begin_c, begin_c + num_c - 1, nb_colors, + nb_deltas, lossy_palette); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast(id)); + } +} + +Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2) { + if (c1 > image.channel.size() || c2 >= image.channel.size() || c2 < c1) { + return JXL_FAILURE("Invalid channel range: %u..%u (there are only %" PRIuS + " channels)", + c1, c2, image.channel.size()); + } + if (c1 < image.nb_meta_channels && c2 >= image.nb_meta_channels) { + return JXL_FAILURE("Invalid: transforming mix of meta and nonmeta"); + } + const auto &ch1 = image.channel[c1]; + for (size_t c = c1 + 1; c <= c2; c++) { + const auto &ch2 = image.channel[c]; + if (ch1.w != ch2.w || ch1.h != ch2.h || ch1.hshift != ch2.hshift || + ch1.vshift != ch2.vshift) { + return false; + } + } + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/modular/transform/transform.h b/media/libjxl/src/lib/jxl/modular/transform/transform.h new file mode 100644 index 000000000..d5d3259f7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular/transform/transform.h @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ + +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +enum class TransformId : uint32_t { + // G, R-G, B-G and variants (including YCoCg). + kRCT = 0, + + // Color palette. Parameters are: [begin_c] [end_c] [nb_colors] + kPalette = 1, + + // Squeezing (Haar-style) + kSqueeze = 2, + + // Invalid for now. + kInvalid = 3, +}; + +struct SqueezeParams : public Fields { + JXL_FIELDS_NAME(SqueezeParams) + bool horizontal; + bool in_place; + uint32_t begin_c; + uint32_t num_c; + SqueezeParams(); + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &horizontal)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &in_place)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(3), BitsOffset(6, 8), + BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(4, 4), 2, &num_c)); + return true; + } +}; + +class Transform : public Fields { + public: + TransformId id; + // for Palette and RCT. + uint32_t begin_c; + // for RCT. 42 possible values starting from 0. + uint32_t rct_type; + // Only for Palette and NearLossless. + uint32_t num_c; + // Only for Palette. + uint32_t nb_colors; + uint32_t nb_deltas; + // for Squeeze. Default squeeze if empty. + std::vector squeezes; + // for NearLossless, not serialized. + int max_delta_error; + // Serialized for Palette. + Predictor predictor; + // for Palette, not serialized. + bool ordered_palette = true; + bool lossy_palette = false; + + explicit Transform(TransformId id); + // default constructor for bundles. + Transform() : Transform(TransformId::kInvalid) {} + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val((uint32_t)TransformId::kRCT), Val((uint32_t)TransformId::kPalette), + Val((uint32_t)TransformId::kSqueeze), + Val((uint32_t)TransformId::kInvalid), (uint32_t)TransformId::kRCT, + reinterpret_cast(&id))); + if (id == TransformId::kInvalid) { + return JXL_FAILURE("Invalid transform ID"); + } + if (visitor->Conditional(id == TransformId::kRCT || + id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(3), BitsOffset(6, 8), BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + } + if (visitor->Conditional(id == TransformId::kRCT)) { + // 0-41, default YCoCg. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(6), Bits(2), BitsOffset(4, 2), + BitsOffset(6, 10), 6, &rct_type)); + if (rct_type >= 42) { + return JXL_FAILURE("Invalid transform RCT type"); + } + } + if (visitor->Conditional(id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(3), Val(4), BitsOffset(13, 1), 3, &num_c)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(8, 0), BitsOffset(10, 256), BitsOffset(12, 1280), + BitsOffset(16, 5376), 256, &nb_colors)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(8, 1), BitsOffset(10, 257), + BitsOffset(16, 1281), 0, &nb_deltas)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, (uint32_t)Predictor::Zero, + reinterpret_cast(&predictor))); + if (predictor >= Predictor::Best) { + return JXL_FAILURE("Invalid predictor"); + } + } + + if (visitor->Conditional(id == TransformId::kSqueeze)) { + uint32_t num_squeezes = static_cast(squeezes.size()); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(4, 1), BitsOffset(6, 9), + BitsOffset(8, 41), 0, &num_squeezes)); + if (visitor->IsReading()) squeezes.resize(num_squeezes); + for (size_t i = 0; i < num_squeezes; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&squeezes[i])); + } + } + return true; + } + + JXL_FIELDS_NAME(Transform) + + Status Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool = nullptr); + Status MetaApply(Image &input); +}; + +Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2); + +static inline pixel_type PixelAdd(pixel_type a, pixel_type b) { + return static_cast(static_cast(a) + + static_cast(b)); +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ diff --git a/media/libjxl/src/lib/jxl/modular_test.cc b/media/libjxl/src/lib/jxl/modular_test.cc new file mode 100644 index 000000000..c87be6822 --- /dev/null +++ b/media/libjxl/src/lib/jxl/modular_test.cc @@ -0,0 +1,545 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/extras/dec/jxl.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_toc.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/modular/encoding/enc_encoding.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +void TestLosslessGroups(size_t group_size_shift) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CompressParams cparams; + cparams.SetLossless(); + cparams.modular_group_size_shift = group_size_shift; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 4, io.ysize() / 4); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 280000u); + EXPECT_LE(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + 0.0); +} + +TEST(ModularTest, RoundtripLosslessGroups128) { TestLosslessGroups(0); } + +TEST(ModularTest, JXL_TSAN_SLOW_TEST(RoundtripLosslessGroups512)) { + TestLosslessGroups(2); +} + +TEST(ModularTest, JXL_TSAN_SLOW_TEST(RoundtripLosslessGroups1024)) { + TestLosslessGroups(3); +} + +TEST(ModularTest, RoundtripLosslessCustomWP_PermuteRCT) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.SetLossless(); + // 9 = permute to GBR, to test the special case of permutation-only + cparams.colorspace = 9; + // slowest speed so different WP modes are tried + cparams.speed_tier = SpeedTier::kTortoise; + cparams.options.predictor = {Predictor::Weighted}; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(100, 100); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 10150u); + EXPECT_LE(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + 0.0); +} + +TEST(ModularTest, RoundtripLossyDeltaPalette) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.lossy_palette = true; + cparams.palette_colors = 0; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(300, 100); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 6800u); + cparams.ba_params.intensity_target = 80.0f; + EXPECT_THAT(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.5)); +} +TEST(ModularTest, RoundtripLossyDeltaPaletteWP) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.SetLossless(); + cparams.lossy_palette = true; + cparams.palette_colors = 0; + cparams.options.predictor = jxl::Predictor::Weighted; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(300, 100); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 7000u); + cparams.ba_params.intensity_target = 80.0f; + EXPECT_THAT(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(10.0)); +} + +TEST(ModularTest, RoundtripLossy) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.butteraugli_distance = 2.f; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 30000u); + cparams.ba_params.intensity_target = 80.0f; + EXPECT_THAT(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(2.3)); +} + +TEST(ModularTest, RoundtripLossy16) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/raw.pixls/DJI-FC6310-16bit_709_v4_krita.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.butteraugli_distance = 2.f; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + JXL_CHECK(io.TransformTo(ColorEncoding::SRGB(), GetJxlCms(), pool)); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + + compressed_size = Roundtrip(&io, cparams, {}, pool, &io_out); + EXPECT_LE(compressed_size, 300u); + cparams.ba_params.intensity_target = 80.0f; + EXPECT_THAT(ButteraugliDistance(io, io_out, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.6)); +} + +TEST(ModularTest, RoundtripExtraProperties) { + constexpr size_t kSize = 250; + Image image(kSize, kSize, /*bitdepth=*/8, 3); + ModularOptions options; + options.max_properties = 4; + options.predictor = Predictor::Zero; + Rng rng(0); + for (size_t y = 0; y < kSize; y++) { + for (size_t x = 0; x < kSize; x++) { + image.channel[0].plane.Row(y)[x] = image.channel[2].plane.Row(y)[x] = + rng.UniformU(0, 9); + } + } + ZeroFillImage(&image.channel[1].plane); + BitWriter writer; + ASSERT_TRUE(ModularGenericCompress(image, options, &writer)); + writer.ZeroPadToByte(); + Image decoded(kSize, kSize, /*bitdepth=*/8, image.channel.size()); + for (size_t i = 0; i < image.channel.size(); i++) { + const Channel& ch = image.channel[i]; + decoded.channel[i] = Channel(ch.w, ch.h, ch.hshift, ch.vshift); + } + Status status = true; + { + BitReader reader(writer.GetSpan()); + BitReaderScopedCloser closer(&reader, &status); + ASSERT_TRUE(ModularGenericDecompress(&reader, decoded, /*header=*/nullptr, + /*group_id=*/0, &options)); + } + ASSERT_TRUE(status); + ASSERT_EQ(image.channel.size(), decoded.channel.size()); + for (size_t c = 0; c < image.channel.size(); c++) { + for (size_t y = 0; y < image.channel[c].plane.ysize(); y++) { + for (size_t x = 0; x < image.channel[c].plane.xsize(); x++) { + EXPECT_EQ(image.channel[c].plane.Row(y)[x], + decoded.channel[c].plane.Row(y)[x]) + << "c = " << c << ", x = " << x << ", y = " << y; + } + } + } +} + +TEST(ModularTest, RoundtripLosslessCustomSqueeze) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + cparams.options.predictor = {Predictor::Zero}; + cparams.speed_tier = SpeedTier::kThunder; + cparams.responsive = 1; + // Custom squeeze params, atm just for testing + SqueezeParams p; + p.horizontal = true; + p.in_place = false; + p.begin_c = 0; + p.num_c = 3; + cparams.squeezes.push_back(p); + p.begin_c = 1; + p.in_place = true; + p.horizontal = false; + cparams.squeezes.push_back(p); + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 265000u); + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool)); +} + +struct RoundtripLosslessConfig { + int bitdepth; + int responsive; +}; +class ModularTestParam + : public ::testing::TestWithParam {}; + +std::vector GenerateLosslessTests() { + std::vector all; + for (int responsive = 0; responsive <= 1; responsive++) { + for (int bitdepth = 1; bitdepth < 32; bitdepth++) { + if (responsive && bitdepth > 30) continue; + all.push_back({bitdepth, responsive}); + } + } + return all; +} +std::string LosslessTestDescription( + const testing::TestParamInfo& info) { + std::stringstream name; + name << info.param.bitdepth << "bit"; + if (info.param.responsive) name << "Squeeze"; + return name.str(); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(RoundtripLossless, ModularTestParam, + testing::ValuesIn(GenerateLosslessTests()), + LosslessTestDescription); + +TEST_P(ModularTestParam, RoundtripLossless) { + RoundtripLosslessConfig config = GetParam(); + int bitdepth = config.bitdepth; + int responsive = config.responsive; + + ThreadPool* pool = nullptr; + Rng generator(123); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io1; + ASSERT_TRUE(SetFromBytes(Span(orig), &io1, pool)); + + // vary the dimensions a bit, in case of bugs related to + // even vs odd width or height. + size_t xsize = 423 + bitdepth; + size_t ysize = 467 + bitdepth; + + CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.SetUintSamples(bitdepth); + + double factor = ((1lu << bitdepth) - 1lu); + double ifactor = 1.0 / factor; + Image3F noise_added(xsize, ysize); + + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize; y++) { + const float* in = io1.Main().color()->PlaneRow(c, y); + float* out = noise_added.PlaneRow(c, y); + for (size_t x = 0; x < xsize; x++) { + // make the least significant bits random + float f = in[x] + generator.UniformF(0.0f, 1.f / 255.f); + if (f > 1.f) f = 1.f; + // quantize to the bitdepth we're testing + unsigned int u = f * factor + 0.5; + out[x] = u * ifactor; + } + } + } + io.SetFromImage(std::move(noise_added), jxl::ColorEncoding::SRGB(false)); + + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + cparams.options.predictor = {Predictor::Zero}; + cparams.speed_tier = SpeedTier::kThunder; + cparams.responsive = responsive; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), + bitdepth * xsize * ysize / 3); + EXPECT_LE(0, ComputeDistance2(io.Main(), io2.Main(), GetJxlCms())); + size_t different = 0; + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize; y++) { + const float* in = io.Main().color()->PlaneRow(c, y); + const float* out = io2.Main().color()->PlaneRow(c, y); + for (size_t x = 0; x < xsize; x++) { + uint32_t uin = in[x] * factor + 0.5; + uint32_t uout = out[x] * factor + 0.5; + // check that the integer values are identical + if (uin != uout) different++; + } + } + } + EXPECT_EQ(different, 0); +} + +TEST(ModularTest, RoundtripLosslessCustomFloat) { + ThreadPool* pool = nullptr; + CodecInOut io; + size_t xsize = 100, ysize = 300; + io.SetSize(xsize, ysize); + io.metadata.m.bit_depth.bits_per_sample = 18; + io.metadata.m.bit_depth.exponent_bits_per_sample = 6; + io.metadata.m.bit_depth.floating_point_sample = true; + io.metadata.m.modular_16_bit_buffer_sufficient = false; + ColorEncoding color_encoding; + color_encoding.tf.SetTransferFunction(TransferFunction::kLinear); + color_encoding.SetColorSpace(ColorSpace::kRGB); + Image3F testimage(xsize, ysize); + float factor = 1.f / (1 << 14); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize; y++) { + float* const JXL_RESTRICT row = testimage.PlaneRow(c, y); + for (size_t x = 0; x < xsize; x++) { + row[x] = factor * (x ^ y); + } + } + } + io.SetFromImage(std::move(testimage), color_encoding); + io.metadata.m.color_encoding = color_encoding; + io.metadata.m.SetIntensityTarget(255); + + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + cparams.options.predictor = {Predictor::Zero}; + cparams.speed_tier = SpeedTier::kThunder; + cparams.decoding_speed_tier = 2; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 23000u); + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool)); +} + +void WriteHeaders(BitWriter* writer, size_t xsize, size_t ysize) { + BitWriter::Allotment allotment(writer, 16); + writer->Write(8, 0xFF); + writer->Write(8, kCodestreamMarker); + ReclaimAndCharge(writer, &allotment, 0, nullptr); + CodecMetadata metadata; + EXPECT_TRUE(metadata.size.Set(xsize, ysize)); + EXPECT_TRUE(WriteSizeHeader(metadata.size, writer, 0, nullptr)); + metadata.m.color_encoding = ColorEncoding::LinearSRGB(/*is_gray=*/true); + metadata.m.xyb_encoded = false; + metadata.m.SetUintSamples(31); + EXPECT_TRUE(WriteImageMetadata(metadata.m, writer, 0, nullptr)); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + EXPECT_TRUE(Bundle::Write(metadata.transform_data, writer, 0, nullptr)); + writer->ZeroPadToByte(); + FrameHeader frame_header(&metadata); + frame_header.encoding = FrameEncoding::kModular; + frame_header.loop_filter.gab = false; + frame_header.loop_filter.epf_iters = 0; + EXPECT_TRUE(WriteFrameHeader(frame_header, writer, nullptr)); +} + +// Tree with single node, zero predictor, offset is 1 and multiplier is 1, +// entropy code is prefix tree with alphabet size 256 and all bits lengths 8. +void WriteHistograms(BitWriter* writer) { + writer->Write(1, 1); // default DC quant + writer->Write(1, 1); // has_tree + // tree histograms + writer->Write(1, 0); // LZ77 disabled + writer->Write(3, 1); // simple context map + writer->Write(1, 1); // prefix code + writer->Write(7, 0x63); // UnintConfig(3, 2, 1) + writer->Write(12, 0xfef); // alphabet_size = 256 + writer->Write(32, 0x10003); // all bit lengths 8 + // tree tokens + writer->Write(8, 0); // tree leaf + writer->Write(8, 0); // zero predictor + writer->Write(8, 64); // offset = UnpackSigned(ReverseBits(64)) = 1 + writer->Write(16, 0); // multiplier = 1 + // histograms + writer->Write(1, 0); // LZ77 disabled + writer->Write(1, 1); // prefix code + writer->Write(7, 0x63); // UnintConfig(3, 2, 1) + writer->Write(12, 0xfef); // alphabet_size = 256 + writer->Write(32, 0x10003); // all bit lengths 8 +} + +TEST(ModularTest, PredictorIntegerOverflow) { + const size_t xsize = 1; + const size_t ysize = 1; + BitWriter writer; + WriteHeaders(&writer, xsize, ysize); + std::vector group_codes(1); + { + BitWriter* bw = &group_codes[0]; + BitWriter::Allotment allotment(bw, 1 << 20); + WriteHistograms(bw); + GroupHeader header; + header.use_global_tree = true; + EXPECT_TRUE(Bundle::Write(header, bw, 0, nullptr)); + // After UnpackSigned this becomes (1 << 31) - 1, the largest pixel_type, + // and after adding the offset we get -(1 << 31). + bw->Write(8, 119); + bw->Write(28, 0xfffffff); + bw->ZeroPadToByte(); + ReclaimAndCharge(bw, &allotment, 0, nullptr); + } + EXPECT_TRUE(WriteGroupOffsets(group_codes, nullptr, &writer, nullptr)); + writer.AppendByteAligned(group_codes); + + PaddedBytes compressed = std::move(writer).TakeBytes(); + extras::PackedPixelFile ppf; + extras::JXLDecompressParams params; + params.accepted_formats.push_back({1, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}); + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), params, + nullptr, &ppf)); + ASSERT_EQ(1, ppf.frames.size()); + const auto& img = ppf.frames[0].color; + const auto pixels = reinterpret_cast(img.pixels()); + EXPECT_EQ(-1.0f, pixels[0]); +} + +TEST(ModularTest, UnsqueezeIntegerOverflow) { + // Image width is 9 so we can test both the SIMD and non-vector code paths. + const size_t xsize = 9; + const size_t ysize = 2; + BitWriter writer; + WriteHeaders(&writer, xsize, ysize); + std::vector group_codes(1); + { + BitWriter* bw = &group_codes[0]; + BitWriter::Allotment allotment(bw, 1 << 20); + WriteHistograms(bw); + GroupHeader header; + header.use_global_tree = true; + header.transforms.emplace_back(); + header.transforms[0].id = TransformId::kSqueeze; + SqueezeParams params; + params.horizontal = false; + params.in_place = true; + params.begin_c = 0; + params.num_c = 1; + header.transforms[0].squeezes.emplace_back(params); + EXPECT_TRUE(Bundle::Write(header, bw, 0, nullptr)); + for (size_t i = 0; i < xsize * ysize; ++i) { + // After UnpackSigned and adding offset, this becomes (1 << 31) - 1, both + // in the image and in the residual channels, and unsqueeze makes them + // ~(3 << 30) and (1 << 30) (in pixel_type_w) and the first wraps around + // to about -(1 << 30). + bw->Write(8, 119); + bw->Write(28, 0xffffffe); + } + bw->ZeroPadToByte(); + ReclaimAndCharge(bw, &allotment, 0, nullptr); + } + EXPECT_TRUE(WriteGroupOffsets(group_codes, nullptr, &writer, nullptr)); + writer.AppendByteAligned(group_codes); + + PaddedBytes compressed = std::move(writer).TakeBytes(); + extras::PackedPixelFile ppf; + extras::JXLDecompressParams params; + params.accepted_formats.push_back({1, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}); + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), params, + nullptr, &ppf)); + ASSERT_EQ(1, ppf.frames.size()); + const auto& img = ppf.frames[0].color; + const auto pixels = reinterpret_cast(img.pixels()); + for (size_t x = 0; x < xsize; ++x) { + EXPECT_NEAR(-0.5f, pixels[x], 1e-10); + EXPECT_NEAR(0.5f, pixels[xsize + x], 1e-10); + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/noise.h b/media/libjxl/src/lib/jxl/noise.h new file mode 100644 index 000000000..d897ea3ab --- /dev/null +++ b/media/libjxl/src/lib/jxl/noise.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_NOISE_H_ +#define LIB_JXL_NOISE_H_ + +// Noise parameters shared by encoder/decoder. + +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +const float kNoisePrecision = 1 << 10; + +struct NoiseParams { + // LUT index is an intensity of pixel / mean intensity of patch + static constexpr size_t kNumNoisePoints = 8; + float lut[kNumNoisePoints]; + + void Clear() { + for (float& i : lut) i = 0.f; + } + bool HasAny() const { + for (float i : lut) { + if (std::abs(i) > 1e-3f) return true; + } + return false; + } +}; + +static inline std::pair IndexAndFrac(float x) { + constexpr size_t kScaleNumerator = NoiseParams::kNumNoisePoints - 2; + // TODO: instead of 1, this should be a proper Y range. + constexpr float kScale = kScaleNumerator / 1; + float scaled_x = std::max(0.f, x * kScale); + float floor_x; + float frac_x = std::modf(scaled_x, &floor_x); + if (JXL_UNLIKELY(scaled_x >= kScaleNumerator + 1)) { + floor_x = kScaleNumerator; + frac_x = 1.f; + } + return std::make_pair(static_cast(floor_x), frac_x); +} + +struct NoiseLevel { + float noise_level; + float intensity; +}; + +} // namespace jxl + +#endif // LIB_JXL_NOISE_H_ diff --git a/media/libjxl/src/lib/jxl/opsin_image_test.cc b/media/libjxl/src/lib/jxl/opsin_image_test.cc new file mode 100644 index 000000000..7573d6b8b --- /dev/null +++ b/media/libjxl/src/lib/jxl/opsin_image_test.cc @@ -0,0 +1,123 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { +namespace { + +// Convert a single linear sRGB color to xyb, using the exact image conversion +// procedure that jpeg xl uses. +void LinearSrgbToOpsin(float rgb_r, float rgb_g, float rgb_b, + float* JXL_RESTRICT xyb_x, float* JXL_RESTRICT xyb_y, + float* JXL_RESTRICT xyb_b) { + Image3F linear(1, 1); + linear.PlaneRow(0, 0)[0] = rgb_r; + linear.PlaneRow(1, 0)[0] = rgb_g; + linear.PlaneRow(2, 0)[0] = rgb_b; + + ImageMetadata metadata; + metadata.SetFloat32Samples(); + metadata.color_encoding = ColorEncoding::LinearSRGB(); + ImageBundle ib(&metadata); + ib.SetFromImage(std::move(linear), metadata.color_encoding); + Image3F opsin(1, 1); + (void)ToXYB(ib, /*pool=*/nullptr, &opsin, GetJxlCms()); + + *xyb_x = opsin.PlaneRow(0, 0)[0]; + *xyb_y = opsin.PlaneRow(1, 0)[0]; + *xyb_b = opsin.PlaneRow(2, 0)[0]; +} + +// Convert a single XYB color to linear sRGB, using the exact image conversion +// procedure that jpeg xl uses. +void OpsinToLinearSrgb(float xyb_x, float xyb_y, float xyb_b, + float* JXL_RESTRICT rgb_r, float* JXL_RESTRICT rgb_g, + float* JXL_RESTRICT rgb_b) { + Image3F opsin(1, 1); + opsin.PlaneRow(0, 0)[0] = xyb_x; + opsin.PlaneRow(1, 0)[0] = xyb_y; + opsin.PlaneRow(2, 0)[0] = xyb_b; + Image3F linear(1, 1); + OpsinParams opsin_params; + opsin_params.Init(/*intensity_target=*/255.0f); + OpsinToLinear(opsin, Rect(opsin), nullptr, &linear, opsin_params); + *rgb_r = linear.PlaneRow(0, 0)[0]; + *rgb_g = linear.PlaneRow(1, 0)[0]; + *rgb_b = linear.PlaneRow(2, 0)[0]; +} + +void OpsinRoundtripTestRGB(float r, float g, float b) { + float xyb_x, xyb_y, xyb_b; + LinearSrgbToOpsin(r, g, b, &xyb_x, &xyb_y, &xyb_b); + float r2, g2, b2; + OpsinToLinearSrgb(xyb_x, xyb_y, xyb_b, &r2, &g2, &b2); + EXPECT_NEAR(r, r2, 1e-3); + EXPECT_NEAR(g, g2, 1e-3); + EXPECT_NEAR(b, b2, 1e-3); +} + +TEST(OpsinImageTest, VerifyOpsinAbsorbanceInverseMatrix) { + float matrix[9]; // writable copy + for (int i = 0; i < 9; i++) { + matrix[i] = GetOpsinAbsorbanceInverseMatrix()[i]; + } + EXPECT_TRUE(Inv3x3Matrix(matrix)); + for (int i = 0; i < 9; i++) { + EXPECT_NEAR(matrix[i], kOpsinAbsorbanceMatrix[i], 1e-6); + } +} + +TEST(OpsinImageTest, OpsinRoundtrip) { + OpsinRoundtripTestRGB(0, 0, 0); + OpsinRoundtripTestRGB(1. / 255, 1. / 255, 1. / 255); + OpsinRoundtripTestRGB(128. / 255, 128. / 255, 128. / 255); + OpsinRoundtripTestRGB(1, 1, 1); + + OpsinRoundtripTestRGB(0, 0, 1. / 255); + OpsinRoundtripTestRGB(0, 0, 128. / 255); + OpsinRoundtripTestRGB(0, 0, 1); + + OpsinRoundtripTestRGB(0, 1. / 255, 0); + OpsinRoundtripTestRGB(0, 128. / 255, 0); + OpsinRoundtripTestRGB(0, 1, 0); + + OpsinRoundtripTestRGB(1. / 255, 0, 0); + OpsinRoundtripTestRGB(128. / 255, 0, 0); + OpsinRoundtripTestRGB(1, 0, 0); +} + +TEST(OpsinImageTest, VerifyZero) { + // Test that black color (zero energy) is 0,0,0 in xyb. + float x, y, b; + LinearSrgbToOpsin(0, 0, 0, &x, &y, &b); + EXPECT_NEAR(0, x, 1e-9); + EXPECT_NEAR(0, y, 1e-7); + EXPECT_NEAR(0, b, 1e-7); +} + +TEST(OpsinImageTest, VerifyGray) { + // Test that grayscale colors have a fixed y/b ratio and x==0. + for (size_t i = 1; i < 255; i++) { + float x, y, b; + LinearSrgbToOpsin(i / 255., i / 255., i / 255., &x, &y, &b); + EXPECT_NEAR(0, x, 1e-6); + EXPECT_NEAR(kYToBRatio, b / y, 3e-5); + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/opsin_inverse_test.cc b/media/libjxl/src/lib/jxl/opsin_inverse_test.cc new file mode 100644 index 000000000..9fa8290e2 --- /dev/null +++ b/media/libjxl/src/lib/jxl/opsin_inverse_test.cc @@ -0,0 +1,57 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "gtest/gtest.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +TEST(OpsinInverseTest, LinearInverseInverts) { + Image3F linear(128, 128); + RandomFillImage(&linear, 0.0f, 1.0f); + + CodecInOut io; + io.metadata.m.SetFloat32Samples(); + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(CopyImage(linear), io.metadata.m.color_encoding); + ThreadPool* null_pool = nullptr; + Image3F opsin(io.xsize(), io.ysize()); + (void)ToXYB(io.Main(), null_pool, &opsin, GetJxlCms()); + + OpsinParams opsin_params; + opsin_params.Init(/*intensity_target=*/255.0f); + OpsinToLinearInplace(&opsin, /*pool=*/nullptr, opsin_params); + + VerifyRelativeError(linear, opsin, 3E-3, 2E-4); +} + +TEST(OpsinInverseTest, YcbCrInverts) { + Image3F rgb(128, 128); + RandomFillImage(&rgb, 0.0f, 1.0f); + + ThreadPool* null_pool = nullptr; + Image3F ycbcr(rgb.xsize(), rgb.ysize()); + EXPECT_TRUE(RgbToYcbcr(rgb.Plane(0), rgb.Plane(1), rgb.Plane(2), + &ycbcr.Plane(1), &ycbcr.Plane(0), &ycbcr.Plane(2), + null_pool)); + + Image3F rgb2(rgb.xsize(), rgb.ysize()); + YcbcrToRgb(ycbcr, &rgb2, Rect(rgb)); + + VerifyRelativeError(rgb, rgb2, 4E-5, 4E-7); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/opsin_params.cc b/media/libjxl/src/lib/jxl/opsin_params.cc new file mode 100644 index 000000000..f80a18af8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/opsin_params.cc @@ -0,0 +1,44 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/opsin_params.h" + +#include + +#include "lib/jxl/linalg.h" + +namespace jxl { + +#define INVERSE_OPSIN_FROM_SPEC 1 + +const float* GetOpsinAbsorbanceInverseMatrix() { +#if INVERSE_OPSIN_FROM_SPEC + return DefaultInverseOpsinAbsorbanceMatrix(); +#else // INVERSE_OPSIN_FROM_SPEC + // Compute the inverse opsin matrix from the forward matrix. Less precise + // than taking the values from the specification, but must be used if the + // forward transform is changed and the spec will require updating. + static const float* const kInverse = [] { + static float inverse[9]; + for (int i = 0; i < 9; i++) { + inverse[i] = kOpsinAbsorbanceMatrix[i]; + } + Inv3x3Matrix(inverse); + return inverse; + }(); + return kInverse; +#endif // INVERSE_OPSIN_FROM_SPEC +} + +void InitSIMDInverseMatrix(const float* JXL_RESTRICT inverse, + float* JXL_RESTRICT simd_inverse, + float intensity_target) { + for (size_t i = 0; i < 9; ++i) { + simd_inverse[4 * i] = simd_inverse[4 * i + 1] = simd_inverse[4 * i + 2] = + simd_inverse[4 * i + 3] = inverse[i] * (255.0f / intensity_target); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/opsin_params.h b/media/libjxl/src/lib/jxl/opsin_params.h new file mode 100644 index 000000000..e8e2e4331 --- /dev/null +++ b/media/libjxl/src/lib/jxl/opsin_params.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_OPSIN_PARAMS_H_ +#define LIB_JXL_OPSIN_PARAMS_H_ + +// Constants that define the XYB color space. + +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Parameters for opsin absorbance. +static const float kM02 = 0.078f; +static const float kM00 = 0.30f; +static const float kM01 = 1.0f - kM02 - kM00; + +static const float kM12 = 0.078f; +static const float kM10 = 0.23f; +static const float kM11 = 1.0f - kM12 - kM10; + +static const float kM20 = 0.24342268924547819f; +static const float kM21 = 0.20476744424496821f; +static const float kM22 = 1.0f - kM20 - kM21; + +static const float kBScale = 1.0f; +static const float kYToBRatio = 1.0f; // works better with 0.50017729543783418 +static const float kBToYRatio = 1.0f / kYToBRatio; + +static const float kB0 = 0.0037930732552754493f; +static const float kB1 = kB0; +static const float kB2 = kB0; + +// Opsin absorbance matrix is now frozen. +static const float kOpsinAbsorbanceMatrix[9] = { + kM00, kM01, kM02, kM10, kM11, kM12, kM20, kM21, kM22, +}; + +// Must be the inverse matrix of kOpsinAbsorbanceMatrix and match the spec. +static inline const float* DefaultInverseOpsinAbsorbanceMatrix() { + static float kDefaultInverseOpsinAbsorbanceMatrix[9] = { + 11.031566901960783f, -9.866943921568629f, -0.16462299647058826f, + -3.254147380392157f, 4.418770392156863f, -0.16462299647058826f, + -3.6588512862745097f, 2.7129230470588235f, 1.9459282392156863f}; + return kDefaultInverseOpsinAbsorbanceMatrix; +} + +// Returns 3x3 row-major matrix inverse of kOpsinAbsorbanceMatrix. +// opsin_image_test verifies this is actually the inverse. +const float* GetOpsinAbsorbanceInverseMatrix(); + +void InitSIMDInverseMatrix(const float* JXL_RESTRICT inverse, + float* JXL_RESTRICT simd_inverse, + float intensity_target); + +static const float kOpsinAbsorbanceBias[3] = { + kB0, + kB1, + kB2, +}; + +static const float kNegOpsinAbsorbanceBiasRGB[4] = { + -kOpsinAbsorbanceBias[0], -kOpsinAbsorbanceBias[1], + -kOpsinAbsorbanceBias[2], 1.0f}; + +} // namespace jxl + +#endif // LIB_JXL_OPSIN_PARAMS_H_ diff --git a/media/libjxl/src/lib/jxl/optimize.cc b/media/libjxl/src/lib/jxl/optimize.cc new file mode 100644 index 000000000..081659636 --- /dev/null +++ b/media/libjxl/src/lib/jxl/optimize.cc @@ -0,0 +1,163 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/optimize.h" + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +namespace optimize { + +namespace { + +// simplex vector must be sorted by first element of its elements +std::vector Midpoint(const std::vector>& simplex) { + JXL_CHECK(!simplex.empty()); + JXL_CHECK(simplex.size() == simplex[0].size()); + int dim = simplex.size() - 1; + std::vector result(dim + 1, 0); + for (int i = 0; i < dim; i++) { + for (int k = 0; k < dim; k++) { + result[i + 1] += simplex[k][i + 1]; + } + result[i + 1] /= dim; + } + return result; +} + +// first element ignored +std::vector Subtract(const std::vector& a, + const std::vector& b) { + JXL_CHECK(a.size() == b.size()); + std::vector result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = a[i] - b[i]; + } + return result; +} + +// first element ignored +std::vector Add(const std::vector& a, + const std::vector& b) { + JXL_CHECK(a.size() == b.size()); + std::vector result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = a[i] + b[i]; + } + return result; +} + +// first element ignored +std::vector Average(const std::vector& a, + const std::vector& b) { + JXL_CHECK(a.size() == b.size()); + std::vector result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = 0.5 * (a[i] + b[i]); + } + return result; +} + +// vec: [0] will contain the objective function, [1:] will +// contain the vector position for the objective function. +// fun: the function evaluates the value. +void Eval(std::vector* vec, + const std::function&)>& fun) { + std::vector args(vec->begin() + 1, vec->end()); + (*vec)[0] = fun(args); +} + +void Sort(std::vector>* simplex) { + std::sort(simplex->begin(), simplex->end()); +} + +// Main iteration step of Nelder-Mead like optimization. +void Reflect(std::vector>* simplex, + const std::function&)>& fun) { + Sort(simplex); + const std::vector& last = simplex->back(); + std::vector mid = Midpoint(*simplex); + std::vector diff = Subtract(mid, last); + std::vector mirrored = Add(mid, diff); + Eval(&mirrored, fun); + if (mirrored[0] > (*simplex)[simplex->size() - 2][0]) { + // Still the worst, shrink towards the best. + std::vector shrinking = Average(simplex->back(), (*simplex)[0]); + Eval(&shrinking, fun); + simplex->back() = shrinking; + } else if (mirrored[0] < (*simplex)[0][0]) { + // new best + std::vector even_further = Add(mirrored, diff); + Eval(&even_further, fun); + if (even_further[0] < mirrored[0]) { + mirrored = even_further; + } + simplex->back() = mirrored; + } else { + // not a best, not a worst point + simplex->back() = mirrored; + } +} + +// Initialize the simplex at origin. +std::vector> InitialSimplex( + int dim, double amount, const std::vector& init, + const std::function&)>& fun) { + std::vector best(1 + dim, 0); + std::copy(init.begin(), init.end(), best.begin() + 1); + Eval(&best, fun); + std::vector> result{best}; + for (int i = 0; i < dim; i++) { + best = result[0]; + best[i + 1] += amount; + Eval(&best, fun); + result.push_back(best); + Sort(&result); + } + return result; +} + +// For comparing the same with the python tool +/*void RunSimplexExternal( + int dim, double amount, int max_iterations, + const std::function&))>& fun) { + vector vars; + for (int i = 0; i < dim; i++) { + vars.push_back(atof(getenv(StrCat("VAR", i).c_str()))); + } + double result = fun(vars); + std::cout << "Result=" << result; +}*/ + +} // namespace + +std::vector RunSimplex( + int dim, double amount, int max_iterations, const std::vector& init, + const std::function&)>& fun) { + std::vector> simplex = + InitialSimplex(dim, amount, init, fun); + for (int i = 0; i < max_iterations; i++) { + Sort(&simplex); + Reflect(&simplex, fun); + } + return simplex[0]; +} + +std::vector RunSimplex( + int dim, double amount, int max_iterations, + const std::function&)>& fun) { + std::vector init(dim, 0.0); + return RunSimplex(dim, amount, max_iterations, init, fun); +} + +} // namespace optimize + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/optimize.h b/media/libjxl/src/lib/jxl/optimize.h new file mode 100644 index 000000000..0a6019821 --- /dev/null +++ b/media/libjxl/src/lib/jxl/optimize.h @@ -0,0 +1,218 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Utility functions for optimizing multi-dimensional nonlinear functions. + +#ifndef LIB_JXL_OPTIMIZE_H_ +#define LIB_JXL_OPTIMIZE_H_ + +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace optimize { + +// An array type of numeric values that supports math operations with operator-, +// operator+, etc. +template +class Array { + public: + Array() = default; + explicit Array(T v) { + for (size_t i = 0; i < N; i++) v_[i] = v; + } + + size_t size() const { return N; } + + T& operator[](size_t index) { + JXL_DASSERT(index < N); + return v_[index]; + } + T operator[](size_t index) const { + JXL_DASSERT(index < N); + return v_[index]; + } + + private: + // The values used by this Array. + T v_[N]; +}; + +template +Array operator+(const Array& x, const Array& y) { + Array z; + for (size_t i = 0; i < N; ++i) { + z[i] = x[i] + y[i]; + } + return z; +} + +template +Array operator-(const Array& x, const Array& y) { + Array z; + for (size_t i = 0; i < N; ++i) { + z[i] = x[i] - y[i]; + } + return z; +} + +template +Array operator*(T v, const Array& x) { + Array y; + for (size_t i = 0; i < N; ++i) { + y[i] = v * x[i]; + } + return y; +} + +template +T operator*(const Array& x, const Array& y) { + T r = 0.0; + for (size_t i = 0; i < N; ++i) { + r += x[i] * y[i]; + } + return r; +} + +// Runs Nelder-Mead like optimization. Runs for max_iterations times, +// fun gets called with a vector of size dim as argument, and returns the score +// based on those parameters (lower is better). Returns a vector of dim+1 +// dimensions, where the first value is the optimal value of the function and +// the rest is the argmin value. Use init to pass an initial guess or where +// the optimal value is. +// +// Usage example: +// +// RunSimplex(2, 0.1, 100, [](const vector& v) { +// return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7); +// }); +// +// Returns (0.0, 5, 7) +std::vector RunSimplex( + int dim, double amount, int max_iterations, + const std::function&)>& fun); +std::vector RunSimplex( + int dim, double amount, int max_iterations, const std::vector& init, + const std::function&)>& fun); + +// Implementation of the Scaled Conjugate Gradient method described in the +// following paper: +// Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised +// Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 +// http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf +// +// The Function template parameter is a class that has the following method: +// +// // Returns the value of the function at point w and sets *df to be the +// // negative gradient vector of the function at point w. +// double Compute(const optimize::Array& w, +// optimize::Array* df) const; +// +// Returns a vector w, such that |df(w)| < grad_norm_threshold. +template +Array OptimizeWithScaledConjugateGradientMethod( + const Function& f, const Array& w0, const T grad_norm_threshold, + size_t max_iters) { + const size_t n = w0.size(); + const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; + const T sigma0 = static_cast(0.0001); + const T l_min = static_cast(1.0e-15); + const T l_max = static_cast(1.0e15); + + Array w = w0; + Array wp; + Array r; + Array rt; + Array e; + Array p; + T psq; + T fp; + T D; + T d; + T m; + T a; + T b; + T s; + T t; + + T fw = f.Compute(w, &r); + T rsq = r * r; + e = r; + p = r; + T l = static_cast(1.0); + bool success = true; + size_t n_success = 0; + size_t k = 0; + + while (k++ < max_iters) { + if (success) { + m = -(p * r); + if (m >= 0) { + p = r; + m = -(p * r); + } + psq = p * p; + s = sigma0 / std::sqrt(psq); + f.Compute(w + (s * p), &rt); + t = (p * (r - rt)) / s; + } + + d = t + l * psq; + if (d <= 0) { + d = l * psq; + l = l - t / psq; + } + + a = -m / d; + wp = w + a * p; + fp = f.Compute(wp, &rt); + + D = 2.0 * (fp - fw) / (a * m); + if (D >= 0.0) { + success = true; + n_success++; + w = wp; + } else { + success = false; + } + + if (success) { + e = r; + r = rt; + rsq = r * r; + fw = fp; + if (rsq <= rsq_threshold) { + break; + } + } + + if (D < 0.25) { + l = std::min(4.0 * l, l_max); + } else if (D > 0.75) { + l = std::max(0.25 * l, l_min); + } + + if ((n_success % n) == 0) { + p = r; + l = 1.0; + } else if (success) { + b = ((e - r) * r) / m; + p = b * p + r; + } + } + + return w; +} + +} // namespace optimize +} // namespace jxl + +#endif // LIB_JXL_OPTIMIZE_H_ diff --git a/media/libjxl/src/lib/jxl/optimize_test.cc b/media/libjxl/src/lib/jxl/optimize_test.cc new file mode 100644 index 000000000..c606a035c --- /dev/null +++ b/media/libjxl/src/lib/jxl/optimize_test.cc @@ -0,0 +1,109 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/optimize.h" + +#include + +#include "gtest/gtest.h" + +namespace jxl { +namespace optimize { +namespace { + +// The maximum number of iterations for the test. +static const size_t kMaxTestIter = 100000; + +// F(w) = (w - w_min)^2. +struct SimpleQuadraticFunction { + typedef Array ArrayType; + explicit SimpleQuadraticFunction(const ArrayType& w0) : w_min(w0) {} + + double Compute(const ArrayType& w, ArrayType* df) const { + ArrayType dw = w - w_min; + *df = -2.0 * dw; + return dw * dw; + } + + ArrayType w_min; +}; + +// F(alpha, beta, gamma| x,y) = \sum_i(y_i - (alpha x_i ^ gamma + beta))^2. +struct PowerFunction { + explicit PowerFunction(const std::vector& x0, + const std::vector& y0) + : x(x0), y(y0) {} + + typedef Array ArrayType; + double Compute(const ArrayType& w, ArrayType* df) const { + double loss_function = 0; + (*df)[0] = 0; + (*df)[1] = 0; + (*df)[2] = 0; + for (size_t ind = 0; ind < y.size(); ++ind) { + if (x[ind] != 0) { + double l_f = y[ind] - (w[0] * pow(x[ind], w[1]) + w[2]); + (*df)[0] += 2.0 * l_f * pow(x[ind], w[1]); + (*df)[1] += 2.0 * l_f * w[0] * pow(x[ind], w[1]) * log(x[ind]); + (*df)[2] += 2.0 * l_f * 1; + loss_function += l_f * l_f; + } + } + return loss_function; + } + + std::vector x; + std::vector y; +}; + +TEST(OptimizeTest, SimpleQuadraticFunction) { + SimpleQuadraticFunction::ArrayType w_min; + w_min[0] = 1.0; + w_min[1] = 2.0; + SimpleQuadraticFunction f(w_min); + SimpleQuadraticFunction::ArrayType w(0.); + static const double kPrecision = 1e-8; + w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision, + kMaxTestIter); + EXPECT_NEAR(w[0], 1.0, kPrecision); + EXPECT_NEAR(w[1], 2.0, kPrecision); +} + +TEST(OptimizeTest, PowerFunction) { + std::vector x(10); + std::vector y(10); + for (int ind = 0; ind < 10; ++ind) { + x[ind] = 1. * ind; + y[ind] = 2. * pow(x[ind], 3) + 5.; + } + PowerFunction f(x, y); + PowerFunction::ArrayType w(0.); + + static const double kPrecision = 0.01; + w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision, + kMaxTestIter); + EXPECT_NEAR(w[0], 2.0, kPrecision); + EXPECT_NEAR(w[1], 3.0, kPrecision); + EXPECT_NEAR(w[2], 5.0, kPrecision); +} + +TEST(OptimizeTest, SimplexOptTest) { + auto f = [](const std::vector& x) -> double { + double t1 = x[0] - 1.0; + double t2 = x[1] + 1.5; + return 2.0 + t1 * t1 + t2 * t2; + }; + auto opt = RunSimplex(2, 0.01, 100, f); + EXPECT_EQ(opt.size(), 3u); + + static const double kPrecision = 0.01; + EXPECT_NEAR(opt[0], 2.0, kPrecision); + EXPECT_NEAR(opt[1], 1.0, kPrecision); + EXPECT_NEAR(opt[2], -1.5, kPrecision); +} + +} // namespace +} // namespace optimize +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/padded_bytes_test.cc b/media/libjxl/src/lib/jxl/padded_bytes_test.cc new file mode 100644 index 000000000..d8005e469 --- /dev/null +++ b/media/libjxl/src/lib/jxl/padded_bytes_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/padded_bytes.h" + +#include // iota +#include + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(PaddedBytesTest, TestNonEmptyFirstByteZero) { + PaddedBytes pb(1); + EXPECT_EQ(0, pb[0]); + // Even after resizing.. + pb.resize(20); + EXPECT_EQ(0, pb[0]); + // And reserving. + pb.reserve(200); + EXPECT_EQ(0, pb[0]); +} + +TEST(PaddedBytesTest, TestEmptyFirstByteZero) { + PaddedBytes pb(0); + // After resizing - new zero is written despite there being nothing to copy. + pb.resize(20); + EXPECT_EQ(0, pb[0]); +} + +TEST(PaddedBytesTest, TestFillWithoutReserve) { + PaddedBytes pb; + for (size_t i = 0; i < 170u; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170u, pb.size()); + EXPECT_GE(pb.capacity(), 170u); +} + +TEST(PaddedBytesTest, TestFillWithExactReserve) { + PaddedBytes pb; + pb.reserve(170); + for (size_t i = 0; i < 170u; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170u, pb.size()); + EXPECT_EQ(pb.capacity(), 170u); +} + +TEST(PaddedBytesTest, TestFillWithMoreReserve) { + PaddedBytes pb; + pb.reserve(171); + for (size_t i = 0; i < 170u; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170u, pb.size()); + EXPECT_GT(pb.capacity(), 170u); +} + +// Can assign() a subset of the valid data. +TEST(PaddedBytesTest, TestAssignFromWithin) { + PaddedBytes pb; + pb.reserve(256); + for (size_t i = 0; i < 256; ++i) { + pb.push_back(i); + } + pb.assign(pb.data() + 64, pb.data() + 192); + EXPECT_EQ(128u, pb.size()); + for (size_t i = 0; i < 128; ++i) { + EXPECT_EQ(i + 64, pb[i]); + } +} + +// Can assign() a range with both valid and previously-allocated data. +TEST(PaddedBytesTest, TestAssignReclaim) { + PaddedBytes pb; + pb.reserve(256); + for (size_t i = 0; i < 256; ++i) { + pb.push_back(i); + } + + const uint8_t* mem = pb.data(); + pb.resize(200); + // Just shrank without reallocating + EXPECT_EQ(mem, pb.data()); + EXPECT_EQ(256u, pb.capacity()); + + // Reclaim part of initial allocation + pb.assign(pb.data() + 100, pb.data() + 240); + EXPECT_EQ(140u, pb.size()); + + for (size_t i = 0; i < 140; ++i) { + EXPECT_EQ(i + 100, pb[i]); + } +} + +// Can assign() smaller and larger ranges outside the current allocation. +TEST(PaddedBytesTest, TestAssignOutside) { + PaddedBytes pb; + pb.resize(400); + std::iota(pb.begin(), pb.end(), 1); + + std::vector small(64); + std::iota(small.begin(), small.end(), 500); + + pb.assign(small.data(), small.data() + small.size()); + EXPECT_EQ(64u, pb.size()); + for (size_t i = 0; i < 64; ++i) { + EXPECT_EQ((i + 500) & 0xFF, pb[i]); + } + + std::vector large(1000); + std::iota(large.begin(), large.end(), 600); + + pb.assign(large.data(), large.data() + large.size()); + EXPECT_EQ(1000u, pb.size()); + for (size_t i = 0; i < 1000; ++i) { + EXPECT_EQ((i + 600) & 0xFF, pb[i]); + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/passes_state.cc b/media/libjxl/src/lib/jxl/passes_state.cc new file mode 100644 index 000000000..2f287ec9b --- /dev/null +++ b/media/libjxl/src/lib/jxl/passes_state.cc @@ -0,0 +1,70 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/passes_state.h" + +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" + +namespace jxl { + +Status InitializePassesSharedState(const FrameHeader& frame_header, + PassesSharedState* JXL_RESTRICT shared, + bool encoder) { + JXL_ASSERT(frame_header.nonserialized_metadata != nullptr); + shared->frame_header = frame_header; + shared->metadata = frame_header.nonserialized_metadata; + shared->frame_dim = frame_header.ToFrameDimensions(); + shared->image_features.patches.SetPassesSharedState(shared); + + const FrameDimensions& frame_dim = shared->frame_dim; + + shared->ac_strategy = + AcStrategyImage(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->raw_quant_field = + ImageI(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->epf_sharpness = + ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->cmap = ColorCorrelationMap(frame_dim.xsize, frame_dim.ysize); + + // In the decoder, we allocate coeff orders afterwards, when we know how many + // we will actually need. + shared->coeff_order_size = kCoeffOrderMaxSize; + if (encoder && + shared->coeff_orders.size() < + frame_header.passes.num_passes * kCoeffOrderMaxSize && + frame_header.encoding == FrameEncoding::kVarDCT) { + shared->coeff_orders.resize(frame_header.passes.num_passes * + kCoeffOrderMaxSize); + } + + shared->quant_dc = ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + + bool use_dc_frame = !!(frame_header.flags & FrameHeader::kUseDcFrame); + if (!encoder && use_dc_frame) { + if (frame_header.dc_level == 4) { + return JXL_FAILURE("Invalid DC level for kUseDcFrame: %u", + frame_header.dc_level); + } + shared->dc_storage = Image3F(); + shared->dc = &shared->dc_frames[frame_header.dc_level]; + if (shared->dc->xsize() == 0) { + return JXL_FAILURE( + "kUseDcFrame specified for dc_level %u, but no frame was decoded " + "with level %u", + frame_header.dc_level, frame_header.dc_level + 1); + } + ZeroFillImage(&shared->quant_dc); + } else { + shared->dc_storage = + Image3F(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->dc = &shared->dc_storage; + } + + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/passes_state.h b/media/libjxl/src/lib/jxl/passes_state.h new file mode 100644 index 000000000..069d7acdf --- /dev/null +++ b/media/libjxl/src/lib/jxl/passes_state.h @@ -0,0 +1,138 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PASSES_STATE_H_ +#define LIB_JXL_PASSES_STATE_H_ + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/noise.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" + +// Structures that hold the (en/de)coder state for a JPEG XL kVarDCT +// (en/de)coder. + +namespace jxl { + +struct ImageFeatures { + NoiseParams noise_params; + PatchDictionary patches; + Splines splines; +}; + +// State common to both encoder and decoder. +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct PassesSharedState { + PassesSharedState() : frame_header(nullptr) {} + + // Headers and metadata. + const CodecMetadata* metadata; + FrameHeader frame_header; + + FrameDimensions frame_dim; + + // Control fields and parameters. + AcStrategyImage ac_strategy; + + // Dequant matrices + quantizer. + DequantMatrices matrices; + Quantizer quantizer{&matrices}; + ImageI raw_quant_field; + + // Per-block side information for EPF detail preservation. + ImageB epf_sharpness; + + ColorCorrelationMap cmap; + + ImageFeatures image_features; + + // Memory area for storing coefficient orders. + // `coeff_order_size` is the size used by *one* set of coefficient orders (at + // most kMaxCoeffOrderSize). A set of coefficient orders is present for each + // pass. + size_t coeff_order_size = 0; + std::vector coeff_orders; + + // Decoder-side DC and quantized DC. + ImageB quant_dc; + Image3F dc_storage; + const Image3F* JXL_RESTRICT dc = &dc_storage; + + BlockCtxMap block_ctx_map; + + Image3F dc_frames[4]; + + struct { + ImageBundle storage; + // Can either point to `storage`, if this is a frame that is not stored in + // the CodecInOut, or can point to an existing ImageBundle. + // TODO(veluca): pointing to ImageBundles in CodecInOut is not possible for + // now, as they are stored in a vector and thus may be moved. Fix this. + ImageBundle* JXL_RESTRICT frame = &storage; + // ImageBundle doesn't yet have a simple way to state it is in XYB. + bool ib_is_in_xyb = false; + } reference_frames[4] = {}; + + // Number of pre-clustered set of histograms (with the same ctx map), per + // pass. Encoded as num_histograms_ - 1. + size_t num_histograms = 0; + + bool IsGrayscale() const { return metadata->m.color_encoding.IsGray(); } + + Rect GroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim, frame_dim.xsize, + frame_dim.ysize); + return rect; + } + + Rect PaddedGroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim, + frame_dim.xsize_padded, frame_dim.ysize_padded); + return rect; + } + + Rect BlockGroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * (frame_dim.group_dim >> 3), + gy * (frame_dim.group_dim >> 3), frame_dim.group_dim >> 3, + frame_dim.group_dim >> 3, frame_dim.xsize_blocks, + frame_dim.ysize_blocks); + return rect; + } + + Rect DCGroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_dc_groups; + const size_t gy = group_index / frame_dim.xsize_dc_groups; + const Rect rect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim, + frame_dim.xsize_blocks, frame_dim.ysize_blocks); + return rect; + } +}; + +// Initialized the state information that is shared between encoder and decoder. +Status InitializePassesSharedState(const FrameHeader& frame_header, + PassesSharedState* JXL_RESTRICT shared, + bool encoder = false); + +} // namespace jxl + +#endif // LIB_JXL_PASSES_STATE_H_ diff --git a/media/libjxl/src/lib/jxl/passes_test.cc b/media/libjxl/src/lib/jxl/passes_test.cc new file mode 100644 index 000000000..a58aadc01 --- /dev/null +++ b/media/libjxl/src/lib/jxl/passes_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +TEST(PassesTest, RoundtripSmallPasses) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + cparams.progressive_mode = true; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.0)); +} + +TEST(PassesTest, RoundtripUnalignedPasses) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 12, io.ysize() / 7); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + cparams.progressive_mode = true; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.6)); +} + +TEST(PassesTest, RoundtripMultiGroupPasses) { + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + { + ThreadPoolInternal pool(4); + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + } + io.ShrinkTo(600, 1024); // partial X, full Y group + + auto test = [&](float target_distance, float threshold) { + ThreadPoolInternal pool(4); + CompressParams cparams; + cparams.butteraugli_distance = target_distance; + cparams.progressive_mode = true; + CodecInOut io2; + Roundtrip(&io, cparams, {}, &pool, &io2); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(target_distance + threshold)); + }; + + auto run1 = std::async(std::launch::async, test, 1.0f, 0.3f); + auto run2 = std::async(std::launch::async, test, 2.0f, 0.3f); +} + +TEST(PassesTest, RoundtripLargeFastPasses) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = true; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, &pool, &io2); +} + +// Checks for differing size/distance in two consecutive runs of distance 2, +// which involves additional processing including adaptive reconstruction. +// Failing this may be a sign of race conditions or invalid memory accesses. +TEST(PassesTest, RoundtripProgressiveConsistent) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = true; + cparams.butteraugli_distance = 2.0; + + // Try each xsize mod kBlockDim to verify right border handling. + for (size_t xsize = 48; xsize > 40; --xsize) { + io.ShrinkTo(xsize, 15); + + CodecInOut io2; + const size_t size2 = Roundtrip(&io, cparams, {}, &pool, &io2); + + CodecInOut io3; + const size_t size3 = Roundtrip(&io, cparams, {}, &pool, &io3); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = + ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool); + const float dist3 = + ButteraugliDistance(io, io3, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, &pool); + EXPECT_EQ(dist2, dist3); + } +} + +TEST(PassesTest, AllDownsampleFeasible) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + &aux, &pool)); + + EXPECT_LE(compressed.size(), 240000u); + float target_butteraugli[9] = {}; + target_butteraugli[1] = 2.5f; + target_butteraugli[2] = 16.0f; + target_butteraugli[4] = 20.0f; + target_butteraugli[8] = 80.0f; + + // The default progressive encoding scheme should make all these downsampling + // factors achievable. + // TODO(veluca): re-enable downsampling 16. + std::vector downsamplings = {1, 2, 4, 8}; //, 16}; + + auto check = [&](const uint32_t task, size_t /* thread */) -> void { + const size_t downsampling = downsamplings[task]; + extras::JXLDecompressParams dparams; + dparams.max_downsampling = downsampling; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output, nullptr)); + EXPECT_EQ(output.xsize(), io.xsize()) << "downsampling = " << downsampling; + EXPECT_EQ(output.ysize(), io.ysize()) << "downsampling = " << downsampling; + EXPECT_LE(ButteraugliDistance(io, output, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, nullptr), + target_butteraugli[downsampling]) + << "downsampling: " << downsampling; + }; + EXPECT_TRUE(RunOnPool(&pool, 0, downsamplings.size(), ThreadPool::NoInit, + check, "TestDownsampling")); +} + +TEST(PassesTest, AllDownsampleFeasibleQProgressive) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.qprogressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + &aux, &pool)); + + EXPECT_LE(compressed.size(), 220000u); + + float target_butteraugli[9] = {}; + target_butteraugli[1] = 3.0f; + target_butteraugli[2] = 6.0f; + target_butteraugli[4] = 10.0f; + target_butteraugli[8] = 80.0f; + + // The default progressive encoding scheme should make all these downsampling + // factors achievable. + std::vector downsamplings = {1, 2, 4, 8}; + + auto check = [&](const uint32_t task, size_t /* thread */) -> void { + const size_t downsampling = downsamplings[task]; + extras::JXLDecompressParams dparams; + dparams.max_downsampling = downsampling; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output, nullptr)); + EXPECT_EQ(output.xsize(), io.xsize()) << "downsampling = " << downsampling; + EXPECT_EQ(output.ysize(), io.ysize()) << "downsampling = " << downsampling; + EXPECT_LE(ButteraugliDistance(io, output, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, nullptr), + target_butteraugli[downsampling]) + << "downsampling: " << downsampling; + }; + EXPECT_TRUE(RunOnPool(&pool, 0, downsamplings.size(), ThreadPool::NoInit, + check, "TestQProgressive")); +} + +TEST(PassesTest, ProgressiveDownsample2DegradesCorrectlyGrayscale) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData( + "external/wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io_orig; + ASSERT_TRUE(SetFromBytes(Span(orig), &io_orig, &pool)); + Rect rect(0, 0, io_orig.xsize(), 128); + // need 2 DC groups for the DC frame to actually be progressive. + Image3F large(4242, rect.ysize()); + ZeroFillImage(&large); + CopyImageTo(rect, *io_orig.Main().color(), rect, &large); + CodecInOut io; + io.metadata = io_orig.metadata; + io.SetFromImage(std::move(large), io_orig.Main().c_current()); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.qprogressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + &aux, &pool)); + + EXPECT_LE(compressed.size(), 10000u); + + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 1; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output, nullptr)); + + dparams.max_downsampling = 2; + CodecInOut output_d2; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output_d2, nullptr)); + + // 0 if reading all the passes, ~15 if skipping the 8x pass. + float butteraugli_distance_down2_full = + ButteraugliDistance(output, output_d2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, nullptr); + + EXPECT_LE(butteraugli_distance_down2_full, 3.2f); + EXPECT_GE(butteraugli_distance_down2_full, 1.0f); +} + +TEST(PassesTest, ProgressiveDownsample2DegradesCorrectly) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io_orig; + ASSERT_TRUE(SetFromBytes(Span(orig), &io_orig, &pool)); + Rect rect(0, 0, io_orig.xsize(), 128); + // need 2 DC groups for the DC frame to actually be progressive. + Image3F large(4242, rect.ysize()); + ZeroFillImage(&large); + CopyImageTo(rect, *io_orig.Main().color(), rect, &large); + CodecInOut io; + io.SetFromImage(std::move(large), io_orig.Main().c_current()); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.qprogressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + &aux, &pool)); + + EXPECT_LE(compressed.size(), 220000u); + + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 1; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output, nullptr)); + + dparams.max_downsampling = 2; + CodecInOut output_d2; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output_d2, nullptr)); + + // 0 if reading all the passes, ~15 if skipping the 8x pass. + float butteraugli_distance_down2_full = + ButteraugliDistance(output, output_d2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, nullptr); + + EXPECT_LE(butteraugli_distance_down2_full, 3.0f); + EXPECT_GE(butteraugli_distance_down2_full, 1.0f); +} + +TEST(PassesTest, NonProgressiveDCImage) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = false; + cparams.butteraugli_distance = 2.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, GetJxlCms(), + &aux, &pool)); + + // Even in non-progressive mode, it should be possible to return a DC-only + // image. + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 100; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, compressed, &output, &pool)); + EXPECT_EQ(output.xsize(), io.xsize()); + EXPECT_EQ(output.ysize(), io.ysize()); +} + +TEST(PassesTest, RoundtripSmallNoGaborishPasses) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.butteraugli_distance = 1.0; + cparams.progressive_mode = true; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); + EXPECT_THAT(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + IsSlightlyBelow(1.2)); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/patch_dictionary_internal.h b/media/libjxl/src/lib/jxl/patch_dictionary_internal.h new file mode 100644 index 000000000..e4172f6db --- /dev/null +++ b/media/libjxl/src/lib/jxl/patch_dictionary_internal.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ +#define LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ + +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/passes_state.h" // for PassesSharedState + +namespace jxl { + +// Context numbers as specified in Section C.4.5, Listing C.2: +enum Contexts { + kNumRefPatchContext = 0, + kReferenceFrameContext = 1, + kPatchSizeContext = 2, + kPatchReferencePositionContext = 3, + kPatchPositionContext = 4, + kPatchBlendModeContext = 5, + kPatchOffsetContext = 6, + kPatchCountContext = 7, + kPatchAlphaChannelContext = 8, + kPatchClampContext = 9, + kNumPatchDictionaryContexts +}; + +} // namespace jxl + +#endif // LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ diff --git a/media/libjxl/src/lib/jxl/patch_dictionary_test.cc b/media/libjxl/src/lib/jxl/patch_dictionary_test.cc new file mode 100644 index 000000000..3a34b83d0 --- /dev/null +++ b/media/libjxl/src/lib/jxl/patch_dictionary_test.cc @@ -0,0 +1,54 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +using ::jxl::test::Roundtrip; + +TEST(PatchDictionaryTest, GrayscaleModular) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/grayscale_patches.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.SetLossless(); + cparams.patches = jxl::Override::kOn; + + CodecInOut io2; + // Without patches: ~25k + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 8000u); + VerifyRelativeError(*io.Main().color(), *io2.Main().color(), 1e-7f, 0); +} + +TEST(PatchDictionaryTest, GrayscaleVarDCT) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/grayscale_patches.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.patches = jxl::Override::kOn; + + CodecInOut io2; + // Without patches: ~47k + EXPECT_LE(Roundtrip(&io, cparams, {}, pool, &io2), 14000u); + // Without patches: ~1.2 + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + 1.1); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/preview_test.cc b/media/libjxl/src/lib/jxl/preview_test.cc new file mode 100644 index 000000000..35ec70b57 --- /dev/null +++ b/media/libjxl/src/lib/jxl/preview_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +TEST(PreviewTest, RoundtripGivenPreview) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + // Same as main image + io.preview_frame = io.Main().Copy(); + const size_t preview_xsize = 15; + const size_t preview_ysize = 27; + io.preview_frame.ShrinkTo(preview_xsize, preview_ysize); + io.metadata.m.have_preview = true; + ASSERT_TRUE(io.metadata.m.preview_size.Set(io.preview_frame.xsize(), + io.preview_frame.ysize())); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + cparams.speed_tier = SpeedTier::kSquirrel; + + CodecInOut io2; + Roundtrip(&io, cparams, {}, pool, &io2); + EXPECT_EQ(preview_xsize, io2.metadata.m.preview_size.xsize()); + EXPECT_EQ(preview_ysize, io2.metadata.m.preview_size.ysize()); + EXPECT_EQ(preview_xsize, io2.preview_frame.xsize()); + EXPECT_EQ(preview_ysize, io2.preview_frame.ysize()); + + EXPECT_LE(ButteraugliDistance(io.preview_frame, io2.preview_frame, + cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + 2.5); + EXPECT_LE( + ButteraugliDistance(io.Main(), io2.Main(), cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, pool), + 2.5); +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/progressive_split.cc b/media/libjxl/src/lib/jxl/progressive_split.cc new file mode 100644 index 000000000..d0a16b915 --- /dev/null +++ b/media/libjxl/src/lib/jxl/progressive_split.cc @@ -0,0 +1,128 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/progressive_split.h" + +#include + +#include +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +bool ProgressiveSplitter::SuperblockIsSalient(size_t row_start, + size_t col_start, size_t num_rows, + size_t num_cols) const { + if (saliency_map_ == nullptr || saliency_map_->xsize() == 0 || + saliency_threshold_ == 0.0) { + // If we do not have a saliency-map, or the threshold says to include + // every block, we straightaway classify the superblock as 'salient'. + return true; + } + const size_t row_end = std::min(saliency_map_->ysize(), row_start + num_rows); + const size_t col_end = std::min(saliency_map_->xsize(), col_start + num_cols); + for (size_t num_row = row_start; num_row < row_end; num_row++) { + const float* JXL_RESTRICT map_row = saliency_map_->ConstRow(num_row); + for (size_t num_col = col_start; num_col < col_end; num_col++) { + if (map_row[num_col] >= saliency_threshold_) { + // One of the blocks covered by this superblock is above the saliency + // threshold. + return true; + } + } + } + // We did not see any block above the saliency threshold. + return false; +} + +template +void ProgressiveSplitter::SplitACCoefficients( + const T* JXL_RESTRICT block, size_t size, const AcStrategy& acs, size_t bx, + size_t by, size_t offset, T* JXL_RESTRICT output[kMaxNumPasses][3]) { + auto shift_right_round0 = [&](T v, int shift) { + T one_if_negative = static_cast(v) >> 31; + T add = (one_if_negative << shift) - one_if_negative; + return (v + add) >> shift; + }; + // Early quit for the simple case of only one pass. + if (mode_.num_passes == 1) { + for (size_t c = 0; c < 3; c++) { + memcpy(output[0][c] + offset, block + c * size, sizeof(T) * size); + } + return; + } + size_t ncoeffs_all_done_from_earlier_passes = 1; + size_t previous_pass_salient_only = false; + + int previous_pass_shift = 0; + for (size_t num_pass = 0; num_pass < mode_.num_passes; num_pass++) { // pass + // Zero out output block. + for (size_t c = 0; c < 3; c++) { + memset(output[num_pass][c] + offset, 0, size * sizeof(T)); + } + const bool current_pass_salient_only = mode_.passes[num_pass].salient_only; + const int pass_shift = mode_.passes[num_pass].shift; + size_t frame_ncoeffs = mode_.passes[num_pass].num_coefficients; + for (size_t c = 0; c < 3; c++) { // color-channel + size_t xsize = acs.covered_blocks_x(); + size_t ysize = acs.covered_blocks_y(); + CoefficientLayout(&ysize, &xsize); + if (current_pass_salient_only || previous_pass_salient_only) { + // Current or previous pass is salient-only. + const bool superblock_is_salient = + SuperblockIsSalient(by, bx, ysize, xsize); + if (current_pass_salient_only != superblock_is_salient) { + // Current pass is salient-only, but block is not salient, + // OR last pass was salient-only, and block is salient + // (hence was already included in last pass). + continue; + } + } + for (size_t y = 0; y < ysize * frame_ncoeffs; y++) { // superblk-y + for (size_t x = 0; x < xsize * frame_ncoeffs; x++) { // superblk-x + size_t pos = y * xsize * kBlockDim + x; + if (x < xsize * ncoeffs_all_done_from_earlier_passes && + y < ysize * ncoeffs_all_done_from_earlier_passes) { + // This coefficient was already included in an earlier pass, + // which included a genuinely smaller set of coefficients + // (= is not about saliency-splitting). + continue; + } + T v = block[c * size + pos]; + // Previous pass discarded some bits: do not encode them again. + if (previous_pass_shift != 0) { + T previous_v = shift_right_round0(v, previous_pass_shift) * + (1 << previous_pass_shift); + v -= previous_v; + } + output[num_pass][c][offset + pos] = shift_right_round0(v, pass_shift); + } // superblk-x + } // superblk-y + } // color-channel + if (!current_pass_salient_only) { + // We just finished a non-salient pass. + // Hence, we are now guaranteed to have included all coeffs up to + // frame_ncoeffs in every block, unless the current pass is shifted. + if (mode_.passes[num_pass].shift == 0) { + ncoeffs_all_done_from_earlier_passes = frame_ncoeffs; + } + } + previous_pass_salient_only = current_pass_salient_only; + previous_pass_shift = mode_.passes[num_pass].shift; + } // num_pass +} + +template void ProgressiveSplitter::SplitACCoefficients( + const int32_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int32_t* JXL_RESTRICT[kMaxNumPasses][3]); + +template void ProgressiveSplitter::SplitACCoefficients( + const int16_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int16_t* JXL_RESTRICT[kMaxNumPasses][3]); + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/progressive_split.h b/media/libjxl/src/lib/jxl/progressive_split.h new file mode 100644 index 000000000..aeae98044 --- /dev/null +++ b/media/libjxl/src/lib/jxl/progressive_split.h @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PROGRESSIVE_SPLIT_H_ +#define LIB_JXL_PROGRESSIVE_SPLIT_H_ + +#include +#include + +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/splines.h" + +// Functions to split DCT coefficients in multiple passes. All the passes of a +// single frame are added together. + +namespace jxl { + +constexpr size_t kNoDownsamplingFactor = std::numeric_limits::max(); + +struct PassDefinition { + // Side of the square of the coefficients that should be kept in each 8x8 + // block. Must be greater than 1, and at most 8. Should be in non-decreasing + // order. + size_t num_coefficients; + + // How much to shift the encoded values by, with rounding. + size_t shift; + + // Whether or not we should include only salient blocks. + // TODO(veluca): ignored for now. + bool salient_only; + + // If specified, this indicates that if the requested downsampling factor is + // sufficiently high, then it is fine to stop decoding after this pass. + // By default, passes are not marked as being suitable for any downsampling. + size_t suitable_for_downsampling_of_at_least; +}; + +struct ProgressiveMode { + size_t num_passes = 1; + PassDefinition passes[kMaxNumPasses] = {PassDefinition{ + /*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/1}}; + + ProgressiveMode() = default; + + template + explicit ProgressiveMode(const PassDefinition (&p)[nump]) { + JXL_ASSERT(nump <= kMaxNumPasses); + num_passes = nump; + PassDefinition previous_pass{ + /*num_coefficients=*/1, /*shift=*/0, + /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/kNoDownsamplingFactor}; + size_t last_downsampling_factor = kNoDownsamplingFactor; + for (size_t i = 0; i < nump; i++) { + JXL_ASSERT(p[i].num_coefficients > previous_pass.num_coefficients || + (p[i].num_coefficients == previous_pass.num_coefficients && + !p[i].salient_only && previous_pass.salient_only) || + (p[i].num_coefficients == previous_pass.num_coefficients && + p[i].shift < previous_pass.shift)); + JXL_ASSERT(p[i].suitable_for_downsampling_of_at_least == + kNoDownsamplingFactor || + p[i].suitable_for_downsampling_of_at_least <= + last_downsampling_factor); + if (p[i].suitable_for_downsampling_of_at_least != kNoDownsamplingFactor) { + last_downsampling_factor = p[i].suitable_for_downsampling_of_at_least; + } + previous_pass = passes[i] = p[i]; + } + } +}; + +class ProgressiveSplitter { + public: + void SetProgressiveMode(ProgressiveMode mode) { mode_ = mode; } + + void SetSaliencyMap(const ImageF* saliency_map) { + saliency_map_ = saliency_map; + } + + void SetSaliencyThreshold(float threshold) { + saliency_threshold_ = threshold; + } + + size_t GetNumPasses() const { return mode_.num_passes; } + + void InitPasses(Passes* JXL_RESTRICT passes) const { + passes->num_passes = static_cast(GetNumPasses()); + passes->num_downsample = 0; + JXL_ASSERT(passes->num_passes != 0); + passes->shift[passes->num_passes - 1] = 0; + if (passes->num_passes == 1) return; // Done, arrays are empty + + for (uint32_t i = 0; i < mode_.num_passes - 1; ++i) { + const size_t min_downsampling_factor = + mode_.passes[i].suitable_for_downsampling_of_at_least; + passes->shift[i] = mode_.passes[i].shift; + if (1 < min_downsampling_factor && + min_downsampling_factor != kNoDownsamplingFactor) { + passes->downsample[passes->num_downsample] = min_downsampling_factor; + passes->last_pass[passes->num_downsample] = i; + if (mode_.passes[i + 1].suitable_for_downsampling_of_at_least < + min_downsampling_factor) { + passes->num_downsample += 1; + } + } + } + } + + template + void SplitACCoefficients(const T* JXL_RESTRICT block, size_t size, + const AcStrategy& acs, size_t bx, size_t by, + size_t offset, + T* JXL_RESTRICT output[kMaxNumPasses][3]); + + private: + bool SuperblockIsSalient(size_t row_start, size_t col_start, size_t num_rows, + size_t num_cols) const; + ProgressiveMode mode_; + + // Not owned, must remain valid. + const ImageF* saliency_map_ = nullptr; + float saliency_threshold_ = 0.0; +}; + +extern template void ProgressiveSplitter::SplitACCoefficients( + const int32_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int32_t* JXL_RESTRICT[kMaxNumPasses][3]); + +extern template void ProgressiveSplitter::SplitACCoefficients( + const int16_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int16_t* JXL_RESTRICT[kMaxNumPasses][3]); + +} // namespace jxl + +#endif // LIB_JXL_PROGRESSIVE_SPLIT_H_ diff --git a/media/libjxl/src/lib/jxl/quant_weights.cc b/media/libjxl/src/lib/jxl/quant_weights.cc new file mode 100644 index 000000000..756a48141 --- /dev/null +++ b/media/libjxl/src/lib/jxl/quant_weights.cc @@ -0,0 +1,1240 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "lib/jxl/quant_weights.h" + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc" +#include +#include + +#include "lib/jxl/fast_math-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sqrt; + +// kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) +// coefficient in component c. Higher weights correspond to finer quantization +// intervals and more bits spent in encoding. + +static constexpr const float kAlmostZero = 1e-8f; + +void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, + float* weights) { + for (size_t c = 0; c < 3; c++) { + size_t start = c * 64; + weights[start] = 0xBAD; + weights[start + 1] = weights[start + 8] = dct2weights[c][0]; + weights[start + 9] = dct2weights[c][1]; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + weights[start + y * 8 + x + 2] = dct2weights[c][2]; + weights[start + (y + 2) * 8 + x] = dct2weights[c][2]; + } + } + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3]; + } + } + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + weights[start + y * 8 + x + 4] = dct2weights[c][4]; + weights[start + (y + 4) * 8 + x] = dct2weights[c][4]; + } + } + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5]; + } + } + } +} + +void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, + float* weights) { + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 64; i++) { + weights[64 * c + i] = idweights[c][0]; + } + weights[64 * c + 1] = idweights[c][1]; + weights[64 * c + 8] = idweights[c][1]; + weights[64 * c + 9] = idweights[c][2]; + } +} + +float Interpolate(float pos, float max, const float* array, size_t len) { + float scaled_pos = pos * (len - 1) / max; + size_t idx = scaled_pos; + JXL_DASSERT(idx + 1 < len); + float a = array[idx]; + float b = array[idx + 1]; + return a * FastPowf(b / a, scaled_pos - idx); +} + +float Mult(float v) { + if (v > 0.0f) return 1.0f + v; + return 1.0f / (1.0f - v); +} + +using DF4 = HWY_CAPPED(float, 4); + +hwy::HWY_NAMESPACE::Vec InterpolateVec( + hwy::HWY_NAMESPACE::Vec scaled_pos, const float* array) { + HWY_CAPPED(int32_t, 4) di; + + auto idx = ConvertTo(di, scaled_pos); + + auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx)); + + // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but + // it's probably slower. + auto a = GatherIndex(DF4(), array, idx); + auto b = GatherIndex(DF4(), array + 1, idx); + + return Mul(a, FastPowf(DF4(), Div(b, a), frac)); +} + +// Computes quant weights for a COLS*ROWS-sized transform, using num_bands +// eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, +// prints the resulting matrix; if print_mode is 2, prints the matrix in a +// format suitable for a 3d plot with gnuplot. +Status GetQuantWeights( + size_t ROWS, size_t COLS, + const DctQuantWeightParams::DistanceBandsArray& distance_bands, + size_t num_bands, float* out) { + for (size_t c = 0; c < 3; c++) { + float bands[DctQuantWeightParams::kMaxDistanceBands] = { + distance_bands[c][0]}; + if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); + for (size_t i = 1; i < num_bands; i++) { + bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); + if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); + } + float scale = (num_bands - 1) / (kSqrt2 + 1e-6f); + float rcpcol = scale / (COLS - 1); + float rcprow = scale / (ROWS - 1); + JXL_ASSERT(COLS >= Lanes(DF4())); + HWY_ALIGN float l0123[4] = {0, 1, 2, 3}; + for (uint32_t y = 0; y < ROWS; y++) { + float dy = y * rcprow; + float dy2 = dy * dy; + for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) { + auto dx = + Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol)); + auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2))); + auto weight = num_bands == 1 ? Set(DF4(), bands[0]) + : InterpolateVec(scaled_distance, bands); + StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x); + } + } + } + return true; +} + +// TODO(veluca): SIMD-fy. With 256x256, this is actually slow. +Status ComputeQuantTable(const QuantEncoding& encoding, + float* JXL_RESTRICT table, + float* JXL_RESTRICT inv_table, size_t table_num, + DequantMatrices::QuantTable kind, size_t* pos) { + constexpr size_t N = kBlockDim; + size_t wrows = 8 * DequantMatrices::required_size_x[kind], + wcols = 8 * DequantMatrices::required_size_y[kind]; + size_t num = wrows * wcols; + + std::vector weights(3 * num); + + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + // Library and copy quant encoding should get replaced by the actual + // parameters by the caller. + JXL_ASSERT(false); + break; + } + case QuantEncoding::kQuantModeID: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsIdentity(encoding.idweights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT2: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT4: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x4[3 * 4 * 4]; + // Always use 4x4 GetQuantWeights for DCT4 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 4, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x4)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; + } + } + weights[c * num + 1] /= encoding.dct4multipliers[c][0]; + weights[c * num + N] /= encoding.dct4multipliers[c][0]; + weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x8[3 * 4 * 8]; + // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x8[c * 32 + (y / 2) * 8 + x]; + } + } + weights[c * num + N] /= encoding.dct4x8multipliers[c]; + } + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(GetQuantWeights( + wrows, wcols, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights.data())); + break; + } + case QuantEncoding::kQuantModeRAW: { + if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { + return JXL_FAILURE("Invalid table encoding"); + } + for (size_t i = 0; i < 3 * num; i++) { + weights[i] = + 1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]); + } + break; + } + case QuantEncoding::kQuantModeAFV: { + constexpr float kFreqs[] = { + 0xBAD, + 0xBAD, + 0.8517778890324296, + 5.37778436506804, + 0xBAD, + 0xBAD, + 4.734747904497923, + 5.449245381693219, + 1.6598270267479331, + 4, + 7.275749096817861, + 10.423227632456525, + 2.662932286148962, + 7.630657783650829, + 8.962388608184032, + 12.97166202570235, + }; + + float weights4x8[3 * 4 * 8]; + JXL_RETURN_IF_ERROR(( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8))); + float weights4x4[3 * 4 * 4]; + JXL_RETURN_IF_ERROR((GetQuantWeights( + 4, 4, encoding.dct_params_afv_4x4.distance_bands, + encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); + + constexpr float lo = 0.8517778890324296; + constexpr float hi = 12.97166202570235f - lo + 1e-6f; + for (size_t c = 0; c < 3; c++) { + float bands[4]; + bands[0] = encoding.afv_weights[c][5]; + if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); + for (size_t i = 1; i < 4; i++) { + bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); + if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); + } + size_t start = c * 64; + auto set_weight = [&start, &weights](size_t x, size_t y, float val) { + weights[start + y * 8 + x] = val; + }; + weights[start] = 1; // Not used, but causes MSAN error otherwise. + // Weights for (0, 1) and (1, 0). + set_weight(0, 1, encoding.afv_weights[c][0]); + set_weight(1, 0, encoding.afv_weights[c][1]); + // AFV special weights for 3-pixel corner. + set_weight(0, 2, encoding.afv_weights[c][2]); + set_weight(2, 0, encoding.afv_weights[c][3]); + set_weight(2, 2, encoding.afv_weights[c][4]); + + // All other AFV weights. + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + if (x < 2 && y < 2) continue; + float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4); + set_weight(2 * x, 2 * y, val); + } + } + + // Put 4x8 weights in odd rows, except (1, 0). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y + 1) * kBlockDim + x] = + weights4x8[c * 32 + y * 8 + x]; + } + } + // Put 4x4 weights in even rows / odd columns, except (0, 1). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim / 2; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = + weights4x4[c * 16 + y * 4 + x]; + } + } + } + break; + } + } + size_t prev_pos = *pos; + HWY_CAPPED(float, 64) d; + for (size_t i = 0; i < num * 3; i += Lanes(d)) { + auto inv_val = LoadU(d, weights.data() + i); + if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) || + !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) { + return JXL_FAILURE("Invalid quantization table"); + } + auto val = Div(Set(d, 1.0f), inv_val); + StoreU(val, d, table + *pos + i); + StoreU(inv_val, d, inv_table + *pos + i); + } + (*pos) += 3 * num; + + // Ensure that the lowest frequencies have a 0 inverse table. + // This does not affect en/decoding, but allows AC strategy selection to be + // slightly simpler. + size_t xs = DequantMatrices::required_size_x[kind]; + size_t ys = DequantMatrices::required_size_y[kind]; + CoefficientLayout(&ys, &xs); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ys; y++) { + for (size_t x = 0; x < xs; x++) { + inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + + x] = 0; + } + } + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +HWY_EXPORT(ComputeQuantTable); + +static constexpr const float kAlmostZero = 1e-8f; + +Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { + params->num_distance_bands = + br->ReadFixedBits() + 1; + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params->num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); + } + if (params->distance_bands[c][0] < kAlmostZero) { + return JXL_FAILURE("Distance band seed is too small"); + } + params->distance_bands[c][0] *= 64.0f; + } + return true; +} + +Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, + size_t required_size_y, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + size_t required_size = required_size_x * required_size_y; + required_size_x *= kBlockDim; + required_size_y *= kBlockDim; + int mode = br->ReadFixedBits(); + switch (mode) { + case QuantEncoding::kQuantModeLibrary: { + encoding->predefined = br->ReadFixedBits(); + if (encoding->predefined >= kNumPredefinedTables) { + return JXL_FAILURE("Invalid predefined table"); + } + break; + } + case QuantEncoding::kQuantModeID: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); + if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { + return JXL_FAILURE("ID Quantizer is too small"); + } + encoding->idweights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); + if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { + return JXL_FAILURE("Quantizer is too small"); + } + encoding->dct2weights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4x8multipliers[c])); + if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { + return JXL_FAILURE("DCT4X8 multiplier is too small"); + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4multipliers[c][i])); + if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { + return JXL_FAILURE("DCT4 multiplier is too small"); + } + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeAFV: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); + } + for (size_t i = 0; i < 6; i++) { + encoding->afv_weights[c][i] *= 64; + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeRAW: { + // Set mode early, to avoid mem-leak. + encoding->mode = QuantEncoding::kQuantModeRAW; + JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( + required_size_x, required_size_y, br, encoding, idx, + modular_frame_decoder)); + break; + } + default: + return JXL_FAILURE("Invalid quantization table encoding"); + } + encoding->mode = QuantEncoding::Mode(mode); + return true; +} + +} // namespace + +// These definitions are needed before C++17. +constexpr size_t DequantMatrices::required_size_[]; +constexpr size_t DequantMatrices::required_size_x[]; +constexpr size_t DequantMatrices::required_size_y[]; +constexpr DequantMatrices::QuantTable DequantMatrices::kQuantTable[]; + +Status DequantMatrices::Decode(BitReader* br, + ModularFrameDecoder* modular_frame_decoder) { + size_t all_default = br->ReadBits(1); + size_t num_tables = all_default ? 0 : static_cast(kNum); + encodings_.clear(); + encodings_.resize(kNum, QuantEncoding::Library(0)); + for (size_t i = 0; i < num_tables; i++) { + JXL_RETURN_IF_ERROR( + jxl::Decode(br, &encodings_[i], required_size_x[i % kNum], + required_size_y[i % kNum], i, modular_frame_decoder)); + } + computed_mask_ = 0; + return true; +} + +Status DequantMatrices::DecodeDC(BitReader* br) { + bool all_default = br->ReadBits(1); + if (!br->AllReadsWithinBounds()) return JXL_FAILURE("EOS during DecodeDC"); + if (!all_default) { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c])); + dc_quant_[c] *= 1.0f / 128.0f; + // Negative values and nearly zero are invalid values. + if (dc_quant_[c] < kAlmostZero) { + return JXL_FAILURE("Invalid dc_quant: coefficient is too small."); + } + inv_dc_quant_[c] = 1.0f / dc_quant_[c]; + } + } + return true; +} + +constexpr float V(float v) { return static_cast(v); } + +namespace { +struct DequantMatricesLibraryDef { + // DCT8 + static constexpr const QuantEncodingInternal DCT() { + return QuantEncodingInternal::DCT(DctQuantWeightParams({{{{ + V(3150.0), + V(0.0), + V(-0.4), + V(-0.4), + V(-0.4), + V(-2.0), + }}, + {{ + V(560.0), + V(0.0), + V(-0.3), + V(-0.3), + V(-0.3), + V(-0.3), + }}, + {{ + V(512.0), + V(-2.0), + V(-1.0), + V(0.0), + V(-1.0), + V(-2.0), + }}}}, + 6)); + } + + // Identity + static constexpr const QuantEncodingInternal IDENTITY() { + return QuantEncodingInternal::Identity({{{{ + V(280.0), + V(3160.0), + V(3160.0), + }}, + {{ + V(60.0), + V(864.0), + V(864.0), + }}, + {{ + V(18.0), + V(200.0), + V(200.0), + }}}}); + } + + // DCT2 + static constexpr const QuantEncodingInternal DCT2X2() { + return QuantEncodingInternal::DCT2({{{{ + V(3840.0), + V(2560.0), + V(1280.0), + V(640.0), + V(480.0), + V(300.0), + }}, + {{ + V(960.0), + V(640.0), + V(320.0), + V(180.0), + V(140.0), + V(120.0), + }}, + {{ + V(640.0), + V(320.0), + V(128.0), + V(64.0), + V(32.0), + V(16.0), + }}}}); + } + + // DCT4 (quant_kind 3) + static constexpr const QuantEncodingInternal DCT4X4() { + return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{{ + V(2200.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + V(392.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + V(112.0), + V(-0.25), + V(-0.25), + V(-0.5), + }}}}, + 4), + /* kMul */ + {{{{ + V(1.0), + V(1.0), + }}, + {{ + V(1.0), + V(1.0), + }}, + {{ + V(1.0), + V(1.0), + }}}}); + } + + // DCT16 + static constexpr const QuantEncodingInternal DCT16X16() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(8996.8725711814115328), + V(-1.3000777393353804), + V(-0.49424529824571225), + V(-0.439093774457103443), + V(-0.6350101832695744), + V(-0.90177264050827612), + V(-1.6162099239887414), + }}, + {{ + V(3191.48366296844234752), + V(-0.67424582104194355), + V(-0.80745813428471001), + V(-0.44925837484843441), + V(-0.35865440981033403), + V(-0.31322389111877305), + V(-0.37615025315725483), + }}, + {{ + V(1157.50408145487200256), + V(-2.0531423165804414), + V(-1.4), + V(-0.50687130033378396), + V(-0.42708730624733904), + V(-1.4856834539296244), + V(-4.9209142884401604), + }}}}, + 7)); + } + + // DCT32 + static constexpr const QuantEncodingInternal DCT32X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(15718.40830982518931456), + V(-1.025), + V(-0.98), + V(-0.9012), + V(-0.4), + V(-0.48819395464), + V(-0.421064), + V(-0.27), + }}, + {{ + V(7305.7636810695983104), + V(-0.8041958212306401), + V(-0.7633036457487539), + V(-0.55660379990111464), + V(-0.49785304658857626), + V(-0.43699592683512467), + V(-0.40180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(3803.53173721215041536), + V(-3.060733579805728), + V(-2.0413270132490346), + V(-2.0235650159727417), + V(-0.5495389509954993), + V(-0.4), + V(-0.4), + V(-0.3), + }}}}, + 8)); + } + + // DCT16X8 + static constexpr const QuantEncodingInternal DCT8X16() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(7240.7734393502), + V(-0.7), + V(-0.7), + V(-0.2), + V(-0.2), + V(-0.2), + V(-0.5), + }}, + {{ + V(1448.15468787004), + V(-0.5), + V(-0.5), + V(-0.5), + V(-0.2), + V(-0.2), + V(-0.2), + }}, + {{ + V(506.854140754517), + V(-1.4), + V(-0.2), + V(-0.5), + V(-0.5), + V(-1.5), + V(-3.6), + }}}}, + 7)); + } + + // DCT32X8 + static constexpr const QuantEncodingInternal DCT8X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(16283.2494710648897), + V(-1.7812845336559429), + V(-1.6309059012653515), + V(-1.0382179034313539), + V(-0.85), + V(-0.7), + V(-0.9), + V(-1.2360638576849587), + }}, + {{ + V(5089.15750884921511936), + V(-0.320049391452786891), + V(-0.35362849922161446), + V(-0.30340000000000003), + V(-0.61), + V(-0.5), + V(-0.5), + V(-0.6), + }}, + {{ + V(3397.77603275308720128), + V(-0.321327362693153371), + V(-0.34507619223117997), + V(-0.70340000000000003), + V(-0.9), + V(-1.0), + V(-1.0), + V(-1.1754605576265209), + }}}}, + 8)); + } + + // DCT32X16 + static constexpr const QuantEncodingInternal DCT16X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(13844.97076442300573), + V(-0.97113799999999995), + V(-0.658), + V(-0.42026), + V(-0.22712), + V(-0.2206), + V(-0.226), + V(-0.6), + }}, + {{ + V(4798.964084220744293), + V(-0.61125308982767057), + V(-0.83770786552491361), + V(-0.79014862079498627), + V(-0.2692727459704829), + V(-0.38272769465388551), + V(-0.22924222653091453), + V(-0.20719098826199578), + }}, + {{ + V(1807.236946760964614), + V(-1.2), + V(-1.2), + V(-0.7), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT4X8 and 8x4 + static constexpr const QuantEncodingInternal DCT4X8() { + return QuantEncodingInternal::DCT4X8( + DctQuantWeightParams({{ + {{ + V(2198.050556016380522), + V(-0.96269623020744692), + V(-0.76194253026666783), + V(-0.6551140670773547), + }}, + {{ + V(764.3655248643528689), + V(-0.92630200888366945), + V(-0.9675229603596517), + V(-0.27845290869168118), + }}, + {{ + V(527.107573587542228), + V(-1.4594385811273854), + V(-1.450082094097871593), + V(-1.5843722511996204), + }}, + }}, + 4), + /* kMuls */ + {{ + V(1.0), + V(1.0), + V(1.0), + }}); + } + // AFV + static const QuantEncodingInternal AFV0() { + return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params, + {{{{ + // 4x4/4x8 DC tendency. + V(3072.0), + V(3072.0), + // AFV corner. + V(256.0), + V(256.0), + V(256.0), + // AFV high freqs. + V(414.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + // 4x4/4x8 DC tendency. + V(1024.0), + V(1024.0), + // AFV corner. + V(50), + V(50), + V(50), + // AFV high freqs. + V(58.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + // 4x4/4x8 DC tendency. + V(384.0), + V(384.0), + // AFV corner. + V(12.0), + V(12.0), + V(12.0), + // AFV high freqs. + V(22.0), + V(-0.25), + V(-0.25), + V(-0.25), + }}}}); + } + + // DCT64 + static const QuantEncodingInternal DCT64X64() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(0.9 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(0.9 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(0.9 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT64X32 + static const QuantEncodingInternal DCT32X64() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(0.65 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(0.65 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(0.65 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + // DCT128X128 + static const QuantEncodingInternal DCT128X128() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(1.8 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(1.8 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(1.8 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT128X64 + static const QuantEncodingInternal DCT64X128() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(1.3 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(1.3 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(1.3 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + // DCT256X256 + static const QuantEncodingInternal DCT256X256() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(3.6 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(3.6 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(3.6 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT256X128 + static const QuantEncodingInternal DCT128X256() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(2.6 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(2.6 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(2.6 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } +}; +} // namespace + +const DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() { + static_assert(kNum == 17, + "Update this function when adding new quantization kinds."); + static_assert(kNumPredefinedTables == 1, + "Update this function when adding new quantization matrices to " + "the library."); + + // The library and the indices need to be kept in sync manually. + static_assert(0 == DCT, "Update the DequantLibrary array below."); + static_assert(1 == IDENTITY, "Update the DequantLibrary array below."); + static_assert(2 == DCT2X2, "Update the DequantLibrary array below."); + static_assert(3 == DCT4X4, "Update the DequantLibrary array below."); + static_assert(4 == DCT16X16, "Update the DequantLibrary array below."); + static_assert(5 == DCT32X32, "Update the DequantLibrary array below."); + static_assert(6 == DCT8X16, "Update the DequantLibrary array below."); + static_assert(7 == DCT8X32, "Update the DequantLibrary array below."); + static_assert(8 == DCT16X32, "Update the DequantLibrary array below."); + static_assert(9 == DCT4X8, "Update the DequantLibrary array below."); + static_assert(10 == AFV0, "Update the DequantLibrary array below."); + static_assert(11 == DCT64X64, "Update the DequantLibrary array below."); + static_assert(12 == DCT32X64, "Update the DequantLibrary array below."); + static_assert(13 == DCT128X128, "Update the DequantLibrary array below."); + static_assert(14 == DCT64X128, "Update the DequantLibrary array below."); + static_assert(15 == DCT256X256, "Update the DequantLibrary array below."); + static_assert(16 == DCT128X256, "Update the DequantLibrary array below."); + return DequantMatrices::DequantLibraryInternal{{ + DequantMatricesLibraryDef::DCT(), + DequantMatricesLibraryDef::IDENTITY(), + DequantMatricesLibraryDef::DCT2X2(), + DequantMatricesLibraryDef::DCT4X4(), + DequantMatricesLibraryDef::DCT16X16(), + DequantMatricesLibraryDef::DCT32X32(), + DequantMatricesLibraryDef::DCT8X16(), + DequantMatricesLibraryDef::DCT8X32(), + DequantMatricesLibraryDef::DCT16X32(), + DequantMatricesLibraryDef::DCT4X8(), + DequantMatricesLibraryDef::AFV0(), + DequantMatricesLibraryDef::DCT64X64(), + DequantMatricesLibraryDef::DCT32X64(), + // Same default for large transforms (128+) as for 64x* transforms. + DequantMatricesLibraryDef::DCT128X128(), + DequantMatricesLibraryDef::DCT64X128(), + DequantMatricesLibraryDef::DCT256X256(), + DequantMatricesLibraryDef::DCT128X256(), + }}; +} + +const QuantEncoding* DequantMatrices::Library() { + static const DequantMatrices::DequantLibraryInternal kDequantLibrary = + DequantMatrices::LibraryInit(); + // Downcast the result to a const QuantEncoding* from QuantEncodingInternal* + // since the subclass (QuantEncoding) doesn't add any new members and users + // will need to upcast to QuantEncodingInternal to access the members of that + // class. This allows to have kDequantLibrary as a constexpr value while still + // allowing to create QuantEncoding::RAW() instances that use std::vector in + // C++11. + return reinterpret_cast(kDequantLibrary.data()); +} + +DequantMatrices::DequantMatrices() { + encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); + size_t pos = 0; + size_t offsets[kNum * 3]; + for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + for (size_t c = 0; c < 3; c++) { + table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; + } + } +} + +Status DequantMatrices::EnsureComputed(uint32_t acs_mask) { + const QuantEncoding* library = Library(); + + if (!table_storage_) { + table_storage_ = hwy::AllocateAligned(2 * kTotalTableSize); + table_ = table_storage_.get(); + inv_table_ = table_storage_.get() + kTotalTableSize; + } + + size_t offsets[kNum * 3 + 1]; + size_t pos = 0; + for (size_t i = 0; i < kNum; i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + offsets[kNum * 3] = pos; + JXL_ASSERT(pos == kTotalTableSize); + + uint32_t kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (acs_mask & (1u << i)) { + kind_mask |= 1u << kQuantTable[i]; + } + } + uint32_t computed_kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (computed_mask_ & (1u << i)) { + computed_kind_mask |= 1u << kQuantTable[i]; + } + } + for (size_t table = 0; table < kNum; table++) { + if ((1 << table) & computed_kind_mask) continue; + if ((1 << table) & ~kind_mask) continue; + size_t pos = offsets[table * 3]; + if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { + JXL_CHECK(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + library[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } else { + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + encodings_[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } + JXL_ASSERT(pos == offsets[table * 3 + 3]); + } + computed_mask_ |= acs_mask; + + return true; +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/quant_weights.h b/media/libjxl/src/lib/jxl/quant_weights.h new file mode 100644 index 000000000..92a2d9ea7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/quant_weights.h @@ -0,0 +1,449 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_QUANT_WEIGHTS_H_ +#define LIB_JXL_QUANT_WEIGHTS_H_ + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +constexpr T ArraySum(T (&a)[N], size_t i = N - 1) { + static_assert(N > 0, "Trying to compute the sum of an empty array"); + return i == 0 ? a[0] : a[i] + ArraySum(a, i - 1); +} + +static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; +static constexpr size_t kNumPredefinedTables = 1; +static constexpr size_t kCeilLog2NumPredefinedTables = 0; +static constexpr size_t kLog2NumQuantModes = 3; + +struct DctQuantWeightParams { + static constexpr size_t kLog2MaxDistanceBands = 4; + static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); + typedef std::array, 3> + DistanceBandsArray; + + size_t num_distance_bands = 0; + DistanceBandsArray distance_bands = {}; + + constexpr DctQuantWeightParams() : num_distance_bands(0) {} + + constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, + size_t num_dist_bands) + : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} + + template + explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { + num_distance_bands = num_dist_bands; + for (size_t c = 0; c < 3; c++) { + memcpy(distance_bands[c].data(), dist_bands[c], + sizeof(float) * num_dist_bands); + } + } +}; + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct QuantEncodingInternal { + enum Mode { + kQuantModeLibrary, + kQuantModeID, + kQuantModeDCT2, + kQuantModeDCT4, + kQuantModeDCT4X8, + kQuantModeAFV, + kQuantModeDCT, + kQuantModeRAW, + }; + + template + struct Tag {}; + + typedef std::array, 3> IdWeights; + typedef std::array, 3> DCT2Weights; + typedef std::array, 3> DCT4Multipliers; + typedef std::array, 3> AFVWeights; + typedef std::array DCT4x8Multipliers; + + static constexpr QuantEncodingInternal Library(uint8_t predefined) { + return ((predefined < kNumPredefinedTables) || + JXL_ABORT("Assert predefined < kNumPredefinedTables")), + QuantEncodingInternal(Tag(), predefined); + } + constexpr QuantEncodingInternal(Tag /* tag */, + uint8_t predefined) + : mode(kQuantModeLibrary), predefined(predefined) {} + + // Identity + // xybweights is an array of {xweights, yweights, bweights}. + static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { + return QuantEncodingInternal(Tag(), xybweights); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const IdWeights& xybweights) + : mode(kQuantModeID), idweights(xybweights) {} + + // DCT2 + static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { + return QuantEncodingInternal(Tag(), xybweights); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DCT2Weights& xybweights) + : mode(kQuantModeDCT2), dct2weights(xybweights) {} + + // DCT4 + static constexpr QuantEncodingInternal DCT4( + const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { + return QuantEncodingInternal(Tag(), params, xybmul); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params, + const DCT4Multipliers& xybmul) + : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} + + // DCT4x8 + static constexpr QuantEncodingInternal DCT4X8( + const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { + return QuantEncodingInternal(Tag(), params, xybmul); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params, + const DCT4x8Multipliers& xybmul) + : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} + + // DCT + static constexpr QuantEncodingInternal DCT( + const DctQuantWeightParams& params) { + return QuantEncodingInternal(Tag(), params); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params) + : mode(kQuantModeDCT), dct_params(params) {} + + // AFV + static constexpr QuantEncodingInternal AFV( + const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, const AFVWeights& weights) { + return QuantEncodingInternal(Tag(), params4x8, params4x4, + weights); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, + const AFVWeights& weights) + : mode(kQuantModeAFV), + dct_params(params4x8), + afv_weights(weights), + dct_params_afv_4x4(params4x4) {} + + // This constructor is not constexpr so it can't be used in any of the + // constexpr cases above. + explicit QuantEncodingInternal(Mode mode) : mode(mode) {} + + Mode mode; + + // Weights for DCT4+ tables. + DctQuantWeightParams dct_params; + + union { + // Weights for identity. + IdWeights idweights; + + // Weights for DCT2. + DCT2Weights dct2weights; + + // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. + DCT4Multipliers dct4multipliers; + + // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, + // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + + // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated + // as in GetQuantWeights for DC and are used for other coefficients. + AFVWeights afv_weights = {}; + + // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. + DCT4x8Multipliers dct4x8multipliers; + + // Only used in kQuantModeRAW mode. + struct { + // explicit quantization table (like in JPEG) + std::vector* qtable = nullptr; + float qtable_den = 1.f / (8 * 255); + } qraw; + }; + + // Weights for 4x4 sub-block in AFV. + DctQuantWeightParams dct_params_afv_4x4; + + union { + // Which predefined table to use. Only used if mode is kQuantModeLibrary. + uint8_t predefined = 0; + + // Which other quant table to copy; must copy from a table that comes before + // the current one. Only used if mode is kQuantModeCopy. + uint8_t source; + }; +}; + +class QuantEncoding final : public QuantEncodingInternal { + public: + QuantEncoding(const QuantEncoding& other) + : QuantEncodingInternal( + static_cast(other)) { + if (mode == kQuantModeRAW && qraw.qtable) { + // Need to make a copy of the passed *qtable. + qraw.qtable = new std::vector(*other.qraw.qtable); + } + } + QuantEncoding(QuantEncoding&& other) noexcept + : QuantEncodingInternal( + static_cast(other)) { + // Steal the qtable from the other object if any. + if (mode == kQuantModeRAW) { + other.qraw.qtable = nullptr; + } + } + QuantEncoding& operator=(const QuantEncoding& other) { + if (mode == kQuantModeRAW && qraw.qtable) { + delete qraw.qtable; + } + *static_cast(this) = + QuantEncodingInternal(static_cast(other)); + if (mode == kQuantModeRAW && qraw.qtable) { + // Need to make a copy of the passed *qtable. + qraw.qtable = new std::vector(*other.qraw.qtable); + } + return *this; + } + + ~QuantEncoding() { + if (mode == kQuantModeRAW && qraw.qtable) { + delete qraw.qtable; + } + } + + // Wrappers of the QuantEncodingInternal:: static functions that return a + // QuantEncoding instead. This is using the explicit and private cast from + // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. + // In general, you should use this wrappers. The only reason to directly + // create a QuantEncodingInternal instance is if you need a constexpr version + // of this class. Note that RAW() is not supported in that case since it uses + // a std::vector. + static QuantEncoding Library(uint8_t predefined) { + return QuantEncoding(QuantEncodingInternal::Library(predefined)); + } + static QuantEncoding Identity(const IdWeights& xybweights) { + return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); + } + static QuantEncoding DCT2(const DCT2Weights& xybweights) { + return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); + } + static QuantEncoding DCT4(const DctQuantWeightParams& params, + const DCT4Multipliers& xybmul) { + return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); + } + static QuantEncoding DCT4X8(const DctQuantWeightParams& params, + const DCT4x8Multipliers& xybmul) { + return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); + } + static QuantEncoding DCT(const DctQuantWeightParams& params) { + return QuantEncoding(QuantEncodingInternal::DCT(params)); + } + static QuantEncoding AFV(const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, + const AFVWeights& weights) { + return QuantEncoding( + QuantEncodingInternal::AFV(params4x8, params4x4, weights)); + } + + // RAW, note that this one is not a constexpr one. + static QuantEncoding RAW(const std::vector& qtable, int shift = 0) { + QuantEncoding encoding(kQuantModeRAW); + encoding.qraw.qtable = new std::vector(); + *encoding.qraw.qtable = qtable; + encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); + return encoding; + } + + private: + explicit QuantEncoding(const QuantEncodingInternal& other) + : QuantEncodingInternal(other) {} + + explicit QuantEncoding(QuantEncodingInternal::Mode mode) + : QuantEncodingInternal(mode) {} +}; + +// A constexpr QuantEncodingInternal instance is often downcasted to the +// QuantEncoding subclass even if the instance wasn't an instance of the +// subclass. This is safe because user will upcast to QuantEncodingInternal to +// access any of its members. +static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), + "Don't add any members to QuantEncoding"); + +// Let's try to keep these 2**N for possible future simplicity. +const float kInvDCQuant[3] = { + 4096.0f, + 512.0f, + 256.0f, +}; + +const float kDCQuant[3] = { + 1.0f / kInvDCQuant[0], + 1.0f / kInvDCQuant[1], + 1.0f / kInvDCQuant[2], +}; + +class ModularFrameEncoder; +class ModularFrameDecoder; + +class DequantMatrices { + public: + enum QuantTable : size_t { + DCT = 0, + IDENTITY, + DCT2X2, + DCT4X4, + DCT16X16, + DCT32X32, + // DCT16X8 + DCT8X16, + // DCT32X8 + DCT8X32, + // DCT32X16 + DCT16X32, + DCT4X8, + // DCT8X4 + AFV0, + // AFV1 + // AFV2 + // AFV3 + DCT64X64, + // DCT64X32, + DCT32X64, + DCT128X128, + // DCT128X64, + DCT64X128, + DCT256X256, + // DCT256X128, + DCT128X256, + kNum + }; + + static constexpr QuantTable kQuantTable[] = { + QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, + QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, + QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, + QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, + QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, + QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, + QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, + QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, + QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, + }; + static_assert(AcStrategy::kNumValidStrategies == + sizeof(kQuantTable) / sizeof *kQuantTable, + "Update this array when adding or removing AC strategies."); + + DequantMatrices(); + + static const QuantEncoding* Library(); + + typedef std::array + DequantLibraryInternal; + // Return the array of library kNumPredefinedTables QuantEncoding entries as + // a constexpr array. Use Library() to obtain a pointer to the copy in the + // .cc file. + static const DequantLibraryInternal LibraryInit(); + + // Returns aligned memory. + JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &table_[table_offsets_[quant_kind * 3 + c]]; + } + + JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &inv_table_[table_offsets_[quant_kind * 3 + c]]; + } + + // DC quants are used in modular mode for XYB multipliers. + JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } + JXL_INLINE const float* DCQuants() const { return dc_quant_; } + + JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } + + // For encoder. + void SetEncodings(const std::vector& encodings) { + encodings_ = encodings; + computed_mask_ = 0; + } + + // For encoder. + void SetDCQuant(const float dc[3]) { + for (size_t c = 0; c < 3; c++) { + dc_quant_[c] = 1.0f / dc[c]; + inv_dc_quant_[c] = dc[c]; + } + } + + Status Decode(BitReader* br, + ModularFrameDecoder* modular_frame_decoder = nullptr); + Status DecodeDC(BitReader* br); + + const std::vector& encodings() const { return encodings_; } + + static constexpr size_t required_size_x[] = {1, 1, 1, 1, 2, 4, 1, 1, 2, + 1, 1, 8, 4, 16, 8, 32, 16}; + static_assert(kNum == sizeof(required_size_x) / sizeof(*required_size_x), + "Update this array when adding or removing quant tables."); + + static constexpr size_t required_size_y[] = {1, 1, 1, 1, 2, 4, 2, 4, 4, + 1, 1, 8, 8, 16, 16, 32, 32}; + static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y), + "Update this array when adding or removing quant tables."); + + Status EnsureComputed(uint32_t acs_mask); + + private: + static constexpr size_t required_size_[] = { + 1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512}; + static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_), + "Update this array when adding or removing quant tables."); + static constexpr size_t kTotalTableSize = + ArraySum(required_size_) * kDCTBlockSize * 3; + + uint32_t computed_mask_ = 0; + // kTotalTableSize entries followed by kTotalTableSize for inv_table + hwy::AlignedFreeUniquePtr table_storage_; + const float* table_; + const float* inv_table_; + float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; + float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; + size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; + std::vector encodings_; +}; + +} // namespace jxl + +#endif // LIB_JXL_QUANT_WEIGHTS_H_ diff --git a/media/libjxl/src/lib/jxl/quant_weights_test.cc b/media/libjxl/src/lib/jxl/quant_weights_test.cc new file mode 100644 index 000000000..f0497948a --- /dev/null +++ b/media/libjxl/src/lib/jxl/quant_weights_test.cc @@ -0,0 +1,240 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "lib/jxl/quant_weights.h" + +#include + +#include +#include +#include // HWY_ALIGN_MAX +#include +#include + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct_for_test.h" +#include "lib/jxl/dec_transforms_testonly.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_transforms.h" + +namespace jxl { +namespace { + +template +void CheckSimilar(T a, T b) { + EXPECT_EQ(a, b); +} +// minimum exponent = -15. +template <> +void CheckSimilar(float a, float b) { + float m = std::max(std::abs(a), std::abs(b)); + // 10 bits of precision are used in the format. Relative error should be + // below 2^-10. + EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b; +} + +TEST(QuantWeightsTest, DC) { + DequantMatrices mat; + float dc_quant[3] = {1e+5, 1e+3, 1e+1}; + DequantMatricesSetCustomDC(&mat, dc_quant); + for (size_t c = 0; c < 3; c++) { + CheckSimilar(mat.InvDCQuant(c), dc_quant[c]); + } +} + +void RoundtripMatrices(const std::vector& encodings) { + ASSERT_TRUE(encodings.size() == DequantMatrices::kNum); + DequantMatrices mat; + CodecMetadata metadata; + FrameHeader frame_header(&metadata); + ModularFrameEncoder encoder(frame_header, CompressParams{}); + DequantMatricesSetCustom(&mat, encodings, &encoder); + const std::vector& encodings_dec = mat.encodings(); + for (size_t i = 0; i < encodings.size(); i++) { + const QuantEncoding& e = encodings[i]; + const QuantEncoding& d = encodings_dec[i]; + // Check values roundtripped correctly. + EXPECT_EQ(e.mode, d.mode); + EXPECT_EQ(e.predefined, d.predefined); + EXPECT_EQ(e.source, d.source); + + EXPECT_EQ(static_cast(e.dct_params.num_distance_bands), + static_cast(d.dct_params.num_distance_bands)); + for (size_t c = 0; c < 3; c++) { + for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { + CheckSimilar(e.dct_params.distance_bands[c][j], + d.dct_params.distance_bands[c][j]); + } + } + + if (e.mode == QuantEncoding::kQuantModeRAW) { + EXPECT_FALSE(!e.qraw.qtable); + EXPECT_FALSE(!d.qraw.qtable); + EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size()); + for (size_t j = 0; j < e.qraw.qtable->size(); j++) { + EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]); + } + EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f); + } else { + // modes different than kQuantModeRAW use one of the other fields used + // here, which all happen to be arrays of floats. + for (size_t c = 0; c < 3; c++) { + for (size_t j = 0; j < 3; j++) { + CheckSimilar(e.idweights[c][j], d.idweights[c][j]); + } + for (size_t j = 0; j < 6; j++) { + CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]); + } + for (size_t j = 0; j < 2; j++) { + CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]); + } + CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]); + for (size_t j = 0; j < 9; j++) { + CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]); + } + for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { + CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j], + d.dct_params_afv_4x4.distance_bands[c][j]); + } + } + } + } +} + +TEST(QuantWeightsTest, AllDefault) { + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + RoundtripMatrices(encodings); +} + +void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) { + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + encodings[kind] = DequantMatrices::Library()[kind]; + RoundtripMatrices(encodings); +} + +// Ensure we can reasonably represent default quant tables. +TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); } +TEST(QuantWeightsTest, IDENTITY) { + TestSingleQuantMatrix(DequantMatrices::IDENTITY); +} +TEST(QuantWeightsTest, DCT2X2) { + TestSingleQuantMatrix(DequantMatrices::DCT2X2); +} +TEST(QuantWeightsTest, DCT4X4) { + TestSingleQuantMatrix(DequantMatrices::DCT4X4); +} +TEST(QuantWeightsTest, DCT16X16) { + TestSingleQuantMatrix(DequantMatrices::DCT16X16); +} +TEST(QuantWeightsTest, DCT32X32) { + TestSingleQuantMatrix(DequantMatrices::DCT32X32); +} +TEST(QuantWeightsTest, DCT8X16) { + TestSingleQuantMatrix(DequantMatrices::DCT8X16); +} +TEST(QuantWeightsTest, DCT8X32) { + TestSingleQuantMatrix(DequantMatrices::DCT8X32); +} +TEST(QuantWeightsTest, DCT16X32) { + TestSingleQuantMatrix(DequantMatrices::DCT16X32); +} +TEST(QuantWeightsTest, DCT4X8) { + TestSingleQuantMatrix(DequantMatrices::DCT4X8); +} +TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); } +TEST(QuantWeightsTest, RAW) { + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + std::vector matrix(3 * 32 * 32); + Rng rng(0); + for (size_t i = 0; i < matrix.size(); i++) matrix[i] = rng.UniformI(1, 256); + encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] = + QuantEncoding::RAW(matrix, 2); + RoundtripMatrices(encodings); +} + +class QuantWeightsTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest); + +TEST_P(QuantWeightsTargetTest, DCTUniform) { + constexpr float kUniformQuant = 4; + float weights[3][2] = {{1.0f / kUniformQuant, 0}, + {1.0f / kUniformQuant, 0}, + {1.0f / kUniformQuant, 0}}; + DctQuantWeightParams dct_params(weights); + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::DCT(dct_params)); + DequantMatrices dequant_matrices; + CodecMetadata metadata; + FrameHeader frame_header(&metadata); + ModularFrameEncoder encoder(frame_header, CompressParams{}); + DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder); + JXL_CHECK(dequant_matrices.EnsureComputed(~0u)); + + const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, + 1.0f / kUniformQuant}; + DequantMatricesSetCustomDC(&dequant_matrices, dc_quant); + + HWY_ALIGN_MAX float scratch_space[16 * 16 * 2]; + + // DCT8 + { + HWY_ALIGN_MAX float pixels[64]; + std::iota(std::begin(pixels), std::end(pixels), 0); + HWY_ALIGN_MAX float coeffs[64]; + const AcStrategy::Type dct = AcStrategy::DCT; + TransformFromPixels(dct, pixels, 8, coeffs, scratch_space); + HWY_ALIGN_MAX double slow_coeffs[64]; + for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i]; + DCTSlow<8>(slow_coeffs); + + for (size_t i = 0; i < 64; i++) { + // DCTSlow doesn't multiply/divide by 1/N, so we do it manually. + slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; + coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * + dequant_matrices.Matrix(dct, 0)[i]; + } + IDCTSlow<8>(slow_coeffs); + TransformToPixels(dct, coeffs, pixels, 8, scratch_space); + for (size_t i = 0; i < 64; i++) { + EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); + } + } + + // DCT16 + { + HWY_ALIGN_MAX float pixels[64 * 4]; + std::iota(std::begin(pixels), std::end(pixels), 0); + HWY_ALIGN_MAX float coeffs[64 * 4]; + const AcStrategy::Type dct = AcStrategy::DCT16X16; + TransformFromPixels(dct, pixels, 16, coeffs, scratch_space); + HWY_ALIGN_MAX double slow_coeffs[64 * 4]; + for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i]; + DCTSlow<16>(slow_coeffs); + + for (size_t i = 0; i < 64 * 4; i++) { + slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; + coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * + dequant_matrices.Matrix(dct, 0)[i]; + } + + IDCTSlow<16>(slow_coeffs); + TransformToPixels(dct, coeffs, pixels, 16, scratch_space); + for (size_t i = 0; i < 64 * 4; i++) { + EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); + } + } + + // Check that all matrices have the same DC quantization, i.e. that they all + // have the same scaling. + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6); + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/quantizer-inl.h b/media/libjxl/src/lib/jxl/quantizer-inl.h new file mode 100644 index 000000000..64d273c55 --- /dev/null +++ b/media/libjxl/src/lib/jxl/quantizer-inl.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_QUANTIZER_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_QUANTIZER_INL_H_ +#undef LIB_JXL_QUANTIZER_INL_H_ +#else +#define LIB_JXL_QUANTIZER_INL_H_ +#endif + +#include + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::AndNot; +using hwy::HWY_NAMESPACE::ApproximateReciprocal; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::Xor; + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec> AdjustQuantBias( + DI di, const size_t c, const Vec quant_i, + const float* HWY_RESTRICT biases) { + const Rebind df; + + const auto quant = ConvertTo(df, quant_i); + + // Compare |quant|, keep sign bit for negating result. + const auto kSign = BitCast(df, Set(di, INT32_MIN)); + const auto sign = And(quant, kSign); // TODO(janwas): = abs ^ orig + const auto abs_quant = AndNot(kSign, quant); + + // If |x| is 1, kZeroBias creates a different bias for each channel. + // We're implementing the following: + // if (quant == 0) return 0; + // if (quant == 1) return biases[c]; + // if (quant == -1) return -biases[c]; + // return quant - biases[3] / quant; + + // Integer comparison is not helpful because Clang incurs bypass penalties + // from unnecessarily mixing integer and float. + const auto is_01 = Lt(abs_quant, Set(df, 1.125f)); + const auto not_0 = Gt(abs_quant, Zero(df)); + + // Bitwise logic is faster than quant * biases[c]. + const auto one_bias = IfThenElseZero(not_0, Xor(Set(df, biases[c]), sign)); + + // About 2E-5 worse than ReciprocalNR or division. + const auto bias = + NegMulAdd(Set(df, biases[3]), ApproximateReciprocal(quant), quant); + + return IfThenElse(is_01, one_bias, bias); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_QUANTIZER_INL_H_ diff --git a/media/libjxl/src/lib/jxl/quantizer.cc b/media/libjxl/src/lib/jxl/quantizer.cc new file mode 100644 index 000000000..814aea276 --- /dev/null +++ b/media/libjxl/src/lib/jxl/quantizer.cc @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/quantizer.h" + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +static const int32_t kDefaultQuant = 64; + +constexpr int32_t Quantizer::kQuantMax; + +Quantizer::Quantizer(const DequantMatrices* dequant) + : Quantizer(dequant, kDefaultQuant, kGlobalScaleDenom / kDefaultQuant) {} + +Quantizer::Quantizer(const DequantMatrices* dequant, int quant_dc, + int global_scale) + : global_scale_(global_scale), quant_dc_(quant_dc), dequant_(dequant) { + JXL_ASSERT(dequant_ != nullptr); + RecomputeFromGlobalScale(); + inv_quant_dc_ = inv_global_scale_ / quant_dc_; + + memcpy(zero_bias_, kZeroBiasDefault, sizeof(kZeroBiasDefault)); +} + +void Quantizer::ComputeGlobalScaleAndQuant(float quant_dc, float quant_median, + float quant_median_absd) { + // Target value for the median value in the quant field. + const float kQuantFieldTarget = 5; + // We reduce the median of the quant field by the median absolute deviation: + // higher resolution on highly varying quant fields. + float scale = kGlobalScaleDenom * (quant_median - quant_median_absd) / + kQuantFieldTarget; + // Ensure that new_global_scale is positive and no more than 1<<15. + if (scale < 1) scale = 1; + if (scale > (1 << 15)) scale = 1 << 15; + int new_global_scale = static_cast(scale); + // Ensure that quant_dc_ will always be at least + // 0.625 * kGlobalScaleDenom/kGlobalScaleNumerator = 10. + const int scaled_quant_dc = + static_cast(quant_dc * kGlobalScaleNumerator * 1.6); + if (new_global_scale > scaled_quant_dc) { + new_global_scale = scaled_quant_dc; + if (new_global_scale <= 0) new_global_scale = 1; + } + global_scale_ = new_global_scale; + // Code below uses inv_global_scale_. + RecomputeFromGlobalScale(); + + float fval = quant_dc * inv_global_scale_ + 0.5f; + fval = std::min(1 << 16, fval); + const int new_quant_dc = static_cast(fval); + quant_dc_ = new_quant_dc; + + // quant_dc_ was updated, recompute values. + RecomputeFromGlobalScale(); +} + +void Quantizer::SetQuantFieldRect(const ImageF& qf, const Rect& rect, + ImageI* JXL_RESTRICT raw_quant_field) const { + for (size_t y = 0; y < rect.ysize(); ++y) { + const float* JXL_RESTRICT row_qf = rect.ConstRow(qf, y); + int32_t* JXL_RESTRICT row_qi = rect.Row(raw_quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + int val = ClampVal(row_qf[x] * inv_global_scale_ + 0.5f); + row_qi[x] = val; + } + } +} + +void Quantizer::SetQuantField(const float quant_dc, const ImageF& qf, + ImageI* JXL_RESTRICT raw_quant_field) { + std::vector data(qf.xsize() * qf.ysize()); + for (size_t y = 0; y < qf.ysize(); ++y) { + const float* JXL_RESTRICT row_qf = qf.Row(y); + for (size_t x = 0; x < qf.xsize(); ++x) { + float quant = row_qf[x]; + data[qf.xsize() * y + x] = quant; + } + } + std::nth_element(data.begin(), data.begin() + data.size() / 2, data.end()); + const float quant_median = data[data.size() / 2]; + std::vector deviations(data.size()); + for (size_t i = 0; i < data.size(); i++) { + deviations[i] = fabsf(data[i] - quant_median); + } + std::nth_element(deviations.begin(), + deviations.begin() + deviations.size() / 2, + deviations.end()); + const float quant_median_absd = deviations[deviations.size() / 2]; + ComputeGlobalScaleAndQuant(quant_dc, quant_median, quant_median_absd); + if (raw_quant_field) { + JXL_CHECK(SameSize(*raw_quant_field, qf)); + SetQuantFieldRect(qf, Rect(qf), raw_quant_field); + } +} + +void Quantizer::SetQuant(float quant_dc, float quant_ac, + ImageI* JXL_RESTRICT raw_quant_field) { + ComputeGlobalScaleAndQuant(quant_dc, quant_ac, 0); + int32_t val = ClampVal(quant_ac * inv_global_scale_ + 0.5f); + FillImage(val, raw_quant_field); +} + +Status QuantizerParams::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(11, 1), BitsOffset(11, 2049), BitsOffset(12, 4097), + BitsOffset(16, 8193), 1, &global_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), BitsOffset(5, 1), + BitsOffset(8, 1), BitsOffset(16, 1), 1, + &quant_dc)); + return true; +} + +Status Quantizer::Encode(BitWriter* writer, size_t layer, + AuxOut* aux_out) const { + QuantizerParams params; + params.global_scale = global_scale_; + params.quant_dc = quant_dc_; + return Bundle::Write(params, writer, layer, aux_out); +} + +Status Quantizer::Decode(BitReader* reader) { + QuantizerParams params; + JXL_RETURN_IF_ERROR(Bundle::Read(reader, ¶ms)); + global_scale_ = static_cast(params.global_scale); + quant_dc_ = static_cast(params.quant_dc); + RecomputeFromGlobalScale(); + return true; +} + +void Quantizer::DumpQuantizationMap(const ImageI& raw_quant_field) const { + printf("Global scale: %d (%.7f)\nDC quant: %d\n", global_scale_, + global_scale_ * 1.0 / kGlobalScaleDenom, quant_dc_); + printf("AC quantization Map:\n"); + for (size_t y = 0; y < raw_quant_field.ysize(); ++y) { + for (size_t x = 0; x < raw_quant_field.xsize(); ++x) { + printf(" %3d", raw_quant_field.Row(y)[x]); + } + printf("\n"); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/quantizer.h b/media/libjxl/src/lib/jxl/quantizer.h new file mode 100644 index 000000000..09e2e5e45 --- /dev/null +++ b/media/libjxl/src/lib/jxl/quantizer.h @@ -0,0 +1,183 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_QUANTIZER_H_ +#define LIB_JXL_QUANTIZER_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/quant_weights.h" + +// Quantizes DC and AC coefficients, with separate quantization tables according +// to the quant_kind (which is currently computed from the AC strategy and the +// block index inside that strategy). + +namespace jxl { + +static constexpr int kGlobalScaleDenom = 1 << 16; +static constexpr int kGlobalScaleNumerator = 4096; + +// zero-biases for quantizing channels X, Y, B +static constexpr float kZeroBiasDefault[3] = {0.5f, 0.5f, 0.5f}; + +// Returns adjusted version of a quantized integer, such that its value is +// closer to the expected value of the original. +// The residuals of AC coefficients that we quantize are not uniformly +// distributed. Numerical experiments show that they have a distribution with +// the "shape" of 1/(1+x^2) [up to some coefficients]. This means that the +// expected value of a coefficient that gets quantized to x will not be x +// itself, but (at least with reasonable approximation): +// - 0 if x is 0 +// - x * biases[c] if x is 1 or -1 +// - x - biases[3]/x otherwise +// This follows from computing the distribution of the quantization bias, which +// can be approximated fairly well by /x when |x| is at least two. +static constexpr float kBiasNumerator = 0.145f; + +static constexpr float kDefaultQuantBias[4] = { + 1.0f - 0.05465007330715401f, + 1.0f - 0.07005449891748593f, + 1.0f - 0.049935103337343655f, + 0.145f, +}; + +class Quantizer { + public: + explicit Quantizer(const DequantMatrices* dequant); + Quantizer(const DequantMatrices* dequant, int quant_dc, int global_scale); + + static constexpr int32_t kQuantMax = 256; + + static JXL_INLINE int32_t ClampVal(float val) { + return static_cast( + std::max(1.0f, std::min(val, kQuantMax))); + } + + float ScaleGlobalScale(const float scale) { + int new_global_scale = static_cast(global_scale_ * scale + 0.5f); + float scale_out = new_global_scale * 1.0f / global_scale_; + global_scale_ = new_global_scale; + RecomputeFromGlobalScale(); + return scale_out; + } + + // Recomputes other derived fields after global_scale_ has changed. + void RecomputeFromGlobalScale() { + global_scale_float_ = global_scale_ * (1.0 / kGlobalScaleDenom); + inv_global_scale_ = 1.0 * kGlobalScaleDenom / global_scale_; + inv_quant_dc_ = inv_global_scale_ / quant_dc_; + for (size_t c = 0; c < 3; c++) { + mul_dc_[c] = GetDcStep(c); + inv_mul_dc_[c] = GetInvDcStep(c); + } + } + + // Returns scaling factor such that Scale() * (RawDC() or RawQuantField()) + // pixels yields the same float values returned by GetQuantField. + JXL_INLINE float Scale() const { return global_scale_float_; } + + // Reciprocal of Scale(). + JXL_INLINE float InvGlobalScale() const { return inv_global_scale_; } + + void SetQuantFieldRect(const ImageF& qf, const Rect& rect, + ImageI* JXL_RESTRICT raw_quant_field) const; + + void SetQuantField(float quant_dc, const ImageF& qf, + ImageI* JXL_RESTRICT raw_quant_field); + + void SetQuant(float quant_dc, float quant_ac, + ImageI* JXL_RESTRICT raw_quant_field); + + // Returns the DC quantization base value, which is currently global (not + // adaptive). The actual scale factor used to dequantize pixels in channel c + // is: inv_quant_dc() * dequant_->DCQuant(c). + float inv_quant_dc() const { return inv_quant_dc_; } + + // Dequantize by multiplying with this times dequant_matrix. + float inv_quant_ac(int32_t quant) const { return inv_global_scale_ / quant; } + + Status Encode(BitWriter* writer, size_t layer, AuxOut* aux_out) const; + + Status Decode(BitReader* reader); + + void DumpQuantizationMap(const ImageI& raw_quant_field) const; + + JXL_INLINE const float* DequantMatrix(size_t quant_kind, size_t c) const { + return dequant_->Matrix(quant_kind, c); + } + + JXL_INLINE const float* InvDequantMatrix(size_t quant_kind, size_t c) const { + return dequant_->InvMatrix(quant_kind, c); + } + + // Calculates DC quantization step. + JXL_INLINE float GetDcStep(size_t c) const { + return inv_quant_dc_ * dequant_->DCQuant(c); + } + JXL_INLINE float GetInvDcStep(size_t c) const { + return dequant_->InvDCQuant(c) * (global_scale_float_ * quant_dc_); + } + + JXL_INLINE const float* MulDC() const { return mul_dc_; } + JXL_INLINE const float* InvMulDC() const { return inv_mul_dc_; } + + JXL_INLINE void ClearDCMul() { + std::fill(mul_dc_, mul_dc_ + 4, 1.f); + std::fill(inv_mul_dc_, inv_mul_dc_ + 4, 1.f); + } + + void ComputeGlobalScaleAndQuant(float quant_dc, float quant_median, + float quant_median_absd); + + private: + float mul_dc_[4]; + float inv_mul_dc_[4]; + + // These are serialized: + int global_scale_; + int quant_dc_; + + // These are derived from global_scale_: + float inv_global_scale_; + float global_scale_float_; // reciprocal of inv_global_scale_ + float inv_quant_dc_; + + float zero_bias_[3]; + const DequantMatrices* dequant_; +}; + +struct QuantizerParams : public Fields { + QuantizerParams() { Bundle::Init(this); } + JXL_FIELDS_NAME(QuantizerParams) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + uint32_t global_scale; + uint32_t quant_dc; +}; + +} // namespace jxl + +#endif // LIB_JXL_QUANTIZER_H_ diff --git a/media/libjxl/src/lib/jxl/quantizer_test.cc b/media/libjxl/src/lib/jxl/quantizer_test.cc new file mode 100644 index 000000000..d570bf6d4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/quantizer_test.cc @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/quantizer.h" + +#include "gtest/gtest.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +void TestEquivalence(int qxsize, int qysize, const Quantizer& quantizer1, + const Quantizer& quantizer2) { + ASSERT_NEAR(quantizer1.inv_quant_dc(), quantizer2.inv_quant_dc(), 1e-7); +} + +TEST(QuantizerTest, QuantizerParams) { + for (uint32_t i = 1; i < 10000; ++i) { + QuantizerParams p; + p.global_scale = i; + size_t extension_bits = 0, total_bits = 0; + EXPECT_TRUE(Bundle::CanEncode(p, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + EXPECT_GE(total_bits, 4u); + } +} + +TEST(QuantizerTest, BitStreamRoundtripSameQuant) { + const int qxsize = 8; + const int qysize = 8; + DequantMatrices dequant; + Quantizer quantizer1(&dequant); + ImageI raw_quant_field(qxsize, qysize); + quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); + BitWriter writer; + EXPECT_TRUE(quantizer1.Encode(&writer, 0, nullptr)); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + Quantizer quantizer2(&dequant); + BitReader reader(writer.GetSpan()); + EXPECT_TRUE(quantizer2.Decode(&reader)); + EXPECT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + EXPECT_TRUE(reader.Close()); + TestEquivalence(qxsize, qysize, quantizer1, quantizer2); +} + +TEST(QuantizerTest, BitStreamRoundtripRandomQuant) { + const int qxsize = 8; + const int qysize = 8; + DequantMatrices dequant; + Quantizer quantizer1(&dequant); + ImageI raw_quant_field(qxsize, qysize); + quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); + float quant_dc = 0.17f; + ImageF qf(qxsize, qysize); + RandomFillImage(&qf, 0.0f, 1.0f); + quantizer1.SetQuantField(quant_dc, qf, &raw_quant_field); + BitWriter writer; + EXPECT_TRUE(quantizer1.Encode(&writer, 0, nullptr)); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + Quantizer quantizer2(&dequant); + BitReader reader(writer.GetSpan()); + EXPECT_TRUE(quantizer2.Decode(&reader)); + EXPECT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + EXPECT_TRUE(reader.Close()); + TestEquivalence(qxsize, qysize, quantizer1, quantizer2); +} +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/rational_polynomial-inl.h b/media/libjxl/src/lib/jxl/rational_polynomial-inl.h new file mode 100644 index 000000000..176e24092 --- /dev/null +++ b/media/libjxl/src/lib/jxl/rational_polynomial-inl.h @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast SIMD evaluation of rational polynomials for approximating functions. + +#if defined(LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ +#undef LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ +#else +#define LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ +#endif + +#include + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::MulAdd; + +// Primary template: default to actual division. +template +struct FastDivision { + HWY_INLINE V operator()(const V n, const V d) const { return n / d; } +}; +// Partial specialization for float vectors. +template +struct FastDivision { + // One Newton-Raphson iteration. + static HWY_INLINE V ReciprocalNR(const V x) { + const auto rcp = ApproximateReciprocal(x); + const auto sum = Add(rcp, rcp); + const auto x_rcp = Mul(x, rcp); + return NegMulAdd(x_rcp, rcp, sum); + } + + V operator()(const V n, const V d) const { +#if 1 // Faster on SKX + return Div(n, d); +#else + return n * ReciprocalNR(d); +#endif + } +}; + +// Approximates smooth functions via rational polynomials (i.e. dividing two +// polynomials). Evaluates polynomials via Horner's scheme, which is faster than +// Clenshaw recurrence for Chebyshev polynomials. LoadDup128 allows us to +// specify constants (replicated 4x) independently of the lane count. +template +HWY_INLINE HWY_MAYBE_UNUSED V EvalRationalPolynomial(const D d, const V x, + const T (&p)[NP], + const T (&q)[NQ]) { + constexpr size_t kDegP = NP / 4 - 1; + constexpr size_t kDegQ = NQ / 4 - 1; + auto yp = LoadDup128(d, &p[kDegP * 4]); + auto yq = LoadDup128(d, &q[kDegQ * 4]); + // We use pointer arithmetic to refer to &p[(kDegP - n) * 4] to avoid a + // compiler warning that the index is out of bounds since we are already + // checking that it is not out of bounds with (kDegP >= n) and the access + // will be optimized away. Similarly with q and kDegQ. + HWY_FENCE; + if (kDegP >= 1) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 1) * 4))); + if (kDegQ >= 1) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 1) * 4))); + HWY_FENCE; + if (kDegP >= 2) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 2) * 4))); + if (kDegQ >= 2) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 2) * 4))); + HWY_FENCE; + if (kDegP >= 3) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 3) * 4))); + if (kDegQ >= 3) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 3) * 4))); + HWY_FENCE; + if (kDegP >= 4) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 4) * 4))); + if (kDegQ >= 4) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 4) * 4))); + HWY_FENCE; + if (kDegP >= 5) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 5) * 4))); + if (kDegQ >= 5) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 5) * 4))); + HWY_FENCE; + if (kDegP >= 6) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 6) * 4))); + if (kDegQ >= 6) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 6) * 4))); + HWY_FENCE; + if (kDegP >= 7) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 7) * 4))); + if (kDegQ >= 7) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 7) * 4))); + + return FastDivision()(yp, yq); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); +#endif // LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ diff --git a/media/libjxl/src/lib/jxl/rational_polynomial_test.cc b/media/libjxl/src/lib/jxl/rational_polynomial_test.cc new file mode 100644 index 000000000..13fc044a5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/rational_polynomial_test.cc @@ -0,0 +1,238 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/rational_polynomial_test.cc" +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/rational_polynomial-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using T = float; // required by EvalLog2 +using D = HWY_FULL(T); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +// Generic: only computes polynomial +struct EvalPoly { + template + T operator()(T x, const T (&p)[NP], const T (&q)[NQ]) const { + const HWY_FULL(T) d; + const auto vx = Set(d, x); + const auto approx = EvalRationalPolynomial(d, vx, p, q); + return GetLane(approx); + } +}; + +// Range reduction for log2 +struct EvalLog2 { + template + T operator()(T x, const T (&p)[NP], const T (&q)[NQ]) const { + const HWY_FULL(T) d; + auto vx = Set(d, x); + + const HWY_FULL(int32_t) di; + const auto x_bits = BitCast(di, vx); + // Cannot handle negative numbers / NaN. + JXL_DASSERT(AllTrue(di, Eq(Abs(x_bits), x_bits))); + + // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops + const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab)); // = 2/3 + // Shifted exponent = log2; also used to clear mantissa. + const auto exp_shifted = ShiftRight<23>(exp_bits); + const auto mantissa = BitCast(d, Sub(x_bits, ShiftLeft<23>(exp_shifted))); + const auto exp_val = ConvertTo(d, exp_shifted); + vx = Sub(mantissa, Set(d, 1.0f)); + + const auto approx = Add(EvalRationalPolynomial(d, vx, p, q), exp_val); + return GetLane(approx); + } +}; + +// Functions to approximate: + +T LinearToSrgb8Direct(T val) { + if (val < 0.0) return 0.0; + if (val >= 255.0) return 255.0; + if (val <= 10.0 / 12.92) return val * 12.92; + return 255.0 * (std::pow(val / 255.0, 1.0 / 2.4) * 1.055 - 0.055); +} + +T SimpleGamma(T v) { + static const T kGamma = 0.387494322593; + static const T limit = 43.01745241042018; + T bright = v - limit; + if (bright >= 0) { + static const T mul = 0.0383723643799; + v -= bright * mul; + } + static const T limit2 = 94.68634353321337; + T bright2 = v - limit2; + if (bright2 >= 0) { + static const T mul = 0.22885405968; + v -= bright2 * mul; + } + static const T offset = 0.156775786057; + static const T scale = 8.898059160493739; + T retval = scale * (offset + pow(v, kGamma)); + return retval; +} + +// Runs CaratheodoryFejer and verifies the polynomial using a lot of samples to +// return the biggest error. +template +T RunApproximation(T x0, T x1, const T (&p)[NP], const T (&q)[NQ], + const Eval& eval, T func_to_approx(T)) { + float maxerr = 0; + T lastPrint = 0; + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (T x = x0; x <= x1; x += (x1 - x0) / 10000.0) { + const T f = func_to_approx(x); + const T g = eval(x, p, q); + maxerr = std::max(fabsf(g - f), maxerr); + if (x == x0 || x - lastPrint > (x1 - x0) / 20.0) { + printf("x: %11.6f, f: %11.6f, g: %11.6f, e: %11.6f\n", x, f, g, + fabs(g - f)); + lastPrint = x; + } + } + return maxerr; +} + +void TestSimpleGamma() { + const T p[4 * (6 + 1)] = { + HWY_REP4(-5.0646949363741811E-05), HWY_REP4(6.7369380528439771E-05), + HWY_REP4(8.9376652530412794E-05), HWY_REP4(2.1153513301520462E-06), + HWY_REP4(-6.9130322970386449E-08), HWY_REP4(3.9424752749293728E-10), + HWY_REP4(1.2360288207619576E-13)}; + + const T q[4 * (6 + 1)] = { + HWY_REP4(-6.6389733798591366E-06), HWY_REP4(1.3299859726565908E-05), + HWY_REP4(3.8538748358398873E-06), HWY_REP4(-2.8707687262928236E-08), + HWY_REP4(-6.6897385800005434E-10), HWY_REP4(6.1428748869186003E-12), + HWY_REP4(-2.5475738169252870E-15)}; + + const T err = RunApproximation(0.77, 274.579999999999984, p, q, EvalPoly(), + SimpleGamma); + EXPECT_LT(err, 0.05); +} + +void TestLinearToSrgb8Direct() { + const T p[4 * (5 + 1)] = { + HWY_REP4(-9.5357499040105154E-05), HWY_REP4(4.6761186249798248E-04), + HWY_REP4(2.5708174333943594E-04), HWY_REP4(1.5250087770436082E-05), + HWY_REP4(1.1946768008931187E-07), HWY_REP4(5.9916446295972850E-11)}; + + const T q[4 * (4 + 1)] = { + HWY_REP4(1.8932479758079768E-05), HWY_REP4(2.7312342474687321E-05), + HWY_REP4(4.3901204783327006E-06), HWY_REP4(1.0417787306920273E-07), + HWY_REP4(3.0084206762140419E-10)}; + + const T err = + RunApproximation(0.77, 255, p, q, EvalPoly(), LinearToSrgb8Direct); + EXPECT_LT(err, 0.05); +} + +void TestExp() { + const T p[4 * (2 + 1)] = {HWY_REP4(9.6266879665530902E-01), + HWY_REP4(4.8961265681586763E-01), + HWY_REP4(8.2619259189548433E-02)}; + const T q[4 * (2 + 1)] = {HWY_REP4(9.6259895571622622E-01), + HWY_REP4(-4.7272457588933831E-01), + HWY_REP4(7.4802088567547664E-02)}; + const T err = + RunApproximation(-1, 1, p, q, EvalPoly(), [](T x) { return T(exp(x)); }); + EXPECT_LT(err, 1E-4); +} + +void TestNegExp() { + // 4,3 is the min required for monotonicity; max error in 0,10: 751 ppm + // no benefit for k>50. + const T p[4 * (4 + 1)] = { + HWY_REP4(5.9580258551150123E-02), HWY_REP4(-2.5073728806886408E-02), + HWY_REP4(4.1561830213689248E-03), HWY_REP4(-3.1815408488900372E-04), + HWY_REP4(9.3866690094906802E-06)}; + const T q[4 * (3 + 1)] = { + HWY_REP4(5.9579108238812878E-02), HWY_REP4(3.4542074345478582E-02), + HWY_REP4(8.7263562483501714E-03), HWY_REP4(1.4095109143061216E-03)}; + + const T err = + RunApproximation(0, 10, p, q, EvalPoly(), [](T x) { return T(exp(-x)); }); + EXPECT_LT(err, sizeof(T) == 8 ? 2E-5 : 3E-5); +} + +void TestSin() { + const T p[4 * (6 + 1)] = { + HWY_REP4(1.5518122109203780E-05), HWY_REP4(2.3388958643675966E+00), + HWY_REP4(-8.6705520940849157E-01), HWY_REP4(-1.9702294764873535E-01), + HWY_REP4(1.2193404314472320E-01), HWY_REP4(-1.7373966109788839E-02), + HWY_REP4(7.8829435883034796E-04)}; + const T q[4 * (5 + 1)] = { + HWY_REP4(2.3394371422557279E+00), HWY_REP4(-8.7028221081288615E-01), + HWY_REP4(2.0052872219658430E-01), HWY_REP4(-3.2460335995264836E-02), + HWY_REP4(3.1546157932479282E-03), HWY_REP4(-1.6692542019380155E-04)}; + + const T err = RunApproximation(0, Pi(1) * 2, p, q, EvalPoly(), + [](T x) { return T(sin(x)); }); + EXPECT_LT(err, sizeof(T) == 8 ? 5E-4 : 7E-4); +} + +void TestLog() { + HWY_ALIGN const T p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06), + HWY_REP4(1.4287160470083755E+00), + HWY_REP4(7.4245873327820566E-01)}; + HWY_ALIGN const T q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01), + HWY_REP4(1.0096718572241148E+00), + HWY_REP4(1.7409343003366853E-01)}; + const T err = RunApproximation(1E-6, 1000, p, q, EvalLog2(), std::log2); + printf("%E\n", err); +} + +HWY_NOINLINE void TestRationalPolynomial() { + TestSimpleGamma(); + TestLinearToSrgb8Direct(); + TestExp(); + TestNegExp(); + TestSin(); + TestLog(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class RationalPolynomialTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(RationalPolynomialTest); + +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestSimpleGamma); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestLinearToSrgb8Direct); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestExp); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestNegExp); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestSin); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestLog); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.cc b/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.cc new file mode 100644 index 000000000..91147303a --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.cc @@ -0,0 +1,865 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/low_memory_render_pipeline.h" + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/arch_macros.h" + +namespace jxl { +std::pair +LowMemoryRenderPipeline::ColorDimensionsToChannelDimensions( + std::pair in, size_t c, size_t stage) const { + std::pair ret; + std::pair shift = channel_shifts_[stage][c]; + ret.first = + ((in.first << base_color_shift_) + (1 << shift.first) - 1) >> shift.first; + ret.second = ((in.second << base_color_shift_) + (1 << shift.second) - 1) >> + shift.second; + return ret; +} + +std::pair LowMemoryRenderPipeline::BorderToStore( + size_t c) const { + auto ret = ColorDimensionsToChannelDimensions(group_border_, c, 0); + ret.first += padding_[0][c].first; + ret.second += padding_[0][c].second; + return ret; +} + +void LowMemoryRenderPipeline::SaveBorders(size_t group_id, size_t c, + const ImageF& in) { + size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t gx = group_id % frame_dimensions_.xsize_groups; + size_t hshift = channel_shifts_[0][c].first; + size_t vshift = channel_shifts_[0][c].second; + size_t x0 = gx * GroupInputXSize(c); + size_t x1 = std::min((gx + 1) * GroupInputXSize(c), + DivCeil(frame_dimensions_.xsize_upsampled, 1 << hshift)); + size_t y0 = gy * GroupInputYSize(c); + size_t y1 = std::min((gy + 1) * GroupInputYSize(c), + DivCeil(frame_dimensions_.ysize_upsampled, 1 << vshift)); + + auto borders = BorderToStore(c); + size_t borderx_write = borders.first; + size_t bordery_write = borders.second; + + if (gy > 0) { + Rect from(group_data_x_border_, group_data_y_border_, x1 - x0, + bordery_write); + Rect to(x0, (gy * 2 - 1) * bordery_write, x1 - x0, bordery_write); + CopyImageTo(from, in, to, &borders_horizontal_[c]); + } + if (gy + 1 < frame_dimensions_.ysize_groups) { + Rect from(group_data_x_border_, + group_data_y_border_ + y1 - y0 - bordery_write, x1 - x0, + bordery_write); + Rect to(x0, (gy * 2) * bordery_write, x1 - x0, bordery_write); + CopyImageTo(from, in, to, &borders_horizontal_[c]); + } + if (gx > 0) { + Rect from(group_data_x_border_, group_data_y_border_, borderx_write, + y1 - y0); + Rect to((gx * 2 - 1) * borderx_write, y0, borderx_write, y1 - y0); + CopyImageTo(from, in, to, &borders_vertical_[c]); + } + if (gx + 1 < frame_dimensions_.xsize_groups) { + Rect from(group_data_x_border_ + x1 - x0 - borderx_write, + group_data_y_border_, borderx_write, y1 - y0); + Rect to((gx * 2) * borderx_write, y0, borderx_write, y1 - y0); + CopyImageTo(from, in, to, &borders_vertical_[c]); + } +} + +void LowMemoryRenderPipeline::LoadBorders(size_t group_id, size_t c, + const Rect& r, ImageF* out) { + size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t gx = group_id % frame_dimensions_.xsize_groups; + size_t hshift = channel_shifts_[0][c].first; + size_t vshift = channel_shifts_[0][c].second; + // Coordinates of the group in the image. + size_t x0 = gx * GroupInputXSize(c); + size_t x1 = std::min((gx + 1) * GroupInputXSize(c), + DivCeil(frame_dimensions_.xsize_upsampled, 1 << hshift)); + size_t y0 = gy * GroupInputYSize(c); + size_t y1 = std::min((gy + 1) * GroupInputYSize(c), + DivCeil(frame_dimensions_.ysize_upsampled, 1 << vshift)); + + size_t paddingx = padding_[0][c].first; + size_t paddingy = padding_[0][c].second; + + auto borders = BorderToStore(c); + size_t borderx_write = borders.first; + size_t bordery_write = borders.second; + + // Limits of the area to copy from, in image coordinates. + JXL_DASSERT(r.x0() == 0 || (r.x0() << base_color_shift_) >= paddingx); + size_t x0src = DivCeil(r.x0() << base_color_shift_, 1 << hshift); + if (x0src != 0) { + x0src -= paddingx; + } + // r may be such that r.x1 (namely x0() + xsize()) is within paddingx of the + // right side of the image, so we use min() here. + size_t x1src = + DivCeil((r.x0() + r.xsize()) << base_color_shift_, 1 << hshift); + x1src = std::min(x1src + paddingx, + DivCeil(frame_dimensions_.xsize_upsampled, 1 << hshift)); + + // Similar computation for y. + JXL_DASSERT(r.y0() == 0 || (r.y0() << base_color_shift_) >= paddingy); + size_t y0src = DivCeil(r.y0() << base_color_shift_, 1 << vshift); + if (y0src != 0) { + y0src -= paddingy; + } + size_t y1src = + DivCeil((r.y0() + r.ysize()) << base_color_shift_, 1 << vshift); + y1src = std::min(y1src + paddingy, + DivCeil(frame_dimensions_.ysize_upsampled, 1 << vshift)); + + // Copy other groups' borders from the border storage. + if (y0src < y0) { + JXL_DASSERT(gy > 0); + CopyImageTo( + Rect(x0src, (gy * 2 - 2) * bordery_write, x1src - x0src, bordery_write), + borders_horizontal_[c], + Rect(group_data_x_border_ + x0src - x0, + group_data_y_border_ - bordery_write, x1src - x0src, + bordery_write), + out); + } + if (y1src > y1) { + // When copying the bottom border we must not be on the bottom groups. + JXL_DASSERT(gy + 1 < frame_dimensions_.ysize_groups); + CopyImageTo( + Rect(x0src, (gy * 2 + 1) * bordery_write, x1src - x0src, bordery_write), + borders_horizontal_[c], + Rect(group_data_x_border_ + x0src - x0, group_data_y_border_ + y1 - y0, + x1src - x0src, bordery_write), + out); + } + if (x0src < x0) { + JXL_DASSERT(gx > 0); + CopyImageTo( + Rect((gx * 2 - 2) * borderx_write, y0src, borderx_write, y1src - y0src), + borders_vertical_[c], + Rect(group_data_x_border_ - borderx_write, + group_data_y_border_ + y0src - y0, borderx_write, y1src - y0src), + out); + } + if (x1src > x1) { + // When copying the right border we must not be on the rightmost groups. + JXL_DASSERT(gx + 1 < frame_dimensions_.xsize_groups); + CopyImageTo( + Rect((gx * 2 + 1) * borderx_write, y0src, borderx_write, y1src - y0src), + borders_vertical_[c], + Rect(group_data_x_border_ + x1 - x0, group_data_y_border_ + y0src - y0, + borderx_write, y1src - y0src), + out); + } +} + +size_t LowMemoryRenderPipeline::GroupInputXSize(size_t c) const { + return (frame_dimensions_.group_dim << base_color_shift_) >> + channel_shifts_[0][c].first; +} + +size_t LowMemoryRenderPipeline::GroupInputYSize(size_t c) const { + return (frame_dimensions_.group_dim << base_color_shift_) >> + channel_shifts_[0][c].second; +} + +void LowMemoryRenderPipeline::EnsureBordersStorage() { + const auto& shifts = channel_shifts_[0]; + if (borders_horizontal_.size() < shifts.size()) { + borders_horizontal_.resize(shifts.size()); + borders_vertical_.resize(shifts.size()); + } + for (size_t c = 0; c < shifts.size(); c++) { + auto borders = BorderToStore(c); + size_t borderx = borders.first; + size_t bordery = borders.second; + JXL_DASSERT(frame_dimensions_.xsize_groups > 0); + size_t num_xborders = (frame_dimensions_.xsize_groups - 1) * 2; + JXL_DASSERT(frame_dimensions_.ysize_groups > 0); + size_t num_yborders = (frame_dimensions_.ysize_groups - 1) * 2; + size_t downsampled_xsize = + DivCeil(frame_dimensions_.xsize_upsampled_padded, 1 << shifts[c].first); + size_t downsampled_ysize = DivCeil(frame_dimensions_.ysize_upsampled_padded, + 1 << shifts[c].second); + Rect horizontal = Rect(0, 0, downsampled_xsize, bordery * num_yborders); + if (!SameSize(horizontal, borders_horizontal_[c])) { + borders_horizontal_[c] = ImageF(horizontal.xsize(), horizontal.ysize()); + } + Rect vertical = Rect(0, 0, borderx * num_xborders, downsampled_ysize); + if (!SameSize(vertical, borders_vertical_[c])) { + borders_vertical_[c] = ImageF(vertical.xsize(), vertical.ysize()); + } + } +} + +void LowMemoryRenderPipeline::Init() { + group_border_ = {0, 0}; + base_color_shift_ = CeilLog2Nonzero(frame_dimensions_.xsize_upsampled_padded / + frame_dimensions_.xsize_padded); + + const auto& shifts = channel_shifts_[0]; + + // Ensure that each channel has enough many border pixels. + for (size_t c = 0; c < shifts.size(); c++) { + group_border_.first = + std::max(group_border_.first, + DivCeil(padding_[0][c].first << channel_shifts_[0][c].first, + 1 << base_color_shift_)); + group_border_.second = + std::max(group_border_.second, + DivCeil(padding_[0][c].second << channel_shifts_[0][c].second, + 1 << base_color_shift_)); + } + + // Ensure that all channels have an integer number of border pixels in the + // input. + for (size_t c = 0; c < shifts.size(); c++) { + if (channel_shifts_[0][c].first >= base_color_shift_) { + group_border_.first = + RoundUpTo(group_border_.first, + 1 << (channel_shifts_[0][c].first - base_color_shift_)); + } + if (channel_shifts_[0][c].second >= base_color_shift_) { + group_border_.second = + RoundUpTo(group_border_.second, + 1 << (channel_shifts_[0][c].second - base_color_shift_)); + } + } + // Ensure that the X border on color channels is a multiple of kBlockDim or + // the vector size (required for EPF stages). Vectors on ARM NEON are never + // wider than 4 floats, so rounding to multiples of 4 is enough. +#if JXL_ARCH_ARM + constexpr size_t kGroupXAlign = 4; +#else + constexpr size_t kGroupXAlign = 16; +#endif + group_border_.first = RoundUpTo(group_border_.first, kGroupXAlign); + // Allocate borders in group images that are just enough for storing the + // borders to be copied in, plus any rounding to ensure alignment. + std::pair max_border = {0, 0}; + for (size_t c = 0; c < shifts.size(); c++) { + max_border.first = std::max(BorderToStore(c).first, max_border.first); + max_border.second = std::max(BorderToStore(c).second, max_border.second); + } + group_data_x_border_ = RoundUpTo(max_border.first, kGroupXAlign); + group_data_y_border_ = max_border.second; + + EnsureBordersStorage(); + group_border_assigner_.Init(frame_dimensions_); + + for (first_trailing_stage_ = stages_.size(); first_trailing_stage_ > 0; + first_trailing_stage_--) { + bool has_inout_c = false; + for (size_t c = 0; c < shifts.size(); c++) { + if (stages_[first_trailing_stage_ - 1]->GetChannelMode(c) == + RenderPipelineChannelMode::kInOut) { + has_inout_c = true; + } + } + if (has_inout_c) { + break; + } + } + + first_image_dim_stage_ = stages_.size(); + for (size_t i = 0; i < stages_.size(); i++) { + std::vector> input_sizes(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + input_sizes[c] = + std::make_pair(DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[i][c].first), + DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[i][c].second)); + } + stages_[i]->SetInputSizes(input_sizes); + if (stages_[i]->SwitchToImageDimensions()) { + // We don't allow kInOut after switching to image dimensions. + JXL_ASSERT(i >= first_trailing_stage_); + first_image_dim_stage_ = i + 1; + stages_[i]->GetImageDimensions(&full_image_xsize_, &full_image_ysize_, + &frame_origin_); + break; + } + } + for (size_t i = first_image_dim_stage_; i < stages_.size(); i++) { + if (stages_[i]->SwitchToImageDimensions()) { + JXL_ABORT("Cannot switch to image dimensions multiple times"); + } + std::vector> input_sizes(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + input_sizes[c] = {full_image_xsize_, full_image_ysize_}; + } + stages_[i]->SetInputSizes(input_sizes); + } + + anyc_.resize(stages_.size()); + for (size_t i = 0; i < stages_.size(); i++) { + for (size_t c = 0; c < shifts.size(); c++) { + if (stages_[i]->GetChannelMode(c) != + RenderPipelineChannelMode::kIgnored) { + anyc_[i] = c; + } + } + } + + stage_input_for_channel_ = std::vector>( + stages_.size(), std::vector(shifts.size())); + for (size_t c = 0; c < shifts.size(); c++) { + int input = -1; + for (size_t i = 0; i < stages_.size(); i++) { + stage_input_for_channel_[i][c] = input; + if (stages_[i]->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + input = i; + } + } + } + + image_rect_.resize(stages_.size()); + for (size_t i = 0; i < stages_.size(); i++) { + size_t x1 = DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[i][anyc_[i]].first); + size_t y1 = DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[i][anyc_[i]].second); + image_rect_[i] = Rect(0, 0, x1, y1); + } + + virtual_ypadding_for_output_.resize(stages_.size()); + xpadding_for_output_.resize(stages_.size()); + for (size_t c = 0; c < shifts.size(); c++) { + int ypad = 0; + int xpad = 0; + for (size_t i = stages_.size(); i-- > 0;) { + if (stages_[i]->GetChannelMode(c) != + RenderPipelineChannelMode::kIgnored) { + virtual_ypadding_for_output_[i] = + std::max(ypad, virtual_ypadding_for_output_[i]); + xpadding_for_output_[i] = std::max(xpad, xpadding_for_output_[i]); + } + if (stages_[i]->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + ypad = (DivCeil(ypad, 1 << channel_shifts_[i][c].second) + + stages_[i]->settings_.border_y) + << channel_shifts_[i][c].second; + xpad = DivCeil(xpad, 1 << stages_[i]->settings_.shift_x) + + stages_[i]->settings_.border_x; + } + } + } +} + +void LowMemoryRenderPipeline::PrepareForThreadsInternal(size_t num, + bool use_group_ids) { + const auto& shifts = channel_shifts_[0]; + + use_group_ids_ = use_group_ids; + size_t num_buffers = use_group_ids_ ? frame_dimensions_.num_groups : num; + for (size_t t = group_data_.size(); t < num_buffers; t++) { + group_data_.emplace_back(); + group_data_[t].resize(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + group_data_[t][c] = ImageF(GroupInputXSize(c) + group_data_x_border_ * 2, + GroupInputYSize(c) + group_data_y_border_ * 2); + } + } + // TODO(veluca): avoid reallocating buffers if not needed. + stage_data_.resize(num); + size_t upsampling = 1u << base_color_shift_; + size_t group_dim = frame_dimensions_.group_dim * upsampling; + size_t padding = + 2 * group_data_x_border_ * upsampling + // maximum size of a rect + 2 * kRenderPipelineXOffset; // extra padding for processing + size_t stage_buffer_xsize = group_dim + padding; + for (size_t t = 0; t < num; t++) { + stage_data_[t].resize(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + stage_data_[t][c].resize(stages_.size()); + size_t next_y_border = 0; + for (size_t i = stages_.size(); i-- > 0;) { + if (stages_[i]->GetChannelMode(c) == + RenderPipelineChannelMode::kInOut) { + size_t stage_buffer_ysize = + 2 * next_y_border + (1 << stages_[i]->settings_.shift_y); + stage_buffer_ysize = 1 << CeilLog2Nonzero(stage_buffer_ysize); + next_y_border = stages_[i]->settings_.border_y; + stage_data_[t][c][i] = ImageF(stage_buffer_xsize, stage_buffer_ysize); + } + } + } + } + if (first_image_dim_stage_ != stages_.size()) { + RectT image_rect(0, 0, frame_dimensions_.xsize_upsampled, + frame_dimensions_.ysize_upsampled); + RectT full_image_rect(0, 0, full_image_xsize_, full_image_ysize_); + image_rect = image_rect.Translate(frame_origin_.x0, frame_origin_.y0); + image_rect = image_rect.Intersection(full_image_rect); + if (image_rect.xsize() == 0 || image_rect.ysize() == 0) { + image_rect = RectT(0, 0, 0, 0); + } + size_t left_padding = image_rect.x0(); + size_t middle_padding = group_dim; + size_t right_padding = full_image_xsize_ - image_rect.x1(); + size_t out_of_frame_xsize = + padding + + std::max(left_padding, std::max(middle_padding, right_padding)); + out_of_frame_data_.resize(num); + for (size_t t = 0; t < num; t++) { + out_of_frame_data_[t] = ImageF(out_of_frame_xsize, shifts.size()); + } + } +} + +std::vector> LowMemoryRenderPipeline::PrepareBuffers( + size_t group_id, size_t thread_id) { + std::vector> ret(channel_shifts_[0].size()); + const size_t gx = group_id % frame_dimensions_.xsize_groups; + const size_t gy = group_id / frame_dimensions_.xsize_groups; + for (size_t c = 0; c < channel_shifts_[0].size(); c++) { + ret[c].first = &group_data_[use_group_ids_ ? group_id : thread_id][c]; + ret[c].second = Rect(group_data_x_border_, group_data_y_border_, + GroupInputXSize(c), GroupInputYSize(c), + DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[0][c].first) - + gx * GroupInputXSize(c) + group_data_x_border_, + DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[0][c].second) - + gy * GroupInputYSize(c) + group_data_y_border_); + } + return ret; +} + +namespace { + +JXL_INLINE int GetMirroredY(int y, ssize_t group_y0, ssize_t image_ysize) { + if (group_y0 == 0 && (y < 0 || y + group_y0 >= image_ysize)) { + return Mirror(y, image_ysize); + } + if (y + group_y0 >= image_ysize) { + // Here we know that the one mirroring step is sufficient. + return 2 * image_ysize - (y + group_y0) - 1 - group_y0; + } + return y; +} + +JXL_INLINE void ApplyXMirroring(float* row, ssize_t borderx, ssize_t group_x0, + ssize_t group_xsize, ssize_t image_xsize) { + if (image_xsize <= borderx) { + if (group_x0 == 0) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset - ix - 1] = + row[kRenderPipelineXOffset + Mirror(-ix - 1, image_xsize)]; + } + } + if (group_xsize + borderx + group_x0 >= image_xsize) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset + image_xsize + ix - group_x0] = + row[kRenderPipelineXOffset + Mirror(image_xsize + ix, image_xsize) - + group_x0]; + } + } + } else { + // Here we know that the one mirroring step is sufficient. + if (group_x0 == 0) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset - ix - 1] = row[kRenderPipelineXOffset + ix]; + } + } + if (group_xsize + borderx + group_x0 >= image_xsize) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset + image_xsize - group_x0 + ix] = + row[kRenderPipelineXOffset + image_xsize - group_x0 - ix - 1]; + } + } + } +} + +// Information about where the *output* of each stage is stored. +class Rows { + public: + Rows(const std::vector>& stages, + const Rect data_max_color_channel_rect, int group_data_x_border, + int group_data_y_border, + const std::vector>& group_data_shift, + size_t base_color_shift, std::vector>& thread_data, + std::vector& input_data) { + size_t num_stages = stages.size(); + size_t num_channels = input_data.size(); + + JXL_ASSERT(thread_data.size() == num_channels); + JXL_ASSERT(group_data_shift.size() == num_channels); + +#if JXL_ENABLE_ASSERT + for (const auto& td : thread_data) { + JXL_ASSERT(td.size() == num_stages); + } +#endif + + rows_.resize(num_stages + 1, std::vector(num_channels)); + + for (size_t i = 0; i < num_stages; i++) { + for (size_t c = 0; c < input_data.size(); c++) { + if (stages[i]->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + rows_[i + 1][c].ymod_minus_1 = thread_data[c][i].ysize() - 1; + rows_[i + 1][c].base_ptr = thread_data[c][i].Row(0); + rows_[i + 1][c].stride = thread_data[c][i].PixelsPerRow(); + } + } + } + + for (size_t c = 0; c < input_data.size(); c++) { + auto channel_group_data_rect = + data_max_color_channel_rect.As() + .Translate(-group_data_x_border, -group_data_y_border) + .ShiftLeft(base_color_shift) + .CeilShiftRight(group_data_shift[c]) + .Translate(group_data_x_border - ssize_t(kRenderPipelineXOffset), + group_data_y_border); + rows_[0][c].base_ptr = channel_group_data_rect.Row(&input_data[c], 0); + rows_[0][c].stride = input_data[c].PixelsPerRow(); + rows_[0][c].ymod_minus_1 = -1; + } + } + + // Stage -1 refers to the input data; all other values must be nonnegative and + // refer to the data for the output of that stage. + JXL_INLINE float* GetBuffer(int stage, int y, size_t c) const { + JXL_DASSERT(stage >= -1); + const RowInfo& info = rows_[stage + 1][c]; + return info.base_ptr + ssize_t(info.stride) * (y & info.ymod_minus_1); + } + + private: + struct RowInfo { + // Pointer to beginning of the first row. + float* base_ptr; + // Modulo value for the y axis minus 1 (ymod is guaranteed to be a power of + // 2, which allows efficient mod computation by masking). + int ymod_minus_1; + // Number of floats per row. + size_t stride; + }; + std::vector> rows_; +}; + +} // namespace + +void LowMemoryRenderPipeline::RenderRect(size_t thread_id, + std::vector& input_data, + Rect data_max_color_channel_rect, + Rect image_max_color_channel_rect) { + // For each stage, the rect corresponding to the image area currently being + // processed, in the coordinates of that stage (i.e. with the scaling factor + // that that stage has). + std::vector group_rect; + group_rect.resize(stages_.size()); + Rect image_area_rect = + image_max_color_channel_rect.ShiftLeft(base_color_shift_) + .Crop(frame_dimensions_.xsize_upsampled, + frame_dimensions_.ysize_upsampled); + for (size_t i = 0; i < stages_.size(); i++) { + group_rect[i] = + image_area_rect.CeilShiftRight(channel_shifts_[i][anyc_[i]]); + } + + ssize_t frame_x0 = + first_image_dim_stage_ == stages_.size() ? 0 : frame_origin_.x0; + ssize_t frame_y0 = + first_image_dim_stage_ == stages_.size() ? 0 : frame_origin_.y0; + size_t full_image_xsize = first_image_dim_stage_ == stages_.size() + ? frame_dimensions_.xsize_upsampled + : full_image_xsize_; + size_t full_image_ysize = first_image_dim_stage_ == stages_.size() + ? frame_dimensions_.ysize_upsampled + : full_image_ysize_; + + // Compute actual x-axis bounds for the current image area in the context of + // the full image this frame is part of. As the left boundary may be negative, + // we also create the x_pixels_skip value, defined as follows: + // - both x_pixels_skip and full_image_x0 are >= 0, and at least one is 0; + // - full_image_x0 - x_pixels_skip is the position of the current frame area + // in the full image. + ssize_t full_image_x0 = frame_x0 + image_area_rect.x0(); + ssize_t x_pixels_skip = 0; + if (full_image_x0 < 0) { + x_pixels_skip = -full_image_x0; + full_image_x0 = 0; + } + ssize_t full_image_x1 = frame_x0 + image_area_rect.x1(); + full_image_x1 = std::min(full_image_x1, full_image_xsize); + + // If the current image area is entirely outside of the visible image, there + // is no point in proceeding. Note: this uses the assumption that if there is + // a stage with observable effects (i.e. a kInput stage), it only appears + // after the stage that switches to image dimensions. + if (full_image_x1 <= full_image_x0) return; + + // Data structures to hold information about input/output rows and their + // buffers. + Rows rows(stages_, data_max_color_channel_rect, group_data_x_border_, + group_data_y_border_, channel_shifts_[0], base_color_shift_, + stage_data_[thread_id], input_data); + + std::vector input_rows(first_trailing_stage_ + + 1); + for (size_t i = 0; i < first_trailing_stage_; i++) { + input_rows[i].resize(input_data.size()); + } + input_rows[first_trailing_stage_].resize(input_data.size(), + std::vector(1)); + + // Maximum possible shift is 3. + RenderPipelineStage::RowInfo output_rows(input_data.size(), + std::vector(8)); + + // Fills in input_rows and output_rows for a given y value (relative to the + // start of the group, measured in actual pixels at the appropriate vertical + // scaling factor) and a given stage, applying mirroring if necessary. This + // function is somewhat inefficient for trailing kInOut or kInput stages, + // where just filling the input row once ought to be sufficient. + auto prepare_io_rows = [&](int y, size_t i) { + ssize_t bordery = stages_[i]->settings_.border_y; + size_t shifty = stages_[i]->settings_.shift_y; + auto make_row = [&](size_t c, ssize_t iy) { + size_t mirrored_y = GetMirroredY(y + iy - bordery, group_rect[i].y0(), + image_rect_[i].ysize()); + input_rows[i][c][iy] = + rows.GetBuffer(stage_input_for_channel_[i][c], mirrored_y, c); + ApplyXMirroring(input_rows[i][c][iy], stages_[i]->settings_.border_x, + group_rect[i].x0(), group_rect[i].xsize(), + image_rect_[i].xsize()); + }; + for (size_t c = 0; c < input_data.size(); c++) { + RenderPipelineChannelMode mode = stages_[i]->GetChannelMode(c); + if (mode == RenderPipelineChannelMode::kIgnored) { + continue; + } + // If we already have rows from a previous iteration, we can just shift + // the rows by 1 and insert the new one. + if (input_rows[i][c].size() == 2 * size_t(bordery) + 1) { + for (ssize_t iy = 0; iy < 2 * bordery; iy++) { + input_rows[i][c][iy] = input_rows[i][c][iy + 1]; + } + make_row(c, bordery * 2); + } else { + input_rows[i][c].resize(2 * bordery + 1); + for (ssize_t iy = 0; iy < 2 * bordery + 1; iy++) { + make_row(c, iy); + } + } + + // If necessary, get the output buffers. + if (mode == RenderPipelineChannelMode::kInOut) { + for (size_t iy = 0; iy < (1u << shifty); iy++) { + output_rows[c][iy] = rows.GetBuffer(i, y * (1 << shifty) + iy, c); + } + } + } + }; + + // We pretend that every stage has a vertical shift of 0, i.e. it is as tall + // as the final image. + // We call each such row a "virtual" row, because it may or may not correspond + // to an actual row of the current processing stage; actual processing happens + // when vy % (1<> channel_shifts_[i][anyc_[i]].second; + + ssize_t image_y = ssize_t(group_rect[i].y0()) + y; + // Do not produce rows in out-of-bounds areas. + if (image_y < 0 || image_y >= ssize_t(image_rect_[i].ysize())) { + continue; + } + + // Get the input/output rows and potentially apply mirroring to the input. + prepare_io_rows(y, i); + + // Produce output rows. + stages_[i]->ProcessRow(input_rows[i], output_rows, + xpadding_for_output_[i], group_rect[i].xsize(), + group_rect[i].x0(), image_y, thread_id); + } + + // Process trailing stages, i.e. the final set of non-kInOut stages; they + // all have the same input buffer and no need to use any mirroring. + + int y = vy - num_extra_rows; + + for (size_t c = 0; c < input_data.size(); c++) { + // Skip pixels that are not part of the actual final image area. + input_rows[first_trailing_stage_][c][0] = + rows.GetBuffer(stage_input_for_channel_[first_trailing_stage_][c], y, + c) + + x_pixels_skip; + } + + // Check that we are not outside of the bounds for the current rendering + // rect. Not doing so might result in overwriting some rows that have been + // written (or will be written) by other threads. + if (y < 0 || y >= ssize_t(image_area_rect.ysize())) { + continue; + } + + // Avoid running pipeline stages on pixels that are outside the full image + // area. As trailing stages have no borders, this is a free optimization + // (and may be necessary for correctness, as some stages assume coordinates + // are within bounds). + ssize_t full_image_y = frame_y0 + image_area_rect.y0() + y; + if (full_image_y < 0 || full_image_y >= ssize_t(full_image_ysize)) { + continue; + } + + for (size_t i = first_trailing_stage_; i < stages_.size(); i++) { + // Before the first_image_dim_stage_, coordinates are relative to the + // current frame. + size_t x0 = + i < first_image_dim_stage_ ? full_image_x0 - frame_x0 : full_image_x0; + size_t y = + i < first_image_dim_stage_ ? full_image_y - frame_y0 : full_image_y; + stages_[i]->ProcessRow(input_rows[first_trailing_stage_], output_rows, + /*xextra=*/0, full_image_x1 - full_image_x0, x0, y, + thread_id); + } + } +} + +void LowMemoryRenderPipeline::RenderPadding(size_t thread_id, Rect rect) { + if (rect.xsize() == 0) return; + size_t numc = channel_shifts_[0].size(); + RenderPipelineStage::RowInfo input_rows(numc, std::vector(1)); + RenderPipelineStage::RowInfo output_rows; + + for (size_t c = 0; c < numc; c++) { + input_rows[c][0] = out_of_frame_data_[thread_id].Row(c); + } + + for (size_t y = 0; y < rect.ysize(); y++) { + stages_[first_image_dim_stage_ - 1]->ProcessPaddingRow( + input_rows, rect.xsize(), rect.x0(), rect.y0() + y); + for (size_t i = first_image_dim_stage_; i < stages_.size(); i++) { + stages_[i]->ProcessRow(input_rows, output_rows, + /*xextra=*/0, rect.xsize(), rect.x0(), + rect.y0() + y, thread_id); + } + } +} + +void LowMemoryRenderPipeline::ProcessBuffers(size_t group_id, + size_t thread_id) { + std::vector& input_data = + group_data_[use_group_ids_ ? group_id : thread_id]; + + // Copy the group borders to the border storage. + for (size_t c = 0; c < input_data.size(); c++) { + SaveBorders(group_id, c, input_data[c]); + } + + size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t gx = group_id % frame_dimensions_.xsize_groups; + + if (first_image_dim_stage_ != stages_.size()) { + size_t group_dim = frame_dimensions_.group_dim << base_color_shift_; + RectT group_rect(gx * group_dim, gy * group_dim, group_dim, + group_dim); + RectT image_rect(0, 0, frame_dimensions_.xsize_upsampled, + frame_dimensions_.ysize_upsampled); + RectT full_image_rect(0, 0, full_image_xsize_, full_image_ysize_); + group_rect = group_rect.Translate(frame_origin_.x0, frame_origin_.y0); + image_rect = image_rect.Translate(frame_origin_.x0, frame_origin_.y0); + image_rect = image_rect.Intersection(full_image_rect); + group_rect = group_rect.Intersection(image_rect); + size_t x0 = group_rect.x0(); + size_t y0 = group_rect.y0(); + size_t x1 = group_rect.x1(); + size_t y1 = group_rect.y1(); + JXL_DEBUG_V(6, + "Rendering padding for full image rect %s " + "outside group rect %s", + Description(full_image_rect).c_str(), + Description(group_rect).c_str()); + + if (group_id == 0 && (image_rect.xsize() == 0 || image_rect.ysize() == 0)) { + // If this frame does not intersect with the full image, we have to + // initialize the whole image area with RenderPadding. + RenderPadding(thread_id, + Rect(0, 0, full_image_xsize_, full_image_ysize_)); + } + + // Render padding for groups that intersect with the full image. The case + // where no groups intersect was handled above. + if (group_rect.xsize() > 0 && group_rect.ysize() > 0) { + if (gx == 0 && gy == 0) { + RenderPadding(thread_id, Rect(0, 0, x0, y0)); + } + if (gy == 0) { + RenderPadding(thread_id, Rect(x0, 0, x1 - x0, y0)); + } + if (gx == 0) { + RenderPadding(thread_id, Rect(0, y0, x0, y1 - y0)); + } + if (gx == 0 && gy + 1 == frame_dimensions_.ysize_groups) { + RenderPadding(thread_id, Rect(0, y1, x0, full_image_ysize_ - y1)); + } + if (gy + 1 == frame_dimensions_.ysize_groups) { + RenderPadding(thread_id, Rect(x0, y1, x1 - x0, full_image_ysize_ - y1)); + } + if (gy == 0 && gx + 1 == frame_dimensions_.xsize_groups) { + RenderPadding(thread_id, Rect(x1, 0, full_image_xsize_ - x1, y0)); + } + if (gx + 1 == frame_dimensions_.xsize_groups) { + RenderPadding(thread_id, Rect(x1, y0, full_image_xsize_ - x1, y1 - y0)); + } + if (gy + 1 == frame_dimensions_.ysize_groups && + gx + 1 == frame_dimensions_.xsize_groups) { + RenderPadding(thread_id, Rect(x1, y1, full_image_xsize_ - x1, + full_image_ysize_ - y1)); + } + } + } + + Rect ready_rects[GroupBorderAssigner::kMaxToFinalize]; + size_t num_ready_rects = 0; + group_border_assigner_.GroupDone(group_id, group_border_.first, + group_border_.second, ready_rects, + &num_ready_rects); + for (size_t i = 0; i < num_ready_rects; i++) { + const Rect& image_max_color_channel_rect = ready_rects[i]; + for (size_t c = 0; c < input_data.size(); c++) { + LoadBorders(group_id, c, image_max_color_channel_rect, &input_data[c]); + } + Rect data_max_color_channel_rect( + group_data_x_border_ + image_max_color_channel_rect.x0() - + gx * frame_dimensions_.group_dim, + group_data_y_border_ + image_max_color_channel_rect.y0() - + gy * frame_dimensions_.group_dim, + image_max_color_channel_rect.xsize(), + image_max_color_channel_rect.ysize()); + RenderRect(thread_id, input_data, data_max_color_channel_rect, + image_max_color_channel_rect); + } +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.h b/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.h new file mode 100644 index 000000000..b386f7c07 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/low_memory_render_pipeline.h @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_LOW_MEMORY_RENDER_PIPELINE_H_ +#define LIB_JXL_RENDER_PIPELINE_LOW_MEMORY_RENDER_PIPELINE_H_ + +#include + +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/render_pipeline/render_pipeline.h" + +namespace jxl { + +// A multithreaded, low-memory rendering pipeline that only allocates a minimal +// amount of buffers. +class LowMemoryRenderPipeline final : public RenderPipeline { + private: + std::vector> PrepareBuffers( + size_t group_id, size_t thread_id) override; + + void PrepareForThreadsInternal(size_t num, bool use_group_ids) override; + + void ProcessBuffers(size_t group_id, size_t thread_id) override; + + void ClearDone(size_t i) override { group_border_assigner_.ClearDone(i); } + + void Init() override; + + void EnsureBordersStorage(); + size_t GroupInputXSize(size_t c) const; + size_t GroupInputYSize(size_t c) const; + void RenderRect(size_t thread_id, std::vector& input_data, + Rect data_max_color_channel_rect, + Rect image_max_color_channel_rect); + void RenderPadding(size_t thread_id, Rect rect); + + void SaveBorders(size_t group_id, size_t c, const ImageF& in); + void LoadBorders(size_t group_id, size_t c, const Rect& r, ImageF* out); + + std::pair ColorDimensionsToChannelDimensions( + std::pair in, size_t c, size_t stage) const; + + std::pair BorderToStore(size_t c) const; + + bool use_group_ids_; + + // Storage for borders between groups. Borders of adjacent groups are stacked + // together, e.g. bottom border of current group is followed by top border + // of next group. + std::vector borders_horizontal_; + std::vector borders_vertical_; + + // Manages the status of borders. + GroupBorderAssigner group_border_assigner_; + + // Size (in color-channel-pixels) of the border around each group that might + // be assigned to that group. + std::pair group_border_; + // base_color_shift_ defines the size of groups in terms of final image + // pixels. + size_t base_color_shift_; + + // Buffer for decoded pixel data for a group, indexed by [thread][channel] or + // [group][channel] depending on `use_group_ids_`. + std::vector> group_data_; + + // Borders for storing group data. + size_t group_data_x_border_; + size_t group_data_y_border_; + + // Buffers for intermediate rows for the various stages, indexed by + // [thread][channel][stage]. + std::vector>> stage_data_; + + // Buffers for out-of-frame data, indexed by [thread]; every row is a + // different channel. + std::vector out_of_frame_data_; + + // For each stage, a non-kIgnored channel. + std::vector anyc_; + + // Size of the image at each stage. + std::vector image_rect_; + + // For each stage, for each channel, keep track of the kInOut stage that + // produced the input to that stage (which corresponds to the buffer index + // containing the data). -1 if data comes from the original input. + std::vector> stage_input_for_channel_; + + // Number of (virtual) extra rows that must be processed at each stage + // to produce sufficient output for future stages. + std::vector virtual_ypadding_for_output_; + + // Same thing for columns, except these are real columns and not virtual ones. + std::vector xpadding_for_output_; + + // First stage that doesn't have any kInOut channel. + size_t first_trailing_stage_; + + // Origin and size of the frame after switching to image dimensions. + FrameOrigin frame_origin_; + size_t full_image_xsize_; + size_t full_image_ysize_; + size_t first_image_dim_stage_; +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_LOW_MEMORY_RENDER_PIPELINE_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.cc b/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.cc new file mode 100644 index 000000000..68b6ef613 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.cc @@ -0,0 +1,132 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/render_pipeline.h" + +#include + +#include "lib/jxl/render_pipeline/low_memory_render_pipeline.h" +#include "lib/jxl/render_pipeline/simple_render_pipeline.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +void RenderPipeline::Builder::AddStage( + std::unique_ptr stage) { + stages_.push_back(std::move(stage)); +} + +std::unique_ptr RenderPipeline::Builder::Finalize( + FrameDimensions frame_dimensions) && { +#if JXL_ENABLE_ASSERT + // Check that the last stage is not an kInOut stage for any channel, and that + // there is at least one stage. + JXL_ASSERT(!stages_.empty()); + for (size_t c = 0; c < num_c_; c++) { + JXL_ASSERT(stages_.back()->GetChannelMode(c) != + RenderPipelineChannelMode::kInOut); + } +#endif + + std::unique_ptr res; + if (use_simple_implementation_) { + res = jxl::make_unique(); + } else { + res = jxl::make_unique(); + } + + res->padding_.resize(stages_.size()); + for (size_t i = stages_.size(); i-- > 0;) { + const auto& stage = stages_[i]; + res->padding_[i].resize(num_c_); + if (i + 1 == stages_.size()) { + continue; + } + for (size_t c = 0; c < num_c_; c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + res->padding_[i][c].first = DivCeil(res->padding_[i + 1][c].first, + 1 << stage->settings_.shift_x) + + stage->settings_.border_x; + res->padding_[i][c].second = DivCeil(res->padding_[i + 1][c].second, + 1 << stage->settings_.shift_y) + + stage->settings_.border_y; + } else { + res->padding_[i][c] = res->padding_[i + 1][c]; + } + } + } + + res->frame_dimensions_ = frame_dimensions; + res->group_completed_passes_.resize(frame_dimensions.num_groups); + res->channel_shifts_.resize(stages_.size()); + res->channel_shifts_[0].resize(num_c_); + for (size_t i = 1; i < stages_.size(); i++) { + auto& stage = stages_[i - 1]; + for (size_t c = 0; c < num_c_; c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + res->channel_shifts_[0][c].first += stage->settings_.shift_x; + res->channel_shifts_[0][c].second += stage->settings_.shift_y; + } + } + } + for (size_t i = 1; i < stages_.size(); i++) { + auto& stage = stages_[i - 1]; + res->channel_shifts_[i].resize(num_c_); + for (size_t c = 0; c < num_c_; c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + res->channel_shifts_[i][c].first = + res->channel_shifts_[i - 1][c].first - stage->settings_.shift_x; + res->channel_shifts_[i][c].second = + res->channel_shifts_[i - 1][c].second - stage->settings_.shift_y; + } else { + res->channel_shifts_[i][c].first = res->channel_shifts_[i - 1][c].first; + res->channel_shifts_[i][c].second = + res->channel_shifts_[i - 1][c].second; + } + } + } + res->stages_ = std::move(stages_); + res->Init(); + return res; +} + +RenderPipelineInput RenderPipeline::GetInputBuffers(size_t group_id, + size_t thread_id) { + RenderPipelineInput ret; + JXL_DASSERT(group_id < group_completed_passes_.size()); + ret.group_id_ = group_id; + ret.thread_id_ = thread_id; + ret.pipeline_ = this; + ret.buffers_ = PrepareBuffers(group_id, thread_id); + return ret; +} + +void RenderPipeline::InputReady( + size_t group_id, size_t thread_id, + const std::vector>& buffers) { + JXL_DASSERT(group_id < group_completed_passes_.size()); + group_completed_passes_[group_id]++; + for (size_t i = 0; i < buffers.size(); ++i) { + (void)i; + JXL_CHECK_PLANE_INITIALIZED(*buffers[i].first, buffers[i].second, i); + } + + ProcessBuffers(group_id, thread_id); +} + +Status RenderPipeline::PrepareForThreads(size_t num, bool use_group_ids) { + for (const auto& stage : stages_) { + JXL_RETURN_IF_ERROR(stage->PrepareForThreads(num)); + } + PrepareForThreadsInternal(num, use_group_ids); + return true; +} + +void RenderPipelineInput::Done() { + JXL_ASSERT(pipeline_); + pipeline_->InputReady(group_id_, thread_id_, buffers_); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.h b/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.h new file mode 100644 index 000000000..bf3ad4975 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline.h @@ -0,0 +1,139 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_H_ +#define LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_H_ + +#include + +#include "lib/jxl/image.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Interface to provide input to the rendering pipeline. When this object is +// destroyed, all the data in the provided ImageF's Rects must have been +// initialized. +class RenderPipelineInput { + public: + RenderPipelineInput(const RenderPipelineInput&) = delete; + RenderPipelineInput(RenderPipelineInput&& other) noexcept { + *this = std::move(other); + } + RenderPipelineInput& operator=(RenderPipelineInput&& other) noexcept { + pipeline_ = other.pipeline_; + group_id_ = other.group_id_; + thread_id_ = other.thread_id_; + buffers_ = std::move(other.buffers_); + other.pipeline_ = nullptr; + return *this; + } + + RenderPipelineInput() = default; + void Done(); + + const std::pair& GetBuffer(size_t c) const { + JXL_ASSERT(c < buffers_.size()); + return buffers_[c]; + } + + private: + RenderPipeline* pipeline_ = nullptr; + size_t group_id_; + size_t thread_id_; + std::vector> buffers_; + friend class RenderPipeline; +}; + +class RenderPipeline { + public: + class Builder { + public: + explicit Builder(size_t num_c) : num_c_(num_c) { JXL_ASSERT(num_c > 0); } + + // Adds a stage to the pipeline. Must be called at least once; the last + // added stage cannot have kInOut channels. + void AddStage(std::unique_ptr stage); + + // Enables using the simple (i.e. non-memory-efficient) implementation of + // the pipeline. + void UseSimpleImplementation() { use_simple_implementation_ = true; } + + // Finalizes setup of the pipeline. Shifts for all channels should be 0 at + // this point. + std::unique_ptr Finalize( + FrameDimensions frame_dimensions) &&; + + private: + std::vector> stages_; + size_t num_c_; + bool use_simple_implementation_ = false; + }; + + friend class Builder; + + virtual ~RenderPipeline() = default; + + Status IsInitialized() const { + for (const auto& stage : stages_) { + JXL_RETURN_IF_ERROR(stage->IsInitialized()); + } + return true; + } + + // Allocates storage to run with `num` threads. If `use_group_ids` is true, + // storage is allocated for each group, not each thread. The behaviour is + // undefined if calling this function multiple times with a different value + // for `use_group_ids`. + Status PrepareForThreads(size_t num, bool use_group_ids); + + // Retrieves a buffer where input data should be stored by the callee. When + // input has been provided for all buffers, the pipeline will complete its + // processing. This method may be called multiple times concurrently from + // different threads, provided that a different `thread_id` is given. + RenderPipelineInput GetInputBuffers(size_t group_id, size_t thread_id); + + size_t PassesWithAllInput() const { + return *std::min_element(group_completed_passes_.begin(), + group_completed_passes_.end()); + } + + virtual void ClearDone(size_t i) {} + + protected: + std::vector> stages_; + // Shifts for every channel at the input of each stage. + std::vector>> channel_shifts_; + + // Amount of (cumulative) padding required by each stage and channel, in + // either direction. + std::vector>> padding_; + + FrameDimensions frame_dimensions_; + + std::vector group_completed_passes_; + + friend class RenderPipelineInput; + + private: + void InputReady(size_t group_id, size_t thread_id, + const std::vector>& buffers); + + virtual std::vector> PrepareBuffers( + size_t group_id, size_t thread_id) = 0; + + virtual void ProcessBuffers(size_t group_id, size_t thread_id) = 0; + + // Note that this method may be called multiple times with different (or + // equal) `num`. + virtual void PrepareForThreadsInternal(size_t num, bool use_group_ids) = 0; + + // Called once frame dimensions and stages are known. + virtual void Init() {} +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline_stage.h b/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline_stage.h new file mode 100644 index 000000000..d1a007416 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/render_pipeline_stage.h @@ -0,0 +1,171 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_STAGE_H_ +#define LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_STAGE_H_ + +#include + +#include "lib/jxl/base/arch_macros.h" +#include "lib/jxl/frame_header.h" + +namespace jxl { + +// The first pixel in the input to RenderPipelineStage will be located at +// this position. Pixels before this position may be accessed as padding. +// This should be at least the RoundUpTo(maximum padding / 2, maximum vector +// size) times 2: this is realized when using Gaborish + EPF + upsampling + +// chroma subsampling. +#if JXL_ARCH_ARM +constexpr size_t kRenderPipelineXOffset = 16; +#else +constexpr size_t kRenderPipelineXOffset = 32; +#endif + +enum class RenderPipelineChannelMode { + // This channel is not modified by this stage. + kIgnored = 0, + // This channel is modified in-place. + kInPlace = 1, + // This channel is modified and written to a new buffer. + kInOut = 2, + // This channel is only read. These are the only stages that are assumed to + // have observable effects, i.e. calls to ProcessRow for other stages may be + // omitted if it can be shown they can't affect any kInput stage ProcessRow + // call that happens inside image boundaries. + kInput = 3, +}; + +class RenderPipeline; + +class RenderPipelineStage { + protected: + using Row = float*; + using ChannelRows = std::vector; + + public: + using RowInfo = std::vector; + struct Settings { + // Amount of padding required in the various directions by all channels + // that have kInOut mode. + size_t border_x = 0; + size_t border_y = 0; + + // Log2 of the number of columns/rows of output that this stage will produce + // for every input row for kInOut channels. + size_t shift_x = 0; + size_t shift_y = 0; + + static Settings ShiftX(size_t shift, size_t border) { + Settings settings; + settings.border_x = border; + settings.shift_x = shift; + return settings; + } + + static Settings ShiftY(size_t shift, size_t border) { + Settings settings; + settings.border_y = border; + settings.shift_y = shift; + return settings; + } + + static Settings Symmetric(size_t shift, size_t border) { + Settings settings; + settings.border_x = settings.border_y = border; + settings.shift_x = settings.shift_y = shift; + return settings; + } + + static Settings SymmetricBorderOnly(size_t border) { + return Symmetric(0, border); + } + }; + + virtual ~RenderPipelineStage() = default; + + // Processes one row of input, producing the appropriate number of rows of + // output. Input/output rows can be obtained by calls to + // `GetInputRow`/`GetOutputRow`. `xsize+2*xextra` represents the total number + // of pixels to be processed in the input row, where the first pixel is at + // position `kRenderPipelineXOffset-xextra`. All pixels in the + // `[kRenderPipelineXOffset-xextra-border_x, + // kRenderPipelineXOffset+xsize+xextra+border_x)` range are initialized and + // accessible. `xpos` and `ypos` represent the position of the first + // (non-extra, i.e. in position kRenderPipelineXOffset) pixel in the center + // row of the input in the full image. `xpos` is a multiple of + // `GroupBorderAssigner::kPaddingXRound`. If `settings_.temp_buffer_size` is + // nonzero, `temp` will point to an HWY-aligned buffer of at least that number + // of floats; concurrent calls will have different buffers. + virtual void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const = 0; + + // How each channel will be processed. Channels are numbered starting from + // color channels (always 3) and followed by all other channels. + virtual RenderPipelineChannelMode GetChannelMode(size_t c) const = 0; + + protected: + explicit RenderPipelineStage(Settings settings) : settings_(settings) {} + + virtual Status IsInitialized() const { return true; } + + // Informs the stage about the total size of each channel. Few stages will + // actually need to use this information. + virtual void SetInputSizes( + const std::vector>& input_sizes) {} + + virtual Status PrepareForThreads(size_t num_threads) { return true; } + + // Returns a pointer to the input row of channel `c` with offset `y`. + // `y` must be in [-settings_.border_y, settings_.border_y]. `c` must be such + // that `GetChannelMode(c) != kIgnored`. The returned pointer points to the + // offset-ed row (i.e. kRenderPipelineXOffset has been applied). + float* GetInputRow(const RowInfo& input_rows, size_t c, int offset) const { + JXL_DASSERT(GetChannelMode(c) != RenderPipelineChannelMode::kIgnored); + JXL_DASSERT(-offset <= static_cast(settings_.border_y)); + JXL_DASSERT(offset <= static_cast(settings_.border_y)); + return input_rows[c][settings_.border_y + offset] + kRenderPipelineXOffset; + } + // Similar to `GetInputRow`, but can only be used if `GetChannelMode(c) == + // kInOut`. Offset must be less than `1< +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/fake_parallel_runner_testonly.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/render_pipeline/test_render_pipeline_stages.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +Status DecodeFile(const Span file, bool use_slow_pipeline, + CodecInOut* io, ThreadPool* pool) { + Status ret = true; + { + BitReader reader(file); + BitReaderScopedCloser reader_closer(&reader, &ret); + JXL_RETURN_IF_ERROR(reader.ReadFixedBits<16>() == 0x0AFF); + JXL_RETURN_IF_ERROR(ReadSizeHeader(&reader, &io->metadata.size)); + JXL_RETURN_IF_ERROR(ReadImageMetadata(&reader, &io->metadata.m)); + io->metadata.transform_data.nonserialized_xyb_encoded = + io->metadata.m.xyb_encoded; + JXL_RETURN_IF_ERROR(Bundle::Read(&reader, &io->metadata.transform_data)); + size_t xsize = io->metadata.xsize(); + size_t ysize = io->metadata.ysize(); + JXL_RETURN_IF_ERROR(VerifyDimensions(&io->constraints, xsize, ysize)); + if (io->metadata.m.color_encoding.WantICC()) { + PaddedBytes icc; + JXL_RETURN_IF_ERROR(ReadICC(&reader, &icc)); + JXL_RETURN_IF_ERROR(io->metadata.m.color_encoding.SetICC(std::move(icc))); + } + PassesDecoderState dec_state; + JXL_RETURN_IF_ERROR( + dec_state.output_encoding_info.SetFromMetadata(io->metadata)); + JXL_RETURN_IF_ERROR(reader.JumpToByteBoundary()); + io->frames.clear(); + do { + io->frames.emplace_back(&io->metadata.m); + // Skip frames that are not displayed. + do { + size_t frame_start = reader.TotalBitsConsumed() / kBitsPerByte; + size_t size_left = file.size() - frame_start; + JXL_RETURN_IF_ERROR( + DecodeFrame(&dec_state, pool, file.data() + frame_start, size_left, + &io->frames.back(), io->metadata, use_slow_pipeline)); + reader.SkipBits(io->frames.back().decoded_bytes() * kBitsPerByte); + } while (dec_state.shared->frame_header.frame_type != + FrameType::kRegularFrame && + dec_state.shared->frame_header.frame_type != + FrameType::kSkipProgressive); + } while (!dec_state.shared->frame_header.is_last); + + if (io->frames.empty()) return JXL_FAILURE("Not enough data."); + + if (reader.TotalBitsConsumed() != file.size() * kBitsPerByte) { + return JXL_FAILURE("Reader position not at EOF."); + } + if (!reader.AllReadsWithinBounds()) { + return JXL_FAILURE("Reader out of bounds read."); + } + io->CheckMetadata(); + // reader is closed here. + } + return ret; +} + +TEST(RenderPipelineTest, Build) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.UseSimpleImplementation(); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + std::move(builder).Finalize(frame_dimensions); +} + +TEST(RenderPipelineTest, CallAllGroups) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.UseSimpleImplementation(); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + auto pipeline = std::move(builder).Finalize(frame_dimensions); + ASSERT_TRUE(pipeline->PrepareForThreads(1, /*use_group_ids=*/false)); + + for (size_t i = 0; i < frame_dimensions.num_groups; i++) { + auto input_buffers = pipeline->GetInputBuffers(i, 0); + FillPlane(0.0f, input_buffers.GetBuffer(0).first, + input_buffers.GetBuffer(0).second); + input_buffers.Done(); + } + + EXPECT_EQ(pipeline->PassesWithAllInput(), 1); +} + +TEST(RenderPipelineTest, BuildFast) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + std::move(builder).Finalize(frame_dimensions); +} + +TEST(RenderPipelineTest, CallAllGroupsFast) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.AddStage(jxl::make_unique()); + builder.UseSimpleImplementation(); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + auto pipeline = std::move(builder).Finalize(frame_dimensions); + ASSERT_TRUE(pipeline->PrepareForThreads(1, /*use_group_ids=*/false)); + + for (size_t i = 0; i < frame_dimensions.num_groups; i++) { + auto input_buffers = pipeline->GetInputBuffers(i, 0); + FillPlane(0.0f, input_buffers.GetBuffer(0).first, + input_buffers.GetBuffer(0).second); + input_buffers.Done(); + } + + EXPECT_EQ(pipeline->PassesWithAllInput(), 1); +} + +struct RenderPipelineTestInputSettings { + // Input image. + std::string input_path; + size_t xsize, ysize; + bool jpeg_transcode = false; + // Encoding settings. + CompressParams cparams; + // Short name for the encoder settings. + std::string cparams_descr; + + bool add_spot_color = false; + + Splines splines; +}; + +class RenderPipelineTestParam + : public ::testing::TestWithParam {}; + +TEST_P(RenderPipelineTestParam, PipelineTest) { + RenderPipelineTestInputSettings config = GetParam(); + + // Use a parallel runner that randomly shuffles tasks to detect possible + // border handling bugs. + FakeParallelRunner fake_pool(/*order_seed=*/123, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + const PaddedBytes orig = ReadTestData(config.input_path); + + CodecInOut io; + if (config.jpeg_transcode) { + ASSERT_TRUE(jpeg::DecodeImageJPG(Span(orig), &io)); + } else { + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + } + io.ShrinkTo(config.xsize, config.ysize); + + if (config.add_spot_color) { + jxl::ImageF spot(config.xsize, config.ysize); + jxl::ZeroFillImage(&spot); + + for (size_t y = 0; y < config.ysize; y++) { + float* JXL_RESTRICT row = spot.Row(y); + for (size_t x = 0; x < config.xsize; x++) { + row[x] = ((x ^ y) & 255) * (1.f / 255.f); + } + } + ExtraChannelInfo info; + info.bit_depth.bits_per_sample = 8; + info.dim_shift = 0; + info.type = jxl::ExtraChannel::kSpotColor; + info.spot_color[0] = 0.5f; + info.spot_color[1] = 0.2f; + info.spot_color[2] = 1.f; + info.spot_color[3] = 0.5f; + + io.metadata.m.extra_channel_info.push_back(info); + std::vector ec; + ec.push_back(std::move(spot)); + io.frames[0].SetExtraChannels(std::move(ec)); + } + + PaddedBytes compressed; + + PassesEncoderState enc_state; + enc_state.shared.image_features.splines = config.splines; + ASSERT_TRUE(EncodeFile(config.cparams, &io, &enc_state, &compressed, + GetJxlCms(), /*aux_out=*/nullptr, &pool)); + + + CodecInOut io_default; + ASSERT_TRUE(DecodeFile(Span(compressed), + /*use_slow_pipeline=*/false, &io_default, &pool)); + CodecInOut io_slow_pipeline; + ASSERT_TRUE(DecodeFile(Span(compressed), + /*use_slow_pipeline=*/true, &io_slow_pipeline, &pool)); + + ASSERT_EQ(io_default.frames.size(), io_slow_pipeline.frames.size()); + for (size_t i = 0; i < io_default.frames.size(); i++) { +#if JXL_HIGH_PRECISION + constexpr float kMaxError = 1e-5; +#else + constexpr float kMaxError = 1e-4; +#endif + Image3F def = std::move(*io_default.frames[i].color()); + Image3F pip = std::move(*io_slow_pipeline.frames[i].color()); + VerifyRelativeError(pip, def, kMaxError, kMaxError); + for (size_t ec = 0; ec < io_default.frames[i].extra_channels().size(); + ec++) { + VerifyRelativeError(io_slow_pipeline.frames[i].extra_channels()[ec], + io_default.frames[i].extra_channels()[ec], kMaxError, + kMaxError); + } + } +} + +Splines CreateTestSplines() { + const ColorCorrelationMap cmap; + std::vector control_points{{9, 54}, {118, 159}, {97, 3}, + {10, 40}, {150, 25}, {120, 300}}; + const Spline spline{ + control_points, + /*color_dct=*/ + {{0.03125f, 0.00625f, 0.003125f}, {1.f, 0.321875f}, {1.f, 0.24375f}}, + /*sigma_dct=*/{0.3125f, 0.f, 0.f, 0.0625f}}; + std::vector spline_data = {spline}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, /*quantization_adjustment=*/0, + cmap.YtoXRatio(0), cmap.YtoBRatio(0)); + starting_points.push_back(spline.control_points.front()); + } + return Splines(/*quantization_adjustment=*/0, std::move(quantized_splines), + std::move(starting_points)); +} + +std::vector GeneratePipelineTests() { + std::vector all_tests; + + std::pair sizes[] = { + {3, 8}, {128, 128}, {256, 256}, {258, 258}, {533, 401}, {777, 777}, + }; + + for (auto size : sizes) { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/flower/flower.png"; + settings.xsize = size.first; + settings.ysize = size.second; + + // Base settings. + settings.cparams.butteraugli_distance = 1.0; + settings.cparams.patches = Override::kOff; + settings.cparams.dots = Override::kOff; + settings.cparams.gaborish = Override::kOff; + settings.cparams.epf = 0; + settings.cparams.color_transform = ColorTransform::kXYB; + + { + auto s = settings; + s.cparams_descr = "NoGabNoEpfNoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.color_transform = ColorTransform::kNone; + s.cparams_descr = "NoGabNoEpfNoPatchesNoXYB"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.gaborish = Override::kOn; + s.cparams_descr = "GabNoEpfNoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.epf = 1; + s.cparams_descr = "NoGabEpf1NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.epf = 2; + s.cparams_descr = "NoGabEpf2NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.epf = 3; + s.cparams_descr = "NoGabEpf3NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.gaborish = Override::kOn; + s.cparams.epf = 3; + s.cparams_descr = "GabEpf3NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "Splines"; + s.splines = CreateTestSplines(); + all_tests.push_back(s); + } + + for (size_t ups : {2, 4, 8}) { + { + auto s = settings; + s.cparams.resampling = ups; + s.cparams_descr = "Ups" + std::to_string(ups); + all_tests.push_back(s); + } + { + auto s = settings; + s.cparams.resampling = ups; + s.cparams.epf = 1; + s.cparams_descr = "Ups" + std::to_string(ups) + "EPF1"; + all_tests.push_back(s); + } + { + auto s = settings; + s.cparams.resampling = ups; + s.cparams.gaborish = Override::kOn; + s.cparams.epf = 1; + s.cparams_descr = "Ups" + std::to_string(ups) + "GabEPF1"; + all_tests.push_back(s); + } + } + + { + auto s = settings; + s.cparams_descr = "Noise"; + s.cparams.photon_noise_iso = 3200; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "NoiseUps"; + s.cparams.photon_noise_iso = 3200; + s.cparams.resampling = 2; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "ModularLossless"; + s.cparams.modular_mode = true; + s.cparams.butteraugli_distance = 0; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "ProgressiveDC"; + s.cparams.progressive_dc = 1; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "ModularLossy"; + s.cparams.modular_mode = true; + s.cparams.butteraugli_distance = 1.f; + all_tests.push_back(s); + } + + { + auto s = settings; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaVarDCT"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaVarDCTUpsamplingEPF"; + s.cparams.epf = 1; + s.cparams.ec_resampling = 2; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.modular_mode = true; + s.cparams.butteraugli_distance = 0; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaLossless"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaDownsample"; + s.cparams.ec_resampling = 2; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "SpotColor"; + s.add_spot_color = true; + all_tests.push_back(s); + } + } + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + for (const char* input : {"jxl/flower/flower.png.im_q85_444.jpg", + "jxl/flower/flower.png.im_q85_420.jpg", + "jxl/flower/flower.png.im_q85_422.jpg", + "jxl/flower/flower.png.im_q85_440.jpg"}) { + RenderPipelineTestInputSettings settings; + settings.input_path = input; + settings.jpeg_transcode = true; + settings.xsize = 2268; + settings.ysize = 1512; + settings.cparams_descr = "Default"; + all_tests.push_back(settings); + } + +#endif + + { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/grayscale_patches.png"; + settings.xsize = 1011; + settings.ysize = 277; + settings.cparams_descr = "Patches"; + all_tests.push_back(settings); + } + + { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/grayscale_patches.png"; + settings.xsize = 1011; + settings.ysize = 277; + settings.cparams.photon_noise_iso = 1000; + settings.cparams_descr = "PatchesAndNoise"; + all_tests.push_back(settings); + } + + { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/grayscale_patches.png"; + settings.xsize = 1011; + settings.ysize = 277; + settings.cparams.resampling = 2; + settings.cparams_descr = "PatchesAndUps2"; + all_tests.push_back(settings); + } + + return all_tests; +} + +std::ostream& operator<<(std::ostream& os, + const RenderPipelineTestInputSettings& c) { + std::string filename; + size_t pos = c.input_path.find_last_of('/'); + if (pos == std::string::npos) { + filename = c.input_path; + } else { + filename = c.input_path.substr(pos + 1); + } + std::replace_if( + filename.begin(), filename.end(), [](char c) { return !isalnum(c); }, + '_'); + os << filename << "_" << (c.jpeg_transcode ? "JPEG_" : "") << c.xsize << "x" + << c.ysize << "_" << c.cparams_descr; + return os; +} + +std::string PipelineTestDescription( + const testing::TestParamInfo& info) { + std::stringstream name; + name << info.param; + return name.str(); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(RenderPipelineTest, RenderPipelineTestParam, + testing::ValuesIn(GeneratePipelineTests()), + PipelineTestDescription); + +TEST(RenderPipelineDecodingTest, Animation) { + FakeParallelRunner fake_pool(/*order_seed=*/123, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + + PaddedBytes compressed = + ReadTestData("jxl/blending/cropped_traffic_light.jxl"); + + CodecInOut io_default; + ASSERT_TRUE(DecodeFile(Span(compressed), + /*use_slow_pipeline=*/false, &io_default, &pool)); + CodecInOut io_slow_pipeline; + ASSERT_TRUE(DecodeFile(Span(compressed), + /*use_slow_pipeline=*/true, &io_slow_pipeline, &pool)); + + ASSERT_EQ(io_default.frames.size(), io_slow_pipeline.frames.size()); + for (size_t i = 0; i < io_default.frames.size(); i++) { +#if JXL_HIGH_PRECISION + constexpr float kMaxError = 1e-5; +#else + constexpr float kMaxError = 1e-4; +#endif + + Image3F fast_pipeline = std::move(*io_default.frames[i].color()); + Image3F slow_pipeline = std::move(*io_slow_pipeline.frames[i].color()); + VerifyRelativeError(slow_pipeline, fast_pipeline, kMaxError, kMaxError); + for (size_t ec = 0; ec < io_default.frames[i].extra_channels().size(); + ec++) { + VerifyRelativeError(io_slow_pipeline.frames[i].extra_channels()[ec], + io_default.frames[i].extra_channels()[ec], kMaxError, + kMaxError); + } + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.cc b/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.cc new file mode 100644 index 000000000..6e6bcb764 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.cc @@ -0,0 +1,265 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/simple_render_pipeline.h" + +#include "hwy/base.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +void SimpleRenderPipeline::PrepareForThreadsInternal(size_t num, + bool use_group_ids) { + if (!channel_data_.empty()) { + return; + } + auto ch_size = [](size_t frame_size, size_t shift) { + return DivCeil(frame_size, 1 << shift) + kRenderPipelineXOffset * 2; + }; + for (size_t c = 0; c < channel_shifts_[0].size(); c++) { + channel_data_.push_back(ImageF( + ch_size(frame_dimensions_.xsize_upsampled, channel_shifts_[0][c].first), + ch_size(frame_dimensions_.ysize_upsampled, + channel_shifts_[0][c].second))); + msan::PoisonImage(channel_data_.back()); + } +} + +Rect SimpleRenderPipeline::MakeChannelRect(size_t group_id, size_t channel) { + size_t base_color_shift = + CeilLog2Nonzero(frame_dimensions_.xsize_upsampled_padded / + frame_dimensions_.xsize_padded); + + const size_t gx = group_id % frame_dimensions_.xsize_groups; + const size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t xgroupdim = (frame_dimensions_.group_dim << base_color_shift) >> + channel_shifts_[0][channel].first; + size_t ygroupdim = (frame_dimensions_.group_dim << base_color_shift) >> + channel_shifts_[0][channel].second; + return Rect( + kRenderPipelineXOffset + gx * xgroupdim, + kRenderPipelineXOffset + gy * ygroupdim, xgroupdim, ygroupdim, + kRenderPipelineXOffset + DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[0][channel].first), + kRenderPipelineXOffset + + DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[0][channel].second)); +} + +std::vector> SimpleRenderPipeline::PrepareBuffers( + size_t group_id, size_t thread_id) { + std::vector> ret; + for (size_t c = 0; c < channel_data_.size(); c++) { + ret.emplace_back(&channel_data_[c], MakeChannelRect(group_id, c)); + } + return ret; +} + +void SimpleRenderPipeline::ProcessBuffers(size_t group_id, size_t thread_id) { + for (size_t c = 0; c < channel_data_.size(); c++) { + Rect r = MakeChannelRect(group_id, c); + (void)r; + JXL_CHECK_PLANE_INITIALIZED(channel_data_[c], r, c); + } + + if (PassesWithAllInput() <= processed_passes_) return; + processed_passes_++; + + for (size_t stage_id = 0; stage_id < stages_.size(); stage_id++) { + const auto& stage = stages_[stage_id]; + // Prepare buffers for kInOut channels. + std::vector new_channels(channel_data_.size()); + std::vector output_channels(channel_data_.size()); + + std::vector> input_sizes(channel_data_.size()); + for (size_t c = 0; c < channel_data_.size(); c++) { + input_sizes[c] = + std::make_pair(channel_data_[c].xsize() - kRenderPipelineXOffset * 2, + channel_data_[c].ysize() - kRenderPipelineXOffset * 2); + } + + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) != RenderPipelineChannelMode::kInOut) { + continue; + } + // Ensure that the newly allocated channels are large enough to avoid + // problems with padding. + new_channels[c] = + ImageF(frame_dimensions_.xsize_upsampled_padded + + kRenderPipelineXOffset * 2 + hwy::kMaxVectorSize * 8, + frame_dimensions_.ysize_upsampled_padded + + kRenderPipelineXOffset * 2); + new_channels[c].ShrinkTo( + (input_sizes[c].first << stage->settings_.shift_x) + + kRenderPipelineXOffset * 2, + (input_sizes[c].second << stage->settings_.shift_y) + + kRenderPipelineXOffset * 2); + output_channels[c] = &new_channels[c]; + } + + auto get_row = [&](size_t c, int64_t y) { + return channel_data_[c].Row(kRenderPipelineXOffset + y) + + kRenderPipelineXOffset; + }; + + // Add mirrored pixes to all kInOut channels. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) != RenderPipelineChannelMode::kInOut) { + continue; + } + // Horizontal mirroring. + for (size_t y = 0; y < input_sizes[c].second; y++) { + float* row = get_row(c, y); + for (size_t ix = 0; ix < stage->settings_.border_x; ix++) { + *(row - ix - 1) = row[Mirror(-ssize_t(ix) - 1, input_sizes[c].first)]; + } + for (size_t ix = 0; ix < stage->settings_.border_x; ix++) { + *(row + ix + input_sizes[c].first) = + row[Mirror(ix + input_sizes[c].first, input_sizes[c].first)]; + } + } + // Vertical mirroring. + for (int y = 0; y < static_cast(stage->settings_.border_y); y++) { + memcpy(get_row(c, -y - 1) - stage->settings_.border_x, + get_row(c, Mirror(-ssize_t(y) - 1, input_sizes[c].second)) - + stage->settings_.border_x, + sizeof(float) * + (input_sizes[c].first + 2 * stage->settings_.border_x)); + } + for (int y = 0; y < static_cast(stage->settings_.border_y); y++) { + memcpy( + get_row(c, input_sizes[c].second + y) - stage->settings_.border_x, + get_row(c, + Mirror(input_sizes[c].second + y, input_sizes[c].second)) - + stage->settings_.border_x, + sizeof(float) * + (input_sizes[c].first + 2 * stage->settings_.border_x)); + } + } + + size_t ysize = 0; + size_t xsize = 0; + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kIgnored) { + continue; + } + ysize = std::max(input_sizes[c].second, ysize); + xsize = std::max(input_sizes[c].first, xsize); + } + + JXL_ASSERT(ysize != 0); + JXL_ASSERT(xsize != 0); + + RenderPipelineStage::RowInfo input_rows(channel_data_.size()); + RenderPipelineStage::RowInfo output_rows(channel_data_.size()); + + // Run the pipeline. + { + stage->SetInputSizes(input_sizes); + int border_y = stage->settings_.border_y; + for (size_t y = 0; y < ysize; y++) { + // Prepare input rows. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kIgnored) { + continue; + } + input_rows[c].resize(2 * border_y + 1); + for (int iy = -border_y; iy <= border_y; iy++) { + input_rows[c][iy + border_y] = + channel_data_[c].Row(y + kRenderPipelineXOffset + iy); + } + } + // Prepare output rows. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (!output_channels[c]) continue; + output_rows[c].resize(1 << stage->settings_.shift_y); + for (size_t iy = 0; iy < output_rows[c].size(); iy++) { + output_rows[c][iy] = output_channels[c]->Row( + (y << stage->settings_.shift_y) + iy + kRenderPipelineXOffset); + } + } + stage->ProcessRow(input_rows, output_rows, /*xextra=*/0, xsize, + /*xpos=*/0, y, thread_id); + } + } + + // Move new channels to current channels. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) != RenderPipelineChannelMode::kInOut) { + continue; + } + channel_data_[c] = std::move(new_channels[c]); + } + for (size_t c = 0; c < channel_data_.size(); c++) { + size_t next_stage = std::min(stage_id + 1, channel_shifts_.size() - 1); + size_t xsize = DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[next_stage][c].first); + size_t ysize = DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[next_stage][c].second); + channel_data_[c].ShrinkTo(xsize + 2 * kRenderPipelineXOffset, + ysize + 2 * kRenderPipelineXOffset); + JXL_CHECK_PLANE_INITIALIZED( + channel_data_[c], + Rect(kRenderPipelineXOffset, kRenderPipelineXOffset, xsize, ysize), + c); + } + + if (stage->SwitchToImageDimensions()) { + size_t image_xsize, image_ysize; + FrameOrigin frame_origin; + stage->GetImageDimensions(&image_xsize, &image_ysize, &frame_origin); + frame_dimensions_.Set(image_xsize, image_ysize, 0, 0, 0, false, 1); + std::vector old_channels = std::move(channel_data_); + channel_data_.clear(); + channel_data_.reserve(old_channels.size()); + for (size_t c = 0; c < old_channels.size(); c++) { + channel_data_.emplace_back(2 * kRenderPipelineXOffset + image_xsize, + 2 * kRenderPipelineXOffset + image_ysize); + } + for (size_t y = 0; y < image_ysize; ++y) { + for (size_t c = 0; c < channel_data_.size(); c++) { + output_rows[c].resize(1); + output_rows[c][0] = channel_data_[c].Row(kRenderPipelineXOffset + y); + } + // TODO(sboukortt): consider doing this only on the parts of the + // background that won't be occluded. + stage->ProcessPaddingRow(output_rows, image_xsize, 0, y); + } + ssize_t x0 = frame_origin.x0; + ssize_t y0 = frame_origin.y0; + size_t x0_fg = 0; + size_t y0_fg = 0; + if (x0 < 0) { + xsize += x0; + x0_fg -= x0; + x0 = 0; + } + if (x0 + xsize > image_xsize) { + xsize = image_xsize - x0; + } + if (y0 < 0) { + ysize += y0; + y0_fg -= x0; + y0 = 0; + } + if (y0 + ysize > image_ysize) { + ysize = image_ysize - y0; + } + const Rect rect_fg_relative_to_image = + Rect(x0, y0, xsize, ysize) + .Translate(kRenderPipelineXOffset, kRenderPipelineXOffset); + const Rect rect_fg = + Rect(x0_fg, y0_fg, xsize, ysize) + .Translate(kRenderPipelineXOffset, kRenderPipelineXOffset); + for (size_t c = 0; c < channel_data_.size(); c++) { + CopyImageTo(rect_fg, old_channels[c], rect_fg_relative_to_image, + &channel_data_[c]); + } + } + } +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.h b/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.h new file mode 100644 index 000000000..10f450591 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/simple_render_pipeline.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_SIMPLE_RENDER_PIPELINE_H_ +#define LIB_JXL_RENDER_PIPELINE_SIMPLE_RENDER_PIPELINE_H_ + +#include + +#include "lib/jxl/render_pipeline/render_pipeline.h" + +namespace jxl { + +// A RenderPipeline that is "obviously correct"; it may use potentially large +// amounts of memory and be slow. It is intended to be used mostly for testing +// purposes. +class SimpleRenderPipeline : public RenderPipeline { + std::vector> PrepareBuffers( + size_t group_id, size_t thread_id) override; + + void ProcessBuffers(size_t group_id, size_t thread_id) override; + + void PrepareForThreadsInternal(size_t num, bool use_group_ids) override; + + // Full frame buffers. Both X and Y dimensions are padded by + // kRenderPipelineXOffset. + std::vector channel_data_; + size_t processed_passes_ = 0; + + private: + Rect MakeChannelRect(size_t group_id, size_t channel); +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_SIMPLE_RENDER_PIPELINE_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.cc new file mode 100644 index 000000000..5d36c0a7d --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.cc @@ -0,0 +1,235 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_blending.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_blending.cc" +#include +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/blending.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class BlendingStage : public RenderPipelineStage { + public: + explicit BlendingStage(const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding) + : RenderPipelineStage(RenderPipelineStage::Settings()), + state_(*dec_state->shared) { + image_xsize_ = state_.frame_header.nonserialized_metadata->xsize(); + image_ysize_ = state_.frame_header.nonserialized_metadata->ysize(); + extra_channel_info_ = + &state_.frame_header.nonserialized_metadata->m.extra_channel_info; + info_ = state_.frame_header.blending_info; + const std::vector& ec_info = + state_.frame_header.extra_channel_blending_info; + ImageBundle& bg = *state_.reference_frames[info_.source].frame; + bg_ = &bg; + if (bg.xsize() == 0 || bg.ysize() == 0) { + zeroes_.resize(image_xsize_, 0.f); + } else if (state_.reference_frames[info_.source].ib_is_in_xyb) { + initialized_ = JXL_FAILURE( + "Trying to blend XYB reference frame %i and non-XYB frame", + info_.source); + return; + } else if (std::any_of(ec_info.begin(), ec_info.end(), + [this](const BlendingInfo& info) { + const ImageBundle& bg = + *state_.reference_frames[info.source].frame; + return bg.xsize() == 0 || bg.ysize() == 0; + })) { + zeroes_.resize(image_xsize_, 0.f); + } + + if (bg.xsize() != 0 && bg.ysize() != 0 && + (bg.xsize() < image_xsize_ || bg.ysize() < image_ysize_ || + bg.origin.x0 != 0 || bg.origin.y0 != 0)) { + initialized_ = JXL_FAILURE("Trying to use a %" PRIuS "x%" PRIuS + " crop as a background", + bg.xsize(), bg.ysize()); + return; + } + if (state_.metadata->m.xyb_encoded) { + if (!dec_state->output_encoding_info.color_encoding_is_original) { + initialized_ = JXL_FAILURE("Blending in unsupported color space"); + return; + } + } + + blending_info_.resize(ec_info.size() + 1); + auto make_blending = [&](const BlendingInfo& info, PatchBlending* pb) { + pb->alpha_channel = info.alpha_channel; + pb->clamp = info.clamp; + switch (info.mode) { + case BlendMode::kReplace: { + pb->mode = PatchBlendMode::kReplace; + break; + } + case BlendMode::kAdd: { + pb->mode = PatchBlendMode::kAdd; + break; + } + case BlendMode::kMul: { + pb->mode = PatchBlendMode::kMul; + break; + } + case BlendMode::kBlend: { + pb->mode = PatchBlendMode::kBlendAbove; + break; + } + case BlendMode::kAlphaWeightedAdd: { + pb->mode = PatchBlendMode::kAlphaWeightedAddAbove; + break; + } + default: { + JXL_ABORT("Invalid blend mode"); // should have failed to decode + } + } + }; + make_blending(info_, &blending_info_[0]); + for (size_t i = 0; i < ec_info.size(); i++) { + make_blending(ec_info[i], &blending_info_[1 + i]); + } + } + + Status IsInitialized() const override { return initialized_; } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("Blend"); + JXL_ASSERT(initialized_); + const FrameOrigin& frame_origin = state_.frame_header.frame_origin; + ssize_t bg_xpos = frame_origin.x0 + static_cast(xpos); + ssize_t bg_ypos = frame_origin.y0 + static_cast(ypos); + int offset = 0; + if (bg_xpos + static_cast(xsize) <= 0 || + frame_origin.x0 >= static_cast(image_xsize_) || bg_ypos < 0 || + bg_ypos >= static_cast(image_ysize_)) { + return; + } + if (bg_xpos < 0) { + offset -= bg_xpos; + xsize += bg_xpos; + bg_xpos = 0; + } + if (bg_xpos + xsize > image_xsize_) { + xsize = + std::max(0, static_cast(image_xsize_) - bg_xpos); + } + std::vector bg_row_ptrs_(input_rows.size()); + std::vector fg_row_ptrs_(input_rows.size()); + size_t num_c = std::min(input_rows.size(), extra_channel_info_->size() + 3); + for (size_t c = 0; c < num_c; ++c) { + fg_row_ptrs_[c] = GetInputRow(input_rows, c, 0) + offset; + if (c < 3) { + bg_row_ptrs_[c] = + bg_->xsize() != 0 && bg_->ysize() != 0 + ? bg_->color()->ConstPlaneRow(c, bg_ypos) + bg_xpos + : zeroes_.data(); + } else { + const ImageBundle& ec_bg = + *state_ + .reference_frames[state_.frame_header + .extra_channel_blending_info[c - 3] + .source] + .frame; + bg_row_ptrs_[c] = + ec_bg.xsize() != 0 && ec_bg.ysize() != 0 + ? ec_bg.extra_channels()[c - 3].ConstRow(bg_ypos) + bg_xpos + : zeroes_.data(); + } + } + PerformBlending(bg_row_ptrs_.data(), fg_row_ptrs_.data(), + fg_row_ptrs_.data(), 0, xsize, blending_info_[0], + blending_info_.data() + 1, *extra_channel_info_); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInPlace; + } + + bool SwitchToImageDimensions() const override { return true; } + + void GetImageDimensions(size_t* xsize, size_t* ysize, + FrameOrigin* frame_origin) const override { + *xsize = image_xsize_; + *ysize = image_ysize_; + *frame_origin = state_.frame_header.frame_origin; + } + + void ProcessPaddingRow(const RowInfo& output_rows, size_t xsize, size_t xpos, + size_t ypos) const override { + if (bg_->xsize() == 0 || bg_->ysize() == 0) { + for (size_t c = 0; c < 3; ++c) { + memset(GetInputRow(output_rows, c, 0), 0, xsize * sizeof(float)); + } + } else { + for (size_t c = 0; c < 3; ++c) { + memcpy(GetInputRow(output_rows, c, 0), + bg_->color()->ConstPlaneRow(c, ypos) + xpos, + xsize * sizeof(float)); + } + } + for (size_t ec = 0; ec < extra_channel_info_->size(); ++ec) { + const ImageBundle& ec_bg = + *state_ + .reference_frames + [state_.frame_header.extra_channel_blending_info[ec].source] + .frame; + if (ec_bg.xsize() == 0 || ec_bg.ysize() == 0) { + memset(GetInputRow(output_rows, 3 + ec, 0), 0, xsize * sizeof(float)); + } else { + memcpy(GetInputRow(output_rows, 3 + ec, 0), + ec_bg.extra_channels()[ec].ConstRow(ypos) + xpos, + xsize * sizeof(float)); + } + } + } + + const char* GetName() const override { return "Blending"; } + + private: + const PassesSharedState& state_; + BlendingInfo info_; + ImageBundle* bg_; + Status initialized_ = true; + size_t image_xsize_; + size_t image_ysize_; + std::vector blending_info_; + const std::vector* extra_channel_info_; + std::vector zeroes_; +}; + +std::unique_ptr GetBlendingStage( + const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding) { + return jxl::make_unique(dec_state, frame_color_encoding); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetBlendingStage); + +std::unique_ptr GetBlendingStage( + const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding) { + return HWY_DYNAMIC_DISPATCH(GetBlendingStage)(dec_state, + frame_color_encoding); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.h new file mode 100644 index 000000000..c8db7490c --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_blending.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_BLENDING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_BLENDING_H_ + +#include + +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +// Applies blending if applicable. +std::unique_ptr GetBlendingStage( + const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_BLENDING_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.cc new file mode 100644 index 000000000..9b73ee91f --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.cc @@ -0,0 +1,129 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_chroma_upsampling.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_chroma_upsampling.cc" +#include +#include + +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; + +class HorizontalChromaUpsamplingStage : public RenderPipelineStage { + public: + explicit HorizontalChromaUpsamplingStage(size_t channel) + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftX( + /*shift=*/1, /*border=*/1)), + c_(channel) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("HorizontalChromaUpsampling"); + HWY_FULL(float) df; + xextra = RoundUpTo(xextra, Lanes(df)); + auto threefour = Set(df, 0.75f); + auto onefour = Set(df, 0.25f); + const float* row_in = GetInputRow(input_rows, c_, 0); + float* row_out = GetOutputRow(output_rows, c_, 0); + for (ssize_t x = -xextra; x < static_cast(xsize + xextra); + x += Lanes(df)) { + auto current = Mul(LoadU(df, row_in + x), threefour); + auto prev = LoadU(df, row_in + x - 1); + auto next = LoadU(df, row_in + x + 1); + auto left = MulAdd(onefour, prev, current); + auto right = MulAdd(onefour, next, current); + StoreInterleaved(df, left, right, row_out + x * 2); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "HChromaUps"; } + + private: + size_t c_; +}; + +class VerticalChromaUpsamplingStage : public RenderPipelineStage { + public: + explicit VerticalChromaUpsamplingStage(size_t channel) + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftY( + /*shift=*/1, /*border=*/1)), + c_(channel) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("VerticalChromaUpsampling"); + HWY_FULL(float) df; + xextra = RoundUpTo(xextra, Lanes(df)); + auto threefour = Set(df, 0.75f); + auto onefour = Set(df, 0.25f); + const float* row_top = GetInputRow(input_rows, c_, -1); + const float* row_mid = GetInputRow(input_rows, c_, 0); + const float* row_bot = GetInputRow(input_rows, c_, 1); + float* row_out0 = GetOutputRow(output_rows, c_, 0); + float* row_out1 = GetOutputRow(output_rows, c_, 1); + for (ssize_t x = -xextra; x < static_cast(xsize + xextra); + x += Lanes(df)) { + auto it = LoadU(df, row_top + x); + auto im = LoadU(df, row_mid + x); + auto ib = LoadU(df, row_bot + x); + auto im_scaled = Mul(im, threefour); + Store(MulAdd(it, onefour, im_scaled), df, row_out0 + x); + Store(MulAdd(ib, onefour, im_scaled), df, row_out1 + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "VChromaUps"; } + + private: + size_t c_; +}; + +std::unique_ptr GetChromaUpsamplingStage(size_t channel, + bool horizontal) { + if (horizontal) { + return jxl::make_unique(channel); + } else { + return jxl::make_unique(channel); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetChromaUpsamplingStage); + +std::unique_ptr GetChromaUpsamplingStage(size_t channel, + bool horizontal) { + return HWY_DYNAMIC_DISPATCH(GetChromaUpsamplingStage)(channel, horizontal); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.h new file mode 100644 index 000000000..b8bfc15f5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_chroma_upsampling.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_CHROMA_UPSAMPLING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_CHROMA_UPSAMPLING_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies simple upsampling, either horizontal or vertical, to the given +// channel. +std::unique_ptr GetChromaUpsamplingStage(size_t channel, + bool horizontal); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_CHROMA_UPSAMPLING_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.cc new file mode 100644 index 000000000..d59c49784 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.cc @@ -0,0 +1,524 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_epf.h" + +#include "lib/jxl/epf.h" +#include "lib/jxl/sanitizers.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_epf.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +// TODO(veluca): In principle, vectors could be not capped, if we want to deal +// with having two different sigma values in a single vector. +using DF = HWY_CAPPED(float, 8); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::VFromD; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +JXL_INLINE Vec Weight(Vec sad, Vec inv_sigma, Vec thres) { + auto v = MulAdd(sad, inv_sigma, Set(DF(), 1.0f)); + return ZeroIfNegative(v); +} + +// 5x5 plus-shaped kernel with 5 SADs per pixel (3x3 plus-shaped). So this makes +// this filter a 7x7 filter. +class EPF0Stage : public RenderPipelineStage { + public: + EPF0Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/3)), + lf_(lf), + sigma_(&sigma) {} + + template + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][7], ssize_t x, + Vec sad, Vec inv_sigma, + Vec* JXL_RESTRICT X, Vec* JXL_RESTRICT Y, + Vec* JXL_RESTRICT B, + Vec* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][3 + row] + x) + : LoadU(DF(), rows[0][3 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][3 + row] + x) + : LoadU(DF(), rows[1][3 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][3 + row] + x) + : LoadU(DF(), rows[2][3 + row] + x); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass1_zeroflush)); + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + + using V = decltype(Zero(df)); + V t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, tA, tB; + V* sads[12] = {&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7, &t8, &t9, &tA, &tB}; + + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = lf_.epf_pass0_sigma_scale * 1.65; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + float* JXL_RESTRICT rows[3][7]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 7; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 3); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][3 + 0] + x); + StoreU(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + + for (size_t i = 0; i < 12; i++) *sads[i] = Zero(df); + constexpr std::array sads_off[12] = { + {{-2, 0}}, {{-1, -1}}, {{-1, 0}}, {{-1, 1}}, {{0, -2}}, {{0, -1}}, + {{0, 1}}, {{0, 2}}, {{1, -1}}, {{1, 0}}, {{1, 1}}, {{2, 0}}, + }; + + // compute sads + // TODO(veluca): consider unrolling and optimizing this. + for (size_t c = 0; c < 3; c++) { + auto scale = Set(df, lf_.epf_channel_scale[c]); + for (size_t i = 0; i < 12; i++) { + auto sad = Zero(df); + constexpr std::array plus_off[] = { + {{0, 0}}, {{-1, 0}}, {{0, -1}}, {{1, 0}}, {{0, 1}}}; + for (size_t j = 0; j < 5; j++) { + const auto r11 = + LoadU(df, rows[c][3 + plus_off[j][0]] + x + plus_off[j][1]); + const auto c11 = + LoadU(df, rows[c][3 + sads_off[i][0] + plus_off[j][0]] + x + + sads_off[i][1] + plus_off[j][1]); + sad = Add(sad, AbsDiff(r11, c11)); + } + *sads[i] = MulAdd(sad, scale, *sads[i]); + } + } + const auto x_cc = Load(df, rows[0][3 + 0] + x); + const auto y_cc = Load(df, rows[1][3 + 0] + x); + const auto b_cc = Load(df, rows[2][3 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + for (size_t i = 0; i < 12; i++) { + AddPixel(/*row=*/sads_off[i][0], rows, + x + sads_off[i][1], *sads[i], inv_sigma, &X, + &Y, &B, &w); + } +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + StoreU(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + StoreU(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + StoreU(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF0"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +// 3x3 plus-shaped kernel with 5 SADs per pixel (also 3x3 plus-shaped). So this +// makes this filter a 5x5 filter. +class EPF1Stage : public RenderPipelineStage { + public: + EPF1Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/2)), + lf_(lf), + sigma_(&sigma) {} + + template + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][5], ssize_t x, + Vec sad, Vec inv_sigma, + Vec* JXL_RESTRICT X, Vec* JXL_RESTRICT Y, + Vec* JXL_RESTRICT B, + Vec* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][2 + row] + x) + : LoadU(DF(), rows[0][2 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][2 + row] + x) + : LoadU(DF(), rows[1][2 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][2 + row] + x) + : LoadU(DF(), rows[2][2 + row] + x); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass1_zeroflush)); + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = 1.65f; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + + float* JXL_RESTRICT rows[3][5]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 5; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 2); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][2 + 0] + x); + Store(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + auto sad0 = Zero(df); + auto sad1 = Zero(df); + auto sad2 = Zero(df); + auto sad3 = Zero(df); + + // compute sads + for (size_t c = 0; c < 3; c++) { + // center px = 22, px above = 21 + auto t = Undefined(df); + + const auto p20 = Load(df, rows[c][2 + -2] + x); + const auto p21 = Load(df, rows[c][2 + -1] + x); + auto sad0c = AbsDiff(p20, p21); // SAD 2, 1 + + const auto p11 = LoadU(df, rows[c][2 + -1] + x - 1); + auto sad1c = AbsDiff(p11, p21); // SAD 1, 2 + + const auto p31 = LoadU(df, rows[c][2 + -1] + x + 1); + auto sad2c = AbsDiff(p31, p21); // SAD 3, 2 + + const auto p02 = LoadU(df, rows[c][2 + 0] + x - 2); + const auto p12 = LoadU(df, rows[c][2 + 0] + x - 1); + sad1c = Add(sad1c, AbsDiff(p02, p12)); // SAD 1, 2 + sad0c = Add(sad0c, AbsDiff(p11, p12)); // SAD 2, 1 + + const auto p22 = LoadU(df, rows[c][2 + 0] + x); + t = AbsDiff(p12, p22); + sad1c = Add(sad1c, t); // SAD 1, 2 + sad2c = Add(sad2c, t); // SAD 3, 2 + t = AbsDiff(p22, p21); + auto sad3c = t; // SAD 2, 3 + sad0c = Add(sad0c, t); // SAD 2, 1 + + const auto p32 = LoadU(df, rows[c][2 + 0] + x + 1); + sad0c = Add(sad0c, AbsDiff(p31, p32)); // SAD 2, 1 + t = AbsDiff(p22, p32); + sad1c = Add(sad1c, t); // SAD 1, 2 + sad2c = Add(sad2c, t); // SAD 3, 2 + + const auto p42 = LoadU(df, rows[c][2 + 0] + x + 2); + sad2c = Add(sad2c, AbsDiff(p42, p32)); // SAD 3, 2 + + const auto p13 = LoadU(df, rows[c][2 + 1] + x - 1); + sad3c = Add(sad3c, AbsDiff(p13, p12)); // SAD 2, 3 + + const auto p23 = Load(df, rows[c][2 + 1] + x); + t = AbsDiff(p22, p23); + sad0c = Add(sad0c, t); // SAD 2, 1 + sad3c = Add(sad3c, t); // SAD 2, 3 + sad1c = Add(sad1c, AbsDiff(p13, p23)); // SAD 1, 2 + + const auto p33 = LoadU(df, rows[c][2 + 1] + x + 1); + sad2c = Add(sad2c, AbsDiff(p33, p23)); // SAD 3, 2 + sad3c = Add(sad3c, AbsDiff(p33, p32)); // SAD 2, 3 + + const auto p24 = Load(df, rows[c][2 + 2] + x); + sad3c = Add(sad3c, AbsDiff(p24, p23)); // SAD 2, 3 + + auto scale = Set(df, lf_.epf_channel_scale[c]); + sad0 = MulAdd(sad0c, scale, sad0); + sad1 = MulAdd(sad1c, scale, sad1); + sad2 = MulAdd(sad2c, scale, sad2); + sad3 = MulAdd(sad3c, scale, sad3); + } + const auto x_cc = Load(df, rows[0][2 + 0] + x); + const auto y_cc = Load(df, rows[1][2 + 0] + x); + const auto b_cc = Load(df, rows[2][2 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixel(/*row=*/-1, rows, x, sad0, inv_sigma, &X, &Y, + &B, &w); + // Center + AddPixel(/*row=*/0, rows, x - 1, sad1, inv_sigma, &X, + &Y, &B, &w); + AddPixel(/*row=*/0, rows, x + 1, sad2, inv_sigma, &X, + &Y, &B, &w); + // Bottom + AddPixel(/*row=*/1, rows, x, sad3, inv_sigma, &X, &Y, + &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + Store(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + Store(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF1"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +// 3x3 plus-shaped kernel with 1 SAD per pixel. So this makes this filter a 3x3 +// filter. +class EPF2Stage : public RenderPipelineStage { + public: + EPF2Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/1)), + lf_(lf), + sigma_(&sigma) {} + + template + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][3], ssize_t x, + Vec rx, Vec ry, Vec rb, + Vec inv_sigma, Vec* JXL_RESTRICT X, + Vec* JXL_RESTRICT Y, Vec* JXL_RESTRICT B, + Vec* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][1 + row] + x) + : LoadU(DF(), rows[0][1 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][1 + row] + x) + : LoadU(DF(), rows[1][1 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][1 + row] + x) + : LoadU(DF(), rows[2][1 + row] + x); + + auto sad = Mul(AbsDiff(cx, rx), Set(DF(), lf_.epf_channel_scale[0])); + sad = MulAdd(AbsDiff(cy, ry), Set(DF(), lf_.epf_channel_scale[1]), sad); + sad = MulAdd(AbsDiff(cb, rb), Set(DF(), lf_.epf_channel_scale[2]), sad); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass2_zeroflush)); + + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = lf_.epf_pass2_sigma_scale * 1.65; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + + float* JXL_RESTRICT rows[3][3]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 3; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 1); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][1 + 0] + x); + Store(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + + const auto x_cc = Load(df, rows[0][1 + 0] + x); + const auto y_cc = Load(df, rows[1][1 + 0] + x); + const auto b_cc = Load(df, rows[2][1 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixel(/*row=*/-1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + // Center + AddPixel(/*row=*/0, rows, x - 1, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + AddPixel(/*row=*/0, rows, x + 1, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + // Bottom + AddPixel(/*row=*/1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + Store(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + Store(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF2"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +std::unique_ptr GetEPFStage0(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique(lf, sigma); +} + +std::unique_ptr GetEPFStage1(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique(lf, sigma); +} + +std::unique_ptr GetEPFStage2(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique(lf, sigma); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetEPFStage0); +HWY_EXPORT(GetEPFStage1); +HWY_EXPORT(GetEPFStage2); + +std::unique_ptr GetEPFStage(const LoopFilter& lf, + const ImageF& sigma, + size_t epf_stage) { + JXL_ASSERT(lf.epf_iters != 0); + switch (epf_stage) { + case 0: + return HWY_DYNAMIC_DISPATCH(GetEPFStage0)(lf, sigma); + case 1: + return HWY_DYNAMIC_DISPATCH(GetEPFStage1)(lf, sigma); + case 2: + return HWY_DYNAMIC_DISPATCH(GetEPFStage2)(lf, sigma); + default: + JXL_ABORT("Invalid EPF stage"); + } +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.h new file mode 100644 index 000000000..c9d0d0c78 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_epf.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_EPF_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_EPF_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/image.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies the `epf_stage`-th EPF step with the given settings and `sigma`. +// `sigma` will be accessed with an offset of (kSigmaPadding, kSigmaPadding), +// and should have (kSigmaBorder, kSigmaBorder) mirrored sigma values available +// around the main image. See also filters.(h|cc) +std::unique_ptr GetEPFStage(const LoopFilter& lf, + const ImageF& sigma, + size_t epf_stage); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_EPF_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.cc new file mode 100644 index 000000000..81f546c6b --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.cc @@ -0,0 +1,189 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_from_linear.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_from_linear.cc" +#include +#include + +#include "lib/jxl/dec_tone_mapping-inl.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::IfThenZeroElse; + +template +struct PerChannelOp { + explicit PerChannelOp(Op op) : op(op) {} + template + void Transform(D d, T* r, T* g, T* b) const { + *r = op.Transform(d, *r); + *g = op.Transform(d, *g); + *b = op.Transform(d, *b); + } + + Op op; +}; +template +PerChannelOp MakePerChannelOp(Op&& op) { + return PerChannelOp(std::forward(op)); +} + +struct OpLinear { + template + T Transform(D d, const T& linear) const { + return linear; + } +}; + +struct OpRgb { + template + T Transform(D d, const T& linear) const { +#if JXL_HIGH_PRECISION + return TF_SRGB().EncodedFromDisplay(d, linear); +#else + return FastLinearToSRGB(d, linear); +#endif + } +}; + +struct OpPq { + template + T Transform(D d, const T& linear) const { + return TF_PQ().EncodedFromDisplay(d, linear); + } +}; + +struct OpHlg { + explicit OpHlg(const float luminances[3], const float intensity_target) + : hlg_ootf_(HlgOOTF::ToSceneLight(/*display_luminance=*/intensity_target, + luminances)) {} + + template + void Transform(D d, T* r, T* g, T* b) const { + hlg_ootf_.Apply(r, g, b); + *r = TF_HLG().EncodedFromDisplay(d, *r); + *g = TF_HLG().EncodedFromDisplay(d, *g); + *b = TF_HLG().EncodedFromDisplay(d, *b); + } + HlgOOTF hlg_ootf_; +}; + +struct Op709 { + template + T Transform(D d, const T& linear) const { + return TF_709().EncodedFromDisplay(d, linear); + } +}; + +struct OpGamma { + const float inverse_gamma; + template + T Transform(D d, const T& linear) const { + return IfThenZeroElse(Le(linear, Set(d, 1e-5f)), + FastPowf(d, linear, Set(d, inverse_gamma))); + } +}; + +template +class FromLinearStage : public RenderPipelineStage { + public: + explicit FromLinearStage(Op op) + : RenderPipelineStage(RenderPipelineStage::Settings()), + op_(std::move(op)) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("FromLinear"); + const HWY_FULL(float) d; + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + auto r = LoadU(d, row0 + x); + auto g = LoadU(d, row1 + x); + auto b = LoadU(d, row2 + x); + op_.Transform(d, &r, &g, &b); + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "FromLinear"; } + + private: + Op op_; +}; + +template +std::unique_ptr> MakeFromLinearStage(Op&& op) { + return jxl::make_unique>(std::forward(op)); +} + +std::unique_ptr GetFromLinearStage( + const OutputEncodingInfo& output_encoding_info) { + if (output_encoding_info.color_encoding.tf.IsLinear()) { + return MakeFromLinearStage(MakePerChannelOp(OpLinear())); + } else if (output_encoding_info.color_encoding.tf.IsSRGB()) { + return MakeFromLinearStage(MakePerChannelOp(OpRgb())); + } else if (output_encoding_info.color_encoding.tf.IsPQ()) { + return MakeFromLinearStage(MakePerChannelOp(OpPq())); + } else if (output_encoding_info.color_encoding.tf.IsHLG()) { + return MakeFromLinearStage( + OpHlg(output_encoding_info.luminances, + output_encoding_info.desired_intensity_target)); + } else if (output_encoding_info.color_encoding.tf.Is709()) { + return MakeFromLinearStage(MakePerChannelOp(Op709())); + } else if (output_encoding_info.color_encoding.tf.IsGamma() || + output_encoding_info.color_encoding.tf.IsDCI()) { + return MakeFromLinearStage( + MakePerChannelOp(OpGamma{output_encoding_info.inverse_gamma})); + } else { + // This is a programming error. + JXL_ABORT("Invalid target encoding"); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetFromLinearStage); + +std::unique_ptr GetFromLinearStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetFromLinearStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.h new file mode 100644 index 000000000..548ab50b8 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_from_linear.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_FROM_LINEAR_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_FROM_LINEAR_H_ + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from linear to the specified output encoding. +std::unique_ptr GetFromLinearStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_FROM_LINEAR_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.cc new file mode 100644 index 000000000..fc90acb47 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.cc @@ -0,0 +1,122 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_gaborish.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_gaborish.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; + +class GaborishStage : public RenderPipelineStage { + public: + explicit GaborishStage(const LoopFilter& lf) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/1)) { + weights_[0] = 1; + weights_[1] = lf.gab_x_weight1; + weights_[2] = lf.gab_x_weight2; + weights_[3] = 1; + weights_[4] = lf.gab_y_weight1; + weights_[5] = lf.gab_y_weight2; + weights_[6] = 1; + weights_[7] = lf.gab_b_weight1; + weights_[8] = lf.gab_b_weight2; + // Normalize + for (size_t c = 0; c < 3; c++) { + const float div = + weights_[3 * c] + 4 * (weights_[3 * c + 1] + weights_[3 * c + 2]); + const float mul = 1.0f / div; + weights_[3 * c] *= mul; + weights_[3 * c + 1] *= mul; + weights_[3 * c + 2] *= mul; + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("Gaborish"); + + const HWY_FULL(float) d; + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT row_t = GetInputRow(input_rows, c, -1); + float* JXL_RESTRICT row_m = GetInputRow(input_rows, c, 0); + float* JXL_RESTRICT row_b = GetInputRow(input_rows, c, 1); + float* JXL_RESTRICT row_out = GetOutputRow(output_rows, c, 0); + const auto w0 = Set(d, weights_[3 * c + 0]); + const auto w1 = Set(d, weights_[3 * c + 1]); + const auto w2 = Set(d, weights_[3 * c + 2]); +// Group data need only be aligned to a block; for >=512 bit vectors, this may +// result in unaligned loads. +#if HWY_CAP_GE512 +#define LoadMaybeU LoadU +#else +#define LoadMaybeU Load +#endif + // Since GetInputRow(input_rows, c, {-1, 0, 1}) is aligned, rounding + // xextra up to Lanes(d) doesn't access anything problematic. + for (ssize_t x = -RoundUpTo(xextra, Lanes(d)); + x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto t = LoadMaybeU(d, row_t + x); + const auto tl = LoadU(d, row_t + x - 1); + const auto tr = LoadU(d, row_t + x + 1); + const auto m = LoadMaybeU(d, row_m + x); + const auto l = LoadU(d, row_m + x - 1); + const auto r = LoadU(d, row_m + x + 1); + const auto b = LoadMaybeU(d, row_b + x); + const auto bl = LoadU(d, row_b + x - 1); + const auto br = LoadU(d, row_b + x + 1); + const auto sum0 = m; + const auto sum1 = Add(Add(l, r), Add(t, b)); + const auto sum2 = Add(Add(tl, tr), Add(bl, br)); + auto pixels = MulAdd(sum2, w2, MulAdd(sum1, w1, Mul(sum0, w0))); + Store(pixels, d, row_out + x); + } + } + } +#undef LoadMaybeU + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Gab"; } + + private: + float weights_[9]; +}; + +std::unique_ptr GetGaborishStage(const LoopFilter& lf) { + return jxl::make_unique(lf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetGaborishStage); + +std::unique_ptr GetGaborishStage(const LoopFilter& lf) { + JXL_ASSERT(lf.gab == 1); + return HWY_DYNAMIC_DISPATCH(GetGaborishStage)(lf); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.h new file mode 100644 index 000000000..761800f66 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_gaborish.h @@ -0,0 +1,25 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_GABORISH_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_GABORISH_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies decoder-side Gaborish with the given settings. `lf.gab` must be 1. +std::unique_ptr GetGaborishStage(const LoopFilter& lf); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_GABORISH_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.cc new file mode 100644 index 000000000..9f0cee316 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.cc @@ -0,0 +1,310 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_noise.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_noise.cc" +#include +#include + +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +using D = HWY_CAPPED(float, kBlockDim); +using DI = hwy::HWY_NAMESPACE::Rebind; +using DI8 = hwy::HWY_NAMESPACE::Repartition; + +// [0, max_value] +template +static HWY_INLINE V Clamp0ToMax(D d, const V x, const V max_value) { + const auto clamped = Min(x, max_value); + return ZeroIfNegative(clamped); +} + +// x is in [0+delta, 1+delta], delta ~= 0.06 +template +typename StrengthEval::V NoiseStrength(const StrengthEval& eval, + const typename StrengthEval::V x) { + return Clamp0ToMax(D(), eval(x), Set(D(), 1.0f)); +} + +// TODO(veluca): SIMD-fy. +class StrengthEvalLut { + public: + using V = Vec; + + explicit StrengthEvalLut(const NoiseParams& noise_params) +#if HWY_TARGET == HWY_SCALAR + : noise_params_(noise_params) +#endif + { +#if HWY_TARGET != HWY_SCALAR + uint32_t lut[8]; + memcpy(lut, noise_params.lut, sizeof(lut)); + for (size_t i = 0; i < 8; i++) { + low16_lut[2 * i] = (lut[i] >> 0) & 0xFF; + low16_lut[2 * i + 1] = (lut[i] >> 8) & 0xFF; + high16_lut[2 * i] = (lut[i] >> 16) & 0xFF; + high16_lut[2 * i + 1] = (lut[i] >> 24) & 0xFF; + } +#endif + } + + V operator()(const V vx) const { + constexpr size_t kScale = NoiseParams::kNumNoisePoints - 2; + auto scaled_vx = Max(Zero(D()), Mul(vx, Set(D(), kScale))); + auto floor_x = Floor(scaled_vx); + auto frac_x = Sub(scaled_vx, floor_x); + floor_x = IfThenElse(Ge(scaled_vx, Set(D(), kScale)), Set(D(), kScale - 1), + floor_x); + frac_x = IfThenElse(Ge(scaled_vx, Set(D(), kScale)), Set(D(), 1), frac_x); + auto floor_x_int = ConvertTo(DI(), floor_x); +#if HWY_TARGET == HWY_SCALAR + auto low = Set(D(), noise_params_.lut[floor_x_int.raw]); + auto hi = Set(D(), noise_params_.lut[floor_x_int.raw + 1]); +#else + // Set each lane's bytes to {0, 0, 2x+1, 2x}. + auto floorx_indices_low = + Add(Mul(floor_x_int, Set(DI(), 0x0202)), Set(DI(), 0x0100)); + // Set each lane's bytes to {2x+1, 2x, 0, 0}. + auto floorx_indices_hi = + Add(Mul(floor_x_int, Set(DI(), 0x02020000)), Set(DI(), 0x01000000)); + // load LUT + auto low16 = BitCast(DI(), LoadDup128(DI8(), low16_lut)); + auto lowm = Set(DI(), 0xFFFF); + auto hi16 = BitCast(DI(), LoadDup128(DI8(), high16_lut)); + auto him = Set(DI(), 0xFFFF0000); + // low = noise_params.lut[floor_x] + auto low = + BitCast(D(), Or(And(TableLookupBytes(low16, floorx_indices_low), lowm), + And(TableLookupBytes(hi16, floorx_indices_hi), him))); + // hi = noise_params.lut[floor_x+1] + floorx_indices_low = Add(floorx_indices_low, Set(DI(), 0x0202)); + floorx_indices_hi = Add(floorx_indices_hi, Set(DI(), 0x02020000)); + auto hi = + BitCast(D(), Or(And(TableLookupBytes(low16, floorx_indices_low), lowm), + And(TableLookupBytes(hi16, floorx_indices_hi), him))); +#endif + return MulAdd(Sub(hi, low), frac_x, low); + } + + private: +#if HWY_TARGET != HWY_SCALAR + // noise_params.lut transformed into two 16-bit lookup tables. + HWY_ALIGN uint8_t high16_lut[16]; + HWY_ALIGN uint8_t low16_lut[16]; +#else + const NoiseParams& noise_params_; +#endif +}; + +template +void AddNoiseToRGB(const D d, const Vec rnd_noise_r, + const Vec rnd_noise_g, const Vec rnd_noise_cor, + const Vec noise_strength_g, const Vec noise_strength_r, + float ytox, float ytob, float* JXL_RESTRICT out_x, + float* JXL_RESTRICT out_y, float* JXL_RESTRICT out_b) { + const auto kRGCorr = Set(d, 0.9921875f); // 127/128 + const auto kRGNCorr = Set(d, 0.0078125f); // 1/128 + + const auto red_noise = + Mul(noise_strength_r, + MulAdd(kRGNCorr, rnd_noise_r, Mul(kRGCorr, rnd_noise_cor))); + const auto green_noise = + Mul(noise_strength_g, + MulAdd(kRGNCorr, rnd_noise_g, Mul(kRGCorr, rnd_noise_cor))); + + auto vx = LoadU(d, out_x); + auto vy = LoadU(d, out_y); + auto vb = LoadU(d, out_b); + + const auto rg_noise = Add(red_noise, green_noise); + vx = Add(MulAdd(Set(d, ytox), rg_noise, Sub(red_noise, green_noise)), vx); + vy = Add(vy, rg_noise); + vb = MulAdd(Set(d, ytob), rg_noise, vb); + + StoreU(vx, d, out_x); + StoreU(vy, d, out_y); + StoreU(vb, d, out_b); +} + +class AddNoiseStage : public RenderPipelineStage { + public: + AddNoiseStage(const NoiseParams& noise_params, + const ColorCorrelationMap& cmap, size_t first_c) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/0)), + noise_params_(noise_params), + cmap_(cmap), + first_c_(first_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("Noise apply"); + + if (!noise_params_.HasAny()) return; + const StrengthEvalLut noise_model(noise_params_); + D d; + const auto half = Set(d, 0.5f); + + // With the prior subtract-random Laplacian approximation, rnd_* ranges were + // about [-1.5, 1.6]; Laplacian3 about doubles this to [-3.6, 3.6], so the + // normalizer is half of what it was before (0.5). + const auto norm_const = Set(d, 0.22f); + + float ytox = cmap_.YtoXRatio(0); + float ytob = cmap_.YtoBRatio(0); + + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + + float* JXL_RESTRICT row_x = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row_y = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row_b = GetInputRow(input_rows, 2, 0); + const float* JXL_RESTRICT row_rnd_r = + GetInputRow(input_rows, first_c_ + 0, 0); + const float* JXL_RESTRICT row_rnd_g = + GetInputRow(input_rows, first_c_ + 1, 0); + const float* JXL_RESTRICT row_rnd_c = + GetInputRow(input_rows, first_c_ + 2, 0); + // Needed by the calls to Floor() in StrengthEvalLut. Only arithmetic and + // shuffles are otherwise done on the data, so this is safe. + msan::UnpoisonMemory(row_x + xsize, (xsize_v - xsize) * sizeof(float)); + msan::UnpoisonMemory(row_y + xsize, (xsize_v - xsize) * sizeof(float)); + for (size_t x = 0; x < xsize_v; x += Lanes(d)) { + const auto vx = LoadU(d, row_x + x); + const auto vy = LoadU(d, row_y + x); + const auto in_g = Sub(vy, vx); + const auto in_r = Add(vy, vx); + const auto noise_strength_g = NoiseStrength(noise_model, Mul(in_g, half)); + const auto noise_strength_r = NoiseStrength(noise_model, Mul(in_r, half)); + const auto addit_rnd_noise_red = Mul(LoadU(d, row_rnd_r + x), norm_const); + const auto addit_rnd_noise_green = + Mul(LoadU(d, row_rnd_g + x), norm_const); + const auto addit_rnd_noise_correlated = + Mul(LoadU(d, row_rnd_c + x), norm_const); + AddNoiseToRGB(D(), addit_rnd_noise_red, addit_rnd_noise_green, + addit_rnd_noise_correlated, noise_strength_g, + noise_strength_r, ytox, ytob, row_x + x, row_y + x, + row_b + x); + } + msan::PoisonMemory(row_x + xsize, (xsize_v - xsize) * sizeof(float)); + msan::PoisonMemory(row_y + xsize, (xsize_v - xsize) * sizeof(float)); + msan::PoisonMemory(row_b + xsize, (xsize_v - xsize) * sizeof(float)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c >= first_c_ ? RenderPipelineChannelMode::kInput + : c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "AddNoise"; } + + private: + const NoiseParams& noise_params_; + const ColorCorrelationMap& cmap_; + size_t first_c_; +}; + +std::unique_ptr GetAddNoiseStage( + const NoiseParams& noise_params, const ColorCorrelationMap& cmap, + size_t noise_c_start) { + return jxl::make_unique(noise_params, cmap, noise_c_start); +} + +class ConvolveNoiseStage : public RenderPipelineStage { + public: + explicit ConvolveNoiseStage(size_t first_c) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/2)), + first_c_(first_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("Noise convolve"); + + const HWY_FULL(float) d; + for (size_t c = first_c_; c < first_c_ + 3; c++) { + float* JXL_RESTRICT rows[5]; + for (size_t i = 0; i < 5; i++) { + rows[i] = GetInputRow(input_rows, c, i - 2); + } + float* JXL_RESTRICT row_out = GetOutputRow(output_rows, c, 0); + for (ssize_t x = -RoundUpTo(xextra, Lanes(d)); + x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto p00 = LoadU(d, rows[2] + x); + auto others = Zero(d); + // TODO(eustas): sum loaded values to reduce the calculation chain + for (ssize_t i = -2; i <= 2; i++) { + others = Add(others, LoadU(d, rows[0] + x + i)); + others = Add(others, LoadU(d, rows[1] + x + i)); + others = Add(others, LoadU(d, rows[3] + x + i)); + others = Add(others, LoadU(d, rows[4] + x + i)); + } + others = Add(others, LoadU(d, rows[2] + x - 2)); + others = Add(others, LoadU(d, rows[2] + x - 1)); + others = Add(others, LoadU(d, rows[2] + x + 1)); + others = Add(others, LoadU(d, rows[2] + x + 2)); + // 4 * (1 - box kernel) + auto pixels = MulAdd(others, Set(d, 0.16), Mul(p00, Set(d, -3.84))); + StoreU(pixels, d, row_out + x); + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c >= first_c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "ConvNoise"; } + + private: + size_t first_c_; +}; + +std::unique_ptr GetConvolveNoiseStage( + size_t noise_c_start) { + return jxl::make_unique(noise_c_start); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetAddNoiseStage); +HWY_EXPORT(GetConvolveNoiseStage); + +std::unique_ptr GetAddNoiseStage( + const NoiseParams& noise_params, const ColorCorrelationMap& cmap, + size_t noise_c_start) { + return HWY_DYNAMIC_DISPATCH(GetAddNoiseStage)(noise_params, cmap, + noise_c_start); +} + +std::unique_ptr GetConvolveNoiseStage( + size_t noise_c_start) { + return HWY_DYNAMIC_DISPATCH(GetConvolveNoiseStage)(noise_c_start); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.h new file mode 100644 index 000000000..bd7797f99 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_noise.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_NOISE_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_NOISE_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/dec_noise.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Adds noise to color channels. +std::unique_ptr GetAddNoiseStage( + const NoiseParams& noise_params, const ColorCorrelationMap& cmap, + size_t noise_c_start); + +// Applies a 5x5 subtract-box-filter convolution to the noise input channels. +std::unique_ptr GetConvolveNoiseStage( + size_t noise_c_start); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_NOISE_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.cc new file mode 100644 index 000000000..527be0383 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.cc @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_patches.h" + +namespace jxl { +namespace { +class PatchDictionaryStage : public RenderPipelineStage { + public: + PatchDictionaryStage(const PatchDictionary* patches, size_t num_channels) + : RenderPipelineStage(RenderPipelineStage::Settings()), + patches_(*patches), + num_channels_(num_channels) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("RenderPatches"); + JXL_ASSERT(xpos == 0 || xpos >= xextra); + size_t x0 = xpos ? xpos - xextra : 0; + std::vector row_ptrs(num_channels_); + for (size_t i = 0; i < num_channels_; i++) { + row_ptrs[i] = GetInputRow(input_rows, i, 0) + x0 - xpos; + } + patches_.AddOneRow(row_ptrs.data(), ypos, x0, xsize + xextra + xpos - x0); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < num_channels_ ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Patches"; } + + private: + const PatchDictionary& patches_; + const size_t num_channels_; +}; +} // namespace + +std::unique_ptr GetPatchesStage( + const PatchDictionary* patches, size_t num_channels) { + return jxl::make_unique(patches, num_channels); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.h new file mode 100644 index 000000000..b35abdc2e --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_patches.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_PATCHES_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_PATCHES_H_ + +#include + +#include "lib/jxl/patch_dictionary_internal.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Draws patches if applicable. +std::unique_ptr GetPatchesStage( + const PatchDictionary* patches, size_t num_channels); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_PATCHES_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.cc new file mode 100644 index 000000000..d97d97e5f --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.cc @@ -0,0 +1,63 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_splines.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_splines.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class SplineStage : public RenderPipelineStage { + public: + explicit SplineStage(const Splines* splines) + : RenderPipelineStage(RenderPipelineStage::Settings()), + splines_(*splines) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("RenderSplines"); + float* row_x = GetInputRow(input_rows, 0, 0); + float* row_y = GetInputRow(input_rows, 1, 0); + float* row_b = GetInputRow(input_rows, 2, 0); + splines_.AddToRow(row_x, row_y, row_b, Rect(xpos, ypos, xsize, 1)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Splines"; } + + private: + const Splines& splines_; +}; + +std::unique_ptr GetSplineStage(const Splines* splines) { + return jxl::make_unique(splines); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetSplineStage); + +std::unique_ptr GetSplineStage(const Splines* splines) { + return HWY_DYNAMIC_DISPATCH(GetSplineStage)(splines); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.h new file mode 100644 index 000000000..363af393e --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_splines.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_SPLINES_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_SPLINES_H_ + +#include + +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +// Draws splines if applicable. +std::unique_ptr GetSplineStage(const Splines* splines); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_SPLINES_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.cc new file mode 100644 index 000000000..d4f615299 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_spot.h" + +namespace jxl { +class SpotColorStage : public RenderPipelineStage { + public: + explicit SpotColorStage(size_t spot_c, const float* spot_color) + : RenderPipelineStage(RenderPipelineStage::Settings()), + spot_c_(spot_c), + spot_color_(spot_color) { + JXL_ASSERT(spot_c_ >= 3); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + // TODO(veluca): add SIMD. + PROFILER_ZONE("RenderSpotColors"); + float scale = spot_color_[3]; + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT p = GetInputRow(input_rows, c, 0); + const float* JXL_RESTRICT s = GetInputRow(input_rows, spot_c_, 0); + for (ssize_t x = -xextra; x < ssize_t(xsize + xextra); x++) { + float mix = scale * s[x]; + p[x] = mix * spot_color_[c] + (1.0f - mix) * p[x]; + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : c == spot_c_ ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Spot"; } + + private: + size_t spot_c_; + const float* spot_color_; +}; + +std::unique_ptr GetSpotColorStage( + size_t spot_c, const float* spot_color) { + return jxl::make_unique(spot_c, spot_color); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.h new file mode 100644 index 000000000..3e79c7582 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_spot.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_SPOT_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_SPOT_H_ + +#include + +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Render the spot color channels. +std::unique_ptr GetSpotColorStage(size_t spot_c, + const float* spot_color); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_SPOT_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.cc new file mode 100644 index 000000000..bf79481e4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.cc @@ -0,0 +1,200 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_to_linear.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_to_linear.cc" +#include +#include + +#include "lib/jxl/dec_tone_mapping-inl.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::IfThenZeroElse; + +template +struct PerChannelOp { + explicit PerChannelOp(Op op) : op(op) {} + template + void Transform(D d, T* r, T* g, T* b) const { + *r = op.Transform(d, *r); + *g = op.Transform(d, *g); + *b = op.Transform(d, *b); + } + + Op op; +}; +template +PerChannelOp MakePerChannelOp(Op&& op) { + return PerChannelOp(std::forward(op)); +} + +struct OpLinear { + template + T Transform(D d, const T& encoded) const { + return encoded; + } +}; + +struct OpRgb { + template + T Transform(D d, const T& encoded) const { + return TF_SRGB().DisplayFromEncoded(encoded); + } +}; + +struct OpPq { + template + T Transform(D d, const T& encoded) const { + return TF_PQ().DisplayFromEncoded(d, encoded); + } +}; + +struct OpHlg { + explicit OpHlg(const float luminances[3], const float intensity_target) + : hlg_ootf_(HlgOOTF::FromSceneLight( + /*display_luminance=*/intensity_target, luminances)) {} + + template + void Transform(D d, T* r, T* g, T* b) const { + for (T* val : {r, g, b}) { + float vals[MaxLanes(d)]; + Store(*val, d, vals); + for (size_t i = 0; i < Lanes(d); ++i) { + vals[i] = TF_HLG().DisplayFromEncoded(vals[i]); + } + *val = Load(d, vals); + } + hlg_ootf_.Apply(r, g, b); + } + HlgOOTF hlg_ootf_; +}; + +struct Op709 { + template + T Transform(D d, const T& encoded) const { + return TF_709().DisplayFromEncoded(d, encoded); + } +}; + +struct OpGamma { + const float gamma; + template + T Transform(D d, const T& encoded) const { + return IfThenZeroElse(Le(encoded, Set(d, 1e-5f)), + FastPowf(d, encoded, Set(d, gamma))); + } +}; + +struct OpInvalid { + template + void Transform(D d, T* r, T* g, T* b) const {} +}; + +template +class ToLinearStage : public RenderPipelineStage { + public: + explicit ToLinearStage(Op op) + : RenderPipelineStage(RenderPipelineStage::Settings()), + op_(std::move(op)) {} + + explicit ToLinearStage() + : RenderPipelineStage(RenderPipelineStage::Settings()), valid_(false) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("ToLinear"); + + const HWY_FULL(float) d; + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + auto r = LoadU(d, row0 + x); + auto g = LoadU(d, row1 + x); + auto b = LoadU(d, row2 + x); + op_.Transform(d, &r, &g, &b); + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "ToLinear"; } + + private: + Status IsInitialized() const override { return valid_; } + + Op op_; + bool valid_ = true; +}; + +template +std::unique_ptr> MakeToLinearStage(Op&& op) { + return jxl::make_unique>(std::forward(op)); +} + +std::unique_ptr GetToLinearStage( + const OutputEncodingInfo& output_encoding_info) { + if (output_encoding_info.color_encoding.tf.IsLinear()) { + return MakeToLinearStage(MakePerChannelOp(OpLinear())); + } else if (output_encoding_info.color_encoding.tf.IsSRGB()) { + return MakeToLinearStage(MakePerChannelOp(OpRgb())); + } else if (output_encoding_info.color_encoding.tf.IsPQ()) { + return MakeToLinearStage(MakePerChannelOp(OpPq())); + } else if (output_encoding_info.color_encoding.tf.IsHLG()) { + return MakeToLinearStage(OpHlg(output_encoding_info.luminances, + output_encoding_info.orig_intensity_target)); + } else if (output_encoding_info.color_encoding.tf.Is709()) { + return MakeToLinearStage(MakePerChannelOp(Op709())); + } else if (output_encoding_info.color_encoding.tf.IsGamma() || + output_encoding_info.color_encoding.tf.IsDCI()) { + return MakeToLinearStage( + MakePerChannelOp(OpGamma{1.f / output_encoding_info.inverse_gamma})); + } else { + return jxl::make_unique>(); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetToLinearStage); + +std::unique_ptr GetToLinearStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetToLinearStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.h new file mode 100644 index 000000000..ccee7b09f --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_to_linear.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_TO_LINEAR_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_TO_LINEAR_H_ + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from `output_encoding_info.color_encoding` to +// linear. +std::unique_ptr GetToLinearStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_TO_LINEAR_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.cc new file mode 100644 index 000000000..7609534a5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.cc @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_tone_mapping.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_tone_mapping.cc" +#include +#include + +#include "lib/jxl/dec_tone_mapping-inl.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class ToneMappingStage : public RenderPipelineStage { + public: + explicit ToneMappingStage(OutputEncodingInfo output_encoding_info) + : RenderPipelineStage(RenderPipelineStage::Settings()), + output_encoding_info_(std::move(output_encoding_info)) { + if (output_encoding_info_.desired_intensity_target == + output_encoding_info_.orig_intensity_target) { + // No tone mapping requested. + return; + } + if (output_encoding_info_.orig_color_encoding.tf.IsPQ() && + output_encoding_info_.desired_intensity_target < + output_encoding_info_.orig_intensity_target) { + tone_mapper_ = jxl::make_unique( + /*source_range=*/std::pair( + 0, output_encoding_info_.orig_intensity_target), + /*target_range=*/ + std::pair( + 0, output_encoding_info_.desired_intensity_target), + output_encoding_info_.luminances); + } else if (output_encoding_info_.orig_color_encoding.tf.IsHLG() && + !output_encoding_info_.color_encoding.tf.IsHLG()) { + hlg_ootf_ = jxl::make_unique( + /*source_luminance=*/output_encoding_info_.orig_intensity_target, + /*target_luminance=*/output_encoding_info_.desired_intensity_target, + output_encoding_info_.luminances); + } + + if (output_encoding_info_.color_encoding.tf.IsPQ() && + (tone_mapper_ || hlg_ootf_)) { + to_intensity_target_ = + 10000.f / output_encoding_info_.orig_intensity_target; + from_desired_intensity_target_ = + output_encoding_info_.desired_intensity_target / 10000.f; + } + } + + bool IsNeeded() const { return tone_mapper_ || hlg_ootf_; } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("ToneMapping"); + + if (!(tone_mapper_ || hlg_ootf_)) return; + + const HWY_FULL(float) d; + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + auto r = LoadU(d, row0 + x); + auto g = LoadU(d, row1 + x); + auto b = LoadU(d, row2 + x); + if (tone_mapper_ || hlg_ootf_) { + r = Mul(r, Set(d, to_intensity_target_)); + g = Mul(g, Set(d, to_intensity_target_)); + b = Mul(b, Set(d, to_intensity_target_)); + if (tone_mapper_) { + tone_mapper_->ToneMap(&r, &g, &b); + } else { + JXL_ASSERT(hlg_ootf_); + hlg_ootf_->Apply(&r, &g, &b); + } + if (tone_mapper_ || hlg_ootf_->WarrantsGamutMapping()) { + GamutMap(&r, &g, &b, output_encoding_info_.luminances); + } + r = Mul(r, Set(d, from_desired_intensity_target_)); + g = Mul(g, Set(d, from_desired_intensity_target_)); + b = Mul(b, Set(d, from_desired_intensity_target_)); + } + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "ToneMapping"; } + + private: + using ToneMapper = Rec2408ToneMapper; + OutputEncodingInfo output_encoding_info_; + std::unique_ptr tone_mapper_; + std::unique_ptr hlg_ootf_; + // When the target colorspace is PQ, 1 represents 10000 nits instead of + // orig_intensity_target. This temporarily changes this if the tone mappers + // require it. + float to_intensity_target_ = 1.f; + float from_desired_intensity_target_ = 1.f; +}; + +std::unique_ptr GetToneMappingStage( + const OutputEncodingInfo& output_encoding_info) { + auto stage = jxl::make_unique(output_encoding_info); + if (!stage->IsNeeded()) return nullptr; + return stage; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetToneMappingStage); + +std::unique_ptr GetToneMappingStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetToneMappingStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.h new file mode 100644 index 000000000..99824f851 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_tone_mapping.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_TONE_MAPPING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_TONE_MAPPING_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Tone maps the image if appropriate. It must be in linear space and +// `output_encoding_info.luminances` must contain the luminance for the +// primaries of that space. It must also be encoded such that (1, 1, 1) +// represents `output_encoding_info.orig_intensity_target` nits, unless +// `output_encoding_info.color_encoding.tf.IsPQ()`, in which case (1, 1, 1) must +// represent 10000 nits. This corresponds to what XYBStage outputs. After this +// stage, (1, 1, 1) will represent +// `output_encoding_info.desired_intensity_target` nits, except in the PQ +// special case in which it remains 10000. +// +// If no tone mapping is necessary, this will return nullptr. +std::unique_ptr GetToneMappingStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_TONE_MAPPING_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.cc new file mode 100644 index 000000000..a75e25986 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.cc @@ -0,0 +1,187 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_upsampling.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_upsampling.cc" +#include +#include + +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Min; +using hwy::HWY_NAMESPACE::MulAdd; + +class UpsamplingStage : public RenderPipelineStage { + public: + explicit UpsamplingStage(const CustomTransformData& ups_factors, size_t c, + size_t shift) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/shift, /*border=*/2)), + c_(c) { + const float* weights = shift == 1 ? ups_factors.upsampling2_weights + : shift == 2 ? ups_factors.upsampling4_weights + : ups_factors.upsampling8_weights; + size_t N = 1 << (shift - 1); + for (size_t i = 0; i < 5 * N; i++) { + for (size_t j = 0; j < 5 * N; j++) { + size_t y = std::min(i, j); + size_t x = std::max(i, j); + kernel_[j / 5][i / 5][j % 5][i % 5] = + weights[5 * N * y - y * (y - 1) / 2 + x - y]; + } + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("Upsampling"); + static HWY_FULL(float) df; + size_t shift = settings_.shift_x; + size_t N = 1 << shift; + const size_t xsize_v = RoundUpTo(xsize, Lanes(df)); + for (ssize_t iy = -2; iy <= 2; iy++) { + msan::UnpoisonMemory(GetInputRow(input_rows, c_, iy) + xsize + 2, + sizeof(float) * (xsize_v - xsize)); + } + JXL_ASSERT(xextra == 0); + ssize_t x0 = 0; + ssize_t x1 = xsize; + if (N == 2) { + ProcessRowImpl<2>(input_rows, output_rows, x0, x1); + } + if (N == 4) { + ProcessRowImpl<4>(input_rows, output_rows, x0, x1); + } + if (N == 8) { + ProcessRowImpl<8>(input_rows, output_rows, x0, x1); + } + for (size_t oy = 0; oy < N; oy++) { + float* dst_row = GetOutputRow(output_rows, c_, oy); + msan::PoisonMemory(dst_row + xsize * N, + sizeof(float) * (xsize_v - xsize) * N); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Upsample"; } + + private: + template + JXL_INLINE float Kernel(size_t x, size_t y, ssize_t ix, ssize_t iy) const { + ix += 2; + iy += 2; + if (N == 2) { + return kernel_[0][0][y % 2 ? 4 - iy : iy][x % 2 ? 4 - ix : ix]; + } + if (N == 4) { + return kernel_[y % 4 < 2 ? y % 2 : 1 - y % 2] + [x % 4 < 2 ? x % 2 : 1 - x % 2][y % 4 < 2 ? iy : 4 - iy] + [x % 4 < 2 ? ix : 4 - ix]; + } + if (N == 8) { + return kernel_[y % 8 < 4 ? y % 4 : 3 - y % 4] + [x % 8 < 4 ? x % 4 : 3 - x % 4][y % 8 < 4 ? iy : 4 - iy] + [x % 8 < 4 ? ix : 4 - ix]; + } + JXL_ABORT("Invalid upsample"); + } + + template + void ProcessRowImpl(const RowInfo& input_rows, const RowInfo& output_rows, + ssize_t x0, ssize_t x1) const { + static HWY_FULL(float) df; + using V = hwy::HWY_NAMESPACE::Vec; + V ups0, ups1, ups2, ups3, ups4, ups5, ups6, ups7; + (void)ups2, (void)ups3, (void)ups4, (void)ups5, (void)ups6, (void)ups7; + V* ups[N]; + if (N >= 2) { + ups[0] = &ups0; + ups[1] = &ups1; + } + if (N >= 4) { + ups[2] = &ups2; + ups[3] = &ups3; + } + if (N == 8) { + ups[4] = &ups4; + ups[5] = &ups5; + ups[6] = &ups6; + ups[7] = &ups7; + } + for (size_t oy = 0; oy < N; oy++) { + float* dst_row = GetOutputRow(output_rows, c_, oy); + for (ssize_t x = x0; x < x1; x += Lanes(df)) { + for (size_t ox = 0; ox < N; ox++) { + auto result = Zero(df); + auto min = LoadU(df, GetInputRow(input_rows, c_, 0) + x); + auto max = min; + for (ssize_t iy = -2; iy <= 2; iy++) { + for (ssize_t ix = -2; ix <= 2; ix++) { + auto v = LoadU(df, GetInputRow(input_rows, c_, iy) + x + ix); + result = MulAdd(Set(df, Kernel(ox, oy, ix, iy)), v, result); + min = Min(v, min); + max = Max(v, max); + } + } + // Avoid overshooting. + *ups[ox] = Clamp(result, min, max); + } + if (N == 2) { + StoreInterleaved(df, ups0, ups1, dst_row + x * N); + } + if (N == 4) { + StoreInterleaved(df, ups0, ups1, ups2, ups3, dst_row + x * N); + } + if (N == 8) { + StoreInterleaved(df, ups0, ups1, ups2, ups3, ups4, ups5, ups6, ups7, + dst_row + x * N); + } + } + } + } + + size_t c_; + float kernel_[4][4][5][5]; +}; + +std::unique_ptr GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift) { + return jxl::make_unique(ups_factors, c, shift); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetUpsamplingStage); + +std::unique_ptr GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift) { + JXL_ASSERT(shift != 0); + JXL_ASSERT(shift <= 3); + return HWY_DYNAMIC_DISPATCH(GetUpsamplingStage)(ups_factors, c, shift); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.h new file mode 100644 index 000000000..7d5defd23 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_upsampling.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_UPSAMPLING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_UPSAMPLING_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Upsamples the given channel by the given factor. +std::unique_ptr GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_UPSAMPLING_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_write.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_write.cc new file mode 100644 index 000000000..71c4d97d5 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_write.cc @@ -0,0 +1,409 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_write.h" + +#include "lib/jxl/alpha.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/sanitizers.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_write.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::NearestInt; +using hwy::HWY_NAMESPACE::U8FromU32; + +template +void StoreRGBA(D d, V r, V g, V b, V a, bool alpha, size_t n, size_t extra, + uint8_t* buf) { +#if HWY_TARGET == HWY_SCALAR + buf[0] = r.raw; + buf[1] = g.raw; + buf[2] = b.raw; + if (alpha) { + buf[3] = a.raw; + } +#elif HWY_TARGET == HWY_NEON + if (alpha) { + uint8x8x4_t data = {r.raw, g.raw, b.raw, a.raw}; + if (extra >= 8) { + vst4_u8(buf, data); + } else { + uint8_t tmp[8 * 4]; + vst4_u8(tmp, data); + memcpy(buf, tmp, n * 4); + } + } else { + uint8x8x3_t data = {r.raw, g.raw, b.raw}; + if (extra >= 8) { + vst3_u8(buf, data); + } else { + uint8_t tmp[8 * 3]; + vst3_u8(tmp, data); + memcpy(buf, tmp, n * 3); + } + } +#else + // TODO(veluca): implement this for x86. + size_t mul = alpha ? 4 : 3; + HWY_ALIGN uint8_t bytes[16]; + StoreU(r, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[mul * i] = bytes[i]; + } + StoreU(g, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[mul * i + 1] = bytes[i]; + } + StoreU(b, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[mul * i + 2] = bytes[i]; + } + if (alpha) { + StoreU(a, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[4 * i + 3] = bytes[i]; + } + } +#endif +} + +class WriteToU8Stage : public RenderPipelineStage { + public: + WriteToU8Stage(uint8_t* rgb, size_t stride, size_t height, bool rgba, + bool has_alpha, size_t alpha_c) + : RenderPipelineStage(RenderPipelineStage::Settings()), + rgb_(rgb), + stride_(stride), + height_(height), + rgba_(rgba), + has_alpha_(has_alpha), + alpha_c_(alpha_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + if (ypos >= height_) return; + JXL_DASSERT(xextra == 0); + size_t bytes = rgba_ ? 4 : 3; + const float* JXL_RESTRICT row_in_r = GetInputRow(input_rows, 0, 0); + const float* JXL_RESTRICT row_in_g = GetInputRow(input_rows, 1, 0); + const float* JXL_RESTRICT row_in_b = GetInputRow(input_rows, 2, 0); + const float* JXL_RESTRICT row_in_a = + has_alpha_ ? GetInputRow(input_rows, alpha_c_, 0) : nullptr; + size_t base_ptr = ypos * stride_ + bytes * (xpos - xextra); + using D = HWY_CAPPED(float, 4); + const D d; + D::Rebind du; + auto zero = Zero(d); + auto one = Set(d, 1.0f); + auto mul = Set(d, 255.0f); + + ssize_t x1 = RoundUpTo(xsize, Lanes(d)); + + msan::UnpoisonMemory(row_in_r + xsize, sizeof(float) * (x1 - xsize)); + msan::UnpoisonMemory(row_in_g + xsize, sizeof(float) * (x1 - xsize)); + msan::UnpoisonMemory(row_in_b + xsize, sizeof(float) * (x1 - xsize)); + if (row_in_a) { + msan::UnpoisonMemory(row_in_a + xsize, sizeof(float) * (x1 - xsize)); + } + + for (ssize_t x = 0; x < x1; x += Lanes(d)) { + auto rf = Mul(Clamp(zero, LoadU(d, row_in_r + x), one), mul); + auto gf = Mul(Clamp(zero, LoadU(d, row_in_g + x), one), mul); + auto bf = Mul(Clamp(zero, LoadU(d, row_in_b + x), one), mul); + auto af = row_in_a ? Mul(Clamp(zero, LoadU(d, row_in_a + x), one), mul) + : Set(d, 255.0f); + auto r8 = U8FromU32(BitCast(du, NearestInt(rf))); + auto g8 = U8FromU32(BitCast(du, NearestInt(gf))); + auto b8 = U8FromU32(BitCast(du, NearestInt(bf))); + auto a8 = U8FromU32(BitCast(du, NearestInt(af))); + size_t n = xsize - x; + if (JXL_LIKELY(n >= Lanes(d))) { + StoreRGBA(D::Rebind(), r8, g8, b8, a8, rgba_, Lanes(d), n, + rgb_ + base_ptr + bytes * x); + } else { + StoreRGBA(D::Rebind(), r8, g8, b8, a8, rgba_, n, n, + rgb_ + base_ptr + bytes * x); + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 || (has_alpha_ && c == alpha_c_) + ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "WriteToU8"; } + + private: + uint8_t* rgb_; + size_t stride_; + size_t height_; + bool rgba_; + bool has_alpha_; + size_t alpha_c_; + std::vector opaque_alpha_; +}; + +std::unique_ptr GetWriteToU8Stage(uint8_t* rgb, + size_t stride, + size_t height, bool rgba, + bool has_alpha, + size_t alpha_c) { + return jxl::make_unique(rgb, stride, height, rgba, has_alpha, + alpha_c); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { + +HWY_EXPORT(GetWriteToU8Stage); + +namespace { +class WriteToImageBundleStage : public RenderPipelineStage { + public: + explicit WriteToImageBundleStage(ImageBundle* image_bundle, + ColorEncoding color_encoding) + : RenderPipelineStage(RenderPipelineStage::Settings()), + image_bundle_(image_bundle), + color_encoding_(std::move(color_encoding)) {} + + void SetInputSizes( + const std::vector>& input_sizes) override { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(input_sizes.size() >= 3); + for (size_t c = 1; c < input_sizes.size(); c++) { + JXL_ASSERT(input_sizes[c].first == input_sizes[0].first); + JXL_ASSERT(input_sizes[c].second == input_sizes[0].second); + } +#endif + // TODO(eustas): what should we do in the case of "want only ECs"? + image_bundle_->SetFromImage( + Image3F(input_sizes[0].first, input_sizes[0].second), color_encoding_); + // TODO(veluca): consider not reallocating ECs if not needed. + image_bundle_->extra_channels().clear(); + for (size_t c = 3; c < input_sizes.size(); c++) { + image_bundle_->extra_channels().emplace_back(input_sizes[c].first, + input_sizes[c].second); + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < 3; c++) { + memcpy(image_bundle_->color()->PlaneRow(c, ypos) + xpos - xextra, + GetInputRow(input_rows, c, 0) - xextra, + sizeof(float) * (xsize + 2 * xextra)); + } + for (size_t ec = 0; ec < image_bundle_->extra_channels().size(); ec++) { + JXL_ASSERT(image_bundle_->extra_channels()[ec].xsize() >= + xpos + xsize + xextra); + memcpy(image_bundle_->extra_channels()[ec].Row(ypos) + xpos - xextra, + GetInputRow(input_rows, 3 + ec, 0) - xextra, + sizeof(float) * (xsize + 2 * xextra)); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInput; + } + + const char* GetName() const override { return "WriteIB"; } + + private: + ImageBundle* image_bundle_; + ColorEncoding color_encoding_; +}; + +class WriteToImage3FStage : public RenderPipelineStage { + public: + explicit WriteToImage3FStage(Image3F* image) + : RenderPipelineStage(RenderPipelineStage::Settings()), image_(image) {} + + void SetInputSizes( + const std::vector>& input_sizes) override { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(input_sizes.size() >= 3); + for (size_t c = 1; c < 3; ++c) { + JXL_ASSERT(input_sizes[c].first == input_sizes[0].first); + JXL_ASSERT(input_sizes[c].second == input_sizes[0].second); + } +#endif + *image_ = Image3F(input_sizes[0].first, input_sizes[0].second); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < 3; c++) { + memcpy(image_->PlaneRow(c, ypos) + xpos - xextra, + GetInputRow(input_rows, c, 0) - xextra, + sizeof(float) * (xsize + 2 * xextra)); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "WriteI3F"; } + + private: + Image3F* image_; +}; + +class WriteToPixelCallbackStage : public RenderPipelineStage { + public: + WriteToPixelCallbackStage(const PixelCallback& pixel_callback, size_t width, + size_t height, bool rgba, bool has_alpha, + bool unpremul_alpha, size_t alpha_c) + : RenderPipelineStage(RenderPipelineStage::Settings()), + pixel_callback_(pixel_callback), + width_(width), + height_(height), + rgba_(rgba), + has_alpha_(has_alpha), + unpremul_alpha_(unpremul_alpha), + alpha_c_(alpha_c), + opaque_alpha_(kMaxPixelsPerCall, 1.0f) {} + + WriteToPixelCallbackStage(const WriteToPixelCallbackStage&) = delete; + WriteToPixelCallbackStage& operator=(const WriteToPixelCallbackStage&) = + delete; + WriteToPixelCallbackStage(WriteToPixelCallbackStage&&) = delete; + WriteToPixelCallbackStage& operator=(WriteToPixelCallbackStage&&) = delete; + + ~WriteToPixelCallbackStage() override { + if (run_opaque_) { + pixel_callback_.destroy(run_opaque_); + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + JXL_DASSERT(run_opaque_); + if (ypos >= height_) return; + const float* line_buffers[4]; + for (size_t c = 0; c < 3; c++) { + line_buffers[c] = GetInputRow(input_rows, c, 0) - xextra; + } + if (has_alpha_) { + line_buffers[3] = GetInputRow(input_rows, alpha_c_, 0) - xextra; + } else { + // No xextra offset; opaque_alpha_ is a way to set all values to 1.0f. + line_buffers[3] = opaque_alpha_.data(); + } + + // TODO(veluca): SIMD. + ssize_t limit = std::min(xextra + xsize, width_ - xpos); + for (ssize_t x0 = -xextra; x0 < limit; x0 += kMaxPixelsPerCall) { + size_t j = 0; + size_t ix = 0; + float* JXL_RESTRICT temp = + reinterpret_cast(temp_[thread_id].get()); + for (; ix < kMaxPixelsPerCall && ssize_t(ix) + x0 < limit; ix++) { + temp[j++] = line_buffers[0][ix]; + temp[j++] = line_buffers[1][ix]; + temp[j++] = line_buffers[2][ix]; + if (rgba_) { + temp[j++] = line_buffers[3][ix]; + } + } + if (has_alpha_ && rgba_ && unpremul_alpha_) { + // TODO(szabadka) SIMDify (possibly in a separate pipeline stage). + UnpremultiplyAlpha(temp, ix); + } + pixel_callback_.run(run_opaque_, thread_id, xpos + x0, ypos, ix, temp); + for (size_t c = 0; c < 3; c++) line_buffers[c] += kMaxPixelsPerCall; + if (has_alpha_) line_buffers[3] += kMaxPixelsPerCall; + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 || (has_alpha_ && c == alpha_c_) + ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "WritePixelCB"; } + + private: + Status PrepareForThreads(size_t num_threads) override { + run_opaque_ = + pixel_callback_.Init(num_threads, /*num_pixels=*/kMaxPixelsPerCall); + JXL_RETURN_IF_ERROR(run_opaque_ != nullptr); + temp_.resize(num_threads); + for (CacheAlignedUniquePtr& temp : temp_) { + temp = AllocateArray(sizeof(float) * kMaxPixelsPerCall * (rgba_ ? 4 : 3)); + } + return true; + } + + static constexpr size_t kMaxPixelsPerCall = 1024; + PixelCallback pixel_callback_; + void* run_opaque_ = nullptr; + size_t width_; + size_t height_; + bool rgba_; + bool has_alpha_; + bool unpremul_alpha_; + size_t alpha_c_; + std::vector opaque_alpha_; + std::vector temp_; +}; + +} // namespace + +std::unique_ptr GetWriteToImageBundleStage( + ImageBundle* image_bundle, ColorEncoding color_encoding) { + return jxl::make_unique(image_bundle, + std::move(color_encoding)); +} + +std::unique_ptr GetWriteToImage3FStage(Image3F* image) { + return jxl::make_unique(image); +} + +std::unique_ptr GetWriteToU8Stage(uint8_t* rgb, + size_t stride, + size_t height, bool rgba, + bool has_alpha, + size_t alpha_c) { + return HWY_DYNAMIC_DISPATCH(GetWriteToU8Stage)(rgb, stride, height, rgba, + has_alpha, alpha_c); +} + +std::unique_ptr GetWriteToPixelCallbackStage( + const PixelCallback& pixel_callback, size_t width, size_t height, bool rgba, + bool has_alpha, bool unpremul_alpha, size_t alpha_c) { + return jxl::make_unique( + pixel_callback, width, height, rgba, has_alpha, unpremul_alpha, alpha_c); +} + +} // namespace jxl + +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_write.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_write.h new file mode 100644 index 000000000..b942fd664 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_write.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_WRITE_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_WRITE_H_ + +#include + +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +std::unique_ptr GetWriteToImageBundleStage( + ImageBundle* image_bundle, ColorEncoding color_encoding); + +// Gets a stage to write color channels to an Image3F. +std::unique_ptr GetWriteToImage3FStage(Image3F* image); + +// Gets a stage to write to a uint8 buffer. +std::unique_ptr GetWriteToU8Stage(uint8_t* rgb, + size_t stride, + size_t height, bool rgba, + bool has_alpha, + size_t alpha_c); + +// Gets a stage to write to a pixel callback. +std::unique_ptr GetWriteToPixelCallbackStage( + const PixelCallback& pixel_callback, size_t width, size_t height, bool rgba, + bool has_alpha, bool unpremul_alpha, size_t alpha_c); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_WRITE_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.cc new file mode 100644 index 000000000..0022a6112 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.cc @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_xyb.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_xyb.cc" +#include +#include + +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class XYBStage : public RenderPipelineStage { + public: + explicit XYBStage(const OpsinParams& opsin_params) + : RenderPipelineStage(RenderPipelineStage::Settings()), + opsin_params_(opsin_params) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("UndoXYB"); + + const HWY_FULL(float) d; + JXL_ASSERT(xextra == 0); + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + // TODO(eustas): when using frame origin, addresses might be unaligned; + // making them aligned will void performance penalty. + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto in_opsin_x = LoadU(d, row0 + x); + const auto in_opsin_y = LoadU(d, row1 + x); + const auto in_opsin_b = LoadU(d, row2 + x); + auto r = Undefined(d); + auto g = Undefined(d); + auto b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params_, &r, &g, + &b); + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "XYB"; } + + private: + const OpsinParams opsin_params_; +}; + +std::unique_ptr GetXYBStage( + const OpsinParams& opsin_params) { + return jxl::make_unique(opsin_params); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetXYBStage); + +std::unique_ptr GetXYBStage( + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(GetXYBStage)(opsin_params); +} + +namespace { +class FastXYBStage : public RenderPipelineStage { + public: + FastXYBStage(uint8_t* rgb, size_t stride, size_t width, size_t height, + bool rgba, bool has_alpha, size_t alpha_c) + : RenderPipelineStage(RenderPipelineStage::Settings()), + rgb_(rgb), + stride_(stride), + width_(width), + height_(height), + rgba_(rgba), + has_alpha_(has_alpha), + alpha_c_(alpha_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + if (ypos >= height_) return; + JXL_ASSERT(xextra == 0); + const float* xyba[4] = { + GetInputRow(input_rows, 0, 0), GetInputRow(input_rows, 1, 0), + GetInputRow(input_rows, 2, 0), + has_alpha_ ? GetInputRow(input_rows, alpha_c_, 0) : nullptr}; + uint8_t* out_buf = rgb_ + stride_ * ypos + (rgba_ ? 4 : 3) * xpos; + FastXYBTosRGB8(xyba, out_buf, rgba_, + xsize + xpos <= width_ ? xsize : width_ - xpos); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 || (has_alpha_ && c == alpha_c_) + ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "FastXYB"; } + + private: + uint8_t* rgb_; + size_t stride_; + size_t width_; + size_t height_; + bool rgba_; + bool has_alpha_; + size_t alpha_c_; + std::vector opaque_alpha_; +}; + +} // namespace + +std::unique_ptr GetFastXYBTosRGB8Stage( + uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba, + bool has_alpha, size_t alpha_c) { + JXL_ASSERT(HasFastXYBTosRGB8()); + return make_unique(rgb, stride, width, height, rgba, has_alpha, + alpha_c); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.h new file mode 100644 index 000000000..2bc507519 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_xyb.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_XYB_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_XYB_H_ +#include + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from XYB to linear with appropriate primaries. +std::unique_ptr GetXYBStage( + const OpsinParams& output_encoding_info); + +// Gets a stage to convert with fixed point arithmetic from XYB to sRGB8 and +// write to a uint8 buffer. +std::unique_ptr GetFastXYBTosRGB8Stage( + uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba, + bool has_alpha, size_t alpha_c); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_XYB_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.cc b/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.cc new file mode 100644 index 000000000..5cba4a7d4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.cc @@ -0,0 +1,85 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_ycbcr.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_ycbcr.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::MulAdd; + +class kYCbCrStage : public RenderPipelineStage { + public: + kYCbCrStage() : RenderPipelineStage(RenderPipelineStage::Settings()) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("UndoYCbCr"); + + const HWY_FULL(float) df; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto c128 = Set(df, 128.0f / 255); + const auto crcr = Set(df, 1.402f); + const auto cgcb = Set(df, -0.114f * 1.772f / 0.587f); + const auto cgcr = Set(df, -0.299f * 1.402f / 0.587f); + const auto cbcb = Set(df, 1.772f); + + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // TODO(eustas): when using frame origin, addresses might be unaligned; + // making them aligned will void performance penalty. + for (size_t x = 0; x < xsize; x += Lanes(df)) { + const auto y_vec = Add(LoadU(df, row1 + x), c128); + const auto cb_vec = LoadU(df, row0 + x); + const auto cr_vec = LoadU(df, row2 + x); + const auto r_vec = MulAdd(crcr, cr_vec, y_vec); + const auto g_vec = MulAdd(cgcr, cr_vec, MulAdd(cgcb, cb_vec, y_vec)); + const auto b_vec = MulAdd(cbcb, cb_vec, y_vec); + StoreU(r_vec, df, row0 + x); + StoreU(g_vec, df, row1 + x); + StoreU(b_vec, df, row2 + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "YCbCr"; } +}; + +std::unique_ptr GetYCbCrStage() { + return jxl::make_unique(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetYCbCrStage); + +std::unique_ptr GetYCbCrStage() { + return HWY_DYNAMIC_DISPATCH(GetYCbCrStage)(); +} + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.h b/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.h new file mode 100644 index 000000000..9320c9723 --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/stage_ycbcr.h @@ -0,0 +1,25 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_YCBCR_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_YCBCR_H_ +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from YCbCr to RGB. +std::unique_ptr GetYCbCrStage(); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_YCBCR_H_ diff --git a/media/libjxl/src/lib/jxl/render_pipeline/test_render_pipeline_stages.h b/media/libjxl/src/lib/jxl/render_pipeline/test_render_pipeline_stages.h new file mode 100644 index 000000000..789a52f8b --- /dev/null +++ b/media/libjxl/src/lib/jxl/render_pipeline/test_render_pipeline_stages.h @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +class UpsampleXSlowStage : public RenderPipelineStage { + public: + UpsampleXSlowStage() + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftX(1, 1)) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < input_rows.size(); c++) { + const float* row = GetInputRow(input_rows, c, 0); + float* row_out = GetOutputRow(output_rows, c, 0); + for (int64_t x = -xextra; x < (int64_t)(xsize + xextra); x++) { + float xp = *(row + x - 1); + float xc = *(row + x); + float xn = *(row + x + 1); + float xout0 = xp * 0.25f + xc * 0.75f; + float xout1 = xc * 0.75f + xn * 0.25f; + *(row_out + 2 * x + 0) = xout0; + *(row_out + 2 * x + 1) = xout1; + } + } + } + + const char* GetName() const override { return "TEST::UpsampleXSlowStage"; } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInOut; + } +}; + +class UpsampleYSlowStage : public RenderPipelineStage { + public: + UpsampleYSlowStage() + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftY(1, 1)) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < input_rows.size(); c++) { + const float* rowp = GetInputRow(input_rows, c, -1); + const float* rowc = GetInputRow(input_rows, c, 0); + const float* rown = GetInputRow(input_rows, c, 1); + float* row_out0 = GetOutputRow(output_rows, c, 0); + float* row_out1 = GetOutputRow(output_rows, c, 1); + for (int64_t x = -xextra; x < (int64_t)(xsize + xextra); x++) { + float xp = *(rowp + x); + float xc = *(rowc + x); + float xn = *(rown + x); + float yout0 = xp * 0.25f + xc * 0.75f; + float yout1 = xc * 0.75f + xn * 0.25f; + *(row_out0 + x) = yout0; + *(row_out1 + x) = yout1; + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInOut; + } + + const char* GetName() const override { return "TEST::UpsampleYSlowStage"; } +}; + +class Check0FinalStage : public RenderPipelineStage { + public: + Check0FinalStage() : RenderPipelineStage(RenderPipelineStage::Settings()) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < input_rows.size(); c++) { + for (size_t x = 0; x < xsize; x++) { + JXL_CHECK(fabsf(GetInputRow(input_rows, c, 0)[x]) < 1e-8); + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInput; + } + const char* GetName() const override { return "TEST::Check0FinalStage"; } +}; + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/roundtrip_test.cc b/media/libjxl/src/lib/jxl/roundtrip_test.cc new file mode 100644 index 000000000..d05717266 --- /dev/null +++ b/media/libjxl/src/lib/jxl/roundtrip_test.cc @@ -0,0 +1,848 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include // std::abs +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "jxl/codestream_header.h" +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "jxl/types.h" +#include "lib/extras/codec.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_comparator.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace { + +// Converts a test image to a CodecInOut. +// icc_profile can be empty to automatically deduce profile from the pixel +// format, or filled in to force this ICC profile +jxl::CodecInOut ConvertTestImage(const std::vector& buf, + const size_t xsize, const size_t ysize, + const JxlPixelFormat& pixel_format, + const jxl::PaddedBytes& icc_profile) { + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + + bool is_gray = pixel_format.num_channels < 3; + bool has_alpha = + pixel_format.num_channels == 2 || pixel_format.num_channels == 4; + + io.metadata.m.color_encoding.SetColorSpace(is_gray ? jxl::ColorSpace::kGray + : jxl::ColorSpace::kRGB); + if (has_alpha) { + // Note: alpha > 16 not yet supported by the C++ codec + switch (pixel_format.data_type) { + case JXL_TYPE_UINT8: + io.metadata.m.SetAlphaBits(8); + break; + case JXL_TYPE_UINT16: + case JXL_TYPE_FLOAT: + case JXL_TYPE_FLOAT16: + io.metadata.m.SetAlphaBits(16); + break; + default: + EXPECT_TRUE(false) << "Roundtrip tests for data type " + << pixel_format.data_type << " not yet implemented."; + } + } + size_t bitdepth = 0; + bool float_in = false; + switch (pixel_format.data_type) { + case JXL_TYPE_FLOAT: + bitdepth = 32; + float_in = true; + io.metadata.m.SetFloat32Samples(); + break; + case JXL_TYPE_FLOAT16: + bitdepth = 16; + float_in = true; + io.metadata.m.SetFloat16Samples(); + break; + case JXL_TYPE_UINT8: + bitdepth = 8; + float_in = false; + io.metadata.m.SetUintSamples(8); + break; + case JXL_TYPE_UINT16: + bitdepth = 16; + float_in = false; + io.metadata.m.SetUintSamples(16); + break; + default: + EXPECT_TRUE(false) << "Roundtrip tests for data type " + << pixel_format.data_type << " not yet implemented."; + } + jxl::ColorEncoding color_encoding; + if (!icc_profile.empty()) { + jxl::PaddedBytes icc_profile_copy(icc_profile); + EXPECT_TRUE(color_encoding.SetICC(std::move(icc_profile_copy))); + } else if (pixel_format.data_type == JXL_TYPE_FLOAT) { + color_encoding = jxl::ColorEncoding::LinearSRGB(is_gray); + } else { + color_encoding = jxl::ColorEncoding::SRGB(is_gray); + } + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(buf.data(), buf.size()), xsize, ysize, + color_encoding, pixel_format.num_channels, + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/bitdepth, pixel_format.endianness, + /*pool=*/nullptr, &io.Main(), float_in, + /*align=*/0)); + return io; +} + +template +T ConvertTestPixel(const float val); + +template <> +float ConvertTestPixel(const float val) { + return val; +} + +template <> +uint16_t ConvertTestPixel(const float val) { + return (uint16_t)(val * UINT16_MAX); +} + +template <> +uint8_t ConvertTestPixel(const float val) { + return (uint8_t)(val * UINT8_MAX); +} + +// Returns a test image. +template +std::vector GetTestImage(const size_t xsize, const size_t ysize, + const JxlPixelFormat& pixel_format) { + std::vector pixels(xsize * ysize * pixel_format.num_channels); + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + for (size_t chan = 0; chan < pixel_format.num_channels; chan++) { + float val; + switch (chan % 4) { + case 0: + val = static_cast(y) / static_cast(ysize); + break; + case 1: + val = static_cast(x) / static_cast(xsize); + break; + case 2: + val = static_cast(x + y) / static_cast(xsize + ysize); + break; + case 3: + val = static_cast(x * y) / static_cast(xsize * ysize); + break; + } + pixels[(y * xsize + x) * pixel_format.num_channels + chan] = + ConvertTestPixel(val); + } + } + } + std::vector bytes(pixels.size() * sizeof(T)); + memcpy(bytes.data(), pixels.data(), sizeof(T) * pixels.size()); + return bytes; +} + +void EncodeWithEncoder(JxlEncoder* enc, std::vector* compressed) { + compressed->resize(64); + uint8_t* next_out = compressed->data(); + size_t avail_out = compressed->size() - (next_out - compressed->data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed->data(); + compressed->resize(compressed->size() * 2); + next_out = compressed->data() + offset; + avail_out = compressed->size() - offset; + } + } + compressed->resize(next_out - compressed->data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); +} + +// Generates some pixels using using some dimensions and pixel_format, +// compresses them, and verifies that the decoded version is similar to the +// original pixels. +// TODO(firsching): change this to be a parameterized test, like in +// decode_test.cc +template +void VerifyRoundtripCompression( + const size_t xsize, const size_t ysize, + const JxlPixelFormat& input_pixel_format, + const JxlPixelFormat& output_pixel_format, const bool lossless, + const bool use_container, const uint32_t resampling = 1, + const bool already_downsampled = false, + const std::vector>& + extra_channels = {}) { + size_t orig_xsize = xsize; + size_t orig_ysize = ysize; + if (already_downsampled) { + orig_xsize = jxl::DivCeil(xsize, resampling); + orig_ysize = jxl::DivCeil(ysize, resampling); + } + + JxlPixelFormat extra_channel_pixel_format = input_pixel_format; + extra_channel_pixel_format.num_channels = 1; + const std::vector extra_channel_bytes = + GetTestImage(xsize, ysize, extra_channel_pixel_format); + const std::vector original_bytes = + GetTestImage(orig_xsize, orig_ysize, input_pixel_format); + jxl::CodecInOut original_io = ConvertTestImage( + original_bytes, orig_xsize, orig_ysize, input_pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, use_container)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &input_pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = lossless; + uint32_t num_channels = input_pixel_format.num_channels; + size_t has_interleaved_alpha = num_channels == 2 || num_channels == 4; + JxlPixelFormat output_pixel_format_with_extra_channel_alpha = + output_pixel_format; + + // In the case where we have an alpha channel, but it is provided as an extra + // channel and not interleaved, we do two things here: + // 1. modify the original_io to have the correct alpha channel + // 2. change the output_format_with_extra_alpha to have an alpha channel + bool alpha_in_extra_channels_vector = false; + for (const auto& extra_channel : extra_channels) { + if (extra_channel.first == JXL_CHANNEL_ALPHA) { + alpha_in_extra_channels_vector = true; + } + } + if (alpha_in_extra_channels_vector && !has_interleaved_alpha) { + jxl::ImageF alpha_channel(xsize, ysize); + + EXPECT_EQ( + jxl::ConvertFromExternal( + jxl::Span(extra_channel_bytes.data(), + extra_channel_bytes.size()), + xsize, ysize, basic_info.bits_per_sample, + input_pixel_format.endianness, /*pool=*/nullptr, &alpha_channel, + /*float_in=*/input_pixel_format.data_type == JXL_TYPE_FLOAT, + /*align=*/0), + true); + + original_io.metadata.m.SetAlphaBits(basic_info.bits_per_sample); + original_io.Main().SetAlpha(std::move(alpha_channel), false); + output_pixel_format_with_extra_channel_alpha.num_channels++; + } + // Those are the num_extra_channels including a potential alpha channel. + basic_info.num_extra_channels = extra_channels.size() + has_interleaved_alpha; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + EXPECT_EQ(enc->metadata.m.num_extra_channels, + extra_channels.size() + has_interleaved_alpha); + JxlColorEncoding color_encoding; + if (input_pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB( + &color_encoding, + /*is_gray=*/input_pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/input_pixel_format.num_channels < 3); + } + + std::vector channel_infos; + for (const auto& extra_channel : extra_channels) { + auto channel_type = extra_channel.first; + JxlExtraChannelInfo channel_info; + JxlEncoderInitExtraChannelInfo(channel_type, &channel_info); + channel_info.bits_per_sample = (lossless ? basic_info.bits_per_sample : 8); + channel_info.exponent_bits_per_sample = + (lossless ? basic_info.exponent_bits_per_sample : 0); + channel_infos.push_back(channel_info); + } + for (size_t index = 0; index < channel_infos.size(); index++) { + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelInfo(enc, index + has_interleaved_alpha, + &channel_infos[index])); + std::string name = extra_channels[index].second; + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelName(enc, index + has_interleaved_alpha, + name.c_str(), name.length())); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + JxlEncoderSetFrameLossless(frame_settings, lossless); + if (resampling > 1) { + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_RESAMPLING, resampling)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED, + already_downsampled)); + } + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &input_pixel_format, + (void*)original_bytes.data(), + original_bytes.size())); + EXPECT_EQ(frame_settings->enc->input_queue.back() + .frame->frame.extra_channels() + .size(), + has_interleaved_alpha + extra_channels.size()); + EXPECT_EQ(frame_settings->enc->input_queue.empty(), false); + for (size_t index = 0; index < channel_infos.size(); index++) { + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelBuffer( + frame_settings, &input_pixel_format, + (void*)extra_channel_bytes.data(), extra_channel_bytes.size(), + index + has_interleaved_alpha)); + } + JxlEncoderCloseInput(enc); + EXPECT_EQ(frame_settings->enc->input_queue.back() + .frame->frame.extra_channels() + .size(), + has_interleaved_alpha + extra_channels.size()); + std::vector compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize( + dec, &output_pixel_format_with_extra_channel_alpha, &buffer_size)); + if (&input_pixel_format == &output_pixel_format_with_extra_channel_alpha && + !already_downsampled) { + EXPECT_EQ(buffer_size, original_bytes.size()); + } + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_EQ(extra_channels.size() + has_interleaved_alpha, + info.num_extra_channels); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &output_pixel_format_with_extra_channel_alpha, + JXL_COLOR_PROFILE_TARGET_DATA, &icc_profile_size)); + jxl::PaddedBytes icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &output_pixel_format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer( + dec, &output_pixel_format_with_extra_channel_alpha, + decoded_bytes.data(), decoded_bytes.size())); + std::vector> extra_channel_decoded_bytes( + info.num_extra_channels - has_interleaved_alpha); + + for (size_t index = has_interleaved_alpha; index < info.num_extra_channels; + index++) { + JxlExtraChannelInfo channel_info; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelInfo(dec, index, &channel_info)); + EXPECT_EQ(channel_info.type, + extra_channels[index - has_interleaved_alpha].first); + std::string input_name = + extra_channels[index - has_interleaved_alpha].second; + const size_t name_length = channel_info.name_length; + EXPECT_EQ(input_name.size(), name_length); + std::vector output_name(name_length + 1); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelName(dec, index, output_name.data(), + output_name.size())); + EXPECT_EQ(0, + memcmp(input_name.data(), output_name.data(), input_name.size())); + size_t extra_buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderExtraChannelBufferSize(dec, &output_pixel_format, + &extra_buffer_size, index)); + std::vector extra_decoded_bytes(extra_buffer_size); + extra_channel_decoded_bytes[index - has_interleaved_alpha] = + std::move(extra_decoded_bytes); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetExtraChannelBuffer( + dec, &output_pixel_format, + extra_channel_decoded_bytes[index - has_interleaved_alpha].data(), + extra_channel_decoded_bytes[index - has_interleaved_alpha].size(), + index)); + } + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + // Check if there are no further errors after getting the full image, e.g. + // check that the final codestream box is actually marked as last. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); + + jxl::CodecInOut decoded_io = ConvertTestImage( + decoded_bytes, xsize, ysize, output_pixel_format_with_extra_channel_alpha, + icc_profile); + + if (already_downsampled) { + jxl::Image3F* color = decoded_io.Main().color(); + jxl::DownsampleImage(color, resampling); + if (decoded_io.Main().HasAlpha()) { + jxl::ImageF* alpha = decoded_io.Main().alpha(); + jxl::DownsampleImage(alpha, resampling); + } + decoded_io.SetSize(color->xsize(), color->ysize()); + } + + jxl::ButteraugliParams ba; + float butteraugli_score = + ButteraugliDistance(original_io, decoded_io, ba, jxl::GetJxlCms(), + /*distmap=*/nullptr, nullptr); + if (lossless && !already_downsampled) { + EXPECT_LE(butteraugli_score, 0.0f); + } else { + EXPECT_LE(butteraugli_score, 2.0f); + } + JxlPixelFormat extra_channel_output_pixel_format = output_pixel_format; + extra_channel_output_pixel_format.num_channels = 1; + for (auto& extra_channel : extra_channel_decoded_bytes) { + EXPECT_EQ(extra_channel.size(), extra_channel_bytes.size()); + if (lossless) { + EXPECT_EQ(jxl::test::ComparePixels(extra_channel.data(), + extra_channel_bytes.data(), xsize, + ysize, extra_channel_pixel_format, + extra_channel_output_pixel_format), + 0u); + EXPECT_EQ(extra_channel, extra_channel_bytes); + } + } +} + +} // namespace + +TEST(RoundtripTest, FloatFrameRoundtripTest) { + std::vector>> + extra_channels_cases = {{}, + {{JXL_CHANNEL_ALPHA, "my extra alpha channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}}, + {{JXL_CHANNEL_DEPTH, "depth"}, + {JXL_CHANNEL_SELECTION_MASK, "mask"}, + {JXL_CHANNEL_BLACK, "black"}, + {JXL_CHANNEL_CFA, "my cfa channel"}, + {JXL_CHANNEL_OPTIONAL, "optional channel"}}, + {{JXL_CHANNEL_DEPTH, "very deep"}}}; + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + for (auto& extra_channels : extra_channels_cases) { + uint32_t has_alpha = static_cast(num_channels % 2 == 0); + uint32_t total_extra_channels = has_alpha + extra_channels.size(); + // There's no support (yet) for lossless extra float + // channels, so we don't test it. + if (total_extra_channels == 0 || !lossless) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression( + 63, 129, pixel_format, pixel_format, (bool)lossless, + (bool)use_container, 1, false, extra_channels); + } + } + } + } + } +} + +TEST(RoundtripTest, Uint16FrameRoundtripTest) { + std::vector>> + extra_channels_cases = {{}, + {{JXL_CHANNEL_ALPHA, "my extra alpha channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}, + {JXL_CHANNEL_BLACK, "k_channel"}}, + {{JXL_CHANNEL_DEPTH, "very deep"}}}; + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + for (auto& extra_channels : extra_channels_cases) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_UINT16, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression( + 63, 129, pixel_format, pixel_format, (bool)lossless, + (bool)use_container, 1, false, extra_channels); + } + } + } + } +} + +TEST(RoundtripTest, Uint8FrameRoundtripTest) { + std::vector>> + extra_channels_cases = {{}, + {{JXL_CHANNEL_THERMAL, "temperature"}}, + {{JXL_CHANNEL_ALPHA, "my extra alpha channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}, + {JXL_CHANNEL_BLACK, "k_channel"}}, + {{JXL_CHANNEL_DEPTH, "very deep"}}}; + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + for (auto& extra_channels : extra_channels_cases) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression( + 63, 129, pixel_format, pixel_format, (bool)lossless, + (bool)use_container, 1, false, extra_channels); + } + } + } + } +} + +TEST(RoundtripTest, TestNonlinearSrgbAsXybEncoded) { + for (int use_container = 0; use_container < 2; use_container++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + JxlPixelFormat pixel_format_in = + JxlPixelFormat{num_channels, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + JxlPixelFormat pixel_format_out = + JxlPixelFormat{num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression( + 63, 129, pixel_format_in, pixel_format_out, + /*lossless=*/false, (bool)use_container, {}); + } + } +} + +TEST(RoundtripTest, Resampling) { + JxlPixelFormat pixel_format = + JxlPixelFormat{3, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression(63, 129, pixel_format, pixel_format, + /*lossless=*/false, + /*use_container=*/false, 2, + /*already_downsampled=*/false); + + // TODO(lode): also make this work for odd sizes. This requires a fix in + // enc_frame.cc to not set custom_size_or_origin to true due to even/odd + // mismatch. + VerifyRoundtripCompression(64, 128, pixel_format, pixel_format, + /*lossless=*/true, + /*use_container=*/false, 2, + /*already_downsampled=*/true); +} + +TEST(RoundtripTest, ExtraBoxesTest) { + JxlPixelFormat pixel_format = + JxlPixelFormat{4, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + const size_t xsize = 61; + const size_t ysize = 71; + + const std::vector original_bytes = + GetTestImage(xsize, ysize, pixel_format); + jxl::CodecInOut original_io = + ConvertTestImage(original_bytes, xsize, ysize, pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, true)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + if (pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + JxlEncoderSetFrameLossless(frame_settings, false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + std::vector extra_data(1023); + jxl::AppendBoxHeader(jxl::MakeBoxType("crud"), extra_data.size(), false, + &compressed); + compressed.insert(compressed.end(), extra_data.begin(), extra_data.end()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &pixel_format, &buffer_size)); + EXPECT_EQ(buffer_size, original_bytes.size()); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, &pixel_format, + JXL_COLOR_PROFILE_TARGET_DATA, + &icc_profile_size)); + jxl::PaddedBytes icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &pixel_format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer(dec, &pixel_format, + decoded_bytes.data(), + decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); + + jxl::CodecInOut decoded_io = + ConvertTestImage(decoded_bytes, xsize, ysize, pixel_format, icc_profile); + + jxl::ButteraugliParams ba; + float butteraugli_score = + ButteraugliDistance(original_io, decoded_io, ba, jxl::GetJxlCms(), + /*distmap=*/nullptr, nullptr); + EXPECT_LE(butteraugli_score, 2.0f); +} + +static const unsigned char kEncodedTestProfile[] = { + 0x1f, 0x8b, 0x1, 0x13, 0x10, 0x0, 0x0, 0x0, 0x20, 0x4c, 0xcc, 0x3, + 0xe7, 0xa0, 0xa5, 0xa2, 0x90, 0xa4, 0x27, 0xe8, 0x79, 0x1d, 0xe3, 0x26, + 0x57, 0x54, 0xef, 0x0, 0xe8, 0x97, 0x2, 0xce, 0xa1, 0xd7, 0x85, 0x16, + 0xb4, 0x29, 0x94, 0x58, 0xf2, 0x56, 0xc0, 0x76, 0xea, 0x23, 0xec, 0x7c, + 0x73, 0x51, 0x41, 0x40, 0x23, 0x21, 0x95, 0x4, 0x75, 0x12, 0xc9, 0xcc, + 0x16, 0xbd, 0xb6, 0x99, 0xad, 0xf8, 0x75, 0x35, 0xb6, 0x42, 0xae, 0xae, + 0xae, 0x86, 0x56, 0xf8, 0xcc, 0x16, 0x30, 0xb3, 0x45, 0xad, 0xd, 0x40, + 0xd6, 0xd1, 0xd6, 0x99, 0x40, 0xbe, 0xe2, 0xdc, 0x31, 0x7, 0xa6, 0xb9, + 0x27, 0x92, 0x38, 0x0, 0x3, 0x5e, 0x2c, 0xbe, 0xe6, 0xfb, 0x19, 0xbf, + 0xf3, 0x6d, 0xbc, 0x4d, 0x64, 0xe5, 0xba, 0x76, 0xde, 0x31, 0x65, 0x66, + 0x14, 0xa6, 0x3a, 0xc5, 0x8f, 0xb1, 0xb4, 0xba, 0x1f, 0xb1, 0xb8, 0xd4, + 0x75, 0xba, 0x18, 0x86, 0x95, 0x3c, 0x26, 0xf6, 0x25, 0x62, 0x53, 0xfd, + 0x9c, 0x94, 0x76, 0xf6, 0x95, 0x2c, 0xb1, 0xfd, 0xdc, 0xc0, 0xe4, 0x3f, + 0xb3, 0xff, 0x67, 0xde, 0xd5, 0x94, 0xcc, 0xb0, 0x83, 0x2f, 0x28, 0x93, + 0x92, 0x3, 0xa1, 0x41, 0x64, 0x60, 0x62, 0x70, 0x80, 0x87, 0xaf, 0xe7, + 0x60, 0x4a, 0x20, 0x23, 0xb3, 0x11, 0x7, 0x38, 0x38, 0xd4, 0xa, 0x66, + 0xb5, 0x93, 0x41, 0x90, 0x19, 0x17, 0x18, 0x60, 0xa5, 0xb, 0x7a, 0x24, + 0xaa, 0x20, 0x81, 0xac, 0xa9, 0xa1, 0x70, 0xa6, 0x12, 0x8a, 0x4a, 0xa3, + 0xa0, 0xf9, 0x9a, 0x97, 0xe7, 0xa8, 0xac, 0x8, 0xa8, 0xc4, 0x2a, 0x86, + 0xa7, 0x69, 0x1e, 0x67, 0xe6, 0xbe, 0xa4, 0xd3, 0xff, 0x91, 0x61, 0xf6, + 0x8a, 0xe6, 0xb5, 0xb3, 0x61, 0x9f, 0x19, 0x17, 0x98, 0x27, 0x6b, 0xe9, + 0x8, 0x98, 0xe1, 0x21, 0x4a, 0x9, 0xb5, 0xd7, 0xca, 0xfa, 0x94, 0xd0, + 0x69, 0x1a, 0xeb, 0x52, 0x1, 0x4e, 0xf5, 0xf6, 0xdf, 0x7f, 0xe7, 0x29, + 0x70, 0xee, 0x4, 0xda, 0x2f, 0xa4, 0xff, 0xfe, 0xbb, 0x6f, 0xa8, 0xff, + 0xfe, 0xdb, 0xaf, 0x8, 0xf6, 0x72, 0xa1, 0x40, 0x5d, 0xf0, 0x2d, 0x8, + 0x82, 0x5b, 0x87, 0xbd, 0x10, 0x8, 0xe9, 0x7, 0xee, 0x4b, 0x80, 0xda, + 0x4a, 0x4, 0xc5, 0x5e, 0xa0, 0xb7, 0x1e, 0x60, 0xb0, 0x59, 0x76, 0x60, + 0xb, 0x2e, 0x19, 0x8a, 0x2e, 0x1c, 0xe6, 0x6, 0x20, 0xb8, 0x64, 0x18, + 0x2a, 0xcf, 0x51, 0x94, 0xd4, 0xee, 0xc3, 0xfe, 0x39, 0x74, 0xd4, 0x2b, + 0x48, 0xc9, 0x83, 0x4c, 0x9b, 0xd0, 0x4c, 0x35, 0x10, 0xe3, 0x9, 0xf7, + 0x72, 0xf0, 0x7a, 0xe, 0xbf, 0x7d, 0x36, 0x2e, 0x19, 0x7e, 0x3f, 0xc, + 0xf7, 0x93, 0xe7, 0xf4, 0x1d, 0x32, 0xc6, 0xb0, 0x89, 0xad, 0xe0, 0x28, + 0xc1, 0xa7, 0x59, 0xe3, 0x0, +}; + +TEST(RoundtripTest, TestICCProfile) { + // JxlEncoderSetICCProfile parses the ICC profile, so a valid profile is + // needed. The profile should be passed correctly through the roundtrip. + jxl::BitReader reader(jxl::Span(kEncodedTestProfile, + sizeof(kEncodedTestProfile))); + jxl::PaddedBytes icc; + ASSERT_TRUE(ReadICC(&reader, &icc)); + ASSERT_TRUE(reader.Close()); + + JxlPixelFormat format = + JxlPixelFormat{3, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + + size_t xsize = 25; + size_t ysize = 37; + const std::vector original_bytes = + GetTestImage(xsize, ysize, format); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetICCProfile(enc, icc.data(), icc.size())); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &format, + (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(buffer_size, original_bytes.size()); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t dec_icc_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_icc_size)); + EXPECT_EQ(icc.size(), dec_icc_size); + jxl::PaddedBytes dec_icc(dec_icc_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile(dec, &format, + JXL_COLOR_PROFILE_TARGET_ORIGINAL, + dec_icc.data(), dec_icc.size())); + + std::vector decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, decoded_bytes.data(), + decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(icc, dec_icc); + + JxlDecoderDestroy(dec); +} + +#if JPEGXL_ENABLE_JPEG // Loading .jpg files requires libjpeg support. +TEST(RoundtripTest, JXL_TRANSCODE_JPEG_TEST(TestJPEGReconstruction)) { + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + SetFromBytes(jxl::Span(orig), &orig_io, /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed; + EncodeWithEncoder(enc.get(), &compressed); + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_JPEG_RECONSTRUCTION | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec.get())); + std::vector reconstructed_buffer(128); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data(), + reconstructed_buffer.size())); + size_t used = 0; + JxlDecoderStatus dec_process_result = JXL_DEC_JPEG_NEED_MORE_OUTPUT; + while (dec_process_result == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + reconstructed_buffer.resize(reconstructed_buffer.size() * 2); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data() + used, + reconstructed_buffer.size() - used)); + dec_process_result = JxlDecoderProcessInput(dec.get()); + } + ASSERT_EQ(JXL_DEC_FULL_IMAGE, dec_process_result); + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + ASSERT_EQ(used, orig.size()); + EXPECT_EQ(0, memcmp(reconstructed_buffer.data(), orig.data(), used)); +} +#endif // JPEGXL_ENABLE_JPEG diff --git a/media/libjxl/src/lib/jxl/sanitizers.h b/media/libjxl/src/lib/jxl/sanitizers.h new file mode 100644 index 000000000..ce0bd8dc6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/sanitizers.h @@ -0,0 +1,242 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_SANITIZERS_H_ +#define LIB_JXL_SANITIZERS_H_ + +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/sanitizer_definitions.h" +#include "lib/jxl/image.h" + +#if JXL_MEMORY_SANITIZER +#include + +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "sanitizer/msan_interface.h" +#endif + +namespace jxl { +namespace msan { + +#if JXL_MEMORY_SANITIZER + +// Chosen so that kSanitizerSentinel is four copies of kSanitizerSentinelByte. +constexpr uint8_t kSanitizerSentinelByte = 0x48; +constexpr float kSanitizerSentinel = 205089.125f; + +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonMemory(const volatile void* m, + size_t size) { + __msan_poison(m, size); +} + +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonMemory(const volatile void* m, + size_t size) { + __msan_unpoison(m, size); +} + +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonCStr(const char* c) { + do { + UnpoisonMemory(c, 1); + } while (*c++); +} + +static JXL_INLINE JXL_MAYBE_UNUSED void MemoryIsInitialized( + const volatile void* m, size_t size) { + __msan_check_mem_is_initialized(m, size); +} + +// Mark all the bytes of an image (including padding) as poisoned bytes. +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const PlaneBase& im) { + PoisonMemory(im.bytes(), im.bytes_per_row() * im.ysize()); +} + +template +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const Image3& im) { + PoisonImage(im.Plane(0)); + PoisonImage(im.Plane(1)); + PoisonImage(im.Plane(2)); +} + +// Print the uninitialized regions of an image. +template +static JXL_INLINE JXL_MAYBE_UNUSED void PrintImageUninitialized( + const Plane& im) { + fprintf(stderr, + "Uninitialized regions for image of size %" PRIu64 "x%" PRIu64 ":\n", + static_cast(im.xsize()), static_cast(im.ysize())); + + // A segment of uninitialized pixels in a row, in the format [first, second). + typedef std::pair PixelSegment; + + // Helper class to merge and print a list of rows of PixelSegment that may be + // the same over big ranges of rows. This compacts the output to ranges of + // rows like "[y0, y1): [x0, x1) [x2, x3)". + class RowsMerger { + public: + // Add a new row the list of rows. If the row is the same as the previous + // one it will be merged showing a range of rows [y0, y1), but if the new + // row is different the current range of rows (if any) will be printed and a + // new one will be started. + void AddRow(size_t y, std::vector&& new_row) { + if (start_y_ != -1 && new_row != segments_) { + PrintRow(y); + } + if (new_row.empty()) { + // Skip ranges with no uninitialized pixels. + start_y_ = -1; + segments_.clear(); + return; + } + if (start_y_ == -1) { + start_y_ = y; + segments_ = std::move(new_row); + } + } + + // Print the contents of the range of rows [start_y_, end_y) if any. + void PrintRow(size_t end_y) { + if (start_y_ == -1) return; + if (segments_.empty()) { + start_y_ = -1; + return; + } + if (end_y - start_y_ > 1) { + fprintf(stderr, " y=[%" PRId64 ", %" PRIu64 "):", + static_cast(start_y_), static_cast(end_y)); + } else { + fprintf(stderr, " y=[%" PRId64 "]:", static_cast(start_y_)); + } + for (const auto& seg : segments_) { + if (seg.first + 1 == seg.second) { + fprintf(stderr, " [%" PRId64 "]", static_cast(seg.first)); + } else { + fprintf(stderr, " [%" PRId64 ", %" PRIu64 ")", + static_cast(seg.first), + static_cast(seg.second)); + } + } + fprintf(stderr, "\n"); + start_y_ = -1; + } + + private: + std::vector segments_; + // Row number of the first row in the range of rows that have |segments| as + // the undefined segments. + ssize_t start_y_ = -1; + } rows_merger; + + class SegmentsMerger { + public: + void AddValue(size_t x) { + if (row.empty() || row.back().second != x) { + row.emplace_back(x, x + 1); + } else { + row.back().second = x + 1; + } + } + + std::vector row; + }; + + for (size_t y = 0; y < im.ysize(); y++) { + auto* row = im.Row(y); + SegmentsMerger seg_merger; + size_t x = 0; + while (x < im.xsize()) { + intptr_t ret = + __msan_test_shadow(row + x, (im.xsize() - x) * sizeof(row[0])); + if (ret < 0) break; + size_t next_x = x + ret / sizeof(row[0]); + seg_merger.AddValue(next_x); + x = next_x + 1; + } + rows_merger.AddRow(y, std::move(seg_merger.row)); + } + rows_merger.PrintRow(im.ysize()); +} + +// Check that all the pixels in the provided rect of the image are initialized +// (not poisoned). If any of the values is poisoned it will abort. +template +static JXL_INLINE JXL_MAYBE_UNUSED void CheckImageInitialized( + const Plane& im, const Rect& r, size_t c, const char* message) { + JXL_ASSERT(r.x0() <= im.xsize()); + JXL_ASSERT(r.x0() + r.xsize() <= im.xsize()); + JXL_ASSERT(r.y0() <= im.ysize()); + JXL_ASSERT(r.y0() + r.ysize() <= im.ysize()); + for (size_t y = r.y0(); y < r.y0() + r.ysize(); y++) { + const auto* row = im.Row(y); + intptr_t ret = __msan_test_shadow(row + r.x0(), sizeof(*row) * r.xsize()); + if (ret != -1) { + JXL_DEBUG( + 1, + "Checking an image of %" PRIu64 " x %" PRIu64 ", rect x0=%" PRIu64 + ", y0=%" PRIu64 + ", " + "xsize=%" PRIu64 ", ysize=%" PRIu64, + static_cast(im.xsize()), static_cast(im.ysize()), + static_cast(r.x0()), static_cast(r.y0()), + static_cast(r.xsize()), static_cast(r.ysize())); + size_t x = ret / sizeof(*row); + JXL_DEBUG(1, + "CheckImageInitialized failed at x=%" PRIu64 ", y=%" PRIu64 + ", c=%" PRIu64 ": %s", + static_cast(r.x0() + x), static_cast(y), + static_cast(c), message ? message : ""); + PrintImageUninitialized(im); + } + // This will report an error if memory is not initialized. + __msan_check_mem_is_initialized(row + r.x0(), sizeof(*row) * r.xsize()); + } +} + +template +static JXL_INLINE JXL_MAYBE_UNUSED void CheckImageInitialized( + const Image3& im, const Rect& r, const char* message) { + for (size_t c = 0; c < 3; c++) { + std::string str_message(message); + str_message += " c=" + std::to_string(c); + CheckImageInitialized(im.Plane(c), r, c, str_message.c_str()); + } +} + +#define JXL_CHECK_IMAGE_INITIALIZED(im, r) \ + ::jxl::msan::CheckImageInitialized(im, r, "im=" #im ", r=" #r); + +#define JXL_CHECK_PLANE_INITIALIZED(im, r, c) \ + ::jxl::msan::CheckImageInitialized(im, r, c, "im=" #im ", r=" #r ", c=" #c); + +#else // JXL_MEMORY_SANITIZER + +// In non-msan mode these functions don't use volatile since it is not needed +// for the empty functions. + +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonMemory(const void*, size_t) {} +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonMemory(const void*, size_t) {} +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonCStr(const char*) {} +static JXL_INLINE JXL_MAYBE_UNUSED void MemoryIsInitialized(const void*, + size_t) {} + +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const PlaneBase& im) {} +template +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const Plane& im) {} + +#define JXL_CHECK_IMAGE_INITIALIZED(im, r) +#define JXL_CHECK_PLANE_INITIALIZED(im, r, c) + +#endif + +} // namespace msan +} // namespace jxl + +#endif // LIB_JXL_SANITIZERS_H_ diff --git a/media/libjxl/src/lib/jxl/simd_util-inl.h b/media/libjxl/src/lib/jxl/simd_util-inl.h new file mode 100644 index 000000000..77b207ffe --- /dev/null +++ b/media/libjxl/src/lib/jxl/simd_util-inl.h @@ -0,0 +1,349 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Misc utilities for SIMD operations + +#if defined(LIB_JXL_SIMD_UTIL_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_SIMD_UTIL_INL_H_ +#undef LIB_JXL_SIMD_UTIL_INL_H_ +#else +#define LIB_JXL_SIMD_UTIL_INL_H_ +#endif + +#include + +#include "lib/jxl/base/compiler_specific.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +#if HWY_CAP_GE512 +using hwy::HWY_NAMESPACE::Half; +using hwy::HWY_NAMESPACE::Vec; +template +HWY_INLINE Vec>> Quarter(const DF df, V v) { + using HF = Half; + using HHF = Half; + auto half = i >= 2 ? UpperHalf(HF(), v) : LowerHalf(HF(), v); + return i & 1 ? UpperHalf(HHF(), half) : LowerHalf(HHF(), half); +} + +template +HWY_INLINE Vec Concat4(const DF df, V v0, V v1, V v2, V v3) { + using HF = Half; + return Combine(DF(), Combine(HF(), v3, v2), Combine(HF(), v1, v0)); +} + +#endif + +// Stores v0[0], v1[0], v0[1], v1[1], ... to mem, in this order. Mem must be +// aligned. +template +void StoreInterleaved(const DF df, V v0, V v1, T* mem) { + static_assert(sizeof(T) == 4, "only use StoreInterleaved for 4-byte types"); +#if HWY_TARGET == HWY_SCALAR + Store(v0, df, mem); + Store(v1, df, mem + 1); +#elif !HWY_CAP_GE256 + Store(InterleaveLower(df, v0, v1), df, mem); + Store(InterleaveUpper(df, v0, v1), df, mem + Lanes(df)); +#else + if (!HWY_CAP_GE512 || Lanes(df) == 8) { + auto t0 = InterleaveLower(df, v0, v1); + auto t1 = InterleaveUpper(df, v0, v1); + Store(ConcatLowerLower(df, t1, t0), df, mem); + Store(ConcatUpperUpper(df, t1, t0), df, mem + Lanes(df)); + } else { +#if HWY_CAP_GE512 + auto t0 = InterleaveLower(df, v0, v1); + auto t1 = InterleaveUpper(df, v0, v1); + Store(Concat4(df, Quarter<0>(df, t0), Quarter<0>(df, t1), + Quarter<1>(df, t0), Quarter<1>(df, t1)), + df, mem); + Store(Concat4(df, Quarter<2>(df, t0), Quarter<2>(df, t1), + Quarter<3>(df, t0), Quarter<3>(df, t1)), + df, mem + Lanes(df)); +#endif + } +#endif +} + +// Stores v0[0], v1[0], v2[0], v3[0], v0[1] ... to mem, in this order. Mem must +// be aligned. +template +void StoreInterleaved(const DF df, V v0, V v1, V v2, V v3, T* mem) { + static_assert(sizeof(T) == 4, "only use StoreInterleaved for 4-byte types"); +#if HWY_TARGET == HWY_SCALAR + Store(v0, df, mem); + Store(v1, df, mem + 1); + Store(v2, df, mem + 2); + Store(v3, df, mem + 3); +#elif !HWY_CAP_GE256 + auto t0 = InterleaveLower(df, v0, v2); + auto t1 = InterleaveLower(df, v1, v3); + auto t2 = InterleaveUpper(df, v0, v2); + auto t3 = InterleaveUpper(df, v1, v3); + Store(InterleaveLower(df, t0, t1), df, mem); + Store(InterleaveUpper(df, t0, t1), df, mem + Lanes(df)); + Store(InterleaveLower(df, t2, t3), df, mem + 2 * Lanes(df)); + Store(InterleaveUpper(df, t2, t3), df, mem + 3 * Lanes(df)); +#elif !HWY_CAP_GE512 + auto t0 = InterleaveLower(df, v0, v2); + auto t1 = InterleaveLower(df, v1, v3); + auto t2 = InterleaveUpper(df, v0, v2); + auto t3 = InterleaveUpper(df, v1, v3); + + auto m0 = InterleaveLower(df, t0, t1); + auto m1 = InterleaveUpper(df, t0, t1); + auto m2 = InterleaveLower(df, t2, t3); + auto m3 = InterleaveUpper(df, t2, t3); + + Store(ConcatLowerLower(df, m1, m0), df, mem); + Store(ConcatLowerLower(df, m3, m2), df, mem + Lanes(df)); + Store(ConcatUpperUpper(df, m1, m0), df, mem + 2 * Lanes(df)); + Store(ConcatUpperUpper(df, m3, m2), df, mem + 3 * Lanes(df)); +#else + auto t0 = InterleaveLower(df, v0, v2); + auto t1 = InterleaveLower(df, v1, v3); + auto t2 = InterleaveUpper(df, v0, v2); + auto t3 = InterleaveUpper(df, v1, v3); + + auto m0 = InterleaveLower(df, t0, t1); + auto m1 = InterleaveUpper(df, t0, t1); + auto m2 = InterleaveLower(df, t2, t3); + auto m3 = InterleaveUpper(df, t2, t3); + + Store(Concat4(df, Quarter<0>(df, m0), Quarter<0>(df, m1), Quarter<0>(df, m2), + Quarter<0>(df, m3)), + df, mem); + Store(Concat4(df, Quarter<1>(df, m0), Quarter<1>(df, m1), Quarter<1>(df, m2), + Quarter<1>(df, m3)), + df, mem + Lanes(df)); + Store(Concat4(df, Quarter<2>(df, m0), Quarter<2>(df, m1), Quarter<2>(df, m2), + Quarter<2>(df, m3)), + df, mem + 2 * Lanes(df)); + Store(Concat4(df, Quarter<3>(df, m0), Quarter<3>(df, m1), Quarter<3>(df, m2), + Quarter<3>(df, m3)), + df, mem + 3 * Lanes(df)); +#endif +} + +// Stores v0[0], v1[0], v2[0], v3[0], v4[0], v5[0], v6[0], v7[0], v0[1] ... to +// mem, in this order. Mem must be aligned. +template +void StoreInterleaved(const DF df, V v0, V v1, V v2, V v3, V v4, V v5, V v6, + V v7, float* mem) { +#if HWY_TARGET == HWY_SCALAR + Store(v0, df, mem); + Store(v1, df, mem + 1); + Store(v2, df, mem + 2); + Store(v3, df, mem + 3); + Store(v4, df, mem + 4); + Store(v5, df, mem + 5); + Store(v6, df, mem + 6); + Store(v7, df, mem + 7); +#elif !HWY_CAP_GE256 + auto t0 = InterleaveLower(df, v0, v4); + auto t1 = InterleaveLower(df, v1, v5); + auto t2 = InterleaveLower(df, v2, v6); + auto t3 = InterleaveLower(df, v3, v7); + auto t4 = InterleaveUpper(df, v0, v4); + auto t5 = InterleaveUpper(df, v1, v5); + auto t6 = InterleaveUpper(df, v2, v6); + auto t7 = InterleaveUpper(df, v3, v7); + + auto w0 = InterleaveLower(df, t0, t2); + auto w1 = InterleaveLower(df, t1, t3); + auto w2 = InterleaveUpper(df, t0, t2); + auto w3 = InterleaveUpper(df, t1, t3); + auto w4 = InterleaveLower(df, t4, t6); + auto w5 = InterleaveLower(df, t5, t7); + auto w6 = InterleaveUpper(df, t4, t6); + auto w7 = InterleaveUpper(df, t5, t7); + + Store(InterleaveLower(df, w0, w1), df, mem); + Store(InterleaveUpper(df, w0, w1), df, mem + Lanes(df)); + Store(InterleaveLower(df, w2, w3), df, mem + 2 * Lanes(df)); + Store(InterleaveUpper(df, w2, w3), df, mem + 3 * Lanes(df)); + Store(InterleaveLower(df, w4, w5), df, mem + 4 * Lanes(df)); + Store(InterleaveUpper(df, w4, w5), df, mem + 5 * Lanes(df)); + Store(InterleaveLower(df, w6, w7), df, mem + 6 * Lanes(df)); + Store(InterleaveUpper(df, w6, w7), df, mem + 7 * Lanes(df)); +#elif !HWY_CAP_GE512 + auto t0 = InterleaveLower(df, v0, v4); + auto t1 = InterleaveLower(df, v1, v5); + auto t2 = InterleaveLower(df, v2, v6); + auto t3 = InterleaveLower(df, v3, v7); + auto t4 = InterleaveUpper(df, v0, v4); + auto t5 = InterleaveUpper(df, v1, v5); + auto t6 = InterleaveUpper(df, v2, v6); + auto t7 = InterleaveUpper(df, v3, v7); + + auto w0 = InterleaveLower(df, t0, t2); + auto w1 = InterleaveLower(df, t1, t3); + auto w2 = InterleaveUpper(df, t0, t2); + auto w3 = InterleaveUpper(df, t1, t3); + auto w4 = InterleaveLower(df, t4, t6); + auto w5 = InterleaveLower(df, t5, t7); + auto w6 = InterleaveUpper(df, t4, t6); + auto w7 = InterleaveUpper(df, t5, t7); + + auto m0 = InterleaveLower(df, w0, w1); + auto m1 = InterleaveUpper(df, w0, w1); + auto m2 = InterleaveLower(df, w2, w3); + auto m3 = InterleaveUpper(df, w2, w3); + auto m4 = InterleaveLower(df, w4, w5); + auto m5 = InterleaveUpper(df, w4, w5); + auto m6 = InterleaveLower(df, w6, w7); + auto m7 = InterleaveUpper(df, w6, w7); + + Store(ConcatLowerLower(df, m1, m0), df, mem); + Store(ConcatLowerLower(df, m3, m2), df, mem + Lanes(df)); + Store(ConcatLowerLower(df, m5, m4), df, mem + 2 * Lanes(df)); + Store(ConcatLowerLower(df, m7, m6), df, mem + 3 * Lanes(df)); + Store(ConcatUpperUpper(df, m1, m0), df, mem + 4 * Lanes(df)); + Store(ConcatUpperUpper(df, m3, m2), df, mem + 5 * Lanes(df)); + Store(ConcatUpperUpper(df, m5, m4), df, mem + 6 * Lanes(df)); + Store(ConcatUpperUpper(df, m7, m6), df, mem + 7 * Lanes(df)); +#else + auto t0 = InterleaveLower(df, v0, v4); + auto t1 = InterleaveLower(df, v1, v5); + auto t2 = InterleaveLower(df, v2, v6); + auto t3 = InterleaveLower(df, v3, v7); + auto t4 = InterleaveUpper(df, v0, v4); + auto t5 = InterleaveUpper(df, v1, v5); + auto t6 = InterleaveUpper(df, v2, v6); + auto t7 = InterleaveUpper(df, v3, v7); + + auto w0 = InterleaveLower(df, t0, t2); + auto w1 = InterleaveLower(df, t1, t3); + auto w2 = InterleaveUpper(df, t0, t2); + auto w3 = InterleaveUpper(df, t1, t3); + auto w4 = InterleaveLower(df, t4, t6); + auto w5 = InterleaveLower(df, t5, t7); + auto w6 = InterleaveUpper(df, t4, t6); + auto w7 = InterleaveUpper(df, t5, t7); + + auto m0 = InterleaveLower(df, w0, w1); + auto m1 = InterleaveUpper(df, w0, w1); + auto m2 = InterleaveLower(df, w2, w3); + auto m3 = InterleaveUpper(df, w2, w3); + auto m4 = InterleaveLower(df, w4, w5); + auto m5 = InterleaveUpper(df, w4, w5); + auto m6 = InterleaveLower(df, w6, w7); + auto m7 = InterleaveUpper(df, w6, w7); + + Store(Concat4(df, Quarter<0>(df, m0), Quarter<0>(df, m1), Quarter<0>(df, m2), + Quarter<0>(df, m3)), + df, mem); + Store(Concat4(df, Quarter<0>(df, m4), Quarter<0>(df, m5), Quarter<0>(df, m6), + Quarter<0>(df, m7)), + df, mem + Lanes(df)); + Store(Concat4(df, Quarter<1>(df, m0), Quarter<1>(df, m1), Quarter<1>(df, m2), + Quarter<1>(df, m3)), + df, mem + 2 * Lanes(df)); + Store(Concat4(df, Quarter<1>(df, m4), Quarter<1>(df, m5), Quarter<1>(df, m6), + Quarter<1>(df, m7)), + df, mem + 3 * Lanes(df)); + Store(Concat4(df, Quarter<2>(df, m0), Quarter<2>(df, m1), Quarter<2>(df, m2), + Quarter<2>(df, m3)), + df, mem + 4 * Lanes(df)); + Store(Concat4(df, Quarter<2>(df, m4), Quarter<2>(df, m5), Quarter<2>(df, m6), + Quarter<2>(df, m7)), + df, mem + 5 * Lanes(df)); + Store(Concat4(df, Quarter<3>(df, m0), Quarter<3>(df, m1), Quarter<3>(df, m2), + Quarter<3>(df, m3)), + df, mem + 6 * Lanes(df)); + Store(Concat4(df, Quarter<3>(df, m4), Quarter<3>(df, m5), Quarter<3>(df, m6), + Quarter<3>(df, m7)), + df, mem + 7 * Lanes(df)); +#endif +} + +#if HWY_CAP_GE256 +JXL_INLINE void Transpose8x8Block(const int32_t* JXL_RESTRICT from, + int32_t* JXL_RESTRICT to, size_t fromstride) { + const HWY_CAPPED(int32_t, 8) d; + auto i0 = Load(d, from); + auto i1 = Load(d, from + 1 * fromstride); + auto i2 = Load(d, from + 2 * fromstride); + auto i3 = Load(d, from + 3 * fromstride); + auto i4 = Load(d, from + 4 * fromstride); + auto i5 = Load(d, from + 5 * fromstride); + auto i6 = Load(d, from + 6 * fromstride); + auto i7 = Load(d, from + 7 * fromstride); + + const auto q0 = InterleaveLower(d, i0, i2); + const auto q1 = InterleaveLower(d, i1, i3); + const auto q2 = InterleaveUpper(d, i0, i2); + const auto q3 = InterleaveUpper(d, i1, i3); + const auto q4 = InterleaveLower(d, i4, i6); + const auto q5 = InterleaveLower(d, i5, i7); + const auto q6 = InterleaveUpper(d, i4, i6); + const auto q7 = InterleaveUpper(d, i5, i7); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + const auto r4 = InterleaveLower(d, q4, q5); + const auto r5 = InterleaveUpper(d, q4, q5); + const auto r6 = InterleaveLower(d, q6, q7); + const auto r7 = InterleaveUpper(d, q6, q7); + + i0 = ConcatLowerLower(d, r4, r0); + i1 = ConcatLowerLower(d, r5, r1); + i2 = ConcatLowerLower(d, r6, r2); + i3 = ConcatLowerLower(d, r7, r3); + i4 = ConcatUpperUpper(d, r4, r0); + i5 = ConcatUpperUpper(d, r5, r1); + i6 = ConcatUpperUpper(d, r6, r2); + i7 = ConcatUpperUpper(d, r7, r3); + + Store(i0, d, to); + Store(i1, d, to + 1 * 8); + Store(i2, d, to + 2 * 8); + Store(i3, d, to + 3 * 8); + Store(i4, d, to + 4 * 8); + Store(i5, d, to + 5 * 8); + Store(i6, d, to + 6 * 8); + Store(i7, d, to + 7 * 8); +} +#elif HWY_TARGET != HWY_SCALAR +JXL_INLINE void Transpose8x8Block(const int32_t* JXL_RESTRICT from, + int32_t* JXL_RESTRICT to, size_t fromstride) { + const HWY_CAPPED(int32_t, 4) d; + for (size_t n = 0; n < 8; n += 4) { + for (size_t m = 0; m < 8; m += 4) { + auto p0 = Load(d, from + n * fromstride + m); + auto p1 = Load(d, from + (n + 1) * fromstride + m); + auto p2 = Load(d, from + (n + 2) * fromstride + m); + auto p3 = Load(d, from + (n + 3) * fromstride + m); + const auto q0 = InterleaveLower(d, p0, p2); + const auto q1 = InterleaveLower(d, p1, p3); + const auto q2 = InterleaveUpper(d, p0, p2); + const auto q3 = InterleaveUpper(d, p1, p3); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + Store(r0, d, to + m * 8 + n); + Store(r1, d, to + (1 + m) * 8 + n); + Store(r2, d, to + (2 + m) * 8 + n); + Store(r3, d, to + (3 + m) * 8 + n); + } + } +} + +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_SIMD_UTIL_INL_H_ diff --git a/media/libjxl/src/lib/jxl/simd_util_test.cc b/media/libjxl/src/lib/jxl/simd_util_test.cc new file mode 100644 index 000000000..b81f5d127 --- /dev/null +++ b/media/libjxl/src/lib/jxl/simd_util_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/simd_util_test.cc" +#include + +#include "lib/jxl/simd_util-inl.h" + +// Test utils +#include +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestInterleave2() { + HWY_FULL(float) d; + auto vec1 = Iota(d, 0 * 128.0); + auto vec2 = Iota(d, 1 * 128.0); + HWY_ALIGN float mem[MaxLanes(d) * 2]; + StoreInterleaved(d, vec1, vec2, mem); + for (size_t i = 0; i < Lanes(d); i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ(mem[2 * i + j], j * 128 + i) << "i: " << i << " j: " << j; + } + } +} +HWY_NOINLINE void TestInterleave4() { + HWY_FULL(float) d; + auto vec1 = Iota(d, 0 * 128.0); + auto vec2 = Iota(d, 1 * 128.0); + auto vec3 = Iota(d, 2 * 128.0); + auto vec4 = Iota(d, 3 * 128.0); + HWY_ALIGN float mem[MaxLanes(d) * 4]; + StoreInterleaved(d, vec1, vec2, vec3, vec4, mem); + for (size_t i = 0; i < Lanes(d); i++) { + for (size_t j = 0; j < 4; j++) { + EXPECT_EQ(mem[4 * i + j], j * 128 + i) << "i: " << i << " j: " << j; + } + } +} +HWY_NOINLINE void TestInterleave8() { + HWY_FULL(float) d; + auto vec1 = Iota(d, 0 * 128.0); + auto vec2 = Iota(d, 1 * 128.0); + auto vec3 = Iota(d, 2 * 128.0); + auto vec4 = Iota(d, 3 * 128.0); + auto vec5 = Iota(d, 4 * 128.0); + auto vec6 = Iota(d, 5 * 128.0); + auto vec7 = Iota(d, 6 * 128.0); + auto vec8 = Iota(d, 7 * 128.0); + HWY_ALIGN float mem[MaxLanes(d) * 8]; + StoreInterleaved(d, vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, mem); + for (size_t i = 0; i < Lanes(d); i++) { + for (size_t j = 0; j < 8; j++) { + EXPECT_EQ(mem[8 * i + j], j * 128 + i) << "i: " << i << " j: " << j; + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class SimdUtilTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(SimdUtilTargetTest); + +HWY_EXPORT_AND_TEST_P(SimdUtilTargetTest, TestInterleave2); +HWY_EXPORT_AND_TEST_P(SimdUtilTargetTest, TestInterleave4); +HWY_EXPORT_AND_TEST_P(SimdUtilTargetTest, TestInterleave8); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/size_constraints.h b/media/libjxl/src/lib/jxl/size_constraints.h new file mode 100644 index 000000000..20787b164 --- /dev/null +++ b/media/libjxl/src/lib/jxl/size_constraints.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_SIZE_CONSTRAINTS_H_ +#define LIB_JXL_SIZE_CONSTRAINTS_H_ + +#include + +namespace jxl { + +struct SizeConstraints { + // Upper limit on pixel dimensions/area, enforced by VerifyDimensions + // (called from decoders). Fuzzers set smaller values to limit memory use. + uint32_t dec_max_xsize = 0xFFFFFFFFu; + uint32_t dec_max_ysize = 0xFFFFFFFFu; + uint64_t dec_max_pixels = 0xFFFFFFFFu; // Might be up to ~0ull +}; + +} // namespace jxl + +#endif // LIB_JXL_SIZE_CONSTRAINTS_H_ diff --git a/media/libjxl/src/lib/jxl/speed_tier_test.cc b/media/libjxl/src/lib/jxl/speed_tier_test.cc new file mode 100644 index 000000000..9c120fe54 --- /dev/null +++ b/media/libjxl/src/lib/jxl/speed_tier_test.cc @@ -0,0 +1,109 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +struct SpeedTierTestParams { + explicit SpeedTierTestParams(const SpeedTier speed_tier, + const bool shrink8 = false) + : speed_tier(speed_tier), shrink8(shrink8) {} + SpeedTier speed_tier; + bool shrink8; +}; + +std::ostream& operator<<(std::ostream& os, SpeedTierTestParams params) { + auto previous_flags = os.flags(); + os << std::boolalpha; + os << "SpeedTierTestParams{" << SpeedTierName(params.speed_tier) + << ", /*shrink8=*/" << params.shrink8 << "}"; + os.flags(previous_flags); + return os; +} + +class SpeedTierTest : public testing::TestWithParam {}; + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + SpeedTierTestInstantiation, SpeedTierTest, + testing::Values(SpeedTierTestParams{SpeedTier::kCheetah, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kCheetah, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kThunder, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kThunder, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kLightning, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kLightning, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kFalcon, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kFalcon, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kHare, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kHare, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kWombat, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kWombat, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kSquirrel, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kSquirrel, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kKitten, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kKitten, + /*shrink8=*/false}, + // Only downscaled image for Tortoise mode. + SpeedTierTestParams{SpeedTier::kTortoise, + /*shrink8=*/true})); + +TEST_P(SpeedTierTest, Roundtrip) { + const PaddedBytes orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ThreadPoolInternal pool(8); + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + const SpeedTierTestParams& params = GetParam(); + + if (params.shrink8) { + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + } + + CompressParams cparams; + cparams.speed_tier = params.speed_tier; + + CodecInOut io2; + test::Roundtrip(&io, cparams, {}, nullptr, &io2); + + // Can be 2.2 in non-hare mode. + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, GetJxlCms(), + /*distmap=*/nullptr, /*pool=*/nullptr), + 2.8); +} +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/splines.cc b/media/libjxl/src/lib/jxl/splines.cc new file mode 100644 index 000000000..edaaf2738 --- /dev/null +++ b/media/libjxl/src/lib/jxl/splines.cc @@ -0,0 +1,683 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/splines.h" + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/opsin_params.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/splines.cc" +#include +#include + +#include "lib/jxl/fast_math-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::MulSub; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::Sub; + +// Given a set of DCT coefficients, this returns the result of performing cosine +// interpolation on the original samples. +float ContinuousIDCT(const float dct[32], const float t) { + // We compute here the DCT-3 of the `dct` vector, rescaled by a factor of + // sqrt(32). This is such that an input vector vector {x, 0, ..., 0} produces + // a constant result of x. dct[0] was scaled in Dequantize() to allow uniform + // treatment of all the coefficients. + constexpr float kMultipliers[32] = { + kPi / 32 * 0, kPi / 32 * 1, kPi / 32 * 2, kPi / 32 * 3, kPi / 32 * 4, + kPi / 32 * 5, kPi / 32 * 6, kPi / 32 * 7, kPi / 32 * 8, kPi / 32 * 9, + kPi / 32 * 10, kPi / 32 * 11, kPi / 32 * 12, kPi / 32 * 13, kPi / 32 * 14, + kPi / 32 * 15, kPi / 32 * 16, kPi / 32 * 17, kPi / 32 * 18, kPi / 32 * 19, + kPi / 32 * 20, kPi / 32 * 21, kPi / 32 * 22, kPi / 32 * 23, kPi / 32 * 24, + kPi / 32 * 25, kPi / 32 * 26, kPi / 32 * 27, kPi / 32 * 28, kPi / 32 * 29, + kPi / 32 * 30, kPi / 32 * 31, + }; + HWY_CAPPED(float, 32) df; + auto result = Zero(df); + const auto tandhalf = Set(df, t + 0.5f); + for (int i = 0; i < 32; i += Lanes(df)) { + auto cos_arg = Mul(LoadU(df, kMultipliers + i), tandhalf); + auto cos = FastCosf(df, cos_arg); + auto local_res = Mul(LoadU(df, dct + i), cos); + result = MulAdd(Set(df, kSqrt2), local_res, result); + } + return GetLane(SumOfLanes(df, result)); +} + +template +void DrawSegment(DF df, const SplineSegment& segment, const bool add, + const size_t y, const size_t x, float* JXL_RESTRICT rows[3]) { + Rebind di; + const auto inv_sigma = Set(df, segment.inv_sigma); + const auto half = Set(df, 0.5f); + const auto one_over_2s2 = Set(df, 0.353553391f); + const auto sigma_over_4_times_intensity = + Set(df, segment.sigma_over_4_times_intensity); + const auto dx = Sub(ConvertTo(df, Iota(di, x)), Set(df, segment.center_x)); + const auto dy = Set(df, y - segment.center_y); + const auto sqd = MulAdd(dx, dx, Mul(dy, dy)); + const auto distance = Sqrt(sqd); + const auto one_dimensional_factor = + Sub(FastErff(df, Mul(MulAdd(distance, half, one_over_2s2), inv_sigma)), + FastErff(df, Mul(MulSub(distance, half, one_over_2s2), inv_sigma))); + auto local_intensity = + Mul(sigma_over_4_times_intensity, + Mul(one_dimensional_factor, one_dimensional_factor)); + for (size_t c = 0; c < 3; ++c) { + const auto cm = Set(df, add ? segment.color[c] : -segment.color[c]); + const auto in = LoadU(df, rows[c] + x); + StoreU(MulAdd(cm, local_intensity, in), df, rows[c] + x); + } +} + +void DrawSegment(const SplineSegment& segment, const bool add, const size_t y, + const ssize_t x0, ssize_t x1, float* JXL_RESTRICT rows[3]) { + ssize_t x = + std::max(x0, segment.center_x - segment.maximum_distance + 0.5f); + // one-past-the-end + x1 = + std::min(x1, segment.center_x + segment.maximum_distance + 1.5f); + HWY_FULL(float) df; + for (; x + static_cast(Lanes(df)) <= x1; x += Lanes(df)) { + DrawSegment(df, segment, add, y, x, rows); + } + for (; x < x1; ++x) { + DrawSegment(HWY_CAPPED(float, 1)(), segment, add, y, x, rows); + } +} + +void ComputeSegments(const Spline::Point& center, const float intensity, + const float color[3], const float sigma, + std::vector& segments, + std::vector>& segments_by_y, + size_t* pixel_limit) { + // In worst case zero-sized dot spans over 2 rows / columns. + constexpr const float kThinDotSpan = 2.0f; + // Sanity check sigma, inverse sigma and intensity + if (!(std::isfinite(sigma) && sigma != 0.0f && std::isfinite(1.0f / sigma) && + std::isfinite(intensity))) { + // Even no-draw should still be accounted. + *pixel_limit -= std::min(*pixel_limit, kThinDotSpan * kThinDotSpan); + return; + } +#if JXL_HIGH_PRECISION + constexpr float kDistanceExp = 5; +#else + // About 30% faster. + constexpr float kDistanceExp = 3; +#endif + // We cap from below colors to at least 0.01. + float max_color = 0.01f; + for (size_t c = 0; c < 3; c++) { + max_color = std::max(max_color, std::abs(color[c] * intensity)); + } + // Distance beyond which max_color*intensity*exp(-d^2 / (2 * sigma^2)) drops + // below 10^-kDistanceExp. + const float maximum_distance = + std::sqrt(-2 * sigma * sigma * + (std::log(0.1) * kDistanceExp - std::log(max_color))); + SplineSegment segment; + segment.center_y = center.y; + segment.center_x = center.x; + memcpy(segment.color, color, sizeof(segment.color)); + segment.inv_sigma = 1.0f / sigma; + segment.sigma_over_4_times_intensity = .25f * sigma * intensity; + segment.maximum_distance = maximum_distance; + float cost = 2.0f * maximum_distance + kThinDotSpan; + // Check cost^2 fits size_t. + if (cost >= static_cast(1 << 15)) { + // Too much to rasterize. + *pixel_limit = 0; + return; + } + size_t area_cost = static_cast(cost * cost); + if (area_cost > *pixel_limit) { + *pixel_limit = 0; + return; + } + // TODO(eustas): perhaps we should charge less: (y1 - y0) <= cost + *pixel_limit -= area_cost; + ssize_t y0 = center.y - maximum_distance + .5f; + ssize_t y1 = center.y + maximum_distance + 1.5f; // one-past-the-end + for (ssize_t y = std::max(y0, 0); y < y1; y++) { + segments_by_y.emplace_back(y, segments.size()); + } + segments.push_back(segment); +} + +void DrawSegments(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_rect, + const bool add, const SplineSegment* segments, + const size_t* segment_indices, + const size_t* segment_y_start) { + JXL_ASSERT(image_rect.ysize() == 1); + float* JXL_RESTRICT rows[3] = {row_x - image_rect.x0(), + row_y - image_rect.x0(), + row_b - image_rect.x0()}; + size_t y = image_rect.y0(); + for (size_t i = segment_y_start[y]; i < segment_y_start[y + 1]; i++) { + DrawSegment(segments[segment_indices[i]], add, y, image_rect.x0(), + image_rect.x0() + image_rect.xsize(), rows); + } +} + +void SegmentsFromPoints( + const Spline& spline, + const std::vector>& points_to_draw, + const float arc_length, std::vector& segments, + std::vector>& segments_by_y, + size_t* pixel_limit) { + const float inv_arc_length = 1.0f / arc_length; + int k = 0; + for (const auto& point_to_draw : points_to_draw) { + const Spline::Point& point = point_to_draw.first; + const float multiplier = point_to_draw.second; + const float progress_along_arc = + std::min(1.f, (k * kDesiredRenderingDistance) * inv_arc_length); + ++k; + float color[3]; + for (size_t c = 0; c < 3; ++c) { + color[c] = + ContinuousIDCT(spline.color_dct[c], (32 - 1) * progress_along_arc); + } + const float sigma = + ContinuousIDCT(spline.sigma_dct, (32 - 1) * progress_along_arc); + ComputeSegments(point, multiplier, color, sigma, segments, segments_by_y, + pixel_limit); + if (*pixel_limit == 0) { + return; + } + } +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(SegmentsFromPoints); +HWY_EXPORT(DrawSegments); + +namespace { + +// It is not in spec, but reasonable limit to avoid overflows. +template +Status ValidateSplinePointPos(const T& x, const T& y) { + constexpr T kSplinePosLimit = 1u << 23; + if ((x >= kSplinePosLimit) || (x <= -kSplinePosLimit) || + (y >= kSplinePosLimit) || (y <= -kSplinePosLimit)) { + return JXL_FAILURE("Spline coordinates out of bounds"); + } + return true; +} + +// Maximum number of spline control points per frame is +// std::min(kMaxNumControlPoints, xsize * ysize / 2) +constexpr size_t kMaxNumControlPoints = 1u << 20u; +constexpr size_t kMaxNumControlPointsPerPixelRatio = 2; + +float AdjustedQuant(const int32_t adjustment) { + return (adjustment >= 0) ? (1.f + .125f * adjustment) + : 1.f / (1.f - .125f * adjustment); +} + +float InvAdjustedQuant(const int32_t adjustment) { + return (adjustment >= 0) ? 1.f / (1.f + .125f * adjustment) + : (1.f - .125f * adjustment); +} + +// X, Y, B, sigma. +static constexpr float kChannelWeight[] = {0.0042f, 0.075f, 0.07f, .3333f}; + +Status DecodeAllStartingPoints(std::vector* const points, + BitReader* const br, ANSSymbolReader* reader, + const std::vector& context_map, + const size_t num_splines) { + points->clear(); + points->reserve(num_splines); + int64_t last_x = 0; + int64_t last_y = 0; + for (size_t i = 0; i < num_splines; i++) { + int64_t x = + reader->ReadHybridUint(kStartingPositionContext, br, context_map); + int64_t y = + reader->ReadHybridUint(kStartingPositionContext, br, context_map); + if (i != 0) { + x = UnpackSigned(x) + last_x; + y = UnpackSigned(y) + last_y; + } + JXL_RETURN_IF_ERROR(ValidateSplinePointPos(x, y)); + points->emplace_back(static_cast(x), static_cast(y)); + last_x = x; + last_y = y; + } + return true; +} + +struct Vector { + float x, y; + Vector operator-() const { return {-x, -y}; } + Vector operator+(const Vector& other) const { + return {x + other.x, y + other.y}; + } + float SquaredNorm() const { return x * x + y * y; } +}; +Vector operator*(const float k, const Vector& vec) { + return {k * vec.x, k * vec.y}; +} + +Spline::Point operator+(const Spline::Point& p, const Vector& vec) { + return {p.x + vec.x, p.y + vec.y}; +} +Spline::Point operator-(const Spline::Point& p, const Vector& vec) { + return p + -vec; +} +Vector operator-(const Spline::Point& a, const Spline::Point& b) { + return {a.x - b.x, a.y - b.y}; +} + +// TODO(eustas): avoid making a copy of "points". +void DrawCentripetalCatmullRomSpline(std::vector points, + std::vector& result) { + if (points.empty()) return; + if (points.size() == 1) { + result.push_back(points[0]); + return; + } + // Number of points to compute between each control point. + static constexpr int kNumPoints = 16; + result.reserve((points.size() - 1) * kNumPoints + 1); + points.insert(points.begin(), points[0] + (points[0] - points[1])); + points.push_back(points[points.size() - 1] + + (points[points.size() - 1] - points[points.size() - 2])); + // points has at least 4 elements at this point. + for (size_t start = 0; start < points.size() - 3; ++start) { + // 4 of them are used, and we draw from p[1] to p[2]. + const Spline::Point* const p = &points[start]; + result.push_back(p[1]); + float d[3]; + float t[4]; + t[0] = 0; + for (int k = 0; k < 3; ++k) { + // TODO(eustas): for each segment delta is calculated 3 times... + // TODO(eustas): restrict d[k] with reasonable limit and spec it. + d[k] = std::sqrt(hypotf(p[k + 1].x - p[k].x, p[k + 1].y - p[k].y)); + t[k + 1] = t[k] + d[k]; + } + for (int i = 1; i < kNumPoints; ++i) { + const float tt = d[0] + (static_cast(i) / kNumPoints) * d[1]; + Spline::Point a[3]; + for (int k = 0; k < 3; ++k) { + // TODO(eustas): reciprocal multiplication would be faster. + a[k] = p[k] + ((tt - t[k]) / d[k]) * (p[k + 1] - p[k]); + } + Spline::Point b[2]; + for (int k = 0; k < 2; ++k) { + b[k] = a[k] + ((tt - t[k]) / (d[k] + d[k + 1])) * (a[k + 1] - a[k]); + } + result.push_back(b[0] + ((tt - t[1]) / d[1]) * (b[1] - b[0])); + } + } + result.push_back(points[points.size() - 2]); +} + +// Move along the line segments defined by `points`, `kDesiredRenderingDistance` +// pixels at a time, and call `functor` with each point and the actual distance +// to the previous point (which will always be kDesiredRenderingDistance except +// possibly for the very last point). +// TODO(eustas): this method always adds the last point, but never the first +// (unless those are one); I believe both ends matter. +template +bool ForEachEquallySpacedPoint(const Points& points, const Functor& functor) { + JXL_ASSERT(!points.empty()); + Spline::Point current = points.front(); + functor(current, kDesiredRenderingDistance); + auto next = points.begin(); + while (next != points.end()) { + const Spline::Point* previous = ¤t; + float arclength_from_previous = 0.f; + for (;;) { + if (next == points.end()) { + return functor(*previous, arclength_from_previous); + } + const float arclength_to_next = + std::sqrt((*next - *previous).SquaredNorm()); + if (arclength_from_previous + arclength_to_next >= + kDesiredRenderingDistance) { + current = + *previous + ((kDesiredRenderingDistance - arclength_from_previous) / + arclength_to_next) * + (*next - *previous); + if (!functor(current, kDesiredRenderingDistance)) { + return false; + } + break; + } + arclength_from_previous += arclength_to_next; + previous = &*next; + ++next; + } + } + return true; +} + +} // namespace + +QuantizedSpline::QuantizedSpline(const Spline& original, + const int32_t quantization_adjustment, + const float y_to_x, const float y_to_b) { + JXL_ASSERT(!original.control_points.empty()); + control_points_.reserve(original.control_points.size() - 1); + const Spline::Point& starting_point = original.control_points.front(); + int previous_x = static_cast(roundf(starting_point.x)), + previous_y = static_cast(roundf(starting_point.y)); + int previous_delta_x = 0, previous_delta_y = 0; + for (auto it = original.control_points.begin() + 1; + it != original.control_points.end(); ++it) { + const int new_x = static_cast(roundf(it->x)); + const int new_y = static_cast(roundf(it->y)); + const int new_delta_x = new_x - previous_x; + const int new_delta_y = new_y - previous_y; + control_points_.emplace_back(new_delta_x - previous_delta_x, + new_delta_y - previous_delta_y); + previous_delta_x = new_delta_x; + previous_delta_y = new_delta_y; + previous_x = new_x; + previous_y = new_y; + } + + const auto to_int = [](float v) -> int { + return static_cast(roundf(v)); + }; + + const auto quant = AdjustedQuant(quantization_adjustment); + const auto inv_quant = InvAdjustedQuant(quantization_adjustment); + for (int c : {1, 0, 2}) { + float factor = (c == 0) ? y_to_x : (c == 1) ? 0 : y_to_b; + for (int i = 0; i < 32; ++i) { + const float dct_factor = (i == 0) ? kSqrt2 : 1.0f; + const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f; + auto restored_y = + color_dct_[1][i] * inv_dct_factor * kChannelWeight[1] * inv_quant; + auto decorellated = original.color_dct[c][i] - factor * restored_y; + color_dct_[c][i] = + to_int(decorellated * dct_factor * quant / kChannelWeight[c]); + } + } + for (int i = 0; i < 32; ++i) { + const float dct_factor = (i == 0) ? kSqrt2 : 1.0f; + sigma_dct_[i] = + to_int(original.sigma_dct[i] * dct_factor * quant / kChannelWeight[3]); + } +} + +Status QuantizedSpline::Dequantize(const Spline::Point& starting_point, + const int32_t quantization_adjustment, + const float y_to_x, const float y_to_b, + Spline& result) const { + result.control_points.clear(); + result.control_points.reserve(control_points_.size() + 1); + float px = roundf(starting_point.x); + float py = roundf(starting_point.y); + JXL_RETURN_IF_ERROR(ValidateSplinePointPos(px, py)); + int current_x = static_cast(px); + int current_y = static_cast(py); + result.control_points.push_back(Spline::Point{static_cast(current_x), + static_cast(current_y)}); + int current_delta_x = 0, current_delta_y = 0; + for (const auto& point : control_points_) { + current_delta_x += point.first; + current_delta_y += point.second; + JXL_RETURN_IF_ERROR( + ValidateSplinePointPos(current_delta_x, current_delta_y)); + current_x += current_delta_x; + current_y += current_delta_y; + JXL_RETURN_IF_ERROR(ValidateSplinePointPos(current_x, current_y)); + result.control_points.push_back(Spline::Point{ + static_cast(current_x), static_cast(current_y)}); + } + + const auto inv_quant = InvAdjustedQuant(quantization_adjustment); + for (int c = 0; c < 3; ++c) { + for (int i = 0; i < 32; ++i) { + const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f; + result.color_dct[c][i] = + color_dct_[c][i] * inv_dct_factor * kChannelWeight[c] * inv_quant; + } + } + for (int i = 0; i < 32; ++i) { + result.color_dct[0][i] += y_to_x * result.color_dct[1][i]; + result.color_dct[2][i] += y_to_b * result.color_dct[1][i]; + } + for (int i = 0; i < 32; ++i) { + const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f; + result.sigma_dct[i] = + sigma_dct_[i] * inv_dct_factor * kChannelWeight[3] * inv_quant; + } + + return true; +} + +Status QuantizedSpline::Decode(const std::vector& context_map, + ANSSymbolReader* const decoder, + BitReader* const br, + const size_t max_control_points, + size_t* total_num_control_points) { + const size_t num_control_points = + decoder->ReadHybridUint(kNumControlPointsContext, br, context_map); + *total_num_control_points += num_control_points; + if (*total_num_control_points > max_control_points) { + return JXL_FAILURE("Too many control points: %" PRIuS, + *total_num_control_points); + } + control_points_.resize(num_control_points); + // Maximal image dimension. + constexpr int64_t kDeltaLimit = 1u << 30; + for (std::pair& control_point : control_points_) { + control_point.first = UnpackSigned( + decoder->ReadHybridUint(kControlPointsContext, br, context_map)); + control_point.second = UnpackSigned( + decoder->ReadHybridUint(kControlPointsContext, br, context_map)); + // Check delta-deltas are not outrageous; it is not in spec, but there is + // no reason to allow larger values. + if ((control_point.first >= kDeltaLimit) || + (control_point.first <= -kDeltaLimit) || + (control_point.second >= kDeltaLimit) || + (control_point.second <= -kDeltaLimit)) { + return JXL_FAILURE("Spline delta-delta is out of bounds"); + } + } + + const auto decode_dct = [decoder, br, &context_map](int dct[32]) -> Status { + for (int i = 0; i < 32; ++i) { + dct[i] = + UnpackSigned(decoder->ReadHybridUint(kDCTContext, br, context_map)); + } + return true; + }; + for (int c = 0; c < 3; ++c) { + JXL_RETURN_IF_ERROR(decode_dct(color_dct_[c])); + } + JXL_RETURN_IF_ERROR(decode_dct(sigma_dct_)); + return true; +} + +void Splines::Clear() { + quantization_adjustment_ = 0; + splines_.clear(); + starting_points_.clear(); + segments_.clear(); + segment_indices_.clear(); + segment_y_start_.clear(); +} + +Status Splines::Decode(jxl::BitReader* br, const size_t num_pixels) { + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumSplineContexts, &code, &context_map)); + ANSSymbolReader decoder(&code, br); + const size_t num_splines = + 1 + decoder.ReadHybridUint(kNumSplinesContext, br, context_map); + size_t max_control_points = std::min( + kMaxNumControlPoints, num_pixels / kMaxNumControlPointsPerPixelRatio); + if (num_splines > max_control_points) { + return JXL_FAILURE("Too many splines: %" PRIuS, num_splines); + } + JXL_RETURN_IF_ERROR(DecodeAllStartingPoints(&starting_points_, br, &decoder, + context_map, num_splines)); + + quantization_adjustment_ = UnpackSigned( + decoder.ReadHybridUint(kQuantizationAdjustmentContext, br, context_map)); + + splines_.clear(); + splines_.reserve(num_splines); + size_t num_control_points = num_splines; + for (size_t i = 0; i < num_splines; ++i) { + QuantizedSpline spline; + JXL_RETURN_IF_ERROR(spline.Decode(context_map, &decoder, br, + max_control_points, &num_control_points)); + splines_.push_back(std::move(spline)); + } + + JXL_RETURN_IF_ERROR(decoder.CheckANSFinalState()); + + if (!HasAny()) { + return JXL_FAILURE("Decoded splines but got none"); + } + + return true; +} + +void Splines::AddTo(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect) const { + return Apply(opsin, opsin_rect, image_rect); +} +void Splines::AddToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_row) const { + return ApplyToRow(row_x, row_y, row_b, image_row); +} + +void Splines::SubtractFrom(Image3F* const opsin) const { + return Apply(opsin, Rect(*opsin), Rect(*opsin)); +} + +Status Splines::InitializeDrawCache(const size_t image_xsize, + const size_t image_ysize, + const ColorCorrelationMap& cmap) { + // TODO(veluca): avoid storing segments that are entirely outside image + // boundaries. + segments_.clear(); + segment_indices_.clear(); + segment_y_start_.clear(); + std::vector> segments_by_y; + Spline spline; + // TODO(eustas): not in the spec; limit spline pixels with image area. + float pixel_limit = 16.0f * image_xsize * image_ysize + (1 << 16); + // Apply some extra cap to avoid overflows. + constexpr size_t kHardPixelLimit = 1u << 30; + size_t px_limit = (pixel_limit < static_cast(kHardPixelLimit)) + ? static_cast(pixel_limit) + : kHardPixelLimit; + std::vector intermediate_points; + for (size_t i = 0; i < splines_.size(); ++i) { + JXL_RETURN_IF_ERROR( + splines_[i].Dequantize(starting_points_[i], quantization_adjustment_, + cmap.YtoXRatio(0), cmap.YtoBRatio(0), spline)); + if (std::adjacent_find(spline.control_points.begin(), + spline.control_points.end()) != + spline.control_points.end()) { + // Otherwise division by zero might occur. Once control points coincide, + // the direction of curve is undefined... + return JXL_FAILURE( + "identical successive control points in spline %" PRIuS, i); + } + std::vector> points_to_draw; + const auto add_point = [&](const Spline::Point& point, + const float multiplier) -> bool { + points_to_draw.emplace_back(point, multiplier); + return (points_to_draw.size() <= px_limit); + }; + intermediate_points.clear(); + DrawCentripetalCatmullRomSpline(spline.control_points, intermediate_points); + if (!ForEachEquallySpacedPoint(intermediate_points, add_point)) { + return JXL_FAILURE("Too many pixels covered with splines"); + } + const float arc_length = + (points_to_draw.size() - 2) * kDesiredRenderingDistance + + points_to_draw.back().second; + if (arc_length <= 0.f) { + // This spline wouldn't have any effect. + continue; + } + HWY_DYNAMIC_DISPATCH(SegmentsFromPoints) + (spline, points_to_draw, arc_length, segments_, segments_by_y, &px_limit); + if (px_limit == 0) { + return JXL_FAILURE("Too many pixels covered with splines"); + } + } + // TODO(eustas): consider linear sorting here. + std::sort(segments_by_y.begin(), segments_by_y.end()); + segment_indices_.resize(segments_by_y.size()); + segment_y_start_.resize(image_ysize + 1); + for (size_t i = 0; i < segments_by_y.size(); i++) { + segment_indices_[i] = segments_by_y[i].second; + size_t y = segments_by_y[i].first; + if (y < image_ysize) { + segment_y_start_[y + 1]++; + } + } + for (size_t y = 0; y < image_ysize; y++) { + segment_y_start_[y + 1] += segment_y_start_[y]; + } + return true; +} + +template +void Splines::ApplyToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, + const Rect& image_row) const { + if (segments_.empty()) return; + JXL_ASSERT(image_row.ysize() == 1); + for (size_t iy = 0; iy < image_row.ysize(); iy++) { + HWY_DYNAMIC_DISPATCH(DrawSegments) + (row_x, row_y, row_b, image_row.Line(iy), add, segments_.data(), + segment_indices_.data(), segment_y_start_.data()); + } +} + +template +void Splines::Apply(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect) const { + if (segments_.empty()) return; + for (size_t iy = 0; iy < image_rect.ysize(); iy++) { + const size_t y0 = opsin_rect.Line(iy).y0(); + const size_t x0 = opsin_rect.x0(); + ApplyToRow(opsin->PlaneRow(0, y0) + x0, opsin->PlaneRow(1, y0) + x0, + opsin->PlaneRow(2, y0) + x0, image_rect.Line(iy)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/media/libjxl/src/lib/jxl/splines.h b/media/libjxl/src/lib/jxl/splines.h new file mode 100644 index 000000000..9d2b1a46a --- /dev/null +++ b/media/libjxl/src/lib/jxl/splines.h @@ -0,0 +1,149 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_SPLINES_H_ +#define LIB_JXL_SPLINES_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" + +namespace jxl { + +static constexpr float kDesiredRenderingDistance = 1.f; + +enum SplineEntropyContexts : size_t { + kQuantizationAdjustmentContext = 0, + kStartingPositionContext, + kNumSplinesContext, + kNumControlPointsContext, + kControlPointsContext, + kDCTContext, + kNumSplineContexts +}; + +struct Spline { + struct Point { + Point() : x(0.0f), y(0.0f) {} + Point(float x, float y) : x(x), y(y) {} + float x, y; + bool operator==(const Point& other) const { + return std::fabs(x - other.x) < 1e-3f && std::fabs(y - other.y) < 1e-3f; + } + }; + std::vector control_points; + // X, Y, B. + float color_dct[3][32]; + // Splines are draws by normalized Gaussian splatting. This controls the + // Gaussian's parameter along the spline. + float sigma_dct[32]; +}; + +class QuantizedSplineEncoder; + +class QuantizedSpline { + public: + QuantizedSpline() = default; + explicit QuantizedSpline(const Spline& original, + int32_t quantization_adjustment, float y_to_x, + float y_to_b); + + Status Dequantize(const Spline::Point& starting_point, + int32_t quantization_adjustment, float y_to_x, float y_to_b, + Spline& result) const; + + Status Decode(const std::vector& context_map, + ANSSymbolReader* decoder, BitReader* br, + size_t max_control_points, size_t* total_num_control_points); + + private: + friend class QuantizedSplineEncoder; + + std::vector> + control_points_; // Double delta-encoded. + int color_dct_[3][32] = {}; + int sigma_dct_[32] = {}; +}; + +// A single "drawable unit" of a spline, i.e. a line of the region in which we +// render each Gaussian. The structure doesn't actually depend on the exact +// row, which allows reuse for different y values (which are tracked +// separately). +struct SplineSegment { + float center_x, center_y; + float maximum_distance; + float inv_sigma; + float sigma_over_4_times_intensity; + float color[3]; +}; + +class Splines { + public: + Splines() = default; + explicit Splines(const int32_t quantization_adjustment, + std::vector splines, + std::vector starting_points) + : quantization_adjustment_(quantization_adjustment), + splines_(std::move(splines)), + starting_points_(std::move(starting_points)) {} + + bool HasAny() const { return !splines_.empty(); } + + void Clear(); + + Status Decode(BitReader* br, size_t num_pixels); + + void AddTo(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const; + void AddToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_row) const; + void SubtractFrom(Image3F* opsin) const; + + const std::vector& QuantizedSplines() const { + return splines_; + } + const std::vector& StartingPoints() const { + return starting_points_; + } + + int32_t GetQuantizationAdjustment() const { return quantization_adjustment_; } + + Status InitializeDrawCache(size_t image_xsize, size_t image_ysize, + const ColorCorrelationMap& cmap); + + private: + template + void ApplyToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_row) const; + template + void Apply(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const; + + // If positive, quantization weights are multiplied by 1 + this/8, which + // increases precision. If negative, they are divided by 1 - this/8. If 0, + // they are unchanged. + int32_t quantization_adjustment_ = 0; + std::vector splines_; + std::vector starting_points_; + std::vector segments_; + std::vector segment_indices_; + std::vector segment_y_start_; +}; + +} // namespace jxl + +#endif // LIB_JXL_SPLINES_H_ diff --git a/media/libjxl/src/lib/jxl/splines_gbench.cc b/media/libjxl/src/lib/jxl/splines_gbench.cc new file mode 100644 index 000000000..78ff6d41c --- /dev/null +++ b/media/libjxl/src/lib/jxl/splines_gbench.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/splines.h" + +namespace jxl { +namespace { + +constexpr int kQuantizationAdjustment = 0; +const ColorCorrelationMap* const cmap = new ColorCorrelationMap; +const float kYToX = cmap->YtoXRatio(0); +const float kYToB = cmap->YtoBRatio(0); + +void BM_Splines(benchmark::State& state) { + const size_t n = state.range(); + + std::vector spline_data = { + {/*control_points=*/{ + {9, 54}, {118, 159}, {97, 3}, {10, 40}, {150, 25}, {120, 300}}, + /*color_dct=*/ + {{0.03125f, 0.00625f, 0.003125f}, {1.f, 0.321875f}, {1.f, 0.24375f}}, + /*sigma_dct=*/{0.3125f, 0.f, 0.f, 0.0625f}}}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F drawing_area(320, 320); + ZeroFillImage(&drawing_area); + for (auto _ : state) { + for (size_t i = 0; i < n; ++i) { + JXL_CHECK(splines.InitializeDrawCache(drawing_area.xsize(), + drawing_area.ysize(), *cmap)); + splines.AddTo(&drawing_area, Rect(drawing_area), Rect(drawing_area)); + } + } + + state.SetItemsProcessed(n * state.iterations()); +} + +BENCHMARK(BM_Splines)->Range(1, 1 << 10); + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/splines_test.cc b/media/libjxl/src/lib/jxl/splines_test.cc new file mode 100644 index 000000000..09b2dd562 --- /dev/null +++ b/media/libjxl/src/lib/jxl/splines_test.cc @@ -0,0 +1,345 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/splines.h" + +#include "lib/extras/codec.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { + +std::ostream& operator<<(std::ostream& os, const Spline::Point& p) { + return os << "(" << p.x << ", " << p.y << ")"; +} + +std::ostream& operator<<(std::ostream& os, const Spline& spline) { + return os << "(spline with " << spline.control_points.size() + << " control points)"; +} + +namespace { + +using ::testing::AllOf; +using ::testing::Field; +using ::testing::FloatNear; +using ::testing::Pointwise; + +constexpr int kQuantizationAdjustment = 0; +const ColorCorrelationMap* const cmap = new ColorCorrelationMap; +const float kYToX = cmap->YtoXRatio(0); +const float kYToB = cmap->YtoBRatio(0); + +constexpr float kTolerance = 0.003125; + +std::vector DequantizeSplines(const Splines& splines) { + const auto& quantized_splines = splines.QuantizedSplines(); + const auto& starting_points = splines.StartingPoints(); + JXL_CHECK(quantized_splines.size() == starting_points.size()); + + std::vector dequantized; + for (size_t i = 0; i < quantized_splines.size(); ++i) { + dequantized.emplace_back(); + JXL_CHECK(quantized_splines[i].Dequantize(starting_points[i], + kQuantizationAdjustment, kYToX, + kYToB, dequantized.back())); + } + return dequantized; +} + +MATCHER(ControlPointIs, "") { + const Spline::Point& actual = std::get<0>(arg); + const Spline::Point& expected = std::get<1>(arg); + return testing::ExplainMatchResult( + AllOf(Field(&Spline::Point::x, FloatNear(expected.x, kTolerance)), + Field(&Spline::Point::y, FloatNear(expected.y, kTolerance))), + actual, result_listener); +} + +MATCHER(ControlPointsMatch, "") { + const Spline& actual = std::get<0>(arg); + const Spline& expected = std::get<1>(arg); + return testing::ExplainMatchResult( + Field(&Spline::control_points, + Pointwise(ControlPointIs(), expected.control_points)), + actual, result_listener); +} + +MATCHER(SplinesMatch, "") { + const Spline& actual = std::get<0>(arg); + const Spline& expected = std::get<1>(arg); + if (!testing::ExplainMatchResult(ControlPointsMatch(), arg, + result_listener)) { + return false; + } + for (int i = 0; i < 3; ++i) { + size_t color_dct_size = + sizeof(expected.color_dct[i]) / sizeof(expected.color_dct[i][0]); + for (size_t j = 0; j < color_dct_size; j++) { + testing::StringMatchResultListener color_dct_listener; + if (!testing::ExplainMatchResult( + FloatNear(expected.color_dct[i][j], kTolerance), + actual.color_dct[i][j], &color_dct_listener)) { + *result_listener << ", where color_dct[" << i << "][" << j + << "] don't match, " << color_dct_listener.str(); + return false; + } + } + } + size_t sigma_dct_size = + sizeof(expected.sigma_dct) / sizeof(expected.sigma_dct[0]); + for (size_t i = 0; i < sigma_dct_size; i++) { + testing::StringMatchResultListener sigma_listener; + if (!testing::ExplainMatchResult( + FloatNear(expected.sigma_dct[i], kTolerance), actual.sigma_dct[i], + &sigma_listener)) { + *result_listener << ", where sigma_dct[" << i << "] don't match, " + << sigma_listener.str(); + return false; + } + } + return true; +} + +} // namespace + +TEST(SplinesTest, Serialization) { + std::vector spline_data = { + {/*control_points=*/{ + {109, 54}, {218, 159}, {80, 3}, {110, 274}, {94, 185}, {17, 277}}, + /*color_dct=*/ + {{36.3, 39.7, 23.2, 67.5, 4.4, 71.5, 62.3, 32.3, 92.2, 10.1, 10.8, + 9.2, 6.1, 10.5, 79.1, 7, 24.6, 90.8, 5.5, 84, 43.8, 49, + 33.5, 78.9, 54.5, 77.9, 62.1, 51.4, 36.4, 14.3, 83.7, 35.4}, + {9.4, 53.4, 9.5, 74.9, 72.7, 26.7, 7.9, 0.9, 84.9, 23.2, 26.5, + 31.1, 91, 11.7, 74.1, 39.3, 23.7, 82.5, 4.8, 2.7, 61.2, 96.4, + 13.7, 66.7, 62.9, 82.4, 5.9, 98.7, 21.5, 7.9, 51.7, 63.1}, + {48, 39.3, 6.9, 26.3, 33.3, 6.2, 1.7, 98.9, 59.9, 59.6, 95, + 61.3, 82.7, 53, 6.1, 30.4, 34.7, 96.9, 93.4, 17, 38.8, 80.8, + 63, 18.6, 43.6, 32.3, 61, 20.2, 24.3, 28.3, 69.1, 62.4}}, + /*sigma_dct=*/{32.7, 21.5, 44.4, 1.8, 45.8, 90.6, 29.3, 59.2, + 23.7, 85.2, 84.8, 27.2, 42.1, 84.1, 50.6, 17.6, + 93.7, 4.9, 2.6, 69.8, 94.9, 52, 24.3, 18.8, + 12.1, 95.7, 28.5, 81.4, 89.9, 31.4, 74.8, 52}}, + {/*control_points=*/{{172, 309}, + {196, 277}, + {42, 238}, + {114, 350}, + {307, 290}, + {316, 269}, + {124, 66}, + {233, 267}}, + /*color_dct=*/ + {{15, 28.9, 22, 6.6, 41.8, 83, 8.6, 56.8, 68.9, 9.7, 5.4, + 19.8, 70.8, 90, 52.5, 65.2, 7.8, 23.5, 26.4, 72.2, 64.7, 87.1, + 1.3, 67.5, 46, 68.4, 65.4, 35.5, 29.1, 13, 41.6, 23.9}, + {47.7, 79.4, 62.7, 29.1, 96.8, 18.5, 17.6, 15.2, 80.5, 56, 96.2, + 59.9, 26.7, 96.1, 92.3, 42.1, 35.8, 54, 23.2, 55, 76, 35.8, + 58.4, 88.7, 2.4, 78.1, 95.6, 27.5, 6.6, 78.5, 24.1, 69.8}, + {43.8, 96.5, 0.9, 95.1, 49.1, 71.2, 25.1, 33.6, 75.2, 95, 82.1, + 19.7, 10.5, 44.9, 50, 93.3, 83.5, 99.5, 64.6, 54, 3.5, 99.7, + 45.3, 82.1, 22.4, 37.9, 60, 32.2, 12.6, 4.6, 65.5, 96.4}}, + /*sigma_dct=*/{72.5, 2.6, 41.7, 2.2, 39.7, 79.1, 69.6, 19.9, + 92.3, 71.5, 41.9, 62.1, 30, 49.4, 70.3, 45.3, + 62.5, 47.2, 46.7, 41.2, 90.8, 46.8, 91.2, 55, + 8.1, 69.6, 25.4, 84.7, 61.7, 27.6, 3.7, 46.9}}, + {/*control_points=*/{{100, 186}, + {257, 97}, + {170, 49}, + {25, 169}, + {309, 104}, + {232, 237}, + {385, 101}, + {122, 168}, + {26, 300}, + {390, 88}}, + /*color_dct=*/ + {{16.9, 64.8, 4.2, 10.6, 23.5, 17, 79.3, 5.7, 60.4, 16.6, 94.9, + 63.7, 87.6, 10.5, 3.8, 61.1, 22.9, 81.9, 80.4, 40.5, 45.9, 25.4, + 39.8, 30, 50.2, 90.4, 27.9, 93.7, 65.1, 48.2, 22.3, 43.9}, + {24.9, 66, 3.5, 90.2, 97.1, 15.8, 35.6, 0.6, 68, 39.6, 24.4, + 85.9, 57.7, 77.6, 47.5, 67.9, 4.3, 5.4, 91.2, 58.5, 0.1, 52.2, + 3.5, 47.8, 63.2, 43.5, 85.8, 35.8, 50.2, 35.9, 19.2, 48.2}, + {82.8, 44.9, 76.4, 39.5, 94.1, 14.3, 89.8, 10, 10.5, 74.5, 56.3, + 65.8, 7.8, 23.3, 52.8, 99.3, 56.8, 46, 76.7, 13.5, 67, 22.4, + 29.9, 43.3, 70.3, 26, 74.3, 53.9, 62, 19.1, 49.3, 46.7}}, + /*sigma_dct=*/{83.5, 1.7, 25.1, 18.7, 46.5, 75.3, 28, 62.3, + 50.3, 23.3, 85.6, 96, 45.8, 33.1, 33.4, 52.9, + 26.3, 58.5, 19.6, 70, 92.6, 22.5, 57, 21.6, + 76.8, 87.5, 22.9, 66.3, 35.7, 35.6, 56.8, 67.2}}, + }; + + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + const std::vector quantized_spline_data = DequantizeSplines(splines); + EXPECT_THAT(quantized_spline_data, + Pointwise(ControlPointsMatch(), spline_data)); + + BitWriter writer; + EncodeSplines(splines, &writer, kLayerSplines, HistogramParams(), nullptr); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + + printf("Wrote %" PRIuS " bits of splines.\n", bits_written); + + BitReader reader(writer.GetSpan()); + Splines decoded_splines; + ASSERT_TRUE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + ASSERT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + ASSERT_TRUE(reader.Close()); + + const std::vector decoded_spline_data = + DequantizeSplines(decoded_splines); + EXPECT_THAT(decoded_spline_data, + Pointwise(SplinesMatch(), quantized_spline_data)); +} + +#ifdef JXL_CRASH_ON_ERROR +TEST(SplinesTest, DISABLED_TooManySplinesTest) { +#else +TEST(SplinesTest, TooManySplinesTest) { +#endif + // This is more than the limit for 1000 pixels. + const size_t kNumSplines = 300; + + std::vector quantized_splines; + std::vector starting_points; + for (size_t i = 0; i < kNumSplines; i++) { + Spline spline = { + /*control_points=*/{{1.f + i, 2}, {10.f + i, 25}, {30.f + i, 300}}, + /*color_dct=*/ + {{1.f, 0.2f, 0.1f}, {35.7f, 10.3f}, {35.7f, 7.8f}}, + /*sigma_dct=*/{10.f, 0.f, 0.f, 2.f}}; + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + BitWriter writer; + EncodeSplines(splines, &writer, kLayerSplines, + HistogramParams(SpeedTier::kFalcon, 1), nullptr); + writer.ZeroPadToByte(); + // Re-read splines. + BitReader reader(writer.GetSpan()); + Splines decoded_splines; + EXPECT_FALSE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + EXPECT_TRUE(reader.Close()); +} + +#ifdef JXL_CRASH_ON_ERROR +TEST(SplinesTest, DISABLED_DuplicatePoints) { +#else +TEST(SplinesTest, DuplicatePoints) { +#endif + std::vector control_points{ + {9, 54}, {118, 159}, {97, 3}, // Repeated. + {97, 3}, {10, 40}, {150, 25}, {120, 300}}; + Spline spline{control_points, + /*color_dct=*/ + {{1.f, 0.2f, 0.1f}, {35.7f, 10.3f}, {35.7f, 7.8f}}, + /*sigma_dct=*/{10.f, 0.f, 0.f, 2.f}}; + std::vector spline_data{spline}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F image(320, 320); + ZeroFillImage(&image); + EXPECT_FALSE( + splines.InitializeDrawCache(image.xsize(), image.ysize(), *cmap)); +} + +TEST(SplinesTest, Drawing) { + CodecInOut io_expected; + const PaddedBytes orig = ReadTestData("jxl/splines.pfm"); + ASSERT_TRUE(SetFromBytes(Span(orig), &io_expected, + /*pool=*/nullptr)); + + std::vector control_points{{9, 54}, {118, 159}, {97, 3}, + {10, 40}, {150, 25}, {120, 300}}; + // Use values that survive quant/decorellation roundtrip. + const Spline spline{ + control_points, + /*color_dct=*/ + {{0.4989345073699951171875000f, 0.4997999966144561767578125f}, + {0.4772970676422119140625000f, 0.f, 0.5250000357627868652343750f}, + {-0.0176776945590972900390625f, 0.4900000095367431640625000f, + 0.5250000357627868652343750f}}, + /*sigma_dct=*/ + {0.9427147507667541503906250f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.6665999889373779296875000f}}; + std::vector spline_data = {spline}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F image(320, 320); + ZeroFillImage(&image); + ASSERT_TRUE(splines.InitializeDrawCache(image.xsize(), image.ysize(), *cmap)); + splines.AddTo(&image, Rect(image), Rect(image)); + + CodecInOut io_actual; + io_actual.SetFromImage(CopyImage(image), ColorEncoding::SRGB()); + ASSERT_TRUE( + io_actual.TransformTo(io_expected.Main().c_current(), GetJxlCms())); + + VerifyRelativeError(*io_expected.Main().color(), *io_actual.Main().color(), + 1e-2f, 1e-1f); +} + +TEST(SplinesTest, ClearedEveryFrame) { + CodecInOut io_expected; + const PaddedBytes bytes_expected = + ReadTestData("jxl/spline_on_first_frame.png"); + ASSERT_TRUE(SetFromBytes(Span(bytes_expected), &io_expected, + /*pool=*/nullptr)); + CodecInOut io_actual; + const PaddedBytes bytes_actual = + ReadTestData("jxl/spline_on_first_frame.jxl"); + ASSERT_TRUE(test::DecodeFile({}, bytes_actual, &io_actual, + /*pool=*/nullptr)); + + ASSERT_TRUE(io_actual.TransformTo(ColorEncoding::SRGB(), GetJxlCms())); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < io_actual.ysize(); ++y) { + float* const JXL_RESTRICT row = io_actual.Main().color()->PlaneRow(c, y); + for (size_t x = 0; x < io_actual.xsize(); ++x) { + row[x] = Clamp1(row[x], 0.f, 1.f); + } + } + } + VerifyRelativeError(*io_expected.Main().color(), *io_actual.Main().color(), + 1e-2f, 1e-1f); +} + +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/test_image.h b/media/libjxl/src/lib/jxl/test_image.h new file mode 100644 index 000000000..009344368 --- /dev/null +++ b/media/libjxl/src/lib/jxl/test_image.h @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TEST_IMAGE_H_ +#define LIB_JXL_TEST_IMAGE_H_ + +#include + +#include + +#include "lib/jxl/base/random.h" + +namespace jxl { +namespace test { + +// Returns a test image with some autogenerated pixel content, using 16 bits per +// channel, big endian order, 1 to 4 channels +// The seed parameter allows to create images with different pixel content. +std::vector GetSomeTestImage(size_t xsize, size_t ysize, + size_t num_channels, uint16_t seed) { + // Cause more significant image difference for successive seeds. + Rng generator(seed); + + // Returns random integer in interval [0, max_value) + auto rng = [&generator](size_t max_value) -> size_t { + return generator.UniformU(0, max_value); + }; + + // Dark background gradient color + uint16_t r0 = rng(32768); + uint16_t g0 = rng(32768); + uint16_t b0 = rng(32768); + uint16_t a0 = rng(32768); + uint16_t r1 = rng(32768); + uint16_t g1 = rng(32768); + uint16_t b1 = rng(32768); + uint16_t a1 = rng(32768); + + // Circle with different color + size_t circle_x = rng(xsize); + size_t circle_y = rng(ysize); + size_t circle_r = rng(std::min(xsize, ysize)); + + // Rectangle with random noise + size_t rect_x0 = rng(xsize); + size_t rect_y0 = rng(ysize); + size_t rect_x1 = rng(xsize); + size_t rect_y1 = rng(ysize); + if (rect_x1 < rect_x0) std::swap(rect_x0, rect_y1); + if (rect_y1 < rect_y0) std::swap(rect_y0, rect_y1); + + size_t num_pixels = xsize * ysize; + // 16 bits per channel, big endian, 4 channels + std::vector pixels(num_pixels * num_channels * 2); + // Create pixel content to test, actual content does not matter as long as it + // can be compared after roundtrip. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + uint16_t r = r0 * (ysize - y - 1) / ysize + r1 * y / ysize; + uint16_t g = g0 * (ysize - y - 1) / ysize + g1 * y / ysize; + uint16_t b = b0 * (ysize - y - 1) / ysize + b1 * y / ysize; + uint16_t a = a0 * (ysize - y - 1) / ysize + a1 * y / ysize; + // put some shape in there for visual debugging + if ((x - circle_x) * (x - circle_x) + (y - circle_y) * (y - circle_y) < + circle_r * circle_r) { + r = (65535 - x * y) ^ seed; + g = (x << 8) + y + seed; + b = (y << 8) + x * seed; + a = 32768 + x * 256 - y; + } else if (x > rect_x0 && x < rect_x1 && y > rect_y0 && y < rect_y1) { + r = rng(65536); + g = rng(65536); + b = rng(65536); + a = rng(65536); + } + size_t i = (y * xsize + x) * 2 * num_channels; + pixels[i + 0] = (r >> 8); + pixels[i + 1] = (r & 255); + if (num_channels >= 2) { + // This may store what is called 'g' in the alpha channel of a 2-channel + // image, but that's ok since the content is arbitrary + pixels[i + 2] = (g >> 8); + pixels[i + 3] = (g & 255); + } + if (num_channels >= 3) { + pixels[i + 4] = (b >> 8); + pixels[i + 5] = (b & 255); + } + if (num_channels >= 4) { + pixels[i + 6] = (a >> 8); + pixels[i + 7] = (a & 255); + } + } + } + return pixels; +} + +} // namespace test +} // namespace jxl + +#endif // LIB_JXL_TEST_IMAGE_H_ diff --git a/media/libjxl/src/lib/jxl/test_utils.h b/media/libjxl/src/lib/jxl/test_utils.h new file mode 100644 index 000000000..b55cc3d20 --- /dev/null +++ b/media/libjxl/src/lib/jxl/test_utils.h @@ -0,0 +1,610 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TEST_UTILS_H_ +#define LIB_JXL_TEST_UTILS_H_ + +// Macros and functions useful for tests. + +// gmock unconditionally redefines those macros (to wrong values). +// Lets include it only here and mitigate the problem. +#pragma push_macro("PRIdS") +#pragma push_macro("PRIuS") +#include "gmock/gmock.h" +#pragma pop_macro("PRIuS") +#pragma pop_macro("PRIdS") + +#include "gtest/gtest.h" +#include "jxl/codestream_header.h" +#include "jxl/encode.h" +#include "lib/extras/dec/jxl.h" +#include "lib/extras/packed_image_convert.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" // JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/test_image.h" + +#ifdef JXL_DISABLE_SLOW_TESTS +#define JXL_SLOW_TEST(X) DISABLED_##X +#else +#define JXL_SLOW_TEST(X) X +#endif // JXL_DISABLE_SLOW_TESTS + +#if JPEGXL_ENABLE_TRANSCODE_JPEG +#define JXL_TRANSCODE_JPEG_TEST(X) X +#else +#define JXL_TRANSCODE_JPEG_TEST(X) DISABLED_##X +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +#ifdef THREAD_SANITIZER +#define JXL_TSAN_SLOW_TEST(X) DISABLED_##X +#else +#define JXL_TSAN_SLOW_TEST(X) X +#endif // THREAD_SANITIZER + +// googletest before 1.10 didn't define INSTANTIATE_TEST_SUITE_P() but instead +// used INSTANTIATE_TEST_CASE_P which is now deprecated. +#ifdef INSTANTIATE_TEST_SUITE_P +#define JXL_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_SUITE_P +#else +#define JXL_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P +#endif + +// Ensures that we don't make our test bounds too lax, effectively disabling the +// tests. +MATCHER_P(IsSlightlyBelow, max, "") { return max * 0.75 <= arg && arg <= max; } + +namespace jxl { +namespace test { + +void JxlBasicInfoSetFromPixelFormat(JxlBasicInfo* basic_info, + const JxlPixelFormat* pixel_format) { + JxlEncoderInitBasicInfo(basic_info); + switch (pixel_format->data_type) { + case JXL_TYPE_FLOAT: + basic_info->bits_per_sample = 32; + basic_info->exponent_bits_per_sample = 8; + break; + case JXL_TYPE_FLOAT16: + basic_info->bits_per_sample = 16; + basic_info->exponent_bits_per_sample = 5; + break; + case JXL_TYPE_UINT8: + basic_info->bits_per_sample = 8; + basic_info->exponent_bits_per_sample = 0; + break; + case JXL_TYPE_UINT16: + basic_info->bits_per_sample = 16; + basic_info->exponent_bits_per_sample = 0; + break; + default: + JXL_ABORT("Unhandled JxlDataType"); + } + if (pixel_format->num_channels < 3) { + basic_info->num_color_channels = 1; + } else { + basic_info->num_color_channels = 3; + } + if (pixel_format->num_channels == 2 || pixel_format->num_channels == 4) { + basic_info->alpha_exponent_bits = basic_info->exponent_bits_per_sample; + basic_info->alpha_bits = basic_info->bits_per_sample; + basic_info->num_extra_channels = 1; + } else { + basic_info->alpha_exponent_bits = 0; + basic_info->alpha_bits = 0; + } +} + +MATCHER_P(MatchesPrimariesAndTransferFunction, color_encoding, "") { + return (arg.ICC() == color_encoding.ICC() || + (arg.primaries == color_encoding.primaries && + arg.tf.IsSame(color_encoding.tf))); +} + +MATCHER(MatchesPrimariesAndTransferFunction, "") { + return testing::ExplainMatchResult( + MatchesPrimariesAndTransferFunction(std::get<1>(arg)), std::get<0>(arg), + result_listener); +} + +template +Status DecodeFile(extras::JXLDecompressParams dparams, const Source& file, + CodecInOut* JXL_RESTRICT io, ThreadPool* pool) { + if (pool && !dparams.runner_opaque) { + dparams.runner = pool->runner(); + dparams.runner_opaque = pool->runner_opaque(); + } + extras::PackedPixelFile ppf; + JXL_RETURN_IF_ERROR(DecodeImageJXL(file.data(), file.size(), dparams, + /*decoded_bytes=*/nullptr, &ppf)); + JXL_RETURN_IF_ERROR(ConvertPackedPixelFileToCodecInOut(ppf, pool, io)); + return true; +} + +// Returns compressed size [bytes]. +size_t Roundtrip(const CodecInOut* io, const CompressParams& cparams, + extras::JXLDecompressParams dparams, ThreadPool* pool, + CodecInOut* JXL_RESTRICT io2, AuxOut* aux_out = nullptr) { + PaddedBytes compressed; + + std::vector original_metadata_encodings; + std::vector original_current_encodings; + for (const ImageBundle& ib : io->frames) { + // Remember original encoding, will be returned by decoder. + original_metadata_encodings.push_back(ib.metadata()->color_encoding); + // c_current should not change during encoding. + original_current_encodings.push_back(ib.c_current()); + } + + std::unique_ptr enc_state = + jxl::make_unique(); + EXPECT_TRUE(EncodeFile(cparams, io, enc_state.get(), &compressed, GetJxlCms(), + aux_out, pool)); + + std::vector metadata_encodings_1; + for (const ImageBundle& ib1 : io->frames) { + metadata_encodings_1.push_back(ib1.metadata()->color_encoding); + } + + // Should still be in the same color space after encoding. + EXPECT_THAT(metadata_encodings_1, + testing::Pointwise(MatchesPrimariesAndTransferFunction(), + original_metadata_encodings)); + + EXPECT_TRUE(DecodeFile(dparams, compressed, io2, pool)); + + std::vector metadata_encodings_2; + std::vector current_encodings_2; + for (const ImageBundle& ib2 : io2->frames) { + metadata_encodings_2.push_back(ib2.metadata()->color_encoding); + current_encodings_2.push_back(ib2.c_current()); + } + + EXPECT_THAT(io2->frames, testing::SizeIs(io->frames.size())); + // We always produce the original color encoding if a color transform hook is + // set. + EXPECT_THAT(current_encodings_2, + testing::Pointwise(MatchesPrimariesAndTransferFunction(), + original_current_encodings)); + + // Decoder returns the originals passed to the encoder. + EXPECT_THAT(metadata_encodings_2, + testing::Pointwise(MatchesPrimariesAndTransferFunction(), + original_metadata_encodings)); + + return compressed.size(); +} + +void CoalesceGIFAnimationWithAlpha(CodecInOut* io) { + ImageBundle canvas = io->frames[0].Copy(); + for (size_t i = 1; i < io->frames.size(); i++) { + const ImageBundle& frame = io->frames[i]; + ImageBundle rendered = canvas.Copy(); + for (size_t y = 0; y < frame.ysize(); y++) { + float* row0 = + rendered.color()->PlaneRow(0, frame.origin.y0 + y) + frame.origin.x0; + float* row1 = + rendered.color()->PlaneRow(1, frame.origin.y0 + y) + frame.origin.x0; + float* row2 = + rendered.color()->PlaneRow(2, frame.origin.y0 + y) + frame.origin.x0; + float* rowa = + rendered.alpha()->Row(frame.origin.y0 + y) + frame.origin.x0; + const float* row0f = frame.color().PlaneRow(0, y); + const float* row1f = frame.color().PlaneRow(1, y); + const float* row2f = frame.color().PlaneRow(2, y); + const float* rowaf = frame.alpha().Row(y); + for (size_t x = 0; x < frame.xsize(); x++) { + if (rowaf[x] != 0) { + row0[x] = row0f[x]; + row1[x] = row1f[x]; + row2[x] = row2f[x]; + rowa[x] = rowaf[x]; + } + } + } + if (frame.use_for_next_frame) { + canvas = rendered.Copy(); + } + io->frames[i] = std::move(rendered); + } +} + +// A POD descriptor of a ColorEncoding. Only used in tests as the return value +// of AllEncodings(). +struct ColorEncodingDescriptor { + ColorSpace color_space; + WhitePoint white_point; + Primaries primaries; + TransferFunction tf; + RenderingIntent rendering_intent; +}; + +static inline ColorEncoding ColorEncodingFromDescriptor( + const ColorEncodingDescriptor& desc) { + ColorEncoding c; + c.SetColorSpace(desc.color_space); + c.white_point = desc.white_point; + c.primaries = desc.primaries; + c.tf.SetTransferFunction(desc.tf); + c.rendering_intent = desc.rendering_intent; + JXL_CHECK(c.CreateICC()); + return c; +} + +// Define the operator<< for tests. +static inline ::std::ostream& operator<<(::std::ostream& os, + const ColorEncodingDescriptor& c) { + return os << "ColorEncoding/" << Description(ColorEncodingFromDescriptor(c)); +} + +// Returns ColorEncodingDescriptors, which are only used in tests. To obtain a +// ColorEncoding object call ColorEncodingFromDescriptor and then call +// ColorEncoding::CreateProfile() on that object to generate a profile. +std::vector AllEncodings() { + std::vector all_encodings; + all_encodings.reserve(300); + ColorEncoding c; + + for (ColorSpace cs : Values()) { + if (cs == ColorSpace::kUnknown || cs == ColorSpace::kXYB) continue; + c.SetColorSpace(cs); + + for (WhitePoint wp : Values()) { + if (wp == WhitePoint::kCustom) continue; + if (c.ImplicitWhitePoint() && c.white_point != wp) continue; + c.white_point = wp; + + for (Primaries primaries : Values()) { + if (primaries == Primaries::kCustom) continue; + if (!c.HasPrimaries()) continue; + c.primaries = primaries; + + for (TransferFunction tf : Values()) { + if (tf == TransferFunction::kUnknown) continue; + if (c.tf.SetImplicit() && + (c.tf.IsGamma() || c.tf.GetTransferFunction() != tf)) { + continue; + } + c.tf.SetTransferFunction(tf); + + for (RenderingIntent ri : Values()) { + ColorEncodingDescriptor cdesc; + cdesc.color_space = cs; + cdesc.white_point = wp; + cdesc.primaries = primaries; + cdesc.tf = tf; + cdesc.rendering_intent = ri; + all_encodings.push_back(cdesc); + } + } + } + } + } + + return all_encodings; +} + +// Returns a CodecInOut based on the buf, xsize, ysize, and the assumption +// that the buffer was created using `GetSomeTestImage`. +jxl::CodecInOut SomeTestImageToCodecInOut(const std::vector& buf, + size_t num_channels, size_t xsize, + size_t ysize) { + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB( + /*is_gray=*/num_channels == 1 || num_channels == 2); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(buf.data(), buf.size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/num_channels < 3), num_channels, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*pool=*/nullptr, + /*ib=*/&io.Main(), /*float_in=*/false, 0)); + return io; +} + +bool Near(double expected, double value, double max_dist) { + double dist = expected > value ? expected - value : value - expected; + return dist <= max_dist; +} + +// Loads a Big-Endian float +float LoadBEFloat(const uint8_t* p) { + uint32_t u = LoadBE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +// Loads a Little-Endian float +float LoadLEFloat(const uint8_t* p) { + uint32_t u = LoadLE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +// Based on highway scalar implementation, for testing +float LoadFloat16(uint16_t bits16) { + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = (1.0f / 16384) * (mantissa * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + float result; + memcpy(&result, &bits32, 4); + return result; +} + +float LoadLEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadLE16(p); + return LoadFloat16(bits16); +} + +float LoadBEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadBE16(p); + return LoadFloat16(bits16); +} + +size_t GetPrecision(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + // Floating point mantissa precision + return 24; + case JXL_TYPE_FLOAT16: + return 11; + default: + JXL_ABORT("Unhandled JxlDataType"); + } +} + +size_t GetDataBits(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + default: + JXL_ABORT("Unhandled JxlDataType"); + } +} + +// Procedure to convert pixels to double precision, not efficient, but +// well-controlled for testing. It uses double, to be able to represent all +// precisions needed for the maximum data types the API supports: uint32_t +// integers, and, single precision float. The values are in range 0-1 for SDR. +std::vector ConvertToRGBA32(const uint8_t* pixels, size_t xsize, + size_t ysize, const JxlPixelFormat& format, + double factor = 0.0) { + std::vector result(xsize * ysize * 4); + size_t num_channels = format.num_channels; + bool gray = num_channels == 1 || num_channels == 2; + bool alpha = num_channels == 2 || num_channels == 4; + + size_t stride = + xsize * jxl::DivCeil(GetDataBits(format.data_type) * num_channels, + jxl::kBitsPerByte); + if (format.align > 1) stride = jxl::RoundUpTo(stride, format.align); + + if (format.data_type == JXL_TYPE_UINT8) { + // Multiplier to bring to 0-1.0 range + double mul = factor > 0.0 ? factor : 1.0 / 255.0; + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels; + double r = pixels[i]; + double g = gray ? r : pixels[i + 1]; + double b = gray ? r : pixels[i + 2]; + double a = alpha ? pixels[i + num_channels - 1] : 255; + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_UINT16) { + // Multiplier to bring to 0-1.0 range + double mul = factor > 0.0 ? factor : 1.0 / 65535.0; + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 2; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = (pixels[i + 0] << 8) + pixels[i + 1]; + g = gray ? r : (pixels[i + 2] << 8) + pixels[i + 3]; + b = gray ? r : (pixels[i + 4] << 8) + pixels[i + 5]; + a = alpha ? (pixels[i + num_channels * 2 - 2] << 8) + + pixels[i + num_channels * 2 - 1] + : 65535; + } else { + r = (pixels[i + 1] << 8) + pixels[i + 0]; + g = gray ? r : (pixels[i + 3] << 8) + pixels[i + 2]; + b = gray ? r : (pixels[i + 5] << 8) + pixels[i + 4]; + a = alpha ? (pixels[i + num_channels * 2 - 1] << 8) + + pixels[i + num_channels * 2 - 2] + : 65535; + } + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_FLOAT) { + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 4; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = LoadBEFloat(pixels + i); + g = gray ? r : LoadBEFloat(pixels + i + 4); + b = gray ? r : LoadBEFloat(pixels + i + 8); + a = alpha ? LoadBEFloat(pixels + i + num_channels * 4 - 4) : 1.0; + } else { + r = LoadLEFloat(pixels + i); + g = gray ? r : LoadLEFloat(pixels + i + 4); + b = gray ? r : LoadLEFloat(pixels + i + 8); + a = alpha ? LoadLEFloat(pixels + i + num_channels * 4 - 4) : 1.0; + } + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + } + } else if (format.data_type == JXL_TYPE_FLOAT16) { + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 2; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = LoadBEFloat16(pixels + i); + g = gray ? r : LoadBEFloat16(pixels + i + 2); + b = gray ? r : LoadBEFloat16(pixels + i + 4); + a = alpha ? LoadBEFloat16(pixels + i + num_channels * 2 - 2) : 1.0; + } else { + r = LoadLEFloat16(pixels + i); + g = gray ? r : LoadLEFloat16(pixels + i + 2); + b = gray ? r : LoadLEFloat16(pixels + i + 4); + a = alpha ? LoadLEFloat16(pixels + i + num_channels * 2 - 2) : 1.0; + } + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + } + } else { + JXL_ASSERT(false); // Unsupported type + } + return result; +} +// Returns amount of pixels which differ between the two pictures. Image b is +// the image after roundtrip after roundtrip, image a before roundtrip. There +// are more strict requirements for the alpha channel and grayscale values of +// the output image. +size_t ComparePixels(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format_a, + const JxlPixelFormat& format_b, + double threshold_multiplier = 1.0) { + // Convert both images to equal full precision for comparison. + std::vector a_full = ConvertToRGBA32(a, xsize, ysize, format_a); + std::vector b_full = ConvertToRGBA32(b, xsize, ysize, format_b); + bool gray_a = format_a.num_channels < 3; + bool gray_b = format_b.num_channels < 3; + bool alpha_a = !(format_a.num_channels & 1); + bool alpha_b = !(format_b.num_channels & 1); + size_t bits_a = GetPrecision(format_a.data_type); + size_t bits_b = GetPrecision(format_b.data_type); + size_t bits = std::min(bits_a, bits_b); + // How much distance is allowed in case of pixels with lower bit depths, given + // that the double precision float images use range 0-1.0. + // E.g. in case of 1-bit this is 0.5 since 0.499 must map to 0 and 0.501 must + // map to 1. + double precision = 0.5 * threshold_multiplier / ((1ull << bits) - 1ull); + if (format_a.data_type == JXL_TYPE_FLOAT16 || + format_b.data_type == JXL_TYPE_FLOAT16) { + // Lower the precision for float16, because it currently looks like the + // scalar and wasm implementations of hwy have 1 less bit of precision + // than the x86 implementations. + // TODO(lode): Set the required precision back to 11 bits when possible. + precision = 0.5 * threshold_multiplier / ((1ull << (bits - 1)) - 1ull); + } + size_t numdiff = 0; + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + size_t i = (y * xsize + x) * 4; + bool ok = true; + if (gray_a || gray_b) { + if (!Near(a_full[i + 0], b_full[i + 0], precision)) ok = false; + // If the input was grayscale and the output not, then the output must + // have all channels equal. + if (gray_a && b_full[i + 0] != b_full[i + 1] && + b_full[i + 2] != b_full[i + 2]) { + ok = false; + } + } else { + if (!Near(a_full[i + 0], b_full[i + 0], precision) || + !Near(a_full[i + 1], b_full[i + 1], precision) || + !Near(a_full[i + 2], b_full[i + 2], precision)) { + ok = false; + } + } + if (alpha_a && alpha_b) { + if (!Near(a_full[i + 3], b_full[i + 3], precision)) ok = false; + } else { + // If the input had no alpha channel, the output should be opaque + // after roundtrip. + if (alpha_b && !Near(1.0, b_full[i + 3], precision)) ok = false; + } + if (!ok) numdiff++; + } + } + return numdiff; +} +double DistanceRMS(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format) { + // Convert both images to equal full precision for comparison. + std::vector a_full = ConvertToRGBA32(a, xsize, ysize, format); + std::vector b_full = ConvertToRGBA32(b, xsize, ysize, format); + double sum = 0.0; + for (size_t y = 0; y < ysize; y++) { + double row_sum = 0.0; + for (size_t x = 0; x < xsize; x++) { + size_t i = (y * xsize + x) * 4; + for (size_t c = 0; c < format.num_channels; ++c) { + double diff = a_full[i + c] - b_full[i + c]; + row_sum += diff * diff; + } + } + sum += row_sum; + } + sum /= (xsize * ysize); + return sqrt(sum); +} +} // namespace test + +bool operator==(const jxl::PaddedBytes& a, const jxl::PaddedBytes& b) { + if (a.size() != b.size()) return false; + if (memcmp(a.data(), b.data(), a.size()) != 0) return false; + return true; +} + +// Allow using EXPECT_EQ on jxl::PaddedBytes +bool operator!=(const jxl::PaddedBytes& a, const jxl::PaddedBytes& b) { + return !(a == b); +} +} // namespace jxl + +#endif // LIB_JXL_TEST_UTILS_H_ diff --git a/media/libjxl/src/lib/jxl/testdata.h b/media/libjxl/src/lib/jxl/testdata.h new file mode 100644 index 000000000..d387219bb --- /dev/null +++ b/media/libjxl/src/lib/jxl/testdata.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TESTDATA_H_ +#define LIB_JXL_TESTDATA_H_ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "lib/jxl/base/file_io.h" + +namespace jxl { + +static inline PaddedBytes ReadTestData(const std::string& filename) { + std::string full_path = std::string(TEST_DATA_PATH "/") + filename; + PaddedBytes data; + JXL_CHECK(ReadFile(full_path, &data)); + return data; +} + +} // namespace jxl + +#endif // LIB_JXL_TESTDATA_H_ diff --git a/media/libjxl/src/lib/jxl/tf_gbench.cc b/media/libjxl/src/lib/jxl/tf_gbench.cc new file mode 100644 index 000000000..9c010d460 --- /dev/null +++ b/media/libjxl/src/lib/jxl/tf_gbench.cc @@ -0,0 +1,143 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/image_ops.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/tf_gbench.cc" +#include +#include + +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#define RUN_BENCHMARK(F) \ + constexpr size_t kNum = 1 << 12; \ + HWY_FULL(float) d; \ + /* Three parallel runs, as this will run on R, G and B. */ \ + auto sum1 = Zero(d); \ + auto sum2 = Zero(d); \ + auto sum3 = Zero(d); \ + for (auto _ : state) { \ + auto x = Set(d, 1e-5); \ + auto v1 = Set(d, 1e-5); \ + auto v2 = Set(d, 1.1e-5); \ + auto v3 = Set(d, 1.2e-5); \ + for (size_t i = 0; i < kNum; i++) { \ + sum1 += F(d, v1); \ + sum2 += F(d, v2); \ + sum3 += F(d, v3); \ + v1 += x; \ + v2 += x; \ + v3 += x; \ + } \ + } \ + /* floats per second */ \ + state.SetItemsProcessed(kNum* state.iterations() * Lanes(d) * 3); \ + benchmark::DoNotOptimize(sum1 + sum2 + sum3); + +#define RUN_BENCHMARK_SCALAR(F) \ + constexpr size_t kNum = 1 << 12; \ + /* Three parallel runs, as this will run on R, G and B. */ \ + float sum1 = 0, sum2 = 0, sum3 = 0; \ + for (auto _ : state) { \ + float x = 1e-5; \ + float v1 = 1e-5; \ + float v2 = 1.1e-5; \ + float v3 = 1.2e-5; \ + for (size_t i = 0; i < kNum; i++) { \ + sum1 += F(v1); \ + sum2 += F(v2); \ + sum3 += F(v3); \ + v1 += x; \ + v2 += x; \ + v3 += x; \ + } \ + } \ + /* floats per second */ \ + state.SetItemsProcessed(kNum* state.iterations() * 3); \ + benchmark::DoNotOptimize(sum1 + sum2 + sum3); + +HWY_NOINLINE void BM_FastSRGB(benchmark::State& state) { + RUN_BENCHMARK(FastLinearToSRGB); +} + +HWY_NOINLINE void BM_TFSRGB(benchmark::State& state) { + RUN_BENCHMARK(TF_SRGB().EncodedFromDisplay); +} + +HWY_NOINLINE void BM_PQDFE(benchmark::State& state) { + RUN_BENCHMARK(TF_PQ().DisplayFromEncoded); +} + +HWY_NOINLINE void BM_PQEFD(benchmark::State& state) { + RUN_BENCHMARK(TF_PQ().EncodedFromDisplay); +} + +HWY_NOINLINE void BM_PQSlowDFE(benchmark::State& state) { + RUN_BENCHMARK_SCALAR(TF_PQ().DisplayFromEncoded); +} + +HWY_NOINLINE void BM_PQSlowEFD(benchmark::State& state) { + RUN_BENCHMARK_SCALAR(TF_PQ().EncodedFromDisplay); +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(BM_FastSRGB); +HWY_EXPORT(BM_TFSRGB); +HWY_EXPORT(BM_PQDFE); +HWY_EXPORT(BM_PQEFD); +HWY_EXPORT(BM_PQSlowDFE); +HWY_EXPORT(BM_PQSlowEFD); + +float SRGB_pow(float x) { + return x < 0.0031308f ? 12.92f * x : 1.055f * powf(x, 1.0f / 2.4f) - 0.055f; +} + +void BM_FastSRGB(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_FastSRGB)(state); +} +void BM_TFSRGB(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_TFSRGB)(state); +} +void BM_PQDFE(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQDFE)(state); +} +void BM_PQEFD(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQEFD)(state); +} +void BM_PQSlowDFE(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQSlowDFE)(state); +} +void BM_PQSlowEFD(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQSlowEFD)(state); +} + +void BM_SRGB_pow(benchmark::State& state) { RUN_BENCHMARK_SCALAR(SRGB_pow); } + +BENCHMARK(BM_FastSRGB); +BENCHMARK(BM_TFSRGB); +BENCHMARK(BM_SRGB_pow); +BENCHMARK(BM_PQDFE); +BENCHMARK(BM_PQEFD); +BENCHMARK(BM_PQSlowDFE); +BENCHMARK(BM_PQSlowEFD); + +} // namespace +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl/toc.cc b/media/libjxl/src/lib/jxl/toc.cc new file mode 100644 index 000000000..24cdd02c6 --- /dev/null +++ b/media/libjxl/src/lib/jxl/toc.cc @@ -0,0 +1,106 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/toc.h" + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +size_t MaxBits(const size_t num_sizes) { + const size_t entry_bits = U32Coder::MaxEncodedBits(kTocDist) * num_sizes; + // permutation bit (not its tokens!), padding, entries, padding. + return 1 + kBitsPerByte + entry_bits + kBitsPerByte; +} + +Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector* JXL_RESTRICT sizes, + std::vector* JXL_RESTRICT permutation) { + if (toc_entries > 65536) { + // Prevent out of memory if invalid JXL codestream causes a bogus amount + // of toc_entries such as 2720436919446 to be computed. + // TODO(lode): verify whether 65536 is a reasonable upper bound + return JXL_FAILURE("too many toc entries"); + } + + sizes->clear(); + sizes->resize(toc_entries); + if (reader->TotalBitsConsumed() >= reader->TotalBytes() * kBitsPerByte) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, "Not enough bytes for TOC"); + } + const auto check_bit_budget = [&](size_t num_entries) -> Status { + // U32Coder reads 2 bits to recognize variant and kTocDist cheapest variant + // is Bits(10), this way at least 12 bits are required per toc-entry. + size_t minimal_bit_cost = num_entries * (2 + 10); + size_t bit_budget = reader->TotalBytes() * 8; + size_t expenses = reader->TotalBitsConsumed(); + if ((expenses <= bit_budget) && + (minimal_bit_cost <= bit_budget - expenses)) { + return true; + } + return JXL_STATUS(StatusCode::kNotEnoughBytes, "Not enough bytes for TOC"); + }; + + JXL_DASSERT(toc_entries > 0); + if (reader->ReadFixedBits<1>() == 1) { + JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); + permutation->resize(toc_entries); + JXL_RETURN_IF_ERROR(DecodePermutation(/*skip=*/0, toc_entries, + permutation->data(), reader)); + } + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); + for (size_t i = 0; i < toc_entries; ++i) { + (*sizes)[i] = U32Coder::Read(kTocDist, reader); + } + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + JXL_RETURN_IF_ERROR(check_bit_budget(0)); + return true; +} + +Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector* JXL_RESTRICT offsets, + std::vector* JXL_RESTRICT sizes, + uint64_t* total_size) { + std::vector permutation; + JXL_RETURN_IF_ERROR(ReadToc(toc_entries, reader, sizes, &permutation)); + + offsets->clear(); + offsets->resize(toc_entries); + + // Prefix sum starting with 0 and ending with the offset of the last group + uint64_t offset = 0; + for (size_t i = 0; i < toc_entries; ++i) { + if (offset + (*sizes)[i] < offset) { + return JXL_FAILURE("group offset overflow"); + } + (*offsets)[i] = offset; + offset += (*sizes)[i]; + } + if (total_size) { + *total_size = offset; + } + + if (!permutation.empty()) { + std::vector permuted_offsets; + std::vector permuted_sizes; + permuted_offsets.reserve(toc_entries); + permuted_sizes.reserve(toc_entries); + for (coeff_order_t index : permutation) { + permuted_offsets.push_back((*offsets)[index]); + permuted_sizes.push_back((*sizes)[index]); + } + std::swap(*offsets, permuted_offsets); + std::swap(*sizes, permuted_sizes); + } + + return true; +} +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/toc.h b/media/libjxl/src/lib/jxl/toc.h new file mode 100644 index 000000000..a97197ad4 --- /dev/null +++ b/media/libjxl/src/lib/jxl/toc.h @@ -0,0 +1,55 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TOC_H_ +#define LIB_JXL_TOC_H_ + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// (2+bits) = 2,3,4 bytes so encoders can patch TOC after encoding. +// 30 is sufficient for 4K channels of uncompressed 16-bit samples. +constexpr U32Enc kTocDist(Bits(10), BitsOffset(14, 1024), BitsOffset(22, 17408), + BitsOffset(30, 4211712)); + +size_t MaxBits(const size_t num_sizes); + +// TODO(veluca): move these to FrameDimensions. +static JXL_INLINE size_t AcGroupIndex(size_t pass, size_t group, + size_t num_groups, size_t num_dc_groups, + bool has_ac_global) { + return 1 + num_dc_groups + static_cast(has_ac_global) + + pass * num_groups + group; +} + +static JXL_INLINE size_t NumTocEntries(size_t num_groups, size_t num_dc_groups, + size_t num_passes, bool has_ac_global) { + if (num_groups == 1 && num_passes == 1) return 1; + return AcGroupIndex(0, 0, num_groups, num_dc_groups, has_ac_global) + + num_groups * num_passes; +} + +Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector* JXL_RESTRICT sizes, + std::vector* JXL_RESTRICT permutation); + +Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector* JXL_RESTRICT offsets, + std::vector* JXL_RESTRICT sizes, + uint64_t* total_size); + +} // namespace jxl + +#endif // LIB_JXL_TOC_H_ diff --git a/media/libjxl/src/lib/jxl/toc_test.cc b/media/libjxl/src/lib/jxl/toc_test.cc new file mode 100644 index 000000000..2f3bf5b23 --- /dev/null +++ b/media/libjxl/src/lib/jxl/toc_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/toc.h" + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_toc.h" + +namespace jxl { +namespace { + +void Roundtrip(size_t num_entries, bool permute, Rng* rng) { + // Generate a random permutation. + std::vector permutation(num_entries); + std::vector inv_permutation(num_entries); + for (size_t i = 0; i < num_entries; i++) { + permutation[i] = i; + inv_permutation[i] = i; + } + if (permute) { + rng->Shuffle(permutation.data(), permutation.size()); + for (size_t i = 0; i < num_entries; i++) { + inv_permutation[permutation[i]] = i; + } + } + + // Generate num_entries groups of random (byte-aligned) length + std::vector group_codes(num_entries); + for (BitWriter& writer : group_codes) { + const size_t max_bits = (*rng)() & 0xFFF; + BitWriter::Allotment allotment(&writer, max_bits + kBitsPerByte); + size_t i = 0; + for (; i + BitWriter::kMaxBitsPerCall < max_bits; + i += BitWriter::kMaxBitsPerCall) { + writer.Write(BitWriter::kMaxBitsPerCall, 0); + } + for (; i < max_bits; i += 1) { + writer.Write(/*n_bits=*/1, 0); + } + writer.ZeroPadToByte(); + AuxOut aux_out; + ReclaimAndCharge(&writer, &allotment, 0, &aux_out); + } + + BitWriter writer; + AuxOut aux_out; + ASSERT_TRUE(WriteGroupOffsets(group_codes, permute ? &permutation : nullptr, + &writer, &aux_out)); + + BitReader reader(writer.GetSpan()); + std::vector group_offsets; + std::vector group_sizes; + uint64_t total_size; + ASSERT_TRUE(ReadGroupOffsets(num_entries, &reader, &group_offsets, + &group_sizes, &total_size)); + ASSERT_EQ(num_entries, group_offsets.size()); + ASSERT_EQ(num_entries, group_sizes.size()); + EXPECT_TRUE(reader.Close()); + + uint64_t prefix_sum = 0; + for (size_t i = 0; i < num_entries; ++i) { + EXPECT_EQ(prefix_sum, group_offsets[inv_permutation[i]]); + + EXPECT_EQ(0u, group_codes[i].BitsWritten() % kBitsPerByte); + prefix_sum += group_codes[i].BitsWritten() / kBitsPerByte; + + if (i + 1 < num_entries) { + EXPECT_EQ( + group_offsets[inv_permutation[i]] + group_sizes[inv_permutation[i]], + group_offsets[inv_permutation[i + 1]]); + } + } + EXPECT_EQ(prefix_sum, total_size); +} + +TEST(TocTest, Test) { + Rng rng(0); + for (size_t num_entries = 1; num_entries < 10; ++num_entries) { + for (bool permute : std::vector{false, true}) { + Roundtrip(num_entries, permute, &rng); + } + } +} + +} // namespace +} // namespace jxl diff --git a/media/libjxl/src/lib/jxl/transfer_functions-inl.h b/media/libjxl/src/lib/jxl/transfer_functions-inl.h new file mode 100644 index 000000000..9f4c10c76 --- /dev/null +++ b/media/libjxl/src/lib/jxl/transfer_functions-inl.h @@ -0,0 +1,413 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Transfer functions for color encodings. + +#if defined(LIB_JXL_TRANSFER_FUNCTIONS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ +#undef LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ +#else +#define LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ +#endif + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/rational_polynomial-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::AndNot; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::TableLookupBytes; + +// Definitions for BT.2100-2 transfer functions (used inside/outside SIMD): +// "display" is linear light (nits) normalized to [0, 1]. +// "encoded" is a nonlinear encoding (e.g. PQ) in [0, 1]. +// "scene" is a linear function of photon counts, normalized to [0, 1]. + +// Despite the stated ranges, we need unbounded transfer functions: see +// http://www.littlecms.com/CIC18_UnboundedCMM.pdf. Inputs can be negative or +// above 1 due to chromatic adaptation. To avoid severe round-trip errors caused +// by clamping, we mirror negative inputs via copysign (f(-x) = -f(x), see +// https://developer.apple.com/documentation/coregraphics/cgcolorspace/1644735-extendedsrgb) +// and extend the function domains above 1. + +// Hybrid Log-Gamma. +class TF_HLG { + public: + // EOTF. e = encoded. + JXL_INLINE double DisplayFromEncoded(const double e) const { + return OOTF(InvOETF(e)); + } + + // Inverse EOTF. d = display. + JXL_INLINE double EncodedFromDisplay(const double d) const { + return OETF(InvOOTF(d)); + } + + // Maximum error 5e-7. + template + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + const V below_div12 = Sqrt(Mul(Set(d, 3.0f), x)); + const V e = + MulAdd(Set(d, kA * 0.693147181f), + FastLog2f(d, MulAdd(Set(d, 12), x, Set(d, -kB))), Set(d, kC)); + const V magnitude = IfThenElse(Le(x, Set(d, kDiv12)), below_div12, e); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + // OETF (defines the HLG approach). s = scene, returns encoded. + JXL_INLINE double OETF(double s) const { + if (s == 0.0) return 0.0; + const double original_sign = s; + s = std::abs(s); + + if (s <= kDiv12) return copysignf(std::sqrt(3.0 * s), original_sign); + + const double e = kA * std::log(12 * s - kB) + kC; + JXL_ASSERT(e > 0.0); + return copysignf(e, original_sign); + } + + // e = encoded, returns scene. + JXL_INLINE double InvOETF(double e) const { + if (e == 0.0) return 0.0; + const double original_sign = e; + e = std::abs(e); + + if (e <= 0.5) return copysignf(e * e * (1.0 / 3), original_sign); + + const double s = (std::exp((e - kC) * kRA) + kB) * kDiv12; + JXL_ASSERT(s >= 0); + return copysignf(s, original_sign); + } + + // s = scene, returns display. + JXL_INLINE double OOTF(const double s) const { + // The actual (red channel) OOTF is RD = alpha * YS^(gamma-1) * RS, where + // YS = 0.2627 * RS + 0.6780 * GS + 0.0593 * BS. Let alpha = 1 so we return + // "display" (normalized [0, 1]) instead of nits. Our transfer function + // interface does not allow a dependency on YS. Fortunately, the system + // gamma at 334 nits is 1.0, so this reduces to RD = RS. + return s; + } + + // d = display, returns scene. + JXL_INLINE double InvOOTF(const double d) const { + return d; // see OOTF(). + } + + static constexpr double kA = 0.17883277; + static constexpr double kRA = 1.0 / kA; + static constexpr double kB = 1 - 4 * kA; + static constexpr double kC = 0.5599107295; + static constexpr double kDiv12 = 1.0 / 12; +}; + +class TF_709 { + public: + JXL_INLINE double EncodedFromDisplay(const double d) const { + if (d < kThresh) return kMulLow * d; + return kMulHi * std::pow(d, kPowHi) + kSub; + } + + // Maximum error 1e-6. + template + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + auto low = Mul(Set(d, kMulLow), x); + auto hi = + MulAdd(Set(d, kMulHi), FastPowf(d, x, Set(d, kPowHi)), Set(d, kSub)); + return IfThenElse(Le(x, Set(d, kThresh)), low, hi); + } + + template + JXL_INLINE V DisplayFromEncoded(D d, V x) const { + auto low = Mul(Set(d, kInvMulLow), x); + auto hi = FastPowf(d, MulAdd(x, Set(d, kInvMulHi), Set(d, kInvAdd)), + Set(d, kInvPowHi)); + return IfThenElse(Lt(x, Set(d, kInvThresh)), low, hi); + } + + private: + static constexpr double kThresh = 0.018; + static constexpr double kMulLow = 4.5; + static constexpr double kMulHi = 1.099; + static constexpr double kPowHi = 0.45; + static constexpr double kSub = -0.099; + + static constexpr double kInvThresh = 0.081; + static constexpr double kInvMulLow = 1 / 4.5; + static constexpr double kInvMulHi = 1 / 1.099; + static constexpr double kInvPowHi = 1 / 0.45; + static constexpr double kInvAdd = 0.099 * kInvMulHi; +}; + +// Perceptual Quantization +class TF_PQ { + public: + // EOTF (defines the PQ approach). e = encoded. + JXL_INLINE double DisplayFromEncoded(double e) const { + if (e == 0.0) return 0.0; + const double original_sign = e; + e = std::abs(e); + + const double xp = std::pow(e, 1.0 / kM2); + const double num = std::max(xp - kC1, 0.0); + const double den = kC2 - kC3 * xp; + JXL_DASSERT(den != 0.0); + const double d = std::pow(num / den, 1.0 / kM1); + JXL_DASSERT(d >= 0.0); // Equal for e ~= 1E-9 + return copysignf(d, original_sign); + } + + // Maximum error 3e-6 + template + JXL_INLINE V DisplayFromEncoded(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + // 4-over-4-degree rational polynomial approximation on x+x*x. This improves + // the maximum error by about 5x over a rational polynomial for x. + auto xpxx = MulAdd(x, x, x); + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + HWY_REP4(2.62975656e-04f), HWY_REP4(-6.23553089e-03f), + HWY_REP4(7.38602301e-01f), HWY_REP4(2.64553172e+00f), + HWY_REP4(5.50034862e-01f), + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + HWY_REP4(4.21350107e+02f), HWY_REP4(-4.28736818e+02f), + HWY_REP4(1.74364667e+02f), HWY_REP4(-3.39078883e+01f), + HWY_REP4(2.67718770e+00f), + }; + auto magnitude = EvalRationalPolynomial(d, xpxx, p, q); + return Or(AndNot(kSign, magnitude), original_sign); + } + + // Inverse EOTF. d = display. + JXL_INLINE double EncodedFromDisplay(double d) const { + if (d == 0.0) return 0.0; + const double original_sign = d; + d = std::abs(d); + + const double xp = std::pow(d, kM1); + const double num = kC1 + xp * kC2; + const double den = 1.0 + xp * kC3; + const double e = std::pow(num / den, kM2); + JXL_DASSERT(e > 0.0); + return copysignf(e, original_sign); + } + + // Maximum error 7e-7. + template + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + // 4-over-4-degree rational polynomial approximation on x**0.25, with two + // different polynomials above and below 1e-4. + auto xto025 = Sqrt(Sqrt(x)); + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + HWY_REP4(1.351392e-02f), HWY_REP4(-1.095778e+00f), + HWY_REP4(5.522776e+01f), HWY_REP4(1.492516e+02f), + HWY_REP4(4.838434e+01f), + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + HWY_REP4(1.012416e+00f), HWY_REP4(2.016708e+01f), + HWY_REP4(9.263710e+01f), HWY_REP4(1.120607e+02f), + HWY_REP4(2.590418e+01f), + }; + + HWY_ALIGN constexpr float plo[(4 + 1) * 4] = { + HWY_REP4(9.863406e-06f), HWY_REP4(3.881234e-01f), + HWY_REP4(1.352821e+02f), HWY_REP4(6.889862e+04f), + HWY_REP4(-2.864824e+05f), + }; + HWY_ALIGN constexpr float qlo[(4 + 1) * 4] = { + HWY_REP4(3.371868e+01f), HWY_REP4(1.477719e+03f), + HWY_REP4(1.608477e+04f), HWY_REP4(-4.389884e+04f), + HWY_REP4(-2.072546e+05f), + }; + + auto magnitude = IfThenElse(Lt(x, Set(d, 1e-4f)), + EvalRationalPolynomial(d, xto025, plo, qlo), + EvalRationalPolynomial(d, xto025, p, q)); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + static constexpr double kM1 = 2610.0 / 16384; + static constexpr double kM2 = (2523.0 / 4096) * 128; + static constexpr double kC1 = 3424.0 / 4096; + static constexpr double kC2 = (2413.0 / 4096) * 32; + static constexpr double kC3 = (2392.0 / 4096) * 32; +}; + +// sRGB +class TF_SRGB { + public: + template + JXL_INLINE V DisplayFromEncoded(V x) const { + const HWY_FULL(float) d; + const HWY_FULL(uint32_t) du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + + // TODO(janwas): range reduction + // Computed via af_cheb_rational (k=100); replicated 4x. + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + 2.200248328e-04f, 2.200248328e-04f, 2.200248328e-04f, 2.200248328e-04f, + 1.043637593e-02f, 1.043637593e-02f, 1.043637593e-02f, 1.043637593e-02f, + 1.624820318e-01f, 1.624820318e-01f, 1.624820318e-01f, 1.624820318e-01f, + 7.961564959e-01f, 7.961564959e-01f, 7.961564959e-01f, 7.961564959e-01f, + 8.210152774e-01f, 8.210152774e-01f, 8.210152774e-01f, 8.210152774e-01f, + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + 2.631846970e-01f, 2.631846970e-01f, 2.631846970e-01f, + 2.631846970e-01f, 1.076976492e+00f, 1.076976492e+00f, + 1.076976492e+00f, 1.076976492e+00f, 4.987528350e-01f, + 4.987528350e-01f, 4.987528350e-01f, 4.987528350e-01f, + -5.512498495e-02f, -5.512498495e-02f, -5.512498495e-02f, + -5.512498495e-02f, 6.521209011e-03f, 6.521209011e-03f, + 6.521209011e-03f, 6.521209011e-03f, + }; + const V linear = Mul(x, Set(d, kLowDivInv)); + const V poly = EvalRationalPolynomial(d, x, p, q); + const V magnitude = + IfThenElse(Gt(x, Set(d, kThreshSRGBToLinear)), poly, linear); + return Or(AndNot(kSign, magnitude), original_sign); + } + + // Error ~5e-07 + template + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + + // Computed via af_cheb_rational (k=100); replicated 4x. + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + -5.135152395e-04f, -5.135152395e-04f, -5.135152395e-04f, + -5.135152395e-04f, 5.287254571e-03f, 5.287254571e-03f, + 5.287254571e-03f, 5.287254571e-03f, 3.903842876e-01f, + 3.903842876e-01f, 3.903842876e-01f, 3.903842876e-01f, + 1.474205315e+00f, 1.474205315e+00f, 1.474205315e+00f, + 1.474205315e+00f, 7.352629620e-01f, 7.352629620e-01f, + 7.352629620e-01f, 7.352629620e-01f, + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + 1.004519624e-02f, 1.004519624e-02f, 1.004519624e-02f, 1.004519624e-02f, + 3.036675394e-01f, 3.036675394e-01f, 3.036675394e-01f, 3.036675394e-01f, + 1.340816930e+00f, 1.340816930e+00f, 1.340816930e+00f, 1.340816930e+00f, + 9.258482155e-01f, 9.258482155e-01f, 9.258482155e-01f, 9.258482155e-01f, + 2.424867759e-02f, 2.424867759e-02f, 2.424867759e-02f, 2.424867759e-02f, + }; + const V linear = Mul(x, Set(d, kLowDiv)); + const V poly = EvalRationalPolynomial(d, Sqrt(x), p, q); + const V magnitude = + IfThenElse(Gt(x, Set(d, kThreshLinearToSRGB)), poly, linear); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + static constexpr float kThreshSRGBToLinear = 0.04045f; + static constexpr float kThreshLinearToSRGB = 0.0031308f; + static constexpr float kLowDiv = 12.92f; + static constexpr float kLowDivInv = 1.0f / kLowDiv; +}; + +// Linear to sRGB conversion with error of at most 1.2e-4. +template +V FastLinearToSRGB(D d, V v) { + const hwy::HWY_NAMESPACE::Rebind du; + const hwy::HWY_NAMESPACE::Rebind di; + // Convert to 0.25 - 0.5 range. + auto v025_05 = BitCast( + d, And(Or(BitCast(du, v), Set(du, 0x3e800000)), Set(du, 0x3effffff))); + // third degree polynomial approximation between 0.25 and 0.5 + // of 1.055/2^(7/2.4) * x^(1/2.4) * 0.5. A degree 4 polynomial only improves + // accuracy by about 3x. + auto d1 = MulAdd(v025_05, Set(d, 0.059914046f), Set(d, -0.108894556f)); + auto d2 = MulAdd(d1, v025_05, Set(d, 0.107963754f)); + auto pow = MulAdd(d2, v025_05, Set(d, 0.018092343f)); + // Compute extra multiplier depending on exponent. Valid exponent range for + // [0.0031308f, 1.0) is 0...8 after subtracting 118. + // The next three constants contain a representation of the powers of + // 2**(1/2.4) = 2**(5/12) times two; in particular, bits from 26 to 31 are + // always the same and in k2to512powers_basebits, and the two arrays contain + // the next groups of 8 bits. This ends up being a 22-bit representation (with + // a mantissa of 13 bits). The choice of polynomial to approximate is such + // that the multiplication factor has the highest 5 bits constant, and that + // the factor for the lowest possible exponent is a power of two (thus making + // the additional bits 0, which is used to correctly merge back together the + // floats). + constexpr uint32_t k2to512powers_basebits = 0x40000000; + HWY_ALIGN constexpr uint8_t k2to512powers_25to18bits[16] = { + 0x0, 0xa, 0x19, 0x26, 0x32, 0x41, 0x4d, 0x5c, + 0x68, 0x75, 0x83, 0x8f, 0xa0, 0xaa, 0xb9, 0xc6, + }; + HWY_ALIGN constexpr uint8_t k2to512powers_17to10bits[16] = { + 0x0, 0xb7, 0x4, 0xd, 0xcb, 0xe7, 0x41, 0x68, + 0x51, 0xd1, 0xeb, 0xf2, 0x0, 0xb7, 0x4, 0xd, + }; + // Note that vld1q_s8_x2 on ARM seems to actually be slower. +#if HWY_TARGET != HWY_SCALAR + using hwy::HWY_NAMESPACE::ShiftLeft; + using hwy::HWY_NAMESPACE::ShiftRight; + // Every lane of exp is now (if cast to byte) {0, 0, 0, }. + auto exp = Sub(ShiftRight<23>(BitCast(di, v)), Set(di, 118)); + auto pow25to18bits = TableLookupBytes( + LoadDup128(di, + reinterpret_cast(k2to512powers_25to18bits)), + exp); + auto pow17to10bits = TableLookupBytes( + LoadDup128(di, + reinterpret_cast(k2to512powers_17to10bits)), + exp); + // Now, pow* contain {0, 0, 0, }. Here + // we take advantage of the fact that each table has its position 0 equal to + // 0. + // We can now just reassemble the float. + auto mul = BitCast( + d, Or(Or(ShiftLeft<18>(pow25to18bits), ShiftLeft<10>(pow17to10bits)), + Set(di, k2to512powers_basebits))); +#else + // Fallback for scalar. + uint32_t exp = ((BitCast(di, v).raw >> 23) - 118) & 0xf; + auto mul = BitCast(d, Set(di, (k2to512powers_25to18bits[exp] << 18) | + (k2to512powers_17to10bits[exp] << 10) | + k2to512powers_basebits)); +#endif + return IfThenElse(Lt(v, Set(d, 0.0031308f)), Mul(v, Set(d, 12.92f)), + MulAdd(pow, mul, Set(d, -0.055))); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ diff --git a/media/libjxl/src/lib/jxl/transpose-inl.h b/media/libjxl/src/lib/jxl/transpose-inl.h new file mode 100644 index 000000000..467442073 --- /dev/null +++ b/media/libjxl/src/lib/jxl/transpose-inl.h @@ -0,0 +1,203 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Block transpose for DCT/IDCT + +#if defined(LIB_JXL_TRANSPOSE_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_TRANSPOSE_INL_H_ +#undef LIB_JXL_TRANSPOSE_INL_H_ +#else +#define LIB_JXL_TRANSPOSE_INL_H_ +#endif + +#include + +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_block-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#ifndef JXL_INLINE_TRANSPOSE +// Workaround for issue #42 - (excessive?) inlining causes invalid codegen. +#if defined(__arm__) +#define JXL_INLINE_TRANSPOSE HWY_NOINLINE +#else +#define JXL_INLINE_TRANSPOSE HWY_INLINE +#endif +#endif // JXL_INLINE_TRANSPOSE + +// Simple wrapper that ensures that a function will not be inlined. +template +JXL_NOINLINE void NoInlineWrapper(const T& f, const Args&... args) { + return f(args...); +} + +template +struct TransposeSimdTag {}; + +// TODO(veluca): it's not super useful to have this in the SIMD namespace. +template +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + for (size_t n = 0; n < ROWS; ++n) { + for (size_t m = 0; m < COLS; ++m) { + to.Write(from.Read(n, m), m, n); + } + } +} + +// TODO(veluca): AVX3? +#if HWY_CAP_GE256 +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { + return ROWS % 8 == 0 && COLS % 8 == 0; +} + +template +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + static_assert(MaxLanes(BlockDesc<8>()) == 8, "Invalid descriptor size"); + static_assert(ROWS_or_0 % 8 == 0, "Invalid number of rows"); + static_assert(COLS_or_0 % 8 == 0, "Invalid number of columns"); + for (size_t n = 0; n < ROWS; n += 8) { + for (size_t m = 0; m < COLS; m += 8) { + const BlockDesc<8> d; + auto i0 = from.LoadPart(d, n + 0, m + 0); + auto i1 = from.LoadPart(d, n + 1, m + 0); + auto i2 = from.LoadPart(d, n + 2, m + 0); + auto i3 = from.LoadPart(d, n + 3, m + 0); + auto i4 = from.LoadPart(d, n + 4, m + 0); + auto i5 = from.LoadPart(d, n + 5, m + 0); + auto i6 = from.LoadPart(d, n + 6, m + 0); + auto i7 = from.LoadPart(d, n + 7, m + 0); + // Surprisingly, this straightforward implementation (24 cycles on port5) + // is faster than load128+insert and LoadDup128+ConcatUpperLower+blend. + const auto q0 = InterleaveLower(d, i0, i2); + const auto q1 = InterleaveLower(d, i1, i3); + const auto q2 = InterleaveUpper(d, i0, i2); + const auto q3 = InterleaveUpper(d, i1, i3); + const auto q4 = InterleaveLower(d, i4, i6); + const auto q5 = InterleaveLower(d, i5, i7); + const auto q6 = InterleaveUpper(d, i4, i6); + const auto q7 = InterleaveUpper(d, i5, i7); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + const auto r4 = InterleaveLower(d, q4, q5); + const auto r5 = InterleaveUpper(d, q4, q5); + const auto r6 = InterleaveLower(d, q6, q7); + const auto r7 = InterleaveUpper(d, q6, q7); + + i0 = ConcatLowerLower(d, r4, r0); + i1 = ConcatLowerLower(d, r5, r1); + i2 = ConcatLowerLower(d, r6, r2); + i3 = ConcatLowerLower(d, r7, r3); + i4 = ConcatUpperUpper(d, r4, r0); + i5 = ConcatUpperUpper(d, r5, r1); + i6 = ConcatUpperUpper(d, r6, r2); + i7 = ConcatUpperUpper(d, r7, r3); + to.StorePart(d, i0, m + 0, n + 0); + to.StorePart(d, i1, m + 1, n + 0); + to.StorePart(d, i2, m + 2, n + 0); + to.StorePart(d, i3, m + 3, n + 0); + to.StorePart(d, i4, m + 4, n + 0); + to.StorePart(d, i5, m + 5, n + 0); + to.StorePart(d, i6, m + 6, n + 0); + to.StorePart(d, i7, m + 7, n + 0); + } + } +} +#elif HWY_TARGET != HWY_SCALAR +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { + return ROWS % 4 == 0 && COLS % 4 == 0; +} + +template +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + static_assert(MaxLanes(BlockDesc<4>()) == 4, "Invalid descriptor size"); + static_assert(ROWS_or_0 % 4 == 0, "Invalid number of rows"); + static_assert(COLS_or_0 % 4 == 0, "Invalid number of columns"); + for (size_t n = 0; n < ROWS; n += 4) { + for (size_t m = 0; m < COLS; m += 4) { + const BlockDesc<4> d; + const auto p0 = from.LoadPart(d, n + 0, m + 0); + const auto p1 = from.LoadPart(d, n + 1, m + 0); + const auto p2 = from.LoadPart(d, n + 2, m + 0); + const auto p3 = from.LoadPart(d, n + 3, m + 0); + + const auto q0 = InterleaveLower(d, p0, p2); + const auto q1 = InterleaveLower(d, p1, p3); + const auto q2 = InterleaveUpper(d, p0, p2); + const auto q3 = InterleaveUpper(d, p1, p3); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + + to.StorePart(d, r0, m + 0, n + 0); + to.StorePart(d, r1, m + 1, n + 0); + to.StorePart(d, r2, m + 2, n + 0); + to.StorePart(d, r3, m + 3, n + 0); + } + } +} +#else +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { return false; } +#endif + +template +struct Transpose { + template + static void Run(const From& from, const To& to) { + // This does not guarantee anything, just saves from the most stupid + // mistakes. + JXL_DASSERT(from.Address(0, 0) != to.Address(0, 0)); + TransposeSimdTag tag; + GenericTransposeBlock(tag, from, to, N, M); + } +}; + +// Avoid inlining and unrolling transposes for large blocks. +template +struct Transpose< + N, M, typename std::enable_if<(N >= 8 && M >= 8 && N * M >= 512)>::type> { + template + static void Run(const From& from, const To& to) { + // This does not guarantee anything, just saves from the most stupid + // mistakes. + JXL_DASSERT(from.Address(0, 0) != to.Address(0, 0)); + TransposeSimdTag tag; + constexpr void (*transpose)(TransposeSimdTag, + const From&, const To&, size_t, size_t) = + GenericTransposeBlock<0, 0, From, To>; + NoInlineWrapper(transpose, tag, from, to, N, M); + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_TRANSPOSE_INL_H_ diff --git a/media/libjxl/src/lib/jxl/version.h.in b/media/libjxl/src/lib/jxl/version.h.in new file mode 100644 index 000000000..d077abec7 --- /dev/null +++ b/media/libjxl/src/lib/jxl/version.h.in @@ -0,0 +1,39 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file version.h + * @brief libjxl version information + */ + +#ifndef JXL_VERSION_H_ +#define JXL_VERSION_H_ + +#define JPEGXL_MAJOR_VERSION @JPEGXL_MAJOR_VERSION@ ///< JPEG XL Major version +#define JPEGXL_MINOR_VERSION @JPEGXL_MINOR_VERSION@ ///< JPEG XL Minor version +#define JPEGXL_PATCH_VERSION @JPEGXL_PATCH_VERSION@ ///< JPEG XL Patch version + +/** Can be used to conditionally compile code for a specific JXL version + * @param[maj] major version + * @param[min] minor version + * + * @code + * #if JPEGXL_NUMERIC_VERSION < JPEGXL_COMPUTE_NUMERIC_VERSION(0,8,0) + * // use old/deprecated api + * #else + * // use current api + * #endif + * @endcode + */ +#define JPEGXL_COMPUTE_NUMERIC_VERSION(major,minor,patch) ((major<<24) | (minor<<16) | (patch<<8) | 0) + +/* Numeric representation of the version */ +#define JPEGXL_NUMERIC_VERSION JPEGXL_COMPUTE_NUMERIC_VERSION(JPEGXL_MAJOR_VERSION,JPEGXL_MINOR_VERSION,JPEGXL_PATCH_VERSION) + +#endif /* JXL_VERSION_H_ */ + +/** @}*/ diff --git a/media/libjxl/src/lib/jxl/xorshift128plus-inl.h b/media/libjxl/src/lib/jxl/xorshift128plus-inl.h new file mode 100644 index 000000000..a473d591f --- /dev/null +++ b/media/libjxl/src/lib/jxl/xorshift128plus-inl.h @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast but weak random generator. + +#if defined(LIB_JXL_XORSHIFT128PLUS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_XORSHIFT128PLUS_INL_H_ +#undef LIB_JXL_XORSHIFT128PLUS_INL_H_ +#else +#define LIB_JXL_XORSHIFT128PLUS_INL_H_ +#endif + +#include + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Xor; + +// Adapted from https://github.com/vpxyz/xorshift/blob/master/xorshift128plus/ +// (MIT-license) +class Xorshift128Plus { + public: + // 8 independent generators (= single iteration for AVX-512) + enum { N = 8 }; + + explicit HWY_MAYBE_UNUSED Xorshift128Plus(const uint64_t seed) { + // Init state using SplitMix64 generator + s0_[0] = SplitMix64(seed + 0x9E3779B97F4A7C15ull); + s1_[0] = SplitMix64(s0_[0]); + for (size_t i = 1; i < N; ++i) { + s0_[i] = SplitMix64(s1_[i - 1]); + s1_[i] = SplitMix64(s0_[i]); + } + } + + HWY_MAYBE_UNUSED Xorshift128Plus(const uint32_t seed1, const uint32_t seed2, + const uint32_t seed3, const uint32_t seed4) { + // Init state using SplitMix64 generator + s0_[0] = SplitMix64(((static_cast(seed1) << 32) + seed2) + + 0x9E3779B97F4A7C15ull); + s1_[0] = SplitMix64(((static_cast(seed3) << 32) + seed4) + + 0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < N; ++i) { + s0_[i] = SplitMix64(s0_[i - 1]); + s1_[i] = SplitMix64(s1_[i - 1]); + } + } + + HWY_INLINE HWY_MAYBE_UNUSED void Fill(uint64_t* HWY_RESTRICT random_bits) { +#if HWY_CAP_INTEGER64 + const HWY_FULL(uint64_t) d; + for (size_t i = 0; i < N; i += Lanes(d)) { + auto s1 = Load(d, s0_ + i); + const auto s0 = Load(d, s1_ + i); + const auto bits = Add(s1, s0); // b, c + Store(s0, d, s0_ + i); + s1 = Xor(s1, ShiftLeft<23>(s1)); + Store(bits, d, random_bits + i); + s1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + Store(s1, d, s1_ + i); + } +#else + for (size_t i = 0; i < N; ++i) { + auto s1 = s0_[i]; + const auto s0 = s1_[i]; + const auto bits = s1 + s0; // b, c + s0_[i] = s0; + s1 ^= s1 << 23; + random_bits[i] = bits; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s1_[i] = s1; + } +#endif + } + + private: + static uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + HWY_ALIGN uint64_t s0_[N]; + HWY_ALIGN uint64_t s1_[N]; +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_XORSHIFT128PLUS_INL_H_ diff --git a/media/libjxl/src/lib/jxl/xorshift128plus_test.cc b/media/libjxl/src/lib/jxl/xorshift128plus_test.cc new file mode 100644 index 000000000..7514f0e4a --- /dev/null +++ b/media/libjxl/src/lib/jxl/xorshift128plus_test.cc @@ -0,0 +1,378 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/xorshift128plus_test.cc" +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/xorshift128plus-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +// Define to nonzero in order to print the (new) golden outputs. +#define PRINT_RESULTS 0 + +const size_t kVectors = 64; + +#if PRINT_RESULTS + +template +void Print(const uint64_t (&result)[kNumLanes]) { + printf("{ "); + for (int i = 0; i < kNumLanes; ++i) { + if (i != 0) { + printf(", "); + } + printf("0x%016llXull", result[i]); + } + printf("},\n"); +} + +#else // PRINT_RESULTS + +const uint64_t kExpected[kVectors][Xorshift128Plus::N] = { + {0x6E901576D477CBB1ull, 0xE9E53789195DA2A2ull, 0xB681F6DDA5E0AE99ull, + 0x8EFD18CE21FD6896ull, 0xA898A80DF75CF532ull, 0x50CEB2C9E2DE7E32ull, + 0x3CA7C2FEB25C0DD0ull, 0xA4D0866B80B4D836ull}, + {0x8CD6A1E6233D3A26ull, 0x3D4603ADE98B112Dull, 0xDC427AF674019E36ull, + 0xE28B4D230705AC53ull, 0x7297E9BBA88783DDull, 0x34D3D23CFCD9B41Aull, + 0x5A223615ADBE96B8ull, 0xE5EB529027CFBD01ull}, + {0xC1894CF00DFAC6A2ull, 0x18EDF8AE9085E404ull, 0x8E936625296B4CCDull, + 0x31971EF3A14A899Bull, 0xBE87535FCE0BF26Aull, 0x576F7A752BC6649Full, + 0xA44CBADCE0C6B937ull, 0x3DBA819BB17A353Aull}, + {0x27CE38DFCC1C5EB6ull, 0x920BEB5606340256ull, 0x3986CBC40C9AFC2Cull, + 0xE22BCB3EEB1E191Eull, 0x6E1FCDD3602A8FBAull, 0x052CB044E5415A29ull, + 0x46266646EFB9ECD7ull, 0x8F44914618D29335ull}, + {0xDD30AEDF72A362C5ull, 0xBC1D824E16BB98F4ull, 0x9EA6009C2AA3D2F1ull, + 0xF65C0FBBE17AF081ull, 0x22424D06A8738991ull, 0x8A62763F2B7611D2ull, + 0x2F3E89F722637939ull, 0x84D338BEF50AFD50ull}, + {0x00F46494898E2B0Bull, 0x81239DC4FB8E8003ull, 0x414AD93EC5773FE7ull, + 0x791473C450E4110Full, 0x87F127BF68C959ACull, 0x6429282D695EF67Bull, + 0x661082E11546CBA8ull, 0x5815D53FA5436BFDull}, + {0xB3DEADAB9BE6E0F9ull, 0xAA1B7B8F7CED0202ull, 0x4C5ED437699D279Eull, + 0xA4471727F1CB39D3ull, 0xE439DA193F802F70ull, 0xF89401BB04FA6493ull, + 0x3B08045A4FE898BAull, 0x32137BFE98227950ull}, + {0xFBAE4A092897FEF3ull, 0x0639F6CE56E71C8Eull, 0xF0AD6465C07F0C1Eull, + 0xFF8E28563361DCE5ull, 0xC2013DB7F86BC6B9ull, 0x8EFCC0503330102Full, + 0x3F6B767EA5C4DA40ull, 0xB9864B950B2232E1ull}, + {0x76EB58DE8E5EC22Aull, 0x9BBBF49A18B32F4Full, 0xC8405F02B2B2FAB9ull, + 0xC3E122A5F146BC34ull, 0xC90BB046660F5765ull, 0xB933981310DBECCFull, + 0x5A2A7BFC9126FD1Cull, 0x8BB388C94DF87901ull}, + {0x753EB89AD63EF3C3ull, 0xF24AAF40C89D65ADull, 0x23F68931C1A6AA6Dull, + 0xF47E79BF702C6DD0ull, 0xA3AD113244EE7EAEull, 0xD42CBEA28F793DC3ull, + 0xD896FCF1820F497Cull, 0x042B86D2818948C1ull}, + {0x8F2A4FC5A4265763ull, 0xEC499E6F95EAA10Cull, 0xE3786D4ECCD0DEB5ull, + 0xC725C53D3AC4CC43ull, 0x065A4ACBBF83610Eull, 0x35C61C9FEF167129ull, + 0x7B720AEAA7D70048ull, 0x14206B841377D039ull}, + {0xAD27D78BF96055F6ull, 0x5F43B20FF47ADCD4ull, 0xE184C2401E2BF71Eull, + 0x30B263D78990045Dull, 0xC22F00EBFF9BA201ull, 0xAE7F86522B53A562ull, + 0x2853312BC039F0A4ull, 0x868D619E6549C3C8ull}, + {0xFD5493D8AE9A8371ull, 0x773D5E224DF61B3Bull, 0x5377C54FBB1A8280ull, + 0xCAD4DE3B8265CAFAull, 0xCDF3F19C91EBD5F6ull, 0xC8EA0F182D73BD78ull, + 0x220502D593433FF1ull, 0xB81205E612DC31B1ull}, + {0x8F32A39EAEDA4C70ull, 0x1D4B0914AA4DAC7Full, 0x56EF1570F3A8B405ull, + 0x29812CB17404A592ull, 0x97A2AAF69CAE90F2ull, 0x12BF5E02778BBFE5ull, + 0x9D4B55AD42A05FD2ull, 0x06C2BAB5E6086620ull}, + {0x8DB4B9648302B253ull, 0xD756AD9E3AEA12C7ull, 0x68709B7F11D4B188ull, + 0x7CC299DDCD707A4Bull, 0x97B860C370A7661Dull, 0xCECD314FC20E64F5ull, + 0x55F412CDFB4C7EC3ull, 0x55EE97591193B525ull}, + {0xCF70F3ACA96E6254ull, 0x022FEDECA2E09F46ull, 0x686823DB60AE1ECFull, + 0xFD36190D3739830Eull, 0x74E1C09027F68120ull, 0xB5883A835C093842ull, + 0x93E1EFB927E9E4E3ull, 0xB2721E249D7E5EBEull}, + {0x69B6E21C44188CB8ull, 0x5D6CFB853655A7AAull, 0x3E001A0B425A66DCull, + 0x8C57451103A5138Full, 0x7BF8B4BE18EAB402ull, 0x494102EB8761A365ull, + 0xB33796A9F6A81F0Eull, 0x10005AB3BCCFD960ull}, + {0xB2CF25740AE965DCull, 0x6F7C1DF7EF53D670ull, 0x648DD6087AC2251Eull, + 0x040955D9851D487Dull, 0xBD550FC7E21A7F66ull, 0x57408F484DEB3AB5ull, + 0x481E24C150B506C1ull, 0x72C0C3EAF91A40D6ull}, + {0x1997A481858A5D39ull, 0x539718F4BEF50DC1ull, 0x2EC4DC4787E7E368ull, + 0xFF1CE78879419845ull, 0xE219A93DD6F6DD30ull, 0x85328618D02FEC1Aull, + 0xC86E02D969181B20ull, 0xEBEC8CD8BBA34E6Eull}, + {0x28B55088A16CE947ull, 0xDD25AC11E6350195ull, 0xBD1F176694257B1Cull, + 0x09459CCF9FCC9402ull, 0xF8047341E386C4E4ull, 0x7E8E9A9AD984C6C0ull, + 0xA4661E95062AA092ull, 0x70A9947005ED1152ull}, + {0x4C01CF75DBE98CCDull, 0x0BA076CDFC7373B9ull, 0x6C5E7A004B57FB59ull, + 0x336B82297FD3BC56ull, 0x7990C0BE74E8D60Full, 0xF0275CC00EC5C8C8ull, + 0x6CF29E682DFAD2E9ull, 0xFA4361524BD95D72ull}, + {0x631D2A19FF62F018ull, 0x41C43863B985B3FAull, 0xE052B2267038EFD9ull, + 0xE2A535FAC575F430ull, 0xE004EEA90B1FF5B8ull, 0x42DFE2CA692A1F26ull, + 0x90FB0BFC9A189ECCull, 0x4484102BD3536BD0ull}, + {0xD027134E9ACCA5A5ull, 0xBBAB4F966D476A9Bull, 0x713794A96E03D693ull, + 0x9F6335E6B94CD44Aull, 0xC5090C80E7471617ull, 0x6D9C1B0C87B58E33ull, + 0x1969CE82E31185A5ull, 0x2099B97E87754EBEull}, + {0x60EBAF4ED934350Full, 0xC26FBF0BA5E6ECFFull, 0x9E54150F0312EC57ull, + 0x0973B48364ED0041ull, 0x800A523241426CFCull, 0x03AB5EC055F75989ull, + 0x8CF315935DEEB40Aull, 0x83D3FC0190BD1409ull}, + {0x26D35394CF720A51ull, 0xCE9EAA15243CBAFEull, 0xE2B45FBAF21B29E0ull, + 0xDB92E98EDE73F9E0ull, 0x79B16F5101C26387ull, 0x1AC15959DE88C86Full, + 0x387633AEC6D6A580ull, 0xA6FC05807BFC5EB8ull}, + {0x2D26C8E47C6BADA9ull, 0x820E6EC832D52D73ull, 0xB8432C3E0ED0EE5Bull, + 0x0F84B3C4063AAA87ull, 0xF393E4366854F651ull, 0x749E1B4D2366A567ull, + 0x805EACA43480D004ull, 0x244EBF3AA54400A5ull}, + {0xBFDC3763AA79F75Aull, 0x9E3A74CC751F41DBull, 0xF401302A149DBC55ull, + 0x6B25F7973D7BF7BCull, 0x13371D34FDBC3DAEull, 0xC5E1998C8F484DCDull, + 0x7031B8AE5C364464ull, 0x3847F0C4F3DA2C25ull}, + {0x24C6387D2C0F1225ull, 0x77CCE960255C67A4ull, 0x21A0947E497B10EBull, + 0xBB5DB73A825A9D7Eull, 0x26294A41999E553Dull, 0x3953E0089F87D925ull, + 0x3DAE6E5D4E5EAAFEull, 0x74B545460341A7AAull}, + {0x710E5EB08A7DB820ull, 0x7E43C4E77CAEA025ull, 0xD4C91529C8B060C1ull, + 0x09AE26D8A7B0CA29ull, 0xAB9F356BB360A772ull, 0xB68834A25F19F6E9ull, + 0x79B8D9894C5734E2ull, 0xC6847E7C8FFD265Full}, + {0x10C4BCB06A5111E6ull, 0x57CB50955B6A2516ull, 0xEF53C87798B6995Full, + 0xAB38E15BBD8D0197ull, 0xA51C6106EFF73C93ull, 0x83D7F0E2270A7134ull, + 0x0923FD330397FCE5ull, 0xF9DE54EDFE58FB45ull}, + {0x07D44833ACCD1A94ull, 0xAAD3C9E945E2F9F3ull, 0xABF4C879B876AA37ull, + 0xF29C69A21B301619ull, 0x2DDCE959111C788Bull, 0x7CEDB48F8AC1729Bull, + 0x93F3BA9A02B659BEull, 0xF20A87FF17933CBEull}, + {0x8E96EBE93180CFE6ull, 0x94CAA12873937079ull, 0x05F613D9380D4189ull, + 0xBCAB40C1DC79F38Aull, 0x0AD8907B7C61D19Eull, 0x88534E189D103910ull, + 0x2DB2FAABA160AB8Full, 0xA070E7506B06F15Cull}, + {0x6FB1FCDAFFEF87A9ull, 0xE735CF25337A090Dull, 0x172C6EDCEFEF1825ull, + 0x76957EA49EF0542Dull, 0x819BF4CD250F7C49ull, 0xD6FF23E4AD00C4D4ull, + 0xE79673C1EC358FF0ull, 0xAC9C048144337938ull}, + {0x4C5387FF258B3AF4ull, 0xEDB68FAEC2CB1AA3ull, 0x02A624E67B4E1DA4ull, + 0x5C44797A38E08AF2ull, 0x36546A70E9411B4Bull, 0x47C17B24D2FD9675ull, + 0x101957AAA020CA26ull, 0x47A1619D4779F122ull}, + {0xF84B8BCDC92D9A3Cull, 0x951D7D2C74B3066Bull, 0x7AC287C06EDDD9B2ull, + 0x4C38FC476608D38Full, 0x224D793B19CB4BCDull, 0x835A255899BF1A41ull, + 0x4AD250E9F62DB4ABull, 0xD9B44F4B58781096ull}, + {0xABBAF99A8EB5C6B8ull, 0xFB568E900D3A9F56ull, 0x11EDF63D23C5DF11ull, + 0xA9C3011D3FA7C5A8ull, 0xAEDD3CF11AFFF725ull, 0xABCA472B5F1EDD6Bull, + 0x0600B6BB5D879804ull, 0xDB4DE007F22191A0ull}, + {0xD76CC9EFF0CE9392ull, 0xF5E0A772B59BA49Aull, 0x7D1AE1ED0C1261B5ull, + 0x79224A33B5EA4F4Aull, 0x6DD825D80C40EA60ull, 0x47FC8E747E51C953ull, + 0x695C05F72888BF98ull, 0x1A012428440B9015ull}, + {0xD754DD61F9B772BFull, 0xC4A2FCF4C0F9D4EBull, 0x461167CDF67A24A2ull, + 0x434748490EBCB9D4ull, 0x274DD9CDCA5781DEull, 0x36BAC63BA9A85209ull, + 0x30324DAFDA36B70Full, 0x337570DB4FE6DAB3ull}, + {0xF46CBDD57C551546ull, 0x8E02507E676DA3E3ull, 0xD826245A8C15406Dull, + 0xDFB38A5B71113B72ull, 0x5EA38454C95B16B5ull, 0x28C054FB87ABF3E1ull, + 0xAA2724C0BA1A8096ull, 0xECA83EC980304F2Full}, + {0x6AA76EC294EB3303ull, 0x42D4CDB2A8032E3Bull, 0x7999EDF75DCD8735ull, + 0xB422BFFE696CCDCCull, 0x8F721461FD7CCDFEull, 0x148E1A5814FDE253ull, + 0x4DC941F4375EF8FFull, 0x27B2A9E0EB5B49CFull}, + {0xCEA592EF9343EBE1ull, 0xF7D38B5FA7698903ull, 0x6CCBF352203FEAB6ull, + 0x830F3095FCCDA9C5ull, 0xDBEEF4B81B81C8F4ull, 0x6D7EB9BCEECA5CF9ull, + 0xC58ABB0FBE436C69ull, 0xE4B97E6DB2041A4Bull}, + {0x7E40FC772978AF14ull, 0xCDDA4BBAE28354A1ull, 0xE4F993B832C32613ull, + 0xD3608093C68A4B35ull, 0x9A3B60E01BEE3699ull, 0x03BEF248F3288713ull, + 0x70B9294318F3E9B4ull, 0x8D2ABB913B8610DEull}, + {0x37F209128E7D8B2Cull, 0x81D2AB375BD874BCull, 0xA716A1B7373F7408ull, + 0x0CEE97BEC4706540ull, 0xA40C5FD9CDBC1512ull, 0x73CAF6C8918409E7ull, + 0x45E11BCEDF0BBAA1ull, 0x612C612BFF6E6605ull}, + {0xF8ECB14A12D0F649ull, 0xDA683CD7C01BA1ACull, 0xA2203F7510E124C1ull, + 0x7F83E52E162F3C78ull, 0x77D2BB73456ACADBull, 0x37FC34FC840BBA6Full, + 0x3076BC7D4C6EBC1Full, 0x4F514123632B5FA9ull}, + {0x44D789DED935E884ull, 0xF8291591E09FEC9Full, 0xD9CED2CF32A2E4B7ull, + 0x95F70E1EB604904Aull, 0xDE438FE43C14F6ABull, 0x4C8D23E4FAFCF8D8ull, + 0xC716910A3067EB86ull, 0x3D6B7915315095D3ull}, + {0x3170FDBADAB92095ull, 0x8F1963933FC5650Bull, 0x72F94F00ABECFEABull, + 0x6E3AE826C6AAB4CEull, 0xA677A2BF31068258ull, 0x9660CDC4F363AF10ull, + 0xD81A15A152379EF1ull, 0x5D7D285E1080A3F9ull}, + {0xDAD5DDFF9A2249B3ull, 0x6F9721D926103FAEull, 0x1418CBB83FFA349Aull, + 0xE71A30AD48C012B2ull, 0xBE76376C63751132ull, 0x3496467ACA713AE6ull, + 0x8D7EC01369F991A3ull, 0xD8C73A88B96B154Eull}, + {0x8B5D9C74AEB4833Aull, 0xF914FB3F867B912Full, 0xB894EA034936B1DCull, + 0x8A16D21BE51C4F5Bull, 0x31FF048ED582D98Eull, 0xB95AB2F4DC65B820ull, + 0x04082B9170561AF7ull, 0xA215610A5DC836FAull}, + {0xB2ADE592C092FAACull, 0x7A1E683BCBF13294ull, 0xC7A4DBF86858C096ull, + 0x3A49940F97BFF316ull, 0xCAE5C06B82C46703ull, 0xC7F413A0F951E2BDull, + 0x6665E7BB10EB5916ull, 0x86F84A5A94EDE319ull}, + {0x4EA199D8FAA79CA3ull, 0xDFA26E5BF1981704ull, 0x0F5E081D37FA4E01ull, + 0x9CB632F89CD675CDull, 0x4A09DB89D48C0304ull, 0x88142742EA3C7672ull, + 0xAC4F149E6D2E9BDBull, 0x6D9E1C23F8B1C6C6ull}, + {0xD58BE47B92DEC0E9ull, 0x8E57573645E34328ull, 0x4CC094CCB5FB5126ull, + 0x5F1D66AF6FB40E3Cull, 0x2BA15509132D3B00ull, 0x0D6545646120E567ull, + 0x3CF680C45C223666ull, 0x96B28E32930179DAull}, + {0x5900C45853AC7990ull, 0x61881E3E8B7FF169ull, 0x4DE5F835DF2230FFull, + 0x4427A9E7932F73FFull, 0x9B641BAD379A8C8Dull, 0xDF271E5BF98F4E5Cull, + 0xDFDA16DB830FF5EEull, 0x371C7E7CFB89C0E9ull}, + {0x4410A8576247A250ull, 0x6AD2DA12B45AC0D9ull, 0x18DFC72AAC85EECCull, + 0x06FC8BB2A0EF25C8ull, 0xEB287619C85E6118ull, 0x19553ECA67F25A2Cull, + 0x3B9557F1DCEC5BAAull, 0x7BAD9E8B710D1079ull}, + {0x34F365D66BD22B28ull, 0xE6E124B9F10F835Dull, 0x0573C38ABF2B24DCull, + 0xD32E6AF10A0125AEull, 0x383590ACEA979519ull, 0x8376ED7A39E28205ull, + 0xF0B7F184DCBDA435ull, 0x062A203390E31794ull}, + {0xA2AFFD7E41918760ull, 0x7F90FC1BD0819C86ull, 0x5033C08E5A969533ull, + 0x2707AF5C6D039590ull, 0x57BBD5980F17DF9Cull, 0xD3FE6E61D763268Aull, + 0x9E0A0AE40F335A3Bull, 0x43CF4EB0A99613C5ull}, + {0xD4D2A397CE1A7C2Eull, 0x3DF7CE7CC3212DADull, 0x0880F0D5D356C75Aull, + 0xA8AFC44DD03B1346ull, 0x79263B46C13A29E0ull, 0x11071B3C0ED58E7Aull, + 0xED46DC9F538406BFull, 0x2C94974F2B94843Dull}, + {0xE246E13C39AB5D5Eull, 0xAC1018489D955B20ull, 0x8601B558771852B8ull, + 0x110BD4C06DB40173ull, 0x738FC8A18CCA0EBBull, 0x6673E09BE0EA76E5ull, + 0x024BC7A0C7527877ull, 0x45E6B4652E2EC34Eull}, + {0xD1ED26A1A375CDC8ull, 0xAABC4E896A617CB8ull, 0x0A9C9E8E57D753C6ull, + 0xA3774A75FEB4C30Eull, 0x30B816C01C93E49Eull, 0xF405BABC06D2408Cull, + 0xCC0CE6B4CE788ABCull, 0x75E7922D0447956Cull}, + {0xD07C1676A698BC95ull, 0x5F9AEA4840E2D860ull, 0xD5FC10D58BDF6F02ull, + 0xF190A2AD4BC2EEA7ull, 0x0C24D11F51726931ull, 0xDB646899A16B6512ull, + 0x7BC10670047B1DD8ull, 0x2413A5ABCD45F092ull}, + {0x4E66892190CFD923ull, 0xF10162440365EC8Eull, 0x158ACA5A6A2280AEull, + 0x0D60ED11C0224166ull, 0x7CD2E9A71B9D7488ull, 0x450D7289706AB2A3ull, + 0x88FAE34EC9A0D7DCull, 0x96FF9103575A97DAull}, + {0x77990FAC6046C446ull, 0xB174B5FB30C76676ull, 0xE352CE3EB56CF82Aull, + 0xC6039B6873A9A082ull, 0xE3F80F3AE333148Aull, 0xB853BA24BA3539B9ull, + 0xE8863E52ECCB0C74ull, 0x309B4CC1092CC245ull}, + {0xBC2B70BEE8388D9Full, 0xE48D92AE22216DCEull, 0xF15F3BF3E2C15D8Full, + 0x1DD964D4812D8B24ull, 0xD56AF02FB4665E4Cull, 0x98002200595BD9A3ull, + 0x049246D50BB8FA12ull, 0x1B542DF485B579B9ull}, + {0x2347409ADFA8E497ull, 0x36015C2211D62498ull, 0xE9F141F32EB82690ull, + 0x1F839912D0449FB9ull, 0x4E4DCFFF2D02D97Cull, 0xF8A03AB4C0F625C9ull, + 0x0605F575795DAC5Cull, 0x4746C9BEA0DDA6B1ull}, + {0xCA5BB519ECE7481Bull, 0xFD496155E55CA945ull, 0xF753B9DBB1515F81ull, + 0x50549E8BAC0F70E7ull, 0x8614FB0271E21C60ull, 0x60C72947EB0F0070ull, + 0xA6511C10AEE742B6ull, 0x48FB48F2CACCB43Eull}}; + +#endif // PRINT_RESULTS + +// Ensures Xorshift128+ returns consistent and unchanging values. +void TestGolden() { + HWY_ALIGN Xorshift128Plus rng(12345); + for (uint64_t vector = 0; vector < kVectors; ++vector) { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + rng.Fill(lanes); +#if PRINT_RESULTS + Print(lanes); +#else + for (size_t i = 0; i < Xorshift128Plus::N; ++i) { + ASSERT_EQ(kExpected[vector][i], lanes[i]) + << "Where vector=" << vector << " i=" << i; + } +#endif + } +} + +// Output changes when given different seeds +void TestSeedChanges() { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + + std::vector first; + constexpr size_t kNumSeeds = 16384; + first.reserve(kNumSeeds); + + // All 14-bit seeds + for (size_t seed = 0; seed < kNumSeeds; ++seed) { + HWY_ALIGN Xorshift128Plus rng(seed); + + rng.Fill(lanes); + first.push_back(lanes[0]); + } + + // All outputs are unique + ASSERT_EQ(kNumSeeds, first.size()); + std::sort(first.begin(), first.end()); + first.erase(std::unique(first.begin(), first.end()), first.end()); + EXPECT_EQ(kNumSeeds, first.size()); +} + +void TestFloat() { + ThreadPoolInternal pool(8); + +#ifdef JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 256; +#else // JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 4096; +#endif // JXL_DISABLE_SLOW_TESTS + EXPECT_TRUE(RunOnPool( + &pool, 0, kMaxSeed, ThreadPool::NoInit, + [](const uint32_t seed, size_t /*thread*/) { + HWY_ALIGN Xorshift128Plus rng(seed); + + const HWY_FULL(uint32_t) du; + const HWY_FULL(float) df; + HWY_ALIGN uint64_t batch[Xorshift128Plus::N]; + HWY_ALIGN float lanes[MaxLanes(df)]; + double sum = 0.0; + size_t count = 0; + const size_t kReps = 2000; + for (size_t reps = 0; reps < kReps; ++reps) { + rng.Fill(batch); + for (size_t i = 0; i < Xorshift128Plus::N * 2; i += Lanes(df)) { + const auto bits = + Load(du, reinterpret_cast(batch) + i); + // 1.0 + 23 random mantissa bits = [1, 2) + const auto rand12 = + BitCast(df, Or(ShiftRight<9>(bits), Set(du, 0x3F800000))); + const auto rand01 = Sub(rand12, Set(df, 1.0f)); + Store(rand01, df, lanes); + for (float lane : lanes) { + sum += lane; + count += 1; + EXPECT_LE(lane, 1.0f); + EXPECT_GE(lane, 0.0f); + } + } + } + + // Verify average (uniform distribution) + EXPECT_NEAR(0.5, sum / count, 0.00702); + }, + "TestXorShift")); +} + +// Not more than one 64-bit zero +void TestNotZero() { + ThreadPoolInternal pool(8); + +#ifdef JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 500; +#else // JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 2000; +#endif // JXL_DISABLE_SLOW_TESTS + EXPECT_TRUE(RunOnPool( + &pool, 0, kMaxSeed, ThreadPool::NoInit, + [](const uint32_t task, size_t /*thread*/) { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + + HWY_ALIGN Xorshift128Plus rng(task); + size_t num_zero = 0; + for (size_t vectors = 0; vectors < 10000; ++vectors) { + rng.Fill(lanes); + for (uint64_t lane : lanes) { + num_zero += static_cast(lane == 0); + } + } + EXPECT_LE(num_zero, 1u); + }, + "TestNotZero")); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class Xorshift128Test : public hwy::TestWithParamTarget {}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(Xorshift128Test); + +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestNotZero); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestGolden); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestSeedChanges); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestFloat); + +} // namespace jxl +#endif diff --git a/media/libjxl/src/lib/jxl_benchmark.cmake b/media/libjxl/src/lib/jxl_benchmark.cmake new file mode 100644 index 000000000..f0535d7db --- /dev/null +++ b/media/libjxl/src/lib/jxl_benchmark.cmake @@ -0,0 +1,45 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# All files ending in "_gbench.cc" are considered Google benchmark files and +# should be listed here. +set(JPEGXL_INTERNAL_SOURCES_GBENCH + extras/tone_mapping_gbench.cc + jxl/dec_external_image_gbench.cc + jxl/enc_external_image_gbench.cc + jxl/gauss_blur_gbench.cc + jxl/splines_gbench.cc + jxl/tf_gbench.cc +) + +# benchmark.h doesn't work in our MINGW set up since it ends up including the +# wrong stdlib header. We don't run gbench on MINGW targets anyway. +if(NOT MINGW) + +# This is the Google benchmark project (https://github.com/google/benchmark). +find_package(benchmark QUIET) + +if(benchmark_FOUND) + if(JPEGXL_STATIC AND NOT MINGW) + # benchmark::benchmark hardcodes the librt.so which obviously doesn't + # compile in static mode. + set_target_properties(benchmark::benchmark PROPERTIES + INTERFACE_LINK_LIBRARIES "Threads::Threads;-lrt") + endif() + + # Compiles all the benchmark files into a single binary. Individual benchmarks + # can be run with --benchmark_filter. + add_executable(jxl_gbench "${JPEGXL_INTERNAL_SOURCES_GBENCH}" gbench_main.cc) + + target_compile_definitions(jxl_gbench PRIVATE + -DTEST_DATA_PATH="${JPEGXL_TEST_DATA_PATH}") + target_link_libraries(jxl_gbench + jxl_extras-static + jxl-static + benchmark::benchmark + ) +endif() # benchmark_FOUND + +endif() # MINGW diff --git a/media/libjxl/src/lib/jxl_extras.cmake b/media/libjxl/src/lib/jxl_extras.cmake new file mode 100644 index 000000000..fd801fc6a --- /dev/null +++ b/media/libjxl/src/lib/jxl_extras.cmake @@ -0,0 +1,219 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set(JPEGXL_EXTRAS_SOURCES + extras/codec.cc + extras/codec.h + extras/dec/color_description.cc + extras/dec/color_description.h + extras/dec/color_hints.cc + extras/dec/color_hints.h + extras/dec/decode.cc + extras/dec/decode.h + extras/dec/jxl.cc + extras/dec/jxl.h + extras/dec/pgx.cc + extras/dec/pgx.h + extras/dec/pnm.cc + extras/dec/pnm.h + extras/enc/encode.cc + extras/enc/encode.h + extras/enc/npy.cc + extras/enc/npy.h + extras/enc/pgx.cc + extras/enc/pgx.h + extras/enc/pnm.cc + extras/enc/pnm.h + extras/exif.cc + extras/exif.h + extras/hlg.cc + extras/hlg.h + extras/packed_image.h + extras/packed_image_convert.cc + extras/packed_image_convert.h + extras/render_hdr.cc + extras/render_hdr.h + extras/time.cc + extras/time.h + extras/tone_mapping.cc + extras/tone_mapping.h +) + +set(JPEGXL_EXTRAS_CODEC_SOURCES + extras/dec/color_description.cc + extras/dec/color_description.h + extras/dec/color_hints.cc + extras/dec/color_hints.h + extras/dec/decode.cc + extras/dec/decode.h + extras/dec/jxl.cc + extras/dec/jxl.h + extras/dec/pgx.cc + extras/dec/pgx.h + extras/dec/pnm.cc + extras/dec/pnm.h + extras/enc/encode.cc + extras/enc/encode.h + extras/enc/npy.cc + extras/enc/npy.h + extras/enc/pgx.cc + extras/enc/pgx.h + extras/enc/pnm.cc + extras/enc/pnm.h + extras/exif.cc + extras/exif.h + extras/packed_image.h + extras/time.cc + extras/time.h +) + +add_library(jxl_extras_codec-obj OBJECT "${JPEGXL_EXTRAS_CODEC_SOURCES}") +target_compile_options(jxl_extras_codec-obj PRIVATE "${JPEGXL_INTERNAL_FLAGS}") +target_compile_definitions(jxl_extras_codec-obj PRIVATE -DJXL_EXPORT=) +set_property(TARGET jxl_extras_codec-obj PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_extras_codec-obj PUBLIC + ${PROJECT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + $ +) +set(JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES) +set(JXL_EXTRAS_CODEC_PUBLIC_COMPILE_DEFINITIONS) + +# We only define a static library for jxl_extras since it uses internal parts +# of jxl library which are not accessible from outside the library in the +# shared library case. +add_library(jxl_extras-static STATIC EXCLUDE_FROM_ALL + "${JPEGXL_EXTRAS_SOURCES}") +target_compile_options(jxl_extras-static PRIVATE "${JPEGXL_INTERNAL_FLAGS}") +set_property(TARGET jxl_extras-static PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_extras-static PUBLIC "${PROJECT_SOURCE_DIR}") +target_link_libraries(jxl_extras-static PUBLIC + jxl-static + jxl_threads-static +) + +find_package(GIF 5.1) +if(GIF_FOUND) + target_sources(jxl_extras_codec-obj PRIVATE + extras/dec/gif.cc + extras/dec/gif.h + ) + target_include_directories(jxl_extras_codec-obj PRIVATE "${GIF_INCLUDE_DIRS}") + list(APPEND JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES ${GIF_LIBRARIES}) + list(APPEND JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS -DJPEGXL_ENABLE_GIF=1) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libgif-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libgif COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +find_package(JPEG) +if(JPEG_FOUND) + target_sources(jxl_extras_codec-obj PRIVATE + extras/dec/jpg.cc + extras/dec/jpg.h + extras/enc/jpg.cc + extras/enc/jpg.h + ) + target_include_directories(jxl_extras_codec-obj PRIVATE "${JPEG_INCLUDE_DIRS}") + list(APPEND JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES ${JPEG_LIBRARIES}) + list(APPEND JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS -DJPEGXL_ENABLE_JPEG=1) + target_sources(jxl_extras-static PRIVATE + extras/dec/jpg.cc + extras/dec/jpg.h + extras/enc/jpg.cc + extras/enc/jpg.h + ) + target_include_directories(jxl_extras-static PRIVATE "${JPEG_INCLUDE_DIRS}") + target_link_libraries(jxl_extras-static PRIVATE ${JPEG_LIBRARIES}) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_JPEG=1) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libjpeg-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libjpeg COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +if(NOT JPEGXL_BUNDLE_LIBPNG) + find_package(PNG) +endif() +if(PNG_FOUND) + target_sources(jxl_extras_codec-obj PRIVATE + extras/dec/apng.cc + extras/dec/apng.h + extras/enc/apng.cc + extras/enc/apng.h + ) + target_include_directories(jxl_extras_codec-obj PRIVATE "${PNG_INCLUDE_DIRS}") + list(APPEND JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES ${PNG_LIBRARIES}) + list(APPEND JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS -DJPEGXL_ENABLE_APNG=1) + target_sources(jxl_extras-static PRIVATE + extras/dec/apng.cc + extras/dec/apng.h + extras/enc/apng.cc + extras/enc/apng.h + ) + target_include_directories(jxl_extras-static PUBLIC "${PNG_INCLUDE_DIRS}") + target_link_libraries(jxl_extras-static PUBLIC ${PNG_LIBRARIES}) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_APNG=1) + configure_file(extras/LICENSE.apngdis + ${PROJECT_BINARY_DIR}/LICENSE.apngdis COPYONLY) +endif() + +if (JPEGXL_ENABLE_SJPEG) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_SJPEG=1) + target_link_libraries(jxl_extras-static PRIVATE sjpeg) +endif () + +if (JPEGXL_ENABLE_OPENEXR) +pkg_check_modules(OpenEXR IMPORTED_TARGET OpenEXR) +if (OpenEXR_FOUND) + target_sources(jxl_extras_codec-obj PRIVATE + extras/dec/exr.cc + extras/dec/exr.h + extras/enc/exr.cc + extras/enc/exr.h + ) + list(APPEND JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS -DJPEGXL_ENABLE_EXR=1) + target_include_directories(jxl_extras_codec-obj PRIVATE "${OpenEXR_INCLUDE_DIRS}") + list(APPEND JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES PkgConfig::OpenEXR) + target_sources(jxl_extras-static PRIVATE + extras/dec/exr.cc + extras/dec/exr.h + extras/enc/exr.cc + extras/enc/exr.h + ) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_EXR=1) + target_link_libraries(jxl_extras-static PRIVATE PkgConfig::OpenEXR) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libopenexr-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libopenexr COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR + # OpenEXR generates exceptions, so we need exception support to catch them. + # Actually those flags counteract the ones set in JPEGXL_INTERNAL_FLAGS. + if (NOT WIN32) + set_source_files_properties(extras/dec/exr.cc extras/enc/exr.cc PROPERTIES COMPILE_FLAGS -fexceptions) + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set_source_files_properties(extras/dec/exr.cc extras/enc/exr.cc PROPERTIES COMPILE_FLAGS -fcxx-exceptions) + endif() + endif() +endif() # OpenEXR_FOUND +endif() # JPEGXL_ENABLE_OPENEXR + +target_compile_definitions(jxl_extras_codec-obj PRIVATE ${JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS}) + +### Static library. +add_library(jxl_extras_codec-static STATIC $) +target_compile_definitions(jxl_extras_codec-static PUBLIC ${JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS}) +target_link_libraries(jxl_extras_codec-static PRIVATE ${JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES} jxl) + +### Shared library. +if (BUILD_SHARED_LIBS) +add_library(jxl_extras_codec SHARED $) +target_compile_definitions(jxl_extras_codec PUBLIC ${JXL_EXTRAS_CODEC_PUBLIC_DEFINITIONS}) +target_link_libraries(jxl_extras_codec PRIVATE ${JXL_EXTRAS_CODEC_INTERNAL_LIBRARIES} jxl) +else() +add_library(jxl_extras_codec ALIAS jxl_extras_codec-static) +endif() # BUILD_SHARED_LIBS diff --git a/media/libjxl/src/lib/jxl_profiler.cmake b/media/libjxl/src/lib/jxl_profiler.cmake new file mode 100644 index 000000000..8faa62620 --- /dev/null +++ b/media/libjxl/src/lib/jxl_profiler.cmake @@ -0,0 +1,31 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set(JPEGXL_PROFILER_SOURCES + profiler/profiler.cc + profiler/profiler.h + profiler/tsc_timer.h +) + +### Static library. +add_library(jxl_profiler STATIC ${JPEGXL_PROFILER_SOURCES}) +target_link_libraries(jxl_profiler PUBLIC hwy) + +target_compile_options(jxl_profiler PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(jxl_profiler PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET jxl_profiler PROPERTY POSITION_INDEPENDENT_CODE ON) + +target_include_directories(jxl_profiler + PRIVATE "${PROJECT_SOURCE_DIR}") + +set_target_properties(jxl_profiler PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 +) + +# Make every library linking against the jxl_profiler define this macro to +# enable the profiler. +target_compile_definitions(jxl_profiler + PUBLIC -DPROFILER_ENABLED=1) diff --git a/media/libjxl/src/lib/jxl_tests.cmake b/media/libjxl/src/lib/jxl_tests.cmake new file mode 100644 index 000000000..c858ae97b --- /dev/null +++ b/media/libjxl/src/lib/jxl_tests.cmake @@ -0,0 +1,144 @@ +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set(TEST_FILES + extras/codec_test.cc + extras/dec/color_description_test.cc + extras/dec/pgx_test.cc + jxl/ac_strategy_test.cc + jxl/alpha_test.cc + jxl/ans_common_test.cc + jxl/ans_test.cc + jxl/bit_reader_test.cc + jxl/bits_test.cc + jxl/blending_test.cc + jxl/butteraugli_test.cc + jxl/byte_order_test.cc + jxl/coeff_order_test.cc + jxl/color_encoding_internal_test.cc + jxl/color_management_test.cc + jxl/convolve_test.cc + jxl/data_parallel_test.cc + jxl/dct_test.cc + jxl/decode_test.cc + jxl/enc_external_image_test.cc + jxl/enc_photon_noise_test.cc + jxl/encode_test.cc + jxl/entropy_coder_test.cc + jxl/fast_dct_test.cc + jxl/fast_math_test.cc + jxl/fields_test.cc + jxl/gaborish_test.cc + jxl/gamma_correct_test.cc + jxl/gauss_blur_test.cc + jxl/gradient_test.cc + jxl/iaca_test.cc + jxl/icc_codec_test.cc + jxl/image_bundle_test.cc + jxl/image_ops_test.cc + jxl/jxl_test.cc + jxl/lehmer_code_test.cc + jxl/linalg_test.cc + jxl/modular_test.cc + jxl/opsin_image_test.cc + jxl/opsin_inverse_test.cc + jxl/optimize_test.cc + jxl/padded_bytes_test.cc + jxl/passes_test.cc + jxl/patch_dictionary_test.cc + jxl/preview_test.cc + jxl/quant_weights_test.cc + jxl/quantizer_test.cc + jxl/rational_polynomial_test.cc + jxl/render_pipeline/render_pipeline_test.cc + jxl/roundtrip_test.cc + jxl/simd_util_test.cc + jxl/speed_tier_test.cc + jxl/splines_test.cc + jxl/toc_test.cc + jxl/xorshift128plus_test.cc + threads/thread_parallel_runner_test.cc + ### Files before this line are handled by build_cleaner.py + # TODO(deymo): Move this to tools/ + ../tools/box/box_test.cc + ../tools/djxl_fuzzer_test.cc +) + +# Test-only library code. +set(TESTLIB_FILES + jxl/codec_y4m_testonly.cc + jxl/codec_y4m_testonly.h + jxl/dct_for_test.h + jxl/dec_transforms_testonly.cc + jxl/dec_transforms_testonly.h + jxl/fake_parallel_runner_testonly.h + jxl/image_test_utils.h + jxl/test_image.h + jxl/test_utils.h + jxl/testdata.h +) + +find_package(GTest) + +# Library with test-only code shared between all tests. +add_library(jxl_testlib-static STATIC ${TESTLIB_FILES}) + target_compile_options(jxl_testlib-static PRIVATE + ${JPEGXL_INTERNAL_FLAGS} + ${JPEGXL_COVERAGE_FLAGS} + ) +target_compile_definitions(jxl_testlib-static PUBLIC + -DTEST_DATA_PATH="${JPEGXL_TEST_DATA_PATH}") +target_include_directories(jxl_testlib-static PUBLIC + "${PROJECT_SOURCE_DIR}" +) +target_link_libraries(jxl_testlib-static hwy jxl-static) + +# Individual test binaries: +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests) +foreach (TESTFILE IN LISTS TEST_FILES) + # The TESTNAME is the name without the extension or directory. + get_filename_component(TESTNAME ${TESTFILE} NAME_WE) + if(TESTFILE STREQUAL ../tools/djxl_fuzzer_test.cc) + add_executable(${TESTNAME} ${TESTFILE} ../tools/djxl_fuzzer.cc) + else() + add_executable(${TESTNAME} ${TESTFILE}) + endif() + if(JPEGXL_EMSCRIPTEN) + # The emscripten linking step takes too much memory and crashes during the + # wasm-opt step when using -O2 optimization level + set_target_properties(${TESTNAME} PROPERTIES LINK_FLAGS "\ + -O1 \ + -s USE_LIBPNG=1 \ + -s TOTAL_MEMORY=1536MB \ + -s SINGLE_FILE=1 \ + -s PROXY_TO_PTHREAD \ + -s EXIT_RUNTIME=1 \ + -s USE_PTHREADS=1 \ + -s NODERAWFS=1 \ + ") + endif() + target_compile_options(${TESTNAME} PRIVATE + ${JPEGXL_INTERNAL_FLAGS} + # Add coverage flags to the test binary so code in the private headers of + # the library is also instrumented when running tests that execute it. + ${JPEGXL_COVERAGE_FLAGS} + ) + target_link_libraries(${TESTNAME} + box + jxl_extras-static + jxl_testlib-static + gmock + GTest::GTest + GTest::Main + ) + # Output test targets in the test directory. + set_target_properties(${TESTNAME} PROPERTIES PREFIX "tests/") + if (WIN32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set_target_properties(${TESTNAME} PROPERTIES COMPILE_FLAGS "-Wno-error") + endif () + if(CMAKE_VERSION VERSION_LESS "3.10.3") + gtest_discover_tests(${TESTNAME} TIMEOUT 240) + else () + gtest_discover_tests(${TESTNAME} DISCOVERY_TIMEOUT 240) + endif () +endforeach () diff --git a/media/libjxl/src/lib/jxl_threads.cmake b/media/libjxl/src/lib/jxl_threads.cmake new file mode 100644 index 000000000..006e71eb6 --- /dev/null +++ b/media/libjxl/src/lib/jxl_threads.cmake @@ -0,0 +1,128 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +find_package(Threads REQUIRED) + +set(JPEGXL_THREADS_SOURCES + threads/resizable_parallel_runner.cc + threads/thread_parallel_runner.cc + threads/thread_parallel_runner_internal.cc + threads/thread_parallel_runner_internal.h +) + +### Define the jxl_threads shared or static target library. The ${target} +# parameter should already be created with add_library(), but this function +# sets all the remaining common properties. +function(_set_jxl_threads _target) + +target_compile_options(${_target} PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(${_target} PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET ${_target} PROPERTY POSITION_INDEPENDENT_CODE ON) + +target_include_directories(${_target} + PRIVATE + "${PROJECT_SOURCE_DIR}" + PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") + +target_link_libraries(${_target} + PUBLIC ${JPEGXL_COVERAGE_FLAGS} Threads::Threads +) + +set_target_properties(${_target} PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 + DEFINE_SYMBOL JXL_THREADS_INTERNAL_LIBRARY_BUILD +) + +# Always install the library as jxl_threads.{a,so} file without the "-static" +# suffix, except in Windows. +if (NOT WIN32 OR MINGW) + set_target_properties(${_target} PROPERTIES OUTPUT_NAME "jxl_threads") +endif() +install(TARGETS ${_target} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +endfunction() + + +### Static library. +add_library(jxl_threads-static STATIC ${JPEGXL_THREADS_SOURCES}) +_set_jxl_threads(jxl_threads-static) + +# Make jxl_threads symbols neither imported nor exported when using the static +# library. These will have hidden visibility anyway in the static library case +# in unix. +target_compile_definitions(jxl_threads-static + PUBLIC -DJXL_THREADS_STATIC_DEFINE) + + +### Public shared library. +if (BUILD_SHARED_LIBS) +add_library(jxl_threads SHARED ${JPEGXL_THREADS_SOURCES}) +_set_jxl_threads(jxl_threads) + +set_target_properties(jxl_threads PROPERTIES + VERSION ${JPEGXL_LIBRARY_VERSION} + SOVERSION ${JPEGXL_LIBRARY_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + + set_target_properties(jxl_threads PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl.version) + if(APPLE) + set_property(TARGET ${target} APPEND_STRING PROPERTY + LINK_FLAGS "-Wl,-exported_symbols_list,${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl_osx.syms") + elseif(WIN32) + # Nothing needed here, we use __declspec(dllexport) (jxl_threads_export.h) + else() + set_property(TARGET jxl_threads APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl.version") + endif() # APPLE + +# Compile the shared library such that the JXL_THREADS_EXPORT symbols are +# exported. Users of the library will not set this flag and therefore import +# those symbols. +target_compile_definitions(jxl_threads + PRIVATE -DJXL_THREADS_INTERNAL_LIBRARY_BUILD) + +# Generate the jxl/jxl_threads_export.h header, we only need to generate it once +# but we can use it from both libraries. +generate_export_header(jxl_threads + BASE_NAME JXL_THREADS + EXPORT_FILE_NAME include/jxl/jxl_threads_export.h) +else() +add_library(jxl_threads ALIAS jxl_threads-static) +# When not building the shared library generate the jxl_threads_export.h header +# only based on the static target. +generate_export_header(jxl_threads-static + BASE_NAME JXL_THREADS + EXPORT_FILE_NAME include/jxl/jxl_threads_export.h) +endif() # BUILD_SHARED_LIBS + + +### Add a pkg-config file for libjxl_threads. + +# Allow adding prefix if CMAKE_INSTALL_INCLUDEDIR not absolute. +if(IS_ABSOLUTE "${CMAKE_INSTALL_INCLUDEDIR}") + set(PKGCONFIG_TARGET_INCLUDES "${CMAKE_INSTALL_INCLUDEDIR}") +else() + set(PKGCONFIG_TARGET_INCLUDES "\${prefix}/${CMAKE_INSTALL_INCLUDEDIR}") +endif() +# Allow adding prefix if CMAKE_INSTALL_LIBDIR not absolute. +if(IS_ABSOLUTE "${CMAKE_INSTALL_LIBDIR}") + set(PKGCONFIG_TARGET_LIBS "${CMAKE_INSTALL_LIBDIR}") +else() + set(PKGCONFIG_TARGET_LIBS "\${exec_prefix}/${CMAKE_INSTALL_LIBDIR}") +endif() + +set(JPEGXL_THREADS_LIBRARY_REQUIRES "") +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/threads/libjxl_threads.pc.in" + "libjxl_threads.pc" @ONLY) +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/libjxl_threads.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") diff --git a/media/libjxl/src/lib/lib.gni b/media/libjxl/src/lib/lib.gni new file mode 100644 index 000000000..2914de991 --- /dev/null +++ b/media/libjxl/src/lib/lib.gni @@ -0,0 +1,501 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Source files definitions for GN-based build systems. + +# Library version macros +libjxl_version_defines = [ + "JPEGXL_MAJOR_VERSION=0", + "JPEGXL_MINOR_VERSION=7", + "JPEGXL_PATCH_VERSION=0", +] + +libjxl_public_headers = [ + "include/jxl/butteraugli.h", + "include/jxl/butteraugli_cxx.h", + "include/jxl/cms_interface.h", + "include/jxl/codestream_header.h", + "include/jxl/color_encoding.h", + "include/jxl/decode.h", + "include/jxl/decode_cxx.h", + "include/jxl/encode.h", + "include/jxl/encode_cxx.h", + "include/jxl/memory_manager.h", + "include/jxl/parallel_runner.h", + "include/jxl/types.h", +] + +libjxl_dec_sources = [ + "jxl/ac_context.h", + "jxl/ac_strategy.cc", + "jxl/ac_strategy.h", + "jxl/alpha.cc", + "jxl/alpha.h", + "jxl/ans_common.cc", + "jxl/ans_common.h", + "jxl/ans_params.h", + "jxl/aux_out.cc", + "jxl/aux_out.h", + "jxl/aux_out_fwd.h", + "jxl/base/arch_macros.h", + "jxl/base/bits.h", + "jxl/base/byte_order.h", + "jxl/base/cache_aligned.cc", + "jxl/base/cache_aligned.h", + "jxl/base/compiler_specific.h", + "jxl/base/data_parallel.cc", + "jxl/base/data_parallel.h", + "jxl/base/file_io.h", + "jxl/base/iaca.h", + "jxl/base/os_macros.h", + "jxl/base/override.h", + "jxl/base/padded_bytes.cc", + "jxl/base/padded_bytes.h", + "jxl/base/printf_macros.h", + "jxl/base/profiler.h", + "jxl/base/random.cc", + "jxl/base/random.h", + "jxl/base/sanitizer_definitions.h", + "jxl/base/scope_guard.h", + "jxl/base/span.h", + "jxl/base/status.h", + "jxl/base/thread_pool_internal.h", + "jxl/blending.cc", + "jxl/blending.h", + "jxl/box_content_decoder.cc", + "jxl/box_content_decoder.h", + "jxl/chroma_from_luma.cc", + "jxl/chroma_from_luma.h", + "jxl/codec_in_out.h", + "jxl/coeff_order.cc", + "jxl/coeff_order.h", + "jxl/coeff_order_fwd.h", + "jxl/color_encoding_internal.cc", + "jxl/color_encoding_internal.h", + "jxl/color_management.cc", + "jxl/color_management.h", + "jxl/common.h", + "jxl/compressed_dc.cc", + "jxl/compressed_dc.h", + "jxl/convolve-inl.h", + "jxl/convolve.h", + "jxl/convolve_separable5.cc", + "jxl/convolve_separable7.cc", + "jxl/convolve_slow.cc", + "jxl/convolve_symmetric3.cc", + "jxl/convolve_symmetric5.cc", + "jxl/dct-inl.h", + "jxl/dct_block-inl.h", + "jxl/dct_scales.cc", + "jxl/dct_scales.h", + "jxl/dct_util.h", + "jxl/dec_ans.cc", + "jxl/dec_ans.h", + "jxl/dec_bit_reader.h", + "jxl/dec_cache.cc", + "jxl/dec_cache.h", + "jxl/dec_context_map.cc", + "jxl/dec_context_map.h", + "jxl/dec_external_image.cc", + "jxl/dec_external_image.h", + "jxl/dec_frame.cc", + "jxl/dec_frame.h", + "jxl/dec_group.cc", + "jxl/dec_group.h", + "jxl/dec_group_border.cc", + "jxl/dec_group_border.h", + "jxl/dec_huffman.cc", + "jxl/dec_huffman.h", + "jxl/dec_modular.cc", + "jxl/dec_modular.h", + "jxl/dec_noise.cc", + "jxl/dec_noise.h", + "jxl/dec_patch_dictionary.cc", + "jxl/dec_patch_dictionary.h", + "jxl/dec_tone_mapping-inl.h", + "jxl/dec_transforms-inl.h", + "jxl/dec_xyb-inl.h", + "jxl/dec_xyb.cc", + "jxl/dec_xyb.h", + "jxl/decode.cc", + "jxl/decode_to_jpeg.cc", + "jxl/decode_to_jpeg.h", + "jxl/enc_bit_writer.cc", + "jxl/enc_bit_writer.h", + "jxl/entropy_coder.cc", + "jxl/entropy_coder.h", + "jxl/epf.cc", + "jxl/epf.h", + "jxl/exif.h", + "jxl/fast_dct-inl.h", + "jxl/fast_dct.cc", + "jxl/fast_dct.h", + "jxl/fast_dct128-inl.h", + "jxl/fast_dct16-inl.h", + "jxl/fast_dct256-inl.h", + "jxl/fast_dct32-inl.h", + "jxl/fast_dct64-inl.h", + "jxl/fast_dct8-inl.h", + "jxl/fast_math-inl.h", + "jxl/field_encodings.h", + "jxl/fields.cc", + "jxl/fields.h", + "jxl/frame_header.cc", + "jxl/frame_header.h", + "jxl/gauss_blur.cc", + "jxl/gauss_blur.h", + "jxl/headers.cc", + "jxl/headers.h", + "jxl/huffman_table.cc", + "jxl/huffman_table.h", + "jxl/icc_codec.cc", + "jxl/icc_codec.h", + "jxl/icc_codec_common.cc", + "jxl/icc_codec_common.h", + "jxl/image.cc", + "jxl/image.h", + "jxl/image_bundle.cc", + "jxl/image_bundle.h", + "jxl/image_metadata.cc", + "jxl/image_metadata.h", + "jxl/image_ops.h", + "jxl/jpeg/dec_jpeg_data.cc", + "jxl/jpeg/dec_jpeg_data.h", + "jxl/jpeg/dec_jpeg_data_writer.cc", + "jxl/jpeg/dec_jpeg_data_writer.h", + "jxl/jpeg/dec_jpeg_output_chunk.h", + "jxl/jpeg/dec_jpeg_serialization_state.h", + "jxl/jpeg/jpeg_data.cc", + "jxl/jpeg/jpeg_data.h", + "jxl/jxl_inspection.h", + "jxl/lehmer_code.h", + "jxl/linalg.h", + "jxl/loop_filter.cc", + "jxl/loop_filter.h", + "jxl/luminance.cc", + "jxl/luminance.h", + "jxl/memory_manager_internal.cc", + "jxl/memory_manager_internal.h", + "jxl/modular/encoding/context_predict.h", + "jxl/modular/encoding/dec_ma.cc", + "jxl/modular/encoding/dec_ma.h", + "jxl/modular/encoding/encoding.cc", + "jxl/modular/encoding/encoding.h", + "jxl/modular/encoding/ma_common.h", + "jxl/modular/modular_image.cc", + "jxl/modular/modular_image.h", + "jxl/modular/options.h", + "jxl/modular/transform/palette.h", + "jxl/modular/transform/rct.cc", + "jxl/modular/transform/rct.h", + "jxl/modular/transform/squeeze.cc", + "jxl/modular/transform/squeeze.h", + "jxl/modular/transform/transform.cc", + "jxl/modular/transform/transform.h", + "jxl/noise.h", + "jxl/opsin_params.cc", + "jxl/opsin_params.h", + "jxl/passes_state.cc", + "jxl/passes_state.h", + "jxl/patch_dictionary_internal.h", + "jxl/quant_weights.cc", + "jxl/quant_weights.h", + "jxl/quantizer-inl.h", + "jxl/quantizer.cc", + "jxl/quantizer.h", + "jxl/rational_polynomial-inl.h", + "jxl/render_pipeline/low_memory_render_pipeline.cc", + "jxl/render_pipeline/low_memory_render_pipeline.h", + "jxl/render_pipeline/render_pipeline.cc", + "jxl/render_pipeline/render_pipeline.h", + "jxl/render_pipeline/render_pipeline_stage.h", + "jxl/render_pipeline/simple_render_pipeline.cc", + "jxl/render_pipeline/simple_render_pipeline.h", + "jxl/render_pipeline/stage_blending.cc", + "jxl/render_pipeline/stage_blending.h", + "jxl/render_pipeline/stage_chroma_upsampling.cc", + "jxl/render_pipeline/stage_chroma_upsampling.h", + "jxl/render_pipeline/stage_epf.cc", + "jxl/render_pipeline/stage_epf.h", + "jxl/render_pipeline/stage_from_linear.cc", + "jxl/render_pipeline/stage_from_linear.h", + "jxl/render_pipeline/stage_gaborish.cc", + "jxl/render_pipeline/stage_gaborish.h", + "jxl/render_pipeline/stage_noise.cc", + "jxl/render_pipeline/stage_noise.h", + "jxl/render_pipeline/stage_patches.cc", + "jxl/render_pipeline/stage_patches.h", + "jxl/render_pipeline/stage_splines.cc", + "jxl/render_pipeline/stage_splines.h", + "jxl/render_pipeline/stage_spot.cc", + "jxl/render_pipeline/stage_spot.h", + "jxl/render_pipeline/stage_to_linear.cc", + "jxl/render_pipeline/stage_to_linear.h", + "jxl/render_pipeline/stage_tone_mapping.cc", + "jxl/render_pipeline/stage_tone_mapping.h", + "jxl/render_pipeline/stage_upsampling.cc", + "jxl/render_pipeline/stage_upsampling.h", + "jxl/render_pipeline/stage_write.cc", + "jxl/render_pipeline/stage_write.h", + "jxl/render_pipeline/stage_xyb.cc", + "jxl/render_pipeline/stage_xyb.h", + "jxl/render_pipeline/stage_ycbcr.cc", + "jxl/render_pipeline/stage_ycbcr.h", + "jxl/render_pipeline/test_render_pipeline_stages.h", + "jxl/sanitizers.h", + "jxl/simd_util-inl.h", + "jxl/size_constraints.h", + "jxl/splines.cc", + "jxl/splines.h", + "jxl/toc.cc", + "jxl/toc.h", + "jxl/transfer_functions-inl.h", + "jxl/transpose-inl.h", + "jxl/xorshift128plus-inl.h", +] + +libjxl_enc_sources = [ + "jxl/butteraugli/butteraugli.cc", + "jxl/butteraugli/butteraugli.h", + "jxl/butteraugli_wrapper.cc", + "jxl/enc_ac_strategy.cc", + "jxl/enc_ac_strategy.h", + "jxl/enc_adaptive_quantization.cc", + "jxl/enc_adaptive_quantization.h", + "jxl/enc_ans.cc", + "jxl/enc_ans.h", + "jxl/enc_ans_params.h", + "jxl/enc_ar_control_field.cc", + "jxl/enc_ar_control_field.h", + "jxl/enc_butteraugli_comparator.cc", + "jxl/enc_butteraugli_comparator.h", + "jxl/enc_butteraugli_pnorm.cc", + "jxl/enc_butteraugli_pnorm.h", + "jxl/enc_cache.cc", + "jxl/enc_cache.h", + "jxl/enc_chroma_from_luma.cc", + "jxl/enc_chroma_from_luma.h", + "jxl/enc_cluster.cc", + "jxl/enc_cluster.h", + "jxl/enc_coeff_order.cc", + "jxl/enc_coeff_order.h", + "jxl/enc_color_management.cc", + "jxl/enc_color_management.h", + "jxl/enc_comparator.cc", + "jxl/enc_comparator.h", + "jxl/enc_context_map.cc", + "jxl/enc_context_map.h", + "jxl/enc_detect_dots.cc", + "jxl/enc_detect_dots.h", + "jxl/enc_dot_dictionary.cc", + "jxl/enc_dot_dictionary.h", + "jxl/enc_entropy_coder.cc", + "jxl/enc_entropy_coder.h", + "jxl/enc_external_image.cc", + "jxl/enc_external_image.h", + "jxl/enc_file.cc", + "jxl/enc_file.h", + "jxl/enc_frame.cc", + "jxl/enc_frame.h", + "jxl/enc_gamma_correct.h", + "jxl/enc_group.cc", + "jxl/enc_group.h", + "jxl/enc_heuristics.cc", + "jxl/enc_heuristics.h", + "jxl/enc_huffman.cc", + "jxl/enc_huffman.h", + "jxl/enc_icc_codec.cc", + "jxl/enc_icc_codec.h", + "jxl/enc_image_bundle.cc", + "jxl/enc_image_bundle.h", + "jxl/enc_jxl_skcms.h", + "jxl/enc_modular.cc", + "jxl/enc_modular.h", + "jxl/enc_noise.cc", + "jxl/enc_noise.h", + "jxl/enc_params.h", + "jxl/enc_patch_dictionary.cc", + "jxl/enc_patch_dictionary.h", + "jxl/enc_photon_noise.cc", + "jxl/enc_photon_noise.h", + "jxl/enc_quant_weights.cc", + "jxl/enc_quant_weights.h", + "jxl/enc_splines.cc", + "jxl/enc_splines.h", + "jxl/enc_toc.cc", + "jxl/enc_toc.h", + "jxl/enc_transforms-inl.h", + "jxl/enc_transforms.cc", + "jxl/enc_transforms.h", + "jxl/enc_xyb.cc", + "jxl/enc_xyb.h", + "jxl/encode.cc", + "jxl/encode_internal.h", + "jxl/gaborish.cc", + "jxl/gaborish.h", + "jxl/huffman_tree.cc", + "jxl/huffman_tree.h", + "jxl/jpeg/enc_jpeg_data.cc", + "jxl/jpeg/enc_jpeg_data.h", + "jxl/jpeg/enc_jpeg_data_reader.cc", + "jxl/jpeg/enc_jpeg_data_reader.h", + "jxl/jpeg/enc_jpeg_huffman_decode.cc", + "jxl/jpeg/enc_jpeg_huffman_decode.h", + "jxl/linalg.cc", + "jxl/modular/encoding/enc_debug_tree.cc", + "jxl/modular/encoding/enc_debug_tree.h", + "jxl/modular/encoding/enc_encoding.cc", + "jxl/modular/encoding/enc_encoding.h", + "jxl/modular/encoding/enc_ma.cc", + "jxl/modular/encoding/enc_ma.h", + "jxl/modular/transform/enc_palette.cc", + "jxl/modular/transform/enc_palette.h", + "jxl/modular/transform/enc_rct.cc", + "jxl/modular/transform/enc_rct.h", + "jxl/modular/transform/enc_squeeze.cc", + "jxl/modular/transform/enc_squeeze.h", + "jxl/modular/transform/enc_transform.cc", + "jxl/modular/transform/enc_transform.h", + "jxl/optimize.cc", + "jxl/optimize.h", + "jxl/progressive_split.cc", + "jxl/progressive_split.h", +] + +libjxl_gbench_sources = [ + "extras/tone_mapping_gbench.cc", + "jxl/dec_external_image_gbench.cc", + "jxl/enc_external_image_gbench.cc", + "jxl/gauss_blur_gbench.cc", + "jxl/splines_gbench.cc", + "jxl/tf_gbench.cc", +] + +libjxl_tests_sources = [ + "jxl/ac_strategy_test.cc", + "jxl/alpha_test.cc", + "jxl/ans_common_test.cc", + "jxl/ans_test.cc", + "jxl/bit_reader_test.cc", + "jxl/bits_test.cc", + "jxl/blending_test.cc", + "jxl/butteraugli_test.cc", + "jxl/byte_order_test.cc", + "jxl/coeff_order_test.cc", + "jxl/color_encoding_internal_test.cc", + "jxl/color_management_test.cc", + "jxl/convolve_test.cc", + "jxl/data_parallel_test.cc", + "jxl/dct_test.cc", + "jxl/decode_test.cc", + "jxl/enc_external_image_test.cc", + "jxl/enc_photon_noise_test.cc", + "jxl/encode_test.cc", + "jxl/entropy_coder_test.cc", + "jxl/fast_dct_test.cc", + "jxl/fast_math_test.cc", + "jxl/fields_test.cc", + "jxl/gaborish_test.cc", + "jxl/gamma_correct_test.cc", + "jxl/gauss_blur_test.cc", + "jxl/gradient_test.cc", + "jxl/iaca_test.cc", + "jxl/icc_codec_test.cc", + "jxl/image_bundle_test.cc", + "jxl/image_ops_test.cc", + "jxl/jxl_test.cc", + "jxl/lehmer_code_test.cc", + "jxl/linalg_test.cc", + "jxl/modular_test.cc", + "jxl/opsin_image_test.cc", + "jxl/opsin_inverse_test.cc", + "jxl/optimize_test.cc", + "jxl/padded_bytes_test.cc", + "jxl/passes_test.cc", + "jxl/patch_dictionary_test.cc", + "jxl/preview_test.cc", + "jxl/quant_weights_test.cc", + "jxl/quantizer_test.cc", + "jxl/rational_polynomial_test.cc", + "jxl/render_pipeline/render_pipeline_test.cc", + "jxl/roundtrip_test.cc", + "jxl/simd_util_test.cc", + "jxl/speed_tier_test.cc", + "jxl/splines_test.cc", + "jxl/toc_test.cc", + "jxl/xorshift128plus_test.cc", +] + +# Test-only library code. +libjxl_testlib_sources = [ + "jxl/codec_y4m_testonly.cc", + "jxl/codec_y4m_testonly.h", + "jxl/dct_for_test.h", + "jxl/dec_transforms_testonly.cc", + "jxl/dec_transforms_testonly.h", + "jxl/fake_parallel_runner_testonly.h", + "jxl/image_test_utils.h", + "jxl/test_image.h", + "jxl/test_utils.h", + "jxl/testdata.h", +] + +libjxl_extras_sources = [ + "extras/codec.cc", + "extras/codec.h", + "extras/dec/color_description.cc", + "extras/dec/color_description.h", + "extras/dec/color_hints.cc", + "extras/dec/color_hints.h", + "extras/dec/decode.cc", + "extras/dec/decode.h", + "extras/dec/jxl.cc", + "extras/dec/jxl.h", + "extras/dec/pgx.cc", + "extras/dec/pgx.h", + "extras/dec/pnm.cc", + "extras/dec/pnm.h", + "extras/enc/encode.cc", + "extras/enc/encode.h", + "extras/enc/npy.cc", + "extras/enc/npy.h", + "extras/enc/pgx.cc", + "extras/enc/pgx.h", + "extras/enc/pnm.cc", + "extras/enc/pnm.h", + "extras/exif.cc", + "extras/exif.h", + "extras/hlg.cc", + "extras/hlg.h", + "extras/packed_image.h", + "extras/packed_image_convert.cc", + "extras/packed_image_convert.h", + "extras/render_hdr.cc", + "extras/render_hdr.h", + "extras/time.cc", + "extras/time.h", + "extras/tone_mapping.cc", + "extras/tone_mapping.h", +] + +libjxl_threads_sources = [ + "threads/resizable_parallel_runner.cc", + "threads/thread_parallel_runner.cc", + "threads/thread_parallel_runner_internal.cc", + "threads/thread_parallel_runner_internal.h", +] + +libjxl_threads_public_headers = [ + "include/jxl/resizable_parallel_runner.h", + "include/jxl/resizable_parallel_runner_cxx.h", + "include/jxl/thread_parallel_runner.h", + "include/jxl/thread_parallel_runner_cxx.h", +] + +libjxl_profiler_sources = [ + "profiler/profiler.cc", + "profiler/profiler.h", + "profiler/tsc_timer.h", +] diff --git a/media/libjxl/src/lib/profiler/profiler.cc b/media/libjxl/src/lib/profiler/profiler.cc new file mode 100644 index 000000000..c72656ee8 --- /dev/null +++ b/media/libjxl/src/lib/profiler/profiler.cc @@ -0,0 +1,536 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/profiler.h" + +#if PROFILER_ENABLED + +#include +#include +#include // memcpy + +#include // sort +#include +#include // PRIu64 +#include +#include +#include + +// Optionally use SIMD in StreamCacheLine if available. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/profiler/profiler.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace profiler { +namespace HWY_NAMESPACE { + +// Overwrites `to` without loading it into cache (read-for-ownership). +// Copies 64 bytes from/to naturally aligned addresses. +void StreamCacheLine(const Packet* HWY_RESTRICT from, Packet* HWY_RESTRICT to) { +#if HWY_TARGET == HWY_SCALAR + hwy::CopyBytes<64>(from, to); +#else + const HWY_CAPPED(uint64_t, 2) d; + HWY_FENCE; + const uint64_t* HWY_RESTRICT from64 = reinterpret_cast(from); + const auto v0 = Load(d, from64 + 0); + const auto v1 = Load(d, from64 + 2); + const auto v2 = Load(d, from64 + 4); + const auto v3 = Load(d, from64 + 6); + // Fences prevent the compiler from reordering loads/stores, which may + // interfere with write-combining. + HWY_FENCE; + uint64_t* HWY_RESTRICT to64 = reinterpret_cast(to); + Stream(v0, d, to64 + 0); + Stream(v1, d, to64 + 2); + Stream(v2, d, to64 + 4); + Stream(v3, d, to64 + 6); + HWY_FENCE; +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace profiler +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace profiler { + +HWY_EXPORT(StreamCacheLine); + +namespace { + +// How many mebibytes to allocate (if PROFILER_ENABLED) per thread that +// enters at least one zone. Once this buffer is full, the thread will analyze +// packets (two per zone), which introduces observer overhead. +#ifndef PROFILER_THREAD_STORAGE +#define PROFILER_THREAD_STORAGE 32ULL +#endif + +#define PROFILER_PRINT_OVERHEAD 0 + +// Upper bounds for fixed-size data structures (guarded via HWY_ASSERT): +constexpr size_t kMaxDepth = 64; // Maximum nesting of zones. +constexpr size_t kMaxZones = 256; // Total number of zones. + +// Stack of active (entered but not exited) zones. POD, uninitialized. +// Used to deduct child duration from the parent's self time. +struct ActiveZone { + const char* name; + uint64_t entry_timestamp; + uint64_t child_total; +}; + +// Totals for all Zones with the same name. POD, must be zero-initialized. +struct ZoneTotals { + uint64_t total_duration; + const char* name; + uint64_t num_calls; +}; + +template +inline T ClampedSubtract(const T minuend, const T subtrahend) { + if (subtrahend > minuend) { + return 0; + } + return minuend - subtrahend; +} + +} // namespace + +// Per-thread call graph (stack) and ZoneTotals for each zone. +class Results { + public: + Results() { + // Zero-initialize all accumulators (avoids a check for num_zones_ == 0). + memset(zones_, 0, sizeof(zones_)); + } + + // Used for computing overhead when this thread encounters its first Zone. + // This has no observable effect apart from increasing "analyze_elapsed_". + uint64_t ZoneDuration(const Packet* packets) { + HWY_ASSERT(depth_ == 0); + HWY_ASSERT(num_zones_ == 0); + AnalyzePackets(packets, 2); + const uint64_t duration = zones_[0].total_duration; + zones_[0].num_calls = 0; + zones_[0].total_duration = 0; + HWY_ASSERT(depth_ == 0); + num_zones_ = 0; + return duration; + } + + void SetSelfOverhead(const uint64_t self_overhead) { + self_overhead_ = self_overhead; + } + + void SetChildOverhead(const uint64_t child_overhead) { + child_overhead_ = child_overhead; + } + + // Draw all required information from the packets, which can be discarded + // afterwards. Called whenever this thread's storage is full. + void AnalyzePackets(const Packet* HWY_RESTRICT packets, + const size_t num_packets) { + // Ensures prior weakly-ordered streaming stores are globally visible. + hwy::FlushStream(); + + const uint64_t t0 = TicksBefore(); + + for (size_t i = 0; i < num_packets; ++i) { + const uint64_t timestamp = packets[i].timestamp; + // Entering a zone + if (packets[i].name != nullptr) { + HWY_ASSERT(depth_ < kMaxDepth); + zone_stack_[depth_].name = packets[i].name; + zone_stack_[depth_].entry_timestamp = timestamp; + zone_stack_[depth_].child_total = 0; + ++depth_; + continue; + } + + HWY_ASSERT(depth_ != 0); + const ActiveZone& active = zone_stack_[depth_ - 1]; + const uint64_t duration = timestamp - active.entry_timestamp; + const uint64_t self_duration = ClampedSubtract( + duration, self_overhead_ + child_overhead_ + active.child_total); + + UpdateOrAdd(active.name, 1, self_duration); + --depth_; + + // "Deduct" the nested time from its parent's self_duration. + if (depth_ != 0) { + zone_stack_[depth_ - 1].child_total += duration + child_overhead_; + } + } + + const uint64_t t1 = TicksAfter(); + analyze_elapsed_ += t1 - t0; + } + + // Incorporates results from another thread. Call after all threads have + // exited any zones. + void Assimilate(const Results& other) { + const uint64_t t0 = TicksBefore(); + HWY_ASSERT(depth_ == 0); + HWY_ASSERT(other.depth_ == 0); + + for (size_t i = 0; i < other.num_zones_; ++i) { + const ZoneTotals& zone = other.zones_[i]; + UpdateOrAdd(zone.name, zone.num_calls, zone.total_duration); + } + const uint64_t t1 = TicksAfter(); + analyze_elapsed_ += t1 - t0 + other.analyze_elapsed_; + } + + // Single-threaded. + void Print() { + const uint64_t t0 = TicksBefore(); + MergeDuplicates(); + + // Sort by decreasing total (self) cost. + std::sort(zones_, zones_ + num_zones_, + [](const ZoneTotals& r1, const ZoneTotals& r2) { + return r1.total_duration > r2.total_duration; + }); + + uint64_t total_visible_duration = 0; + for (size_t i = 0; i < num_zones_; ++i) { + const ZoneTotals& r = zones_[i]; + if (r.name[0] != '@') { + total_visible_duration += r.total_duration; + printf("%-40s: %10" PRIu64 " x %15" PRIu64 "= %15" PRIu64 "\n", r.name, + r.num_calls, r.total_duration / r.num_calls, r.total_duration); + } + } + + const uint64_t t1 = TicksAfter(); + analyze_elapsed_ += t1 - t0; + printf("Total clocks during analysis: %" PRIu64 "\n", analyze_elapsed_); + printf("Total clocks measured: %" PRIu64 "\n", total_visible_duration); + } + + // Single-threaded. Clears all results as if no zones had been recorded. + void Reset() { + analyze_elapsed_ = 0; + HWY_ASSERT(depth_ == 0); + num_zones_ = 0; + memset(zone_stack_, 0, sizeof(zone_stack_)); + memset(zones_, 0, sizeof(zones_)); + } + + private: + // Updates ZoneTotals of the same name, or inserts a new one if this thread + // has not yet seen that name. Uses a self-organizing list data structure, + // which avoids dynamic memory allocations and is faster than unordered_map. + void UpdateOrAdd(const char* name, const uint64_t num_calls, + const uint64_t duration) { + // Special case for first zone: (maybe) update, without swapping. + if (zones_[0].name == name) { + zones_[0].total_duration += duration; + zones_[0].num_calls += num_calls; + return; + } + + // Look for a zone with the same name. + for (size_t i = 1; i < num_zones_; ++i) { + if (zones_[i].name == name) { + zones_[i].total_duration += duration; + zones_[i].num_calls += num_calls; + // Swap with predecessor (more conservative than move to front, + // but at least as successful). + std::swap(zones_[i - 1], zones_[i]); + return; + } + } + + // Not found; create a new ZoneTotals. + HWY_ASSERT(num_zones_ < kMaxZones); + ZoneTotals* HWY_RESTRICT zone = zones_ + num_zones_; + zone->name = name; + zone->num_calls = num_calls; + zone->total_duration = duration; + ++num_zones_; + } + + // Each instantiation of a function template seems to get its own copy of + // __func__ and GCC doesn't merge them. An N^2 search for duplicates is + // acceptable because we only expect a few dozen zones. + void MergeDuplicates() { + for (size_t i = 0; i < num_zones_; ++i) { + // Add any subsequent duplicates to num_calls and total_duration. + for (size_t j = i + 1; j < num_zones_;) { + if (!strcmp(zones_[i].name, zones_[j].name)) { + zones_[i].num_calls += zones_[j].num_calls; + zones_[i].total_duration += zones_[j].total_duration; + // Fill hole with last item. + zones_[j] = zones_[--num_zones_]; + } else { // Name differed, try next ZoneTotals. + ++j; + } + } + } + } + + uint64_t analyze_elapsed_ = 0; + uint64_t self_overhead_ = 0; + uint64_t child_overhead_ = 0; + + size_t depth_ = 0; // Number of active zones <= kMaxDepth. + size_t num_zones_ = 0; // Number of unique zones <= kMaxZones. + + // After other members to avoid large pointer offsets. + alignas(64) ActiveZone zone_stack_[kMaxDepth]; // Last = newest + alignas(64) ZoneTotals zones_[kMaxZones]; // Self-organizing list +}; + +ThreadSpecific::ThreadSpecific() + : max_packets_(PROFILER_THREAD_STORAGE << 16), // MiB / sizeof(Packet) + packets_(hwy::AllocateAligned(max_packets_)), + num_packets_(0), + results_(hwy::MakeUniqueAligned()) {} + +ThreadSpecific::~ThreadSpecific() {} + +void ThreadSpecific::FlushBuffer() { + if (num_packets_ + kBufferCapacity > max_packets_) { + results_->AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + // This buffering halves observer overhead and decreases the overall + // runtime by about 3%. + HWY_DYNAMIC_DISPATCH(StreamCacheLine) + (buffer_, packets_.get() + num_packets_); + num_packets_ += kBufferCapacity; + buffer_size_ = 0; +} + +void ThreadSpecific::AnalyzeRemainingPackets() { + // Storage full => empty it. + if (num_packets_ + buffer_size_ > max_packets_) { + results_->AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + + // Move buffer to storage + memcpy(packets_.get() + num_packets_, buffer_, buffer_size_ * sizeof(Packet)); + num_packets_ += buffer_size_; + buffer_size_ = 0; + + results_->AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; +} + +namespace { + +class HalfSampleMode { + public: + // Returns mode. "sorted" must be in ascending order. + template + T operator()(const T* const HWY_RESTRICT sorted, + const size_t num_values) const { + int64_t center = num_values / 2; + int64_t width = num_values; + + // Zoom in on modal intervals of decreasing width. Stop before we reach + // width=1, i.e. single values, for which there is no "slope". + while (width > 2) { + // Round up so we can still reach the outer edges of odd widths. + width = (width + 1) / 2; + + center = CenterOfIntervalWithMinSlope(sorted, num_values, center, width); + } + + return sorted[center]; // mode := middle value in modal interval. + } + + private: + // Returns center of the densest region [c-radius, c+radius]. + template + static HWY_INLINE int64_t CenterOfIntervalWithMinSlope( + const T* HWY_RESTRICT sorted, const int64_t total_values, + const int64_t center, const int64_t width) { + const int64_t radius = (width + 1) / 2; + + auto compute_slope = [radius, total_values, sorted]( + int64_t c, int64_t* actual_center = nullptr) { + // For symmetry, check 2*radius+1 values, i.e. [min, max]. + const int64_t min = std::max(c - radius, int64_t(0)); + const int64_t max = std::min(c + radius, total_values - 1); + HWY_ASSERT(min < max); + HWY_ASSERT(sorted[min] <= + sorted[max] + std::numeric_limits::epsilon()); + const float dx = max - min + 1; + const float slope = (sorted[max] - sorted[min]) / dx; + + if (actual_center != nullptr) { + // c may be out of bounds, so return center of the clamped bounds. + *actual_center = (min + max + 1) / 2; + } + return slope; + }; + + // First find min_slope for all centers. + float min_slope = std::numeric_limits::max(); + for (int64_t c = center - radius; c <= center + radius; ++c) { + min_slope = std::min(min_slope, compute_slope(c)); + } + + // Candidates := centers with slope ~= min_slope. + std::vector candidates; + for (int64_t c = center - radius; c <= center + radius; ++c) { + int64_t actual_center; + const float slope = compute_slope(c, &actual_center); + if (slope <= min_slope * 1.001f) { + candidates.push_back(actual_center); + } + } + + // Keep the median. + HWY_ASSERT(!candidates.empty()); + if (candidates.size() == 1) return candidates[0]; + std::nth_element(candidates.begin(), + candidates.begin() + candidates.size() / 2, + candidates.end()); + return candidates[candidates.size() / 2]; + } +}; + +} // namespace + +void ThreadSpecific::ComputeOverhead() { + // Delay after capturing timestamps before/after the actual zone runs. Even + // with frequency throttling disabled, this has a multimodal distribution, + // including 32, 34, 48, 52, 59, 62. + uint64_t self_overhead; + { + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 1024; + uint32_t durations[kNumDurations]; + + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + { // + PROFILER_ZONE("Dummy Zone (never shown)"); + } + const uint64_t duration = results_->ZoneDuration(buffer_); + buffer_size_ = 0; + durations[idx_duration] = static_cast(duration); + HWY_ASSERT(num_packets_ == 0); + } + std::sort(durations, durations + kNumDurations); + samples[idx_sample] = HalfSampleMode()(durations, kNumDurations); + } + // Median. + std::sort(samples, samples + kNumSamples); + self_overhead = samples[kNumSamples / 2]; +#if PROFILER_PRINT_OVERHEAD + printf("Overhead: %" PRIu64 "\n", static_cast(self_overhead)); +#endif + results_->SetSelfOverhead(self_overhead); + } + + // Delay before capturing start timestamp / after end timestamp. + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 16; + uint32_t durations[kNumDurations]; + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + const size_t kReps = 10000; + // Analysis time should not be included => must fit within buffer. + HWY_ASSERT(kReps * 2 < max_packets_); + hwy::FlushStream(); + const uint64_t t0 = TicksBefore(); + for (size_t i = 0; i < kReps; ++i) { + PROFILER_ZONE("Dummy"); + } + hwy::FlushStream(); + const uint64_t t1 = TicksAfter(); + HWY_ASSERT(num_packets_ + buffer_size_ == kReps * 2); + buffer_size_ = 0; + num_packets_ = 0; + const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps; + durations[idx_duration] = + static_cast(ClampedSubtract(avg_duration, self_overhead)); + } + std::sort(durations, durations + kNumDurations); + samples[idx_sample] = HalfSampleMode()(durations, kNumDurations); + } + std::sort(samples, samples + kNumSamples); + const uint64_t child_overhead = samples[9 * kNumSamples / 10]; +#if PROFILER_PRINT_OVERHEAD + printf("Child overhead: %" PRIu64 "\n", + static_cast(child_overhead)); +#endif + results_->SetChildOverhead(child_overhead); +} + +namespace { + +// Could be a static member of Zone, but that would expose in header. +std::atomic& GetHead() { + static std::atomic head_{nullptr}; // Owning + return head_; +} + +} // namespace + +// Thread-safe. +ThreadSpecific* Zone::InitThreadSpecific() { + ThreadSpecific* thread_specific = + hwy::MakeUniqueAligned().release(); + + // Insert into unordered list + std::atomic& head = GetHead(); + ThreadSpecific* old_head = head.load(std::memory_order_relaxed); + thread_specific->SetNext(old_head); + while (!head.compare_exchange_weak(old_head, thread_specific, + std::memory_order_release, + std::memory_order_relaxed)) { + thread_specific->SetNext(old_head); + // TODO(janwas): pause + } + + // ComputeOverhead also creates a Zone, so this needs to be set before that + // to prevent infinite recursion. + GetThreadSpecific() = thread_specific; + + thread_specific->ComputeOverhead(); + return thread_specific; +} + +// Single-threaded. +/*static*/ void Zone::PrintResults() { + ThreadSpecific* head = GetHead().load(std::memory_order_relaxed); + ThreadSpecific* p = head; + while (p) { + p->AnalyzeRemainingPackets(); + + // Combine all threads into a single Result. + if (p != head) { + head->GetResults().Assimilate(p->GetResults()); + p->GetResults().Reset(); + } + + p = p->GetNext(); + } + + if (head != nullptr) { + head->GetResults().Print(); + head->GetResults().Reset(); + } +} + +} // namespace profiler + +#endif // HWY_ONCE +#endif // PROFILER_ENABLED diff --git a/media/libjxl/src/lib/profiler/profiler.h b/media/libjxl/src/lib/profiler/profiler.h new file mode 100644 index 000000000..c71f63cb3 --- /dev/null +++ b/media/libjxl/src/lib/profiler/profiler.h @@ -0,0 +1,165 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_PROFILER_PROFILER_H_ +#define LIB_PROFILER_PROFILER_H_ + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// Usage: instrument regions of interest: { PROFILER_ZONE("name"); /*code*/ } or +// void FuncToMeasure() { PROFILER_FUNC; /*code*/ }. +// After all threads have exited any zones, invoke PROFILER_PRINT_RESULTS() to +// print call counts and average durations [CPU cycles] to stdout, sorted in +// descending order of total duration. + +// If zero, this file has no effect and no measurements will be recorded. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif +#if PROFILER_ENABLED + +#include +#include + +#include +#include + +#include "lib/profiler/tsc_timer.h" + +#if HWY_COMPILER_MSVC +#define PROFILER_PUBLIC +#else +#define PROFILER_PUBLIC __attribute__((visibility("default"))) +#endif + +namespace profiler { + +// Represents zone entry/exit events. POD. +#pragma pack(push, 1) +struct Packet { + // Computing a hash or string table is likely too expensive, and offsets + // from other libraries' string literals can be too large to combine them and + // a full-resolution timestamp into 64 bits. + uint64_t timestamp; + const char* name; // nullptr for exit packets +#if UINTPTR_MAX <= 0xFFFFFFFFu + uint32_t padding; +#endif +}; +#pragma pack(pop) +static_assert(sizeof(Packet) == 16, "Wrong Packet size"); + +class Results; // pImpl + +// Per-thread packet storage, dynamically allocated and aligned. +class ThreadSpecific { + static constexpr size_t kBufferCapacity = 64 / sizeof(Packet); + + public: + PROFILER_PUBLIC explicit ThreadSpecific(); + PROFILER_PUBLIC ~ThreadSpecific(); + + // Depends on Zone => defined out of line. + PROFILER_PUBLIC void ComputeOverhead(); + + HWY_INLINE void WriteEntry(const char* name) { Write(name, TicksBefore()); } + HWY_INLINE void WriteExit() { Write(nullptr, TicksAfter()); } + + PROFILER_PUBLIC void AnalyzeRemainingPackets(); + + // Accessors instead of public member for well-defined data layout. + void SetNext(ThreadSpecific* next) { next_ = next; } + ThreadSpecific* GetNext() const { return next_; } + + Results& GetResults() { return *results_; } + + private: + PROFILER_PUBLIC void FlushBuffer(); + + // Write packet to buffer/storage, emptying them as needed. + void Write(const char* name, const uint64_t timestamp) { + if (buffer_size_ == kBufferCapacity) { // Full + FlushBuffer(); + } + buffer_[buffer_size_].name = name; + buffer_[buffer_size_].timestamp = timestamp; + ++buffer_size_; + } + + // Write-combining buffer to avoid cache pollution. Must be the first + // non-static member to ensure cache-line alignment. + Packet buffer_[kBufferCapacity]; + size_t buffer_size_ = 0; + + // Contiguous storage for zone enter/exit packets. + const size_t max_packets_; + hwy::AlignedFreeUniquePtr packets_; + size_t num_packets_; + + // Linked list of all threads. + ThreadSpecific* next_ = nullptr; // Owned, never released. + + hwy::AlignedUniquePtr results_; +}; + +// RAII zone enter/exit recorder constructed by PROFILER_ZONE; also +// responsible for initializing ThreadSpecific. +class Zone { + public: + HWY_NOINLINE explicit Zone(const char* name) { + HWY_FENCE; + ThreadSpecific* HWY_RESTRICT thread_specific = GetThreadSpecific(); + if (HWY_UNLIKELY(thread_specific == nullptr)) { + thread_specific = InitThreadSpecific(); + } + + thread_specific->WriteEntry(name); + } + + HWY_NOINLINE ~Zone() { GetThreadSpecific()->WriteExit(); } + + // Call exactly once after all threads have exited all zones. + PROFILER_PUBLIC static void PrintResults(); + + private: + // Returns reference to the thread's ThreadSpecific pointer (initially null). + // Function-local static avoids needing a separate definition. + static ThreadSpecific*& GetThreadSpecific() { + static thread_local ThreadSpecific* thread_specific; + return thread_specific; + } + + // Non time-critical. + PROFILER_PUBLIC ThreadSpecific* InitThreadSpecific(); +}; + +// Creates a zone starting from here until the end of the current scope. +// Timestamps will be recorded when entering and exiting the zone. +// To ensure the name pointer remains valid, we require it to be a string +// literal (by merging with ""). We also compare strings by address. +#define PROFILER_ZONE(name) \ + HWY_FENCE; \ + const ::profiler::Zone zone("" name); \ + HWY_FENCE + +// Creates a zone for an entire function (when placed at its beginning). +// Shorter/more convenient than ZONE. +#define PROFILER_FUNC \ + HWY_FENCE; \ + const ::profiler::Zone zone(__func__); \ + HWY_FENCE + +#define PROFILER_PRINT_RESULTS ::profiler::Zone::PrintResults + +} // namespace profiler + +#else // !PROFILER_ENABLED +#define PROFILER_ZONE(name) +#define PROFILER_FUNC +#define PROFILER_PRINT_RESULTS() +#endif + +#endif // LIB_PROFILER_PROFILER_H_ diff --git a/media/libjxl/src/lib/profiler/tsc_timer.h b/media/libjxl/src/lib/profiler/tsc_timer.h new file mode 100644 index 000000000..9387f4195 --- /dev/null +++ b/media/libjxl/src/lib/profiler/tsc_timer.h @@ -0,0 +1,170 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_PROFILER_TSC_TIMER_H_ +#define LIB_PROFILER_TSC_TIMER_H_ + +// High-resolution (~10 ns) timestamps, using fences to prevent reordering and +// ensure exactly the desired regions are measured. + +#include +#include // clock_gettime + +#if defined(_WIN32) || defined(_WIN64) +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#ifndef NOGDI +#define NOGDI +#endif // NOGDI +#include +// Undef macros to avoid collisions +#undef LoadFence +#endif + +#if defined(__APPLE__) +#include +#include +#endif + +#if defined(__HAIKU__) +#include +#endif + +#include +#include +#include // LoadFence + +namespace profiler { + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. +using Ticks = uint64_t; + +// TicksBefore/After return absolute timestamps and must be placed immediately +// before and after the region to measure. We provide separate Before/After +// functions because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final RDTSCP would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional RDTSCP half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than RDTSCP because it does not read TSC_AUX. In summary, +// we define Before = LFENCE/RDTSC/LFENCE; After = RDTSCP/LFENCE. +// +// Using Before+Before leads to higher variance and overhead than After+After. +// However, After+After includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Before+After +// is faster than Before+Before and more consistent than After+After because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +static HWY_INLINE HWY_MAYBE_UNUSED Ticks TicksBefore() { + Ticks t; +#if HWY_ARCH_PPC && defined(__GLIBC__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + hwy::LoadFence(); + HWY_FENCE; + t = __rdtsc(); + hwy::LoadFence(); + HWY_FENCE; +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#elif HWY_ARCH_RVV + asm volatile("rdcycle %0" : "=r"(t)); +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + t = counter.QuadPart; +#elif defined(__APPLE__) + t = mach_absolute_time(); +#elif defined(__HAIKU__) + t = system_time_nsecs(); // since boot +#else // POSIX + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = static_cast(ts.tv_sec * 1000000000LL + ts.tv_nsec); +#endif + return t; +} + +static HWY_INLINE HWY_MAYBE_UNUSED Ticks TicksAfter() { + Ticks t; +#if HWY_ARCH_PPC && defined(__GLIBC__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + HWY_FENCE; + unsigned aux; + t = __rdtscp(&aux); + hwy::LoadFence(); + HWY_FENCE; +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = TicksBefore(); // no difference on other platforms. +#endif + return t; +} + +} // namespace profiler + +#endif // LIB_PROFILER_TSC_TIMER_H_ diff --git a/media/libjxl/src/lib/threads/libjxl_threads.pc.in b/media/libjxl/src/lib/threads/libjxl_threads.pc.in new file mode 100644 index 000000000..50b937a84 --- /dev/null +++ b/media/libjxl/src/lib/threads/libjxl_threads.pc.in @@ -0,0 +1,13 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@PKGCONFIG_TARGET_LIBS@ +includedir=@PKGCONFIG_TARGET_INCLUDES@ + +Name: libjxl_threads +Description: JPEG XL multi-thread runner using std::threads. +Version: @JPEGXL_LIBRARY_VERSION@ +Requires.private: @JPEGXL_THREADS_LIBRARY_REQUIRES@ +Libs: -L${libdir} -ljxl_threads +Libs.private: -lm +Cflags: -I${includedir} +Cflags.private: -DJXL_THREADS_STATIC_DEFINE diff --git a/media/libjxl/src/lib/threads/resizable_parallel_runner.cc b/media/libjxl/src/lib/threads/resizable_parallel_runner.cc new file mode 100644 index 000000000..1208a3856 --- /dev/null +++ b/media/libjxl/src/lib/threads/resizable_parallel_runner.cc @@ -0,0 +1,195 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/resizable_parallel_runner.h" + +#include +#include +#include +#include +#include +#include + +namespace jpegxl { +namespace { + +// A thread pool that allows changing the number of threads it runs. It also +// runs tasks on the calling thread, which can work better on schedulers for +// heterogeneous architectures. +struct ResizeableParallelRunner { + void SetNumThreads(size_t num) { + if (num > 0) { + num -= 1; + } + { + std::unique_lock l(state_mutex_); + num_desired_workers_ = num; + workers_can_proceed_.notify_all(); + } + if (workers_.size() < num) { + for (size_t i = workers_.size(); i < num; i++) { + workers_.emplace_back([this, i]() { WorkerBody(i); }); + } + } + if (workers_.size() > num) { + for (size_t i = num; i < workers_.size(); i++) { + workers_[i].join(); + } + workers_.resize(num); + } + } + + ~ResizeableParallelRunner() { SetNumThreads(0); } + + JxlParallelRetCode Run(void* jxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start, + uint32_t end) { + if (start + 1 == end) { + JxlParallelRetCode ret = init(jxl_opaque, 1); + if (ret != 0) return ret; + + func(jxl_opaque, start, 0); + return ret; + } + + size_t num_workers = std::min(workers_.size() + 1, end - start); + JxlParallelRetCode ret = init(jxl_opaque, num_workers); + if (ret != 0) { + return ret; + } + + { + std::unique_lock l(state_mutex_); + // Avoid waking up more workers than needed. + max_running_workers_ = end - start - 1; + next_task_ = start; + end_task_ = end; + func_ = func; + jxl_opaque_ = jxl_opaque; + work_available_ = true; + num_running_workers_++; + workers_can_proceed_.notify_all(); + } + + DequeueTasks(0); + + while (true) { + std::unique_lock l(state_mutex_); + if (num_running_workers_ == 0) break; + work_done_.wait(l); + } + + return ret; + } + + private: + void WorkerBody(size_t worker_id) { + while (true) { + { + std::unique_lock l(state_mutex_); + // Worker pool was reduced, resize down. + if (worker_id >= num_desired_workers_) { + return; + } + // Nothing to do this time. + if (!work_available_ || worker_id >= max_running_workers_) { + workers_can_proceed_.wait(l); + continue; + } + num_running_workers_++; + } + DequeueTasks(worker_id + 1); + } + } + + void DequeueTasks(size_t thread_id) { + while (true) { + uint32_t task = next_task_++; + if (task >= end_task_) { + std::unique_lock l(state_mutex_); + num_running_workers_--; + work_available_ = false; + if (num_running_workers_ == 0) { + work_done_.notify_all(); + } + break; + } + func_(jxl_opaque_, task, thread_id); + } + } + + // Checks when the worker has something to do, which can be one of: + // - quitting (when worker_id >= num_desired_workers_) + // - having work available for them (work_available_ is true and worker_id >= + // max_running_workers_) + std::condition_variable workers_can_proceed_; + + // Workers are done, and the main thread can proceed (num_running_workers_ == + // 0) + std::condition_variable work_done_; + + std::vector workers_; + + // Protects all the remaining variables, except for func_, jxl_opaque_ and + // end_task_ (for which only the write by the main thread is protected, and + // subsequent uses by workers happen-after it) and next_task_ (which is + // atomic). + std::mutex state_mutex_; + + // Range of tasks still need to be done. + std::atomic next_task_; + uint32_t end_task_; + + // Function to run and its argument. + JxlParallelRunFunction func_; + void* jxl_opaque_; // not owned + + // Variables that control the workers: + // - work_available_ is set to true after a call to Run() and to false at the + // end of it. + // - num_desired_workers_ represents the number of workers that should be + // present. + // - max_running_workers_ represents the number of workers that should be + // executing tasks. + // - num_running_workers_ represents the number of workers that are executing + // tasks. + size_t num_desired_workers_ = 0; + size_t max_running_workers_ = 0; + size_t num_running_workers_ = 0; + bool work_available_ = false; +}; +} // namespace +} // namespace jpegxl + +extern "C" { +JXL_THREADS_EXPORT JxlParallelRetCode JxlResizableParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + return static_cast(runner_opaque) + ->Run(jpegxl_opaque, init, func, start_range, end_range); +} + +JXL_THREADS_EXPORT void* JxlResizableParallelRunnerCreate( + const JxlMemoryManager* memory_manager) { + return new jpegxl::ResizeableParallelRunner(); +} + +JXL_THREADS_EXPORT void JxlResizableParallelRunnerSetThreads( + void* runner_opaque, size_t num_threads) { + static_cast(runner_opaque) + ->SetNumThreads(num_threads); +} + +JXL_THREADS_EXPORT void JxlResizableParallelRunnerDestroy(void* runner_opaque) { + delete static_cast(runner_opaque); +} + +JXL_THREADS_EXPORT uint32_t +JxlResizableParallelRunnerSuggestThreads(uint64_t xsize, uint64_t ysize) { + // ~one thread per group. + return std::min(std::thread::hardware_concurrency(), + xsize * ysize / (256 * 256)); +} +} diff --git a/media/libjxl/src/lib/threads/thread_parallel_runner.cc b/media/libjxl/src/lib/threads/thread_parallel_runner.cc new file mode 100644 index 000000000..0d5b962d9 --- /dev/null +++ b/media/libjxl/src/lib/threads/thread_parallel_runner.cc @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/thread_parallel_runner.h" + +#include + +#include "lib/threads/thread_parallel_runner_internal.h" + +namespace { + +// Default JxlMemoryManager using malloc and free for the jpegxl_threads +// library. Same as the default JxlMemoryManager for the jpegxl library +// itself. + +// Default alloc and free functions. +void* ThreadMemoryManagerDefaultAlloc(void* opaque, size_t size) { + return malloc(size); +} + +void ThreadMemoryManagerDefaultFree(void* opaque, void* address) { + free(address); +} + +// Initializes the memory manager instance with the passed one. The +// MemoryManager passed in |memory_manager| may be NULL or contain NULL +// functions which will be initialized with the default ones. If either alloc +// or free are NULL, then both must be NULL, otherwise this function returns an +// error. +bool ThreadMemoryManagerInit(JxlMemoryManager* self, + const JxlMemoryManager* memory_manager) { + if (memory_manager) { + *self = *memory_manager; + } else { + memset(self, 0, sizeof(*self)); + } + if (!self->alloc != !self->free) { + return false; + } + if (!self->alloc) self->alloc = ThreadMemoryManagerDefaultAlloc; + if (!self->free) self->free = ThreadMemoryManagerDefaultFree; + + return true; +} + +void* ThreadMemoryManagerAlloc(const JxlMemoryManager* memory_manager, + size_t size) { + return memory_manager->alloc(memory_manager->opaque, size); +} + +void ThreadMemoryManagerFree(const JxlMemoryManager* memory_manager, + void* address) { + return memory_manager->free(memory_manager->opaque, address); +} + +} // namespace + +JxlParallelRetCode JxlThreadParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + return jpegxl::ThreadParallelRunner::Runner( + runner_opaque, jpegxl_opaque, init, func, start_range, end_range); +} + +/// Starts the given number of worker threads and blocks until they are ready. +/// "num_worker_threads" defaults to one per hyperthread. If zero, all tasks +/// run on the main thread. +void* JxlThreadParallelRunnerCreate(const JxlMemoryManager* memory_manager, + size_t num_worker_threads) { + JxlMemoryManager local_memory_manager; + if (!ThreadMemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = ThreadMemoryManagerAlloc(&local_memory_manager, + sizeof(jpegxl::ThreadParallelRunner)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + jpegxl::ThreadParallelRunner* runner = + new (alloc) jpegxl::ThreadParallelRunner(num_worker_threads); + runner->memory_manager = local_memory_manager; + + return runner; +} + +void JxlThreadParallelRunnerDestroy(void* runner_opaque) { + jpegxl::ThreadParallelRunner* runner = + reinterpret_cast(runner_opaque); + if (runner) { + JxlMemoryManager local_memory_manager = runner->memory_manager; + // Call destructor directly since custom free function is used. + runner->~ThreadParallelRunner(); + ThreadMemoryManagerFree(&local_memory_manager, runner); + } +} + +// Get default value for num_worker_threads parameter of +// InitJxlThreadParallelRunner. +size_t JxlThreadParallelRunnerDefaultNumWorkerThreads() { + return std::thread::hardware_concurrency(); +} diff --git a/media/libjxl/src/lib/threads/thread_parallel_runner_internal.cc b/media/libjxl/src/lib/threads/thread_parallel_runner_internal.cc new file mode 100644 index 000000000..2b05ad992 --- /dev/null +++ b/media/libjxl/src/lib/threads/thread_parallel_runner_internal.cc @@ -0,0 +1,213 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/threads/thread_parallel_runner_internal.h" + +#include + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif // defined(*_SANITIZER) + +#include "jxl/thread_parallel_runner.h" +#include "lib/jxl/base/profiler.h" + +namespace { + +// Exits the program after printing a stack trace when possible. +bool Abort() { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + // If compiled with any sanitizer print a stack trace. This call doesn't crash + // the program, instead the trap below will crash it also allowing gdb to + // break there. + __sanitizer_print_stack_trace(); +#endif // defined(*_SANITIZER) + +#ifdef _MSC_VER + __debugbreak(); + abort(); +#else + __builtin_trap(); +#endif +} + +// Does not guarantee running the code, use only for debug mode checks. +#if JXL_ENABLE_ASSERT +#define JXL_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + Abort(); \ + } \ + } while (0) +#else +#define JXL_ASSERT(condition) \ + do { \ + } while (0) +#endif +} // namespace + +namespace jpegxl { + +// static +JxlParallelRetCode ThreadParallelRunner::Runner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + ThreadParallelRunner* self = + static_cast(runner_opaque); + if (start_range > end_range) return -1; + if (start_range == end_range) return 0; + + int ret = init(jpegxl_opaque, std::max(self->num_worker_threads_, 1)); + if (ret != 0) return ret; + + // Use a sequential run when num_worker_threads_ is zero since we have no + // worker threads. + if (self->num_worker_threads_ == 0) { + const size_t thread = 0; + for (uint32_t task = start_range; task < end_range; ++task) { + func(jpegxl_opaque, task, thread); + } + return 0; + } + + if (self->depth_.fetch_add(1, std::memory_order_acq_rel) != 0) { + return -1; // Must not re-enter. + } + + const WorkerCommand worker_command = + (static_cast(start_range) << 32) + end_range; + // Ensure the inputs do not result in a reserved command. + JXL_ASSERT(worker_command != kWorkerWait); + JXL_ASSERT(worker_command != kWorkerOnce); + JXL_ASSERT(worker_command != kWorkerExit); + + self->data_func_ = func; + self->jpegxl_opaque_ = jpegxl_opaque; + self->num_reserved_.store(0, std::memory_order_relaxed); + + self->StartWorkers(worker_command); + self->WorkersReadyBarrier(); + + if (self->depth_.fetch_add(-1, std::memory_order_acq_rel) != 1) { + return -1; + } + return 0; +} + +// static +void ThreadParallelRunner::RunRange(ThreadParallelRunner* self, + const WorkerCommand command, + const int thread) { + const uint32_t begin = command >> 32; + const uint32_t end = command & 0xFFFFFFFF; + const uint32_t num_tasks = end - begin; + const uint32_t num_worker_threads = self->num_worker_threads_; + + // OpenMP introduced several "schedule" strategies: + // "single" (static assignment of exactly one chunk per thread): slower. + // "dynamic" (allocates k tasks at a time): competitive for well-chosen k. + // "guided" (allocates k tasks, decreases k): computing k = remaining/n + // is faster than halving k each iteration. We prefer this strategy + // because it avoids user-specified parameters. + + for (;;) { +#if 0 + // dynamic + const uint32_t my_size = std::max(num_tasks / (num_worker_threads * 4), 1); +#else + // guided + const uint32_t num_reserved = + self->num_reserved_.load(std::memory_order_relaxed); + // It is possible that more tasks are reserved than ready to run. + const uint32_t num_remaining = + num_tasks - std::min(num_reserved, num_tasks); + const uint32_t my_size = + std::max(num_remaining / (num_worker_threads * 4), 1u); +#endif + const uint32_t my_begin = begin + self->num_reserved_.fetch_add( + my_size, std::memory_order_relaxed); + const uint32_t my_end = std::min(my_begin + my_size, begin + num_tasks); + // Another thread already reserved the last task. + if (my_begin >= my_end) { + break; + } + for (uint32_t task = my_begin; task < my_end; ++task) { + self->data_func_(self->jpegxl_opaque_, task, thread); + } + } +} + +// static +void ThreadParallelRunner::ThreadFunc(ThreadParallelRunner* self, + const int thread) { + // Until kWorkerExit command received: + for (;;) { + std::unique_lock lock(self->mutex_); + // Notify main thread that this thread is ready. + if (++self->workers_ready_ == self->num_threads_) { + self->workers_ready_cv_.notify_one(); + } + RESUME_WAIT: + // Wait for a command. + self->worker_start_cv_.wait(lock); + const WorkerCommand command = self->worker_start_command_; + switch (command) { + case kWorkerWait: // spurious wakeup: + goto RESUME_WAIT; // lock still held, avoid incrementing ready. + case kWorkerOnce: + lock.unlock(); + self->data_func_(self->jpegxl_opaque_, thread, thread); + break; + case kWorkerExit: + return; // exits thread + default: + lock.unlock(); + RunRange(self, command, thread); + break; + } + } +} + +ThreadParallelRunner::ThreadParallelRunner(const int num_worker_threads) + : num_worker_threads_(num_worker_threads), + num_threads_(std::max(num_worker_threads, 1)) { + PROFILER_ZONE("ThreadParallelRunner ctor"); + + threads_.reserve(num_worker_threads_); + + // Suppress "unused-private-field" warning. + (void)padding1; + (void)padding2; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + + for (uint32_t i = 0; i < num_worker_threads_; ++i) { + threads_.emplace_back(ThreadFunc, this, i); + } + + if (num_worker_threads_ != 0) { + WorkersReadyBarrier(); + } + + // Warm up profiler on worker threads so its expensive initialization + // doesn't count towards other timer measurements. + RunOnEachThread( + [](const int task, const int thread) { PROFILER_ZONE("@InitWorkers"); }); +} + +ThreadParallelRunner::~ThreadParallelRunner() { + if (num_worker_threads_ != 0) { + StartWorkers(kWorkerExit); + } + + for (std::thread& thread : threads_) { + JXL_ASSERT(thread.joinable()); + thread.join(); + } +} +} // namespace jpegxl diff --git a/media/libjxl/src/lib/threads/thread_parallel_runner_internal.h b/media/libjxl/src/lib/threads/thread_parallel_runner_internal.h new file mode 100644 index 000000000..372c6a895 --- /dev/null +++ b/media/libjxl/src/lib/threads/thread_parallel_runner_internal.h @@ -0,0 +1,172 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// + +// C++ implementation using std::thread of a ::JxlParallelRunner. + +// The main class in this module, ThreadParallelRunner, implements a static +// method ThreadParallelRunner::Runner than can be passed as a +// JxlParallelRunner when using the JPEG XL library. This uses std::thread +// internally and related synchronization functions. The number of threads +// created is fixed at construction time and the threads are re-used for every +// ThreadParallelRunner::Runner call. Only one concurrent Runner() call per +// instance is allowed at a time. +// +// This is a scalable, lower-overhead thread pool runner, especially suitable +// for data-parallel computations in the fork-join model, where clients need to +// know when all tasks have completed. +// +// This thread pool can efficiently load-balance millions of tasks using an +// atomic counter, thus avoiding per-task virtual or system calls. With 48 +// hyperthreads and 1M tasks that add to an atomic counter, overall runtime is +// 10-20x higher when using std::async, and ~200x for a queue-based thread +// pool. +// +// Usage: +// ThreadParallelRunner runner; +// JxlDecode( +// ... , &ThreadParallelRunner::Runner, static_cast(&runner)); + +#ifndef LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_ +#define LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_ + +#include +#include +#include + +#include +#include //NOLINT +#include //NOLINT +#include //NOLINT +#include + +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +namespace jpegxl { + +// Main helper class implementing the ::JxlParallelRunner interface. +class ThreadParallelRunner { + public: + // ::JxlParallelRunner interface. + static JxlParallelRetCode Runner(void* runner_opaque, void* jpegxl_opaque, + JxlParallelRunInit init, + JxlParallelRunFunction func, + uint32_t start_range, uint32_t end_range); + + // Starts the given number of worker threads and blocks until they are ready. + // "num_worker_threads" defaults to one per hyperthread. If zero, all tasks + // run on the main thread. + explicit ThreadParallelRunner( + int num_worker_threads = std::thread::hardware_concurrency()); + + // Waits for all threads to exit. + ~ThreadParallelRunner(); + + // Returns number of worker threads created (some may be sleeping and never + // wake up in time to participate in Run). Useful for characterizing + // performance; 0 means "run on main thread". + size_t NumWorkerThreads() const { return num_worker_threads_; } + + // Returns maximum number of main/worker threads that may call Func. Useful + // for allocating per-thread storage. + size_t NumThreads() const { return num_threads_; } + + // Runs func(thread, thread) on all thread(s) that may participate in Run. + // If NumThreads() == 0, runs on the main thread with thread == 0, otherwise + // concurrently called by each worker thread in [0, NumThreads()). + template + void RunOnEachThread(const Func& func) { + if (num_worker_threads_ == 0) { + const int thread = 0; + func(thread, thread); + return; + } + + data_func_ = reinterpret_cast(&CallClosure); + jpegxl_opaque_ = const_cast(static_cast(&func)); + StartWorkers(kWorkerOnce); + WorkersReadyBarrier(); + } + + JxlMemoryManager memory_manager; + + private: + // After construction and between calls to Run, workers are "ready", i.e. + // waiting on worker_start_cv_. They are "started" by sending a "command" + // and notifying all worker_start_cv_ waiters. (That is why all workers + // must be ready/waiting - otherwise, the notification will not reach all of + // them and the main thread waits in vain for them to report readiness.) + using WorkerCommand = uint64_t; + + // Special values; all others encode the begin/end parameters. Note that all + // these are no-op ranges (begin >= end) and therefore never used to encode + // ranges. + static constexpr WorkerCommand kWorkerWait = ~1ULL; + static constexpr WorkerCommand kWorkerOnce = ~2ULL; + static constexpr WorkerCommand kWorkerExit = ~3ULL; + + // Calls f(task, thread). Used for type erasure of Func arguments. The + // signature must match JxlParallelRunFunction, hence a void* argument. + template + static void CallClosure(void* f, const uint32_t task, const size_t thread) { + (*reinterpret_cast(f))(task, thread); + } + + void WorkersReadyBarrier() { + std::unique_lock lock(mutex_); + // Typically only a single iteration. + while (workers_ready_ != threads_.size()) { + workers_ready_cv_.wait(lock); + } + workers_ready_ = 0; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + } + + // Precondition: all workers are ready. + void StartWorkers(const WorkerCommand worker_command) { + mutex_.lock(); + worker_start_command_ = worker_command; + // Workers will need this lock, so release it before they wake up. + mutex_.unlock(); + worker_start_cv_.notify_all(); + } + + // Attempts to reserve and perform some work from the global range of tasks, + // which is encoded within "command". Returns after all tasks are reserved. + static void RunRange(ThreadParallelRunner* self, const WorkerCommand command, + const int thread); + + static void ThreadFunc(ThreadParallelRunner* self, int thread); + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + const uint32_t num_worker_threads_; // == threads_.size() + const uint32_t num_threads_; + + std::atomic depth_{0}; // detects if Run is re-entered (not supported). + + std::mutex mutex_; // guards both cv and their variables. + std::condition_variable workers_ready_cv_; + uint32_t workers_ready_ = 0; + std::condition_variable worker_start_cv_; + WorkerCommand worker_start_command_; + + // Written by main thread, read by workers (after mutex lock/unlock). + JxlParallelRunFunction data_func_; + void* jpegxl_opaque_; + + // Updated by workers; padding avoids false sharing. + uint8_t padding1[64]; + std::atomic num_reserved_{0}; + uint8_t padding2[64]; +}; + +} // namespace jpegxl + +#endif // LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_ diff --git a/media/libjxl/src/lib/threads/thread_parallel_runner_test.cc b/media/libjxl/src/lib/threads/thread_parallel_runner_test.cc new file mode 100644 index 000000000..2293b5ceb --- /dev/null +++ b/media/libjxl/src/lib/threads/thread_parallel_runner_test.cc @@ -0,0 +1,121 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" + +namespace jpegxl { +namespace { + +int PopulationCount(uint64_t bits) { + int num_set = 0; + while (bits != 0) { + num_set += bits & 1; + bits >>= 1; + } + return num_set; +} + +// Ensures task parameter is in bounds, every parameter is reached, +// pool can be reused (multiple consecutive Run calls), pool can be destroyed +// (joining with its threads), num_threads=0 works (runs on current thread). +TEST(ThreadParallelRunnerTest, TestPool) { + for (int num_threads = 0; num_threads <= 18; ++num_threads) { + jxl::ThreadPoolInternal pool(num_threads); + for (int num_tasks = 0; num_tasks < 32; ++num_tasks) { + std::vector mementos(num_tasks); + for (int begin = 0; begin < 32; ++begin) { + std::fill(mementos.begin(), mementos.end(), 0); + EXPECT_TRUE(RunOnPool( + &pool, begin, begin + num_tasks, jxl::ThreadPool::NoInit, + [begin, num_tasks, &mementos](const int task, const int thread) { + // Parameter is in the given range + EXPECT_GE(task, begin); + EXPECT_LT(task, begin + num_tasks); + + // Store mementos to be sure we visited each task. + mementos.at(task - begin) = 1000 + task; + }, + "TestPool")); + for (int task = begin; task < begin + num_tasks; ++task) { + EXPECT_EQ(1000 + task, mementos.at(task - begin)); + } + } + } + } +} + +// Verify "thread" parameter when processing few tasks. +TEST(ThreadParallelRunnerTest, TestSmallAssignments) { + // WARNING: cumulative total threads must not exceed profiler.h kMaxThreads. + const int kMaxThreads = 8; + for (int num_threads = 1; num_threads <= kMaxThreads; ++num_threads) { + jxl::ThreadPoolInternal pool(num_threads); + + // (Avoid mutex because it may perturb the worker thread scheduling) + std::atomic id_bits{0}; + std::atomic num_calls{0}; + + EXPECT_TRUE(RunOnPool( + &pool, 0, num_threads, jxl::ThreadPool::NoInit, + [&num_calls, num_threads, &id_bits](const int task, const int thread) { + num_calls.fetch_add(1, std::memory_order_relaxed); + + EXPECT_LT(thread, num_threads); + uint64_t bits = id_bits.load(std::memory_order_relaxed); + while ( + !id_bits.compare_exchange_weak(bits, bits | (1ULL << thread))) { + } + }, + "TestSmallAssignments")); + + // Correct number of tasks. + EXPECT_EQ(num_threads, num_calls.load()); + + const int num_participants = PopulationCount(id_bits.load()); + // Can't expect equality because other workers may have woken up too late. + EXPECT_LE(num_participants, num_threads); + } +} + +struct Counter { + Counter() { + // Suppress "unused-field" warning. + (void)padding; + } + void Assimilate(const Counter& victim) { counter += victim.counter; } + int counter = 0; + int padding[31]; +}; + +TEST(ThreadParallelRunnerTest, TestCounter) { + const int kNumThreads = 12; + jxl::ThreadPoolInternal pool(kNumThreads); + alignas(128) Counter counters[kNumThreads]; + + const int kNumTasks = kNumThreads * 19; + EXPECT_TRUE(RunOnPool( + &pool, 0, kNumTasks, jxl::ThreadPool::NoInit, + [&counters](const int task, const int thread) { + counters[thread].counter += task; + }, + "TestCounter")); + + int expected = 0; + for (int i = 0; i < kNumTasks; ++i) { + expected += i; + } + + for (int i = 1; i < kNumThreads; ++i) { + counters[0].Assimilate(counters[i]); + } + EXPECT_EQ(expected, counters[0].counter); +} + +} // namespace +} // namespace jpegxl diff --git a/media/libjxl/src/plugins/CMakeLists.txt b/media/libjxl/src/plugins/CMakeLists.txt new file mode 100644 index 000000000..bff1bff29 --- /dev/null +++ b/media/libjxl/src/plugins/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +if(NOT MSVC) + option(JPEGXL_ENABLE_PLUGIN_GDKPIXBUF "Enable plugin for GdkPixbuf image loading library" ON) + if(JPEGXL_ENABLE_PLUGIN_GDKPIXBUF) + add_subdirectory(gdk-pixbuf) + endif() +endif() + +option(JPEGXL_ENABLE_PLUGIN_GIMP210 "Enable plugin for GIMP 2.10.x series" ON) +if(JPEGXL_ENABLE_PLUGIN_GIMP210) + add_subdirectory(gimp) +endif() + +option(JPEGXL_ENABLE_PLUGIN_MIME "Enable image/jxl declaration for shared-mime-info" ON) +if(JPEGXL_ENABLE_PLUGIN_MIME) + add_subdirectory(mime) +endif() diff --git a/media/libjxl/src/plugins/gdk-pixbuf/CMakeLists.txt b/media/libjxl/src/plugins/gdk-pixbuf/CMakeLists.txt new file mode 100644 index 000000000..e56d312b7 --- /dev/null +++ b/media/libjxl/src/plugins/gdk-pixbuf/CMakeLists.txt @@ -0,0 +1,80 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +find_package(PkgConfig) +pkg_check_modules(Gdk-Pixbuf IMPORTED_TARGET gdk-pixbuf-2.0>=2.36) + +if (NOT Gdk-Pixbuf_FOUND) + message(WARNING "GDK Pixbuf development libraries not found, \ + the Gdk-Pixbuf plugin will not be built") + return () +endif () + +add_library(pixbufloader-jxl SHARED pixbufloader-jxl.c) + +# Mark all symbols as hidden by default. The PkgConfig::Gdk-Pixbuf dependency +# will cause fill_info and fill_vtable entry points to be made public. +set_target_properties(pixbufloader-jxl PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 +) + +# Note: This only needs the decoder library, but we don't install the decoder +# shared library. +target_link_libraries(pixbufloader-jxl jxl jxl_threads skcms-interface PkgConfig::Gdk-Pixbuf) + +execute_process(COMMAND ${PKG_CONFIG_EXECUTABLE} gdk-pixbuf-2.0 --variable gdk_pixbuf_moduledir --define-variable=prefix=${CMAKE_INSTALL_PREFIX} OUTPUT_VARIABLE GDK_PIXBUF_MODULEDIR OUTPUT_STRIP_TRAILING_WHITESPACE) +install(TARGETS pixbufloader-jxl LIBRARY DESTINATION "${GDK_PIXBUF_MODULEDIR}") + +# Instead of the following, we might instead add the +# mime type image/jxl to +# /usr/share/thumbnailers/gdk-pixbuf-thumbnailer.thumbnailer +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/jxl.thumbnailer DESTINATION share/thumbnailers/) + +if(BUILD_TESTING AND NOT CMAKE_CROSSCOMPILING) + pkg_check_modules(Gdk IMPORTED_TARGET gdk-2.0) + if (Gdk_FOUND) + # Test for loading a .jxl file using the pixbufloader library via GDK. This + # requires to have the image/jxl mime type and loader library configured, + # which we do in a fake environment in the CMAKE_CURRENT_BINARY_DIR. + add_executable(pixbufloader_test pixbufloader_test.cc) + target_link_libraries(pixbufloader_test PkgConfig::Gdk) + + # Create a mime cache for test. + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/mime/mime.cache" + COMMAND env XDG_DATA_HOME=${CMAKE_CURRENT_BINARY_DIR} + xdg-mime install --novendor + "${CMAKE_SOURCE_DIR}/plugins/mime/image-jxl.xml" + DEPENDS "${CMAKE_SOURCE_DIR}/plugins/mime/image-jxl.xml" + ) + add_custom_target(pixbufloader_test_mime + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/mime/mime.cache" + ) + add_dependencies(pixbufloader_test pixbufloader_test_mime) + + # Use a fake X server to run the test if xvfb is installed. + find_program (XVFB_PROGRAM xvfb-run) + if(XVFB_PROGRAM) + set(XVFB_PROGRAM_PREFIX "${XVFB_PROGRAM};-a") + else() + set(XVFB_PROGRAM_PREFIX "") + endif() + + # libX11.so and libgdk-x11-2.0.so are not compiled with MSAN -> report + # use-of-uninitialized-value for string some internal string value. + if (NOT (SANITIZER STREQUAL "msan")) + add_test( + NAME pixbufloader_test_jxl + COMMAND + ${XVFB_PROGRAM_PREFIX} $ + "${CMAKE_CURRENT_SOURCE_DIR}/loaders_test.cache" + "${CMAKE_SOURCE_DIR}/testdata/jxl/blending/cropped_traffic_light.jxl" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + ) + set_tests_properties(pixbufloader_test_jxl PROPERTIES SKIP_RETURN_CODE 254) + endif() + endif() # Gdk_FOUND +endif() # BUILD_TESTING diff --git a/media/libjxl/src/plugins/gdk-pixbuf/README.md b/media/libjxl/src/plugins/gdk-pixbuf/README.md new file mode 100644 index 000000000..f7174baf3 --- /dev/null +++ b/media/libjxl/src/plugins/gdk-pixbuf/README.md @@ -0,0 +1,50 @@ +## JPEG XL GDK Pixbuf + + +The plugin may already have been installed when following the instructions from the +[Installing section of README.md](../../README.md#installing), in which case it should +already be in the correct place, e.g. + +```/usr/lib/x86_64-linux-gnu/gdk-pixbuf-2.0/2.10.0/loaders/libpixbufloader-jxl.so``` + +Otherwise we can copy it manually: + +```bash +sudo cp $your_build_directory/plugins/gdk-pixbuf/libpixbufloader-jxl.so /usr/lib/x86_64-linux-gnu/gdk-pixbuf-2.0/2.10.0/loaders/libpixbufloader-jxl.so +``` + + +Then we need to update the cache, for example with: + +```bash +sudo /usr/lib/x86_64-linux-gnu/gdk-pixbuf-2.0/gdk-pixbuf-query-loaders --update-cache +``` + +In order to get thumbnails with this, first one has to add the jxl MIME type, see +[../mime/README.md](../mime/README.md). + +Ensure that the thumbnailer file is installed in the correct place, +`/usr/share/thumbnailers/jxl.thumbnailer` or `/usr/local/share/thumbnailers/jxl.thumbnailer`. + +The file should have been copied automatically when following the instructions +in the [Installing section of README.md](../../README.md#installing), but +otherwise it can be copied manually: + +```bash +sudo cp plugins/gdk-pixbuf/jxl.thumbnailer /usr/local/share/thumbnailers/jxl.thumbnailer +``` + +Update the Mime database with +```bash +update-mime --local +``` +or +```bash +sudo update-desktop-database +``` + +Then possibly delete the thumbnail cache with +```bash +rm -r ~/.cache/thumbnails +``` +and restart the application displaying thumbnails, e.g. `nautilus -q` to display thumbnails. diff --git a/media/libjxl/src/plugins/gdk-pixbuf/jxl.thumbnailer b/media/libjxl/src/plugins/gdk-pixbuf/jxl.thumbnailer new file mode 100644 index 000000000..1bcaab61f --- /dev/null +++ b/media/libjxl/src/plugins/gdk-pixbuf/jxl.thumbnailer @@ -0,0 +1,4 @@ +[Thumbnailer Entry] +TryExec=/usr/bin/gdk-pixbuf-thumbnailer +Exec=/usr/bin/gdk-pixbuf-thumbnailer -s %s %u %o +MimeType=image/jxl; diff --git a/media/libjxl/src/plugins/gdk-pixbuf/loaders_test.cache b/media/libjxl/src/plugins/gdk-pixbuf/loaders_test.cache new file mode 100644 index 000000000..95c62c8fc --- /dev/null +++ b/media/libjxl/src/plugins/gdk-pixbuf/loaders_test.cache @@ -0,0 +1,16 @@ +# GdkPixbuf Image Loader Modules file for testing +# Automatically generated file, do not edit +# Created by gdk-pixbuf-query-loaders from gdk-pixbuf-2.42.2 +# +# Generated with: +# GDK_PIXBUF_MODULEDIR=`pwd`/build/plugins/gdk-pixbuf/ gdk-pixbuf-query-loaders +# +# Modified to use the library from the current working directory at runtime. +"./libpixbufloader-jxl.so" +"jxl" 4 "gdk-pixbuf" "JPEG XL image" "BSD-3" +"image/jxl" "" +"jxl" "" +"\377\n" " " 100 +"...\fJXL \r\n\207\n" "zzz " 100 + + diff --git a/media/libjxl/src/plugins/gdk-pixbuf/pixbufloader-jxl.c b/media/libjxl/src/plugins/gdk-pixbuf/pixbufloader-jxl.c new file mode 100644 index 000000000..24bbcf8cf --- /dev/null +++ b/media/libjxl/src/plugins/gdk-pixbuf/pixbufloader-jxl.c @@ -0,0 +1,569 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "jxl/codestream_header.h" +#include "jxl/decode.h" +#include "jxl/resizable_parallel_runner.h" +#include "jxl/types.h" +#include "skcms.h" + +#define GDK_PIXBUF_ENABLE_BACKEND +#include +#undef GDK_PIXBUF_ENABLE_BACKEND + +G_BEGIN_DECLS + +// Information about a single frame. +typedef struct { + uint64_t duration_ms; + GdkPixbuf *data; + gboolean decoded; +} GdkPixbufJxlAnimationFrame; + +// Represent a whole JPEG XL animation; all its fields are owned; as a GObject, +// the Animation struct itself is reference counted (as are the GdkPixbufs for +// individual frames). +struct _GdkPixbufJxlAnimation { + GdkPixbufAnimation parent_instance; + + // GDK interface implementation callbacks. + GdkPixbufModuleSizeFunc image_size_callback; + GdkPixbufModulePreparedFunc pixbuf_prepared_callback; + GdkPixbufModuleUpdatedFunc area_updated_callback; + gpointer user_data; + + // All frames known so far; a frame is added when the JXL_DEC_FRAME event is + // received from the decoder; initially frame.decoded is FALSE, until + // the JXL_DEC_IMAGE event is received. + GArray *frames; + + // JPEG XL decoder and related structures. + JxlParallelRunner *parallel_runner; + JxlDecoder *decoder; + JxlPixelFormat pixel_format; + + // Decoding is `done` when JXL_DEC_SUCCESS is received; calling + // load_increment afterwards gives an error. + gboolean done; + + // Image information. + size_t xsize; + size_t ysize; + gboolean alpha_premultiplied; + gboolean has_animation; + gboolean has_alpha; + uint64_t total_duration_ms; + uint64_t tick_duration_us; + uint64_t repetition_count; // 0 = loop forever + + // ICC profile, to which `icc` might refer to. + gpointer icc_buff; + skcms_ICCProfile icc; +}; + +#define GDK_TYPE_PIXBUF_JXL_ANIMATION (gdk_pixbuf_jxl_animation_get_type()) +G_DECLARE_FINAL_TYPE(GdkPixbufJxlAnimation, gdk_pixbuf_jxl_animation, GDK, + JXL_ANIMATION, GdkPixbufAnimation); + +G_DEFINE_TYPE(GdkPixbufJxlAnimation, gdk_pixbuf_jxl_animation, + GDK_TYPE_PIXBUF_ANIMATION); + +// Iterator to a given point in time in the animation; contains a pointer to the +// full animation. +struct _GdkPixbufJxlAnimationIter { + GdkPixbufAnimationIter parent_instance; + GdkPixbufJxlAnimation *animation; + size_t current_frame; + uint64_t time_offset; +}; + +#define GDK_TYPE_PIXBUF_JXL_ANIMATION_ITER \ + (gdk_pixbuf_jxl_animation_iter_get_type()) +G_DECLARE_FINAL_TYPE(GdkPixbufJxlAnimationIter, gdk_pixbuf_jxl_animation_iter, + GDK, JXL_ANIMATION_ITER, GdkPixbufAnimationIter); +G_DEFINE_TYPE(GdkPixbufJxlAnimationIter, gdk_pixbuf_jxl_animation_iter, + GDK_TYPE_PIXBUF_ANIMATION_ITER); + +static void gdk_pixbuf_jxl_animation_init(GdkPixbufJxlAnimation *obj) { + // Suppress "unused function" warnings. + (void)glib_autoptr_cleanup_GdkPixbufJxlAnimation; + (void)GDK_JXL_ANIMATION; + (void)GDK_IS_JXL_ANIMATION; +} + +static gboolean gdk_pixbuf_jxl_animation_is_static_image( + GdkPixbufAnimation *anim) { + GdkPixbufJxlAnimation *jxl_anim = (GdkPixbufJxlAnimation *)anim; + return !jxl_anim->has_animation; +} + +static GdkPixbuf *gdk_pixbuf_jxl_animation_get_static_image( + GdkPixbufAnimation *anim) { + GdkPixbufJxlAnimation *jxl_anim = (GdkPixbufJxlAnimation *)anim; + if (jxl_anim->frames == NULL || jxl_anim->frames->len == 0) return NULL; + GdkPixbufJxlAnimationFrame *frame = + &g_array_index(jxl_anim->frames, GdkPixbufJxlAnimationFrame, 0); + return frame->decoded ? frame->data : NULL; +} + +static void gdk_pixbuf_jxl_animation_get_size(GdkPixbufAnimation *anim, + int *width, int *height) { + GdkPixbufJxlAnimation *jxl_anim = (GdkPixbufJxlAnimation *)anim; + if (width) *width = jxl_anim->xsize; + if (height) *height = jxl_anim->ysize; +} + +G_GNUC_BEGIN_IGNORE_DEPRECATIONS +static gboolean gdk_pixbuf_jxl_animation_iter_advance( + GdkPixbufAnimationIter *iter, const GTimeVal *current_time); + +static GdkPixbufAnimationIter *gdk_pixbuf_jxl_animation_get_iter( + GdkPixbufAnimation *anim, const GTimeVal *start_time) { + GdkPixbufJxlAnimationIter *iter = + g_object_new(GDK_TYPE_PIXBUF_JXL_ANIMATION_ITER, NULL); + iter->animation = (GdkPixbufJxlAnimation *)anim; + iter->time_offset = start_time->tv_sec * 1000ULL + start_time->tv_usec / 1000; + g_object_ref(iter->animation); + gdk_pixbuf_jxl_animation_iter_advance((GdkPixbufAnimationIter *)iter, + start_time); + return (GdkPixbufAnimationIter *)iter; +} +G_GNUC_END_IGNORE_DEPRECATIONS + +static void gdk_pixbuf_jxl_animation_finalize(GObject *obj) { + GdkPixbufJxlAnimation *decoder_state = (GdkPixbufJxlAnimation *)obj; + if (decoder_state->frames != NULL) { + for (size_t i = 0; i < decoder_state->frames->len; i++) { + g_object_unref( + g_array_index(decoder_state->frames, GdkPixbufJxlAnimationFrame, i) + .data); + } + g_array_free(decoder_state->frames, /*free_segment=*/TRUE); + } + JxlResizableParallelRunnerDestroy(decoder_state->parallel_runner); + JxlDecoderDestroy(decoder_state->decoder); + g_free(decoder_state->icc_buff); +} + +static void gdk_pixbuf_jxl_animation_class_init( + GdkPixbufJxlAnimationClass *klass) { + G_OBJECT_CLASS(klass)->finalize = gdk_pixbuf_jxl_animation_finalize; + klass->parent_class.is_static_image = + gdk_pixbuf_jxl_animation_is_static_image; + klass->parent_class.get_static_image = + gdk_pixbuf_jxl_animation_get_static_image; + klass->parent_class.get_size = gdk_pixbuf_jxl_animation_get_size; + klass->parent_class.get_iter = gdk_pixbuf_jxl_animation_get_iter; +} + +static void gdk_pixbuf_jxl_animation_iter_init(GdkPixbufJxlAnimationIter *obj) { + (void)glib_autoptr_cleanup_GdkPixbufJxlAnimationIter; + (void)GDK_JXL_ANIMATION_ITER; + (void)GDK_IS_JXL_ANIMATION_ITER; +} + +static int gdk_pixbuf_jxl_animation_iter_get_delay_time( + GdkPixbufAnimationIter *iter) { + GdkPixbufJxlAnimationIter *jxl_iter = (GdkPixbufJxlAnimationIter *)iter; + if (jxl_iter->animation->frames->len <= jxl_iter->current_frame) { + return 0; + } + return g_array_index(jxl_iter->animation->frames, GdkPixbufJxlAnimationFrame, + jxl_iter->current_frame) + .duration_ms; +} + +static GdkPixbuf *gdk_pixbuf_jxl_animation_iter_get_pixbuf( + GdkPixbufAnimationIter *iter) { + GdkPixbufJxlAnimationIter *jxl_iter = (GdkPixbufJxlAnimationIter *)iter; + if (jxl_iter->animation->frames->len <= jxl_iter->current_frame) { + return NULL; + } + return g_array_index(jxl_iter->animation->frames, GdkPixbufJxlAnimationFrame, + jxl_iter->current_frame) + .data; +} + +static gboolean gdk_pixbuf_jxl_animation_iter_on_currently_loading_frame( + GdkPixbufAnimationIter *iter) { + GdkPixbufJxlAnimationIter *jxl_iter = (GdkPixbufJxlAnimationIter *)iter; + if (jxl_iter->animation->frames->len <= jxl_iter->current_frame) { + return TRUE; + } + return !g_array_index(jxl_iter->animation->frames, GdkPixbufJxlAnimationFrame, + jxl_iter->current_frame) + .decoded; +} + +G_GNUC_BEGIN_IGNORE_DEPRECATIONS +static gboolean gdk_pixbuf_jxl_animation_iter_advance( + GdkPixbufAnimationIter *iter, const GTimeVal *current_time) { + GdkPixbufJxlAnimationIter *jxl_iter = (GdkPixbufJxlAnimationIter *)iter; + size_t old_frame = jxl_iter->current_frame; + + uint64_t current_time_ms = current_time->tv_sec * 1000ULL + + current_time->tv_usec / 1000 - + jxl_iter->time_offset; + + if (jxl_iter->animation->frames->len == 0) { + jxl_iter->current_frame = 0; + } else if (!jxl_iter->animation->done && + current_time_ms >= jxl_iter->animation->total_duration_ms) { + jxl_iter->current_frame = jxl_iter->animation->frames->len - 1; + } else if (jxl_iter->animation->repetition_count != 0 && + current_time_ms > jxl_iter->animation->repetition_count * + jxl_iter->animation->total_duration_ms) { + jxl_iter->current_frame = jxl_iter->animation->frames->len - 1; + } else { + uint64_t total_duration_ms = jxl_iter->animation->total_duration_ms; + // Guard against divide-by-0 in malicious files. + if (total_duration_ms == 0) total_duration_ms = 1; + uint64_t loop_offset = current_time_ms % total_duration_ms; + jxl_iter->current_frame = 0; + while (true) { + uint64_t duration = + g_array_index(jxl_iter->animation->frames, GdkPixbufJxlAnimationFrame, + jxl_iter->current_frame) + .duration_ms; + if (duration >= loop_offset) { + break; + } + loop_offset -= duration; + jxl_iter->current_frame++; + } + } + + return old_frame != jxl_iter->current_frame; +} +G_GNUC_END_IGNORE_DEPRECATIONS + +static void gdk_pixbuf_jxl_animation_iter_finalize(GObject *obj) { + GdkPixbufJxlAnimationIter *iter = (GdkPixbufJxlAnimationIter *)obj; + g_object_unref(iter->animation); +} + +static void gdk_pixbuf_jxl_animation_iter_class_init( + GdkPixbufJxlAnimationIterClass *klass) { + G_OBJECT_CLASS(klass)->finalize = gdk_pixbuf_jxl_animation_iter_finalize; + klass->parent_class.get_delay_time = + gdk_pixbuf_jxl_animation_iter_get_delay_time; + klass->parent_class.get_pixbuf = gdk_pixbuf_jxl_animation_iter_get_pixbuf; + klass->parent_class.on_currently_loading_frame = + gdk_pixbuf_jxl_animation_iter_on_currently_loading_frame; + klass->parent_class.advance = gdk_pixbuf_jxl_animation_iter_advance; +} + +G_END_DECLS + +static gpointer begin_load(GdkPixbufModuleSizeFunc size_func, + GdkPixbufModulePreparedFunc prepare_func, + GdkPixbufModuleUpdatedFunc update_func, + gpointer user_data, GError **error) { + GdkPixbufJxlAnimation *decoder_state = + g_object_new(GDK_TYPE_PIXBUF_JXL_ANIMATION, NULL); + if (decoder_state == NULL) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Creation of the animation state failed"); + return NULL; + } + decoder_state->image_size_callback = size_func; + decoder_state->pixbuf_prepared_callback = prepare_func; + decoder_state->area_updated_callback = update_func; + decoder_state->user_data = user_data; + decoder_state->frames = + g_array_new(/*zero_terminated=*/FALSE, /*clear_=*/TRUE, + sizeof(GdkPixbufJxlAnimationFrame)); + + if (decoder_state->frames == NULL) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Creation of the frame array failed"); + goto cleanup; + } + + if (!(decoder_state->parallel_runner = + JxlResizableParallelRunnerCreate(NULL))) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Creation of the JXL parallel runner failed"); + goto cleanup; + } + + if (!(decoder_state->decoder = JxlDecoderCreate(NULL))) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Creation of the JXL decoder failed"); + goto cleanup; + } + + JxlDecoderStatus status; + + if ((status = JxlDecoderSetParallelRunner( + decoder_state->decoder, JxlResizableParallelRunner, + decoder_state->parallel_runner)) != JXL_DEC_SUCCESS) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JxlDecoderSetParallelRunner failed: %x", status); + goto cleanup; + } + if ((status = JxlDecoderSubscribeEvents( + decoder_state->decoder, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE | JXL_DEC_FRAME)) != + JXL_DEC_SUCCESS) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JxlDecoderSubscribeEvents failed: %x", status); + goto cleanup; + } + + decoder_state->pixel_format.data_type = JXL_TYPE_FLOAT; + decoder_state->pixel_format.endianness = JXL_NATIVE_ENDIAN; + + return decoder_state; +cleanup: + JxlResizableParallelRunnerDestroy(decoder_state->parallel_runner); + JxlDecoderDestroy(decoder_state->decoder); + g_object_unref(decoder_state); + return NULL; +} + +static gboolean stop_load(gpointer context, GError **error) { + g_object_unref(context); + return TRUE; +} + +static void draw_pixels(void *context, size_t x, size_t y, size_t num_pixels, + const void *pixels) { + GdkPixbufJxlAnimation *decoder_state = context; + gboolean has_alpha = decoder_state->pixel_format.num_channels == 4; + + GdkPixbuf *output = + g_array_index(decoder_state->frames, GdkPixbufJxlAnimationFrame, + decoder_state->frames->len - 1) + .data; + + guchar *dst = gdk_pixbuf_get_pixels(output) + + decoder_state->pixel_format.num_channels * x + + gdk_pixbuf_get_rowstride(output) * y; + + skcms_Transform( + pixels, + has_alpha ? skcms_PixelFormat_RGBA_ffff : skcms_PixelFormat_RGB_fff, + decoder_state->alpha_premultiplied ? skcms_AlphaFormat_PremulAsEncoded + : skcms_AlphaFormat_Unpremul, + &decoder_state->icc, dst, + has_alpha ? skcms_PixelFormat_RGBA_8888 : skcms_PixelFormat_RGB_888, + skcms_AlphaFormat_Unpremul, skcms_sRGB_profile(), num_pixels); +} + +static gboolean load_increment(gpointer context, const guchar *buf, guint size, + GError **error) { + GdkPixbufJxlAnimation *decoder_state = context; + if (decoder_state->done == TRUE) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JXL decoder load_increment called after end of file"); + return FALSE; + } + + JxlDecoderStatus status; + + if ((status = JxlDecoderSetInput(decoder_state->decoder, buf, size)) != + JXL_DEC_SUCCESS) { + // Should never happen if things are done properly. + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JXL decoder logic error: %x", status); + return FALSE; + } + + for (;;) { + status = JxlDecoderProcessInput(decoder_state->decoder); + switch (status) { + case JXL_DEC_NEED_MORE_INPUT: { + JxlDecoderReleaseInput(decoder_state->decoder); + return TRUE; + } + + case JXL_DEC_BASIC_INFO: { + JxlBasicInfo info; + if (JxlDecoderGetBasicInfo(decoder_state->decoder, &info) != + JXL_DEC_SUCCESS) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JXLDecoderGetBasicInfo failed"); + return FALSE; + } + decoder_state->pixel_format.num_channels = info.alpha_bits > 0 ? 4 : 3; + decoder_state->alpha_premultiplied = info.alpha_premultiplied; + decoder_state->xsize = info.xsize; + decoder_state->ysize = info.ysize; + decoder_state->has_animation = info.have_animation; + decoder_state->has_alpha = info.alpha_bits > 0; + if (info.have_animation) { + decoder_state->repetition_count = info.animation.num_loops; + decoder_state->tick_duration_us = 1000000ULL * + info.animation.tps_denominator / + info.animation.tps_numerator; + } + gint width = info.xsize; + gint height = info.ysize; + if (decoder_state->image_size_callback) { + decoder_state->image_size_callback(&width, &height, + decoder_state->user_data); + } + + // GDK convention for signaling being interested only in the basic info. + if (width == 0 || height == 0) { + decoder_state->done = TRUE; + return TRUE; + } + + // Set an appropriate number of threads for the image size. + JxlResizableParallelRunnerSetThreads( + decoder_state->parallel_runner, + JxlResizableParallelRunnerSuggestThreads(info.xsize, info.ysize)); + break; + } + + case JXL_DEC_COLOR_ENCODING: { + // Get the ICC color profile of the pixel data + size_t icc_size; + if (JXL_DEC_SUCCESS != JxlDecoderGetICCProfileSize( + decoder_state->decoder, + &decoder_state->pixel_format, + JXL_COLOR_PROFILE_TARGET_DATA, &icc_size)) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JxlDecoderGetICCProfileSize failed"); + return FALSE; + } + if (!(decoder_state->icc_buff = g_malloc(icc_size))) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Allocating ICC profile failed"); + return FALSE; + } + if (JXL_DEC_SUCCESS != + JxlDecoderGetColorAsICCProfile(decoder_state->decoder, + &decoder_state->pixel_format, + JXL_COLOR_PROFILE_TARGET_DATA, + decoder_state->icc_buff, icc_size)) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JxlDecoderGetColorAsICCProfile failed"); + return FALSE; + } + if (!skcms_Parse(decoder_state->icc_buff, icc_size, + &decoder_state->icc)) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Invalid ICC profile from JXL image decoder"); + return FALSE; + } + break; + } + + case JXL_DEC_FRAME: { + // TODO(veluca): support rescaling. + JxlFrameHeader frame_header; + if (JxlDecoderGetFrameHeader(decoder_state->decoder, &frame_header) != + JXL_DEC_SUCCESS) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to retrieve frame info"); + return FALSE; + } + + { + GdkPixbufJxlAnimationFrame frame; + frame.decoded = FALSE; + frame.duration_ms = + frame_header.duration * decoder_state->tick_duration_us / 1000; + decoder_state->total_duration_ms += frame.duration_ms; + frame.data = + gdk_pixbuf_new(GDK_COLORSPACE_RGB, decoder_state->has_alpha, + /*bits_per_sample=*/8, decoder_state->xsize, + decoder_state->ysize); + if (frame.data == NULL) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to allocate output pixel buffer"); + return FALSE; + } + decoder_state->pixel_format.align = + gdk_pixbuf_get_rowstride(frame.data); + g_array_append_val(decoder_state->frames, frame); + } + if (decoder_state->pixbuf_prepared_callback && + decoder_state->frames->len == 1) { + decoder_state->pixbuf_prepared_callback( + g_array_index(decoder_state->frames, GdkPixbufJxlAnimationFrame, + 0) + .data, + decoder_state->has_animation ? (GdkPixbufAnimation *)decoder_state + : NULL, + decoder_state->user_data); + } + break; + } + + case JXL_DEC_NEED_IMAGE_OUT_BUFFER: { + if (JXL_DEC_SUCCESS != + JxlDecoderSetImageOutCallback(decoder_state->decoder, + &decoder_state->pixel_format, + draw_pixels, decoder_state)) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "JxlDecoderSetImageOutCallback failed"); + return FALSE; + } + break; + } + + case JXL_DEC_FULL_IMAGE: { + // TODO(veluca): consider doing partial updates. + if (decoder_state->area_updated_callback) { + GdkPixbuf *output = g_array_index(decoder_state->frames, + GdkPixbufJxlAnimationFrame, 0) + .data; + decoder_state->area_updated_callback( + output, 0, 0, gdk_pixbuf_get_width(output), + gdk_pixbuf_get_height(output), decoder_state->user_data); + } + g_array_index(decoder_state->frames, GdkPixbufJxlAnimationFrame, + decoder_state->frames->len - 1) + .decoded = TRUE; + break; + } + + case JXL_DEC_SUCCESS: { + decoder_state->done = TRUE; + return TRUE; + } + + default: { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Unexpected JxlDecoderProcessInput return code: %x", + status); + return FALSE; + } + } + } + return TRUE; +} + +void fill_vtable(GdkPixbufModule *module) { + module->begin_load = begin_load; + module->stop_load = stop_load; + module->load_increment = load_increment; + // TODO(veluca): implement saving. +} + +void fill_info(GdkPixbufFormat *info) { + static GdkPixbufModulePattern signature[] = { + {"\xFF\x0A", " ", 100}, + {"...\x0CJXL \x0D\x0A\x87\x0A", "zzz ", 100}, + {NULL, NULL, 0}, + }; + + static gchar *mime_types[] = {"image/jxl", NULL}; + + static gchar *extensions[] = {"jxl", NULL}; + + info->name = "jxl"; + info->signature = signature; + info->description = "JPEG XL image"; + info->mime_types = mime_types; + info->extensions = extensions; + // TODO(veluca): add writing support. + info->flags = GDK_PIXBUF_FORMAT_THREADSAFE; + info->license = "BSD-3"; +} diff --git a/media/libjxl/src/plugins/gdk-pixbuf/pixbufloader_test.cc b/media/libjxl/src/plugins/gdk-pixbuf/pixbufloader_test.cc new file mode 100644 index 000000000..5e5642d49 --- /dev/null +++ b/media/libjxl/src/plugins/gdk-pixbuf/pixbufloader_test.cc @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include +#include + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + const char* loaders_cache = argv[1]; + const char* filename = argv[2]; + setenv("GDK_PIXBUF_MODULE_FILE", loaders_cache, true); + + // XDG_DATA_HOME is the path where we look for the mime cache. + // XDG_DATA_DIRS directories are used in addition to XDG_DATA_HOME. + setenv("XDG_DATA_HOME", ".", true); + setenv("XDG_DATA_DIRS", "", true); + + if (!gdk_init_check(nullptr, nullptr)) { + fprintf(stderr, "This test requires a DISPLAY\n"); + // Signals ctest that we should mark this test as skipped. + return 254; + } + GError* error = nullptr; + GdkPixbuf* pb = gdk_pixbuf_new_from_file(filename, &error); + if (pb != nullptr) { + g_object_unref(pb); + return 0; + } else { + fprintf(stderr, "Error loading file: %s\n", filename); + g_assert_no_error(error); + return 1; + } +} diff --git a/media/libjxl/src/plugins/gimp/CMakeLists.txt b/media/libjxl/src/plugins/gimp/CMakeLists.txt new file mode 100644 index 000000000..f0a49005e --- /dev/null +++ b/media/libjxl/src/plugins/gimp/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +find_package(PkgConfig) +pkg_check_modules(Gimp IMPORTED_TARGET gimp-2.0>=2.10 gimpui-2.0>=2.10) + +if (NOT Gimp_FOUND) + message(WARNING "Gimp development libraries not found, the Gimp plugin will not be built") + return () +endif () + +add_executable(file-jxl WIN32 + common.h + common.cc + file-jxl-load.cc + file-jxl-load.h + file-jxl-save.cc + file-jxl-save.h + file-jxl.cc) +target_link_libraries(file-jxl jxl jxl_threads PkgConfig::Gimp) + +target_include_directories(file-jxl PUBLIC + ${PROJECT_SOURCE_DIR}) # for plugins/gimp absolute paths. + +pkg_get_variable(GIMP_LIB_DIR gimp-2.0 gimplibdir) +install(TARGETS file-jxl RUNTIME DESTINATION "${GIMP_LIB_DIR}/plug-ins/file-jxl/") diff --git a/media/libjxl/src/plugins/gimp/common.cc b/media/libjxl/src/plugins/gimp/common.cc new file mode 100644 index 000000000..1a884570c --- /dev/null +++ b/media/libjxl/src/plugins/gimp/common.cc @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "plugins/gimp/common.h" + +namespace jxl { + +JpegXlGimpProgress::JpegXlGimpProgress(const char *message) { + cur_progress = 0; + max_progress = 100; + + gimp_progress_init_printf("%s\n", message); +} + +void JpegXlGimpProgress::update() { + gimp_progress_update((float)++cur_progress / (float)max_progress); + return; +} + +void JpegXlGimpProgress::finished() { + gimp_progress_update(1.0); + return; +} + +} // namespace jxl diff --git a/media/libjxl/src/plugins/gimp/common.h b/media/libjxl/src/plugins/gimp/common.h new file mode 100644 index 000000000..95c51bf93 --- /dev/null +++ b/media/libjxl/src/plugins/gimp/common.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef PLUGINS_GIMP_COMMON_H_ +#define PLUGINS_GIMP_COMMON_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#define PLUG_IN_BINARY "file-jxl" +#define SAVE_PROC "file-jxl-save" + +// Defined by both FUIF and glib. +#undef MAX +#undef MIN +#undef CLAMP + +#include "jxl/resizable_parallel_runner.h" +#include "jxl/resizable_parallel_runner_cxx.h" + +namespace jxl { + +class JpegXlGimpProgress { + public: + explicit JpegXlGimpProgress(const char *message); + void update(); + void finished(); + + private: + int cur_progress; + int max_progress; + +}; // class JpegXlGimpProgress + +} // namespace jxl + +#endif // PLUGINS_GIMP_COMMON_H_ diff --git a/media/libjxl/src/plugins/gimp/file-jxl-load.cc b/media/libjxl/src/plugins/gimp/file-jxl-load.cc new file mode 100644 index 000000000..b1d1f154e --- /dev/null +++ b/media/libjxl/src/plugins/gimp/file-jxl-load.cc @@ -0,0 +1,378 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "plugins/gimp/file-jxl-load.h" + +#define _PROFILE_ORIGIN_ JXL_COLOR_PROFILE_TARGET_ORIGINAL +#define _PROFILE_TARGET_ JXL_COLOR_PROFILE_TARGET_DATA +#define LOAD_PROC "file-jxl-load" + +namespace jxl { + +bool LoadJpegXlImage(const gchar *const filename, gint32 *const image_id) { + std::vector icc_profile; + GimpColorProfile *profile_icc = nullptr; + GimpColorProfile *profile_int = nullptr; + bool is_linear = false; + + gint32 layer; + + gpointer pixels_buffer_1 = nullptr; + gpointer pixels_buffer_2 = nullptr; + size_t buffer_size = 0; + + GimpImageBaseType image_type = GIMP_RGB; + GimpImageType layer_type = GIMP_RGB_IMAGE; + GimpPrecision precision = GIMP_PRECISION_U16_GAMMA; + JxlBasicInfo info = {}; + JxlPixelFormat format = {}; + + format.num_channels = 4; + format.data_type = JXL_TYPE_FLOAT; + format.endianness = JXL_NATIVE_ENDIAN; + format.align = 0; + + bool is_gray = false; + + JpegXlGimpProgress gimp_load_progress( + ("Opening JPEG XL file:" + std::string(filename)).c_str()); + gimp_load_progress.update(); + + // read file + std::ifstream instream(filename, std::ios::in | std::ios::binary); + std::vector compressed((std::istreambuf_iterator(instream)), + std::istreambuf_iterator()); + instream.close(); + + gimp_load_progress.update(); + + // multi-threaded parallel runner. + auto runner = JxlResizableParallelRunnerMake(nullptr); + + auto dec = JxlDecoderMake(nullptr); + if (JXL_DEC_SUCCESS != + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)) { + g_printerr(LOAD_PROC " Error: JxlDecoderSubscribeEvents failed\n"); + return false; + } + + if (JXL_DEC_SUCCESS != JxlDecoderSetParallelRunner(dec.get(), + JxlResizableParallelRunner, + runner.get())) { + g_printerr(LOAD_PROC " Error: JxlDecoderSetParallelRunner failed\n"); + return false; + } + + // grand decode loop... + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + + while (true) { + gimp_load_progress.update(); + + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + + if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec.get(), &info)) { + g_printerr(LOAD_PROC " Error: JxlDecoderGetBasicInfo failed\n"); + return false; + } + + JxlResizableParallelRunnerSetThreads( + runner.get(), + JxlResizableParallelRunnerSuggestThreads(info.xsize, info.ysize)); + } else if (status == JXL_DEC_COLOR_ENCODING) { + // check for ICC profile + size_t icc_size = 0; + JxlColorEncoding color_encoding; + if (JXL_DEC_SUCCESS != + JxlDecoderGetColorAsEncodedProfile( + dec.get(), &format, _PROFILE_ORIGIN_, &color_encoding)) { + // Attempt to load ICC profile when no internal color encoding + if (JXL_DEC_SUCCESS != JxlDecoderGetICCProfileSize(dec.get(), &format, + _PROFILE_ORIGIN_, + &icc_size)) { + g_printerr(LOAD_PROC + " Warning: JxlDecoderGetICCProfileSize failed\n"); + } + + if (icc_size > 0) { + icc_profile.resize(icc_size); + if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsICCProfile( + dec.get(), &format, _PROFILE_ORIGIN_, + icc_profile.data(), icc_profile.size())) { + g_printerr(LOAD_PROC + " Warning: JxlDecoderGetColorAsICCProfile failed\n"); + } + + profile_icc = gimp_color_profile_new_from_icc_profile( + icc_profile.data(), icc_profile.size(), nullptr); + + if (profile_icc) { + is_linear = gimp_color_profile_is_linear(profile_icc); + g_printerr(LOAD_PROC " Info: Color profile is_linear = %d\n", + is_linear); + } else { + g_printerr(LOAD_PROC " Warning: Failed to read ICC profile.\n"); + } + } else { + g_printerr(LOAD_PROC " Warning: Empty ICC data.\n"); + } + } + + // Internal color profile detection... + if (JXL_DEC_SUCCESS == + JxlDecoderGetColorAsEncodedProfile( + dec.get(), &format, _PROFILE_TARGET_, &color_encoding)) { + g_printerr(LOAD_PROC " Info: Internal color encoding detected.\n"); + + // figure out linearity of internal profile + switch (color_encoding.transfer_function) { + case JXL_TRANSFER_FUNCTION_LINEAR: + is_linear = true; + break; + + case JXL_TRANSFER_FUNCTION_709: + case JXL_TRANSFER_FUNCTION_PQ: + case JXL_TRANSFER_FUNCTION_HLG: + case JXL_TRANSFER_FUNCTION_GAMMA: + case JXL_TRANSFER_FUNCTION_DCI: + case JXL_TRANSFER_FUNCTION_SRGB: + is_linear = false; + break; + + case JXL_TRANSFER_FUNCTION_UNKNOWN: + default: + if (profile_icc) { + g_printerr(LOAD_PROC + " Info: Unknown transfer function. " + "ICC profile is present."); + } else { + g_printerr(LOAD_PROC + " Info: Unknown transfer function. " + "No ICC profile present."); + } + break; + } + + switch (color_encoding.color_space) { + case JXL_COLOR_SPACE_RGB: + if (color_encoding.white_point == JXL_WHITE_POINT_D65 && + color_encoding.primaries == JXL_PRIMARIES_SRGB) { + if (is_linear) { + profile_int = gimp_color_profile_new_rgb_srgb_linear(); + } else { + profile_int = gimp_color_profile_new_rgb_srgb(); + } + } else if (!is_linear && + color_encoding.white_point == JXL_WHITE_POINT_D65 && + (color_encoding.primaries_green_xy[0] == 0.2100 || + color_encoding.primaries_green_xy[1] == 0.7100)) { + // Probably Adobe RGB + profile_int = gimp_color_profile_new_rgb_adobe(); + } else if (profile_icc) { + g_printerr(LOAD_PROC + " Info: Unknown RGB colorspace. " + "Using ICC profile.\n"); + } else { + g_printerr(LOAD_PROC + " Info: Unknown RGB colorspace. " + "Treating as sRGB.\n"); + if (is_linear) { + profile_int = gimp_color_profile_new_rgb_srgb_linear(); + } else { + profile_int = gimp_color_profile_new_rgb_srgb(); + } + } + break; + + case JXL_COLOR_SPACE_GRAY: + is_gray = true; + if (!profile_icc || + color_encoding.white_point == JXL_WHITE_POINT_D65) { + if (is_linear) { + profile_int = gimp_color_profile_new_d65_gray_linear(); + } else { + profile_int = gimp_color_profile_new_d65_gray_srgb_trc(); + } + } + break; + case JXL_COLOR_SPACE_XYB: + case JXL_COLOR_SPACE_UNKNOWN: + default: + if (profile_icc) { + g_printerr(LOAD_PROC + " Info: Unknown colorspace. Using ICC profile.\n"); + } else { + g_error( + LOAD_PROC + " Warning: Unknown colorspace. Treating as sRGB profile.\n"); + + if (is_linear) { + profile_int = gimp_color_profile_new_rgb_srgb_linear(); + } else { + profile_int = gimp_color_profile_new_rgb_srgb(); + } + } + break; + } + } + + // set pixel format + if (info.num_color_channels > 1) { + if (info.alpha_bits == 0) { + image_type = GIMP_RGB; + layer_type = GIMP_RGB_IMAGE; + format.num_channels = info.num_color_channels; + } else { + image_type = GIMP_RGB; + layer_type = GIMP_RGBA_IMAGE; + format.num_channels = info.num_color_channels + 1; + } + } else if (info.num_color_channels == 1) { + if (info.alpha_bits == 0) { + image_type = GIMP_GRAY; + layer_type = GIMP_GRAY_IMAGE; + format.num_channels = info.num_color_channels; + } else { + image_type = GIMP_GRAY; + layer_type = GIMP_GRAYA_IMAGE; + format.num_channels = info.num_color_channels + 1; + } + } + + // Set image bit depth and linearity + if (info.bits_per_sample <= 8) { + if (is_linear) { + precision = GIMP_PRECISION_U8_LINEAR; + } else { + precision = GIMP_PRECISION_U8_GAMMA; + } + } else if (info.bits_per_sample <= 16) { + if (info.exponent_bits_per_sample > 0) { + if (is_linear) { + precision = GIMP_PRECISION_HALF_LINEAR; + } else { + precision = GIMP_PRECISION_HALF_GAMMA; + } + } else if (is_linear) { + precision = GIMP_PRECISION_U16_LINEAR; + } else { + precision = GIMP_PRECISION_U16_GAMMA; + } + } else { + if (info.exponent_bits_per_sample > 0) { + if (is_linear) { + precision = GIMP_PRECISION_FLOAT_LINEAR; + } else { + precision = GIMP_PRECISION_FLOAT_GAMMA; + } + } else if (is_linear) { + precision = GIMP_PRECISION_U32_LINEAR; + } else { + precision = GIMP_PRECISION_U32_GAMMA; + } + } + + // create new image + if (is_linear) { + *image_id = gimp_image_new_with_precision( + info.xsize, info.ysize, image_type, GIMP_PRECISION_FLOAT_LINEAR); + } else { + *image_id = gimp_image_new_with_precision( + info.xsize, info.ysize, image_type, GIMP_PRECISION_FLOAT_GAMMA); + } + + if (profile_int) { + gimp_image_set_color_profile(*image_id, profile_int); + } else if (!profile_icc) { + g_printerr(LOAD_PROC " Warning: No color profile.\n"); + } + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + // get image from decoder in FLOAT + format.data_type = JXL_TYPE_FLOAT; + if (JXL_DEC_SUCCESS != + JxlDecoderImageOutBufferSize(dec.get(), &format, &buffer_size)) { + g_printerr(LOAD_PROC " Error: JxlDecoderImageOutBufferSize failed\n"); + return false; + } + pixels_buffer_1 = g_malloc(buffer_size); + if (JXL_DEC_SUCCESS != JxlDecoderSetImageOutBuffer(dec.get(), &format, + pixels_buffer_1, + buffer_size)) { + g_printerr(LOAD_PROC " Error: JxlDecoderSetImageOutBuffer failed\n"); + return false; + } + } else if (status == JXL_DEC_FULL_IMAGE || status == JXL_DEC_FRAME) { + // create and insert layer + layer = gimp_layer_new(*image_id, "Background", info.xsize, info.ysize, + layer_type, /*opacity=*/100, + gimp_image_get_default_new_layer_mode(*image_id)); + + gimp_image_insert_layer(*image_id, layer, /*parent_id=*/-1, + /*position=*/0); + + pixels_buffer_2 = g_malloc(buffer_size); + GeglBuffer *buffer = gimp_drawable_get_buffer(layer); + const Babl *destination_format = gegl_buffer_set_format(buffer, nullptr); + + std::string babl_format_str = ""; + if (is_gray) { + babl_format_str += "Y'"; + } else { + babl_format_str += "R'G'B'"; + } + if (info.alpha_bits > 0) { + babl_format_str += "A"; + } + babl_format_str += " float"; + + const Babl *source_format = babl_format(babl_format_str.c_str()); + + babl_process(babl_fish(source_format, destination_format), + pixels_buffer_1, pixels_buffer_2, info.xsize * info.ysize); + + gegl_buffer_set(buffer, GEGL_RECTANGLE(0, 0, info.xsize, info.ysize), 0, + nullptr, pixels_buffer_2, GEGL_AUTO_ROWSTRIDE); + + g_clear_object(&buffer); + } else if (status == JXL_DEC_SUCCESS) { + // All decoding successfully finished. + // It's not required to call JxlDecoderReleaseInput(dec.get()) + // since the decoder will be destroyed. + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + g_printerr(LOAD_PROC " Error: Already provided all input\n"); + return false; + } else if (status == JXL_DEC_ERROR) { + g_printerr(LOAD_PROC " Error: Decoder error\n"); + return false; + } else { + g_printerr(LOAD_PROC " Error: Unknown decoder status\n"); + return false; + } + } // end grand decode loop + + gimp_load_progress.update(); + + if (profile_icc) { + gimp_image_set_color_profile(*image_id, profile_icc); + } + + gimp_load_progress.update(); + + // TODO(xiota): Add option to keep image as float + if (info.bits_per_sample < 32) { + gimp_image_convert_precision(*image_id, precision); + } + + gimp_image_set_filename(*image_id, filename); + + gimp_load_progress.finished(); + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/plugins/gimp/file-jxl-load.h b/media/libjxl/src/plugins/gimp/file-jxl-load.h new file mode 100644 index 000000000..c9ca6d99b --- /dev/null +++ b/media/libjxl/src/plugins/gimp/file-jxl-load.h @@ -0,0 +1,19 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef PLUGINS_GIMP_FILE_JXL_LOAD_H_ +#define PLUGINS_GIMP_FILE_JXL_LOAD_H_ + +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "plugins/gimp/common.h" + +namespace jxl { + +bool LoadJpegXlImage(const gchar* filename, gint32* image_id); + +} // namespace jxl + +#endif // PLUGINS_GIMP_FILE_JXL_LOAD_H_ diff --git a/media/libjxl/src/plugins/gimp/file-jxl-save.cc b/media/libjxl/src/plugins/gimp/file-jxl-save.cc new file mode 100644 index 000000000..5eb141229 --- /dev/null +++ b/media/libjxl/src/plugins/gimp/file-jxl-save.cc @@ -0,0 +1,897 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "plugins/gimp/file-jxl-save.h" + +#include + +#include "gobject/gsignal.h" + +#define PLUG_IN_BINARY "file-jxl" +#define SAVE_PROC "file-jxl-save" + +#define SCALE_WIDTH 200 + +namespace jxl { + +namespace { + +#ifndef g_clear_signal_handler +// g_clear_signal_handler was added in glib 2.62 +void g_clear_signal_handler(gulong* handler, gpointer instance) { + if (handler != nullptr && *handler != 0) { + g_signal_handler_disconnect(instance, *handler); + *handler = 0; + } +} +#endif // g_clear_signal_handler + +class JpegXlSaveOpts { + public: + float distance; + float quality; + + bool lossless = false; + bool is_linear = false; + bool has_alpha = false; + bool is_gray = false; + bool icc_attached = false; + + bool advanced_mode = false; + bool use_container = true; + bool save_exif = false; + int encoding_effort = 7; + int faster_decoding = 0; + + std::string babl_format_str = "RGB u16"; + std::string babl_type_str = "u16"; + std::string babl_model_str = "RGB"; + + JxlPixelFormat pixel_format; + JxlBasicInfo basic_info; + + // functions + JpegXlSaveOpts(); + + bool SetDistance(float dist); + bool SetQuality(float qual); + bool SetDimensions(int x, int y); + bool SetNumChannels(int channels); + + bool UpdateDistance(); + bool UpdateQuality(); + + bool SetModel(bool is_linear_); + + bool UpdateBablFormat(); + bool SetBablModel(std::string model); + bool SetBablType(std::string type); + + bool SetPrecision(int gimp_precision); + + private: +}; // class JpegXlSaveOpts + +JpegXlSaveOpts jxl_save_opts; + +class JpegXlSaveGui { + public: + bool SaveDialog(); + + private: + GtkWidget* toggle_lossless = nullptr; + GtkAdjustment* entry_distance = nullptr; + GtkAdjustment* entry_quality = nullptr; + GtkAdjustment* entry_effort = nullptr; + GtkAdjustment* entry_faster = nullptr; + GtkWidget* frame_advanced = nullptr; + GtkWidget* toggle_no_xyb = nullptr; + GtkWidget* toggle_raw = nullptr; + gulong handle_toggle_lossless = 0; + gulong handle_entry_quality = 0; + gulong handle_entry_distance = 0; + + static bool GuiOnChangeQuality(GtkAdjustment* adj_qual, void* this_pointer); + + static bool GuiOnChangeDistance(GtkAdjustment* adj_dist, void* this_pointer); + + static bool GuiOnChangeEffort(GtkAdjustment* adj_effort); + static bool GuiOnChangeLossless(GtkWidget* toggle, void* this_pointer); + static bool GuiOnChangeCodestream(GtkWidget* toggle); + static bool GuiOnChangeNoXYB(GtkWidget* toggle); + + static bool GuiOnChangeAdvancedMode(GtkWidget* toggle, void* this_pointer); +}; // class JpegXlSaveGui + +JpegXlSaveGui jxl_save_gui; + +bool JpegXlSaveGui::GuiOnChangeQuality(GtkAdjustment* adj_qual, + void* this_pointer) { + JpegXlSaveGui* self = static_cast(this_pointer); + + g_clear_signal_handler(&self->handle_entry_distance, self->entry_distance); + g_clear_signal_handler(&self->handle_entry_quality, self->entry_quality); + g_clear_signal_handler(&self->handle_toggle_lossless, self->toggle_lossless); + + GtkAdjustment* adj_dist = self->entry_distance; + jxl_save_opts.quality = gtk_adjustment_get_value(adj_qual); + jxl_save_opts.UpdateDistance(); + gtk_adjustment_set_value(adj_dist, jxl_save_opts.distance); + + self->handle_toggle_lossless = g_signal_connect( + self->toggle_lossless, "toggled", G_CALLBACK(GuiOnChangeLossless), self); + self->handle_entry_distance = + g_signal_connect(self->entry_distance, "value-changed", + G_CALLBACK(GuiOnChangeDistance), self); + self->handle_entry_quality = + g_signal_connect(self->entry_quality, "value-changed", + G_CALLBACK(GuiOnChangeQuality), self); + return true; +} + +bool JpegXlSaveGui::GuiOnChangeDistance(GtkAdjustment* adj_dist, + void* this_pointer) { + JpegXlSaveGui* self = static_cast(this_pointer); + GtkAdjustment* adj_qual = self->entry_quality; + + g_clear_signal_handler(&self->handle_entry_distance, self->entry_distance); + g_clear_signal_handler(&self->handle_entry_quality, self->entry_quality); + g_clear_signal_handler(&self->handle_toggle_lossless, self->toggle_lossless); + + jxl_save_opts.distance = gtk_adjustment_get_value(adj_dist); + jxl_save_opts.UpdateQuality(); + gtk_adjustment_set_value(adj_qual, jxl_save_opts.quality); + + if (!(jxl_save_opts.distance < 0.001)) { + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(self->toggle_lossless), + false); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(self->toggle_no_xyb), false); + } + + self->handle_toggle_lossless = g_signal_connect( + self->toggle_lossless, "toggled", G_CALLBACK(GuiOnChangeLossless), self); + self->handle_entry_distance = + g_signal_connect(self->entry_distance, "value-changed", + G_CALLBACK(GuiOnChangeDistance), self); + self->handle_entry_quality = + g_signal_connect(self->entry_quality, "value-changed", + G_CALLBACK(GuiOnChangeQuality), self); + return true; +} + +bool JpegXlSaveGui::GuiOnChangeEffort(GtkAdjustment* adj_effort) { + float new_effort = 10 - gtk_adjustment_get_value(adj_effort); + jxl_save_opts.encoding_effort = new_effort; + return true; +} + +bool JpegXlSaveGui::GuiOnChangeLossless(GtkWidget* toggle, void* this_pointer) { + JpegXlSaveGui* self = static_cast(this_pointer); + GtkAdjustment* adj_distance = self->entry_distance; + GtkAdjustment* adj_quality = self->entry_quality; + GtkAdjustment* adj_effort = self->entry_effort; + + jxl_save_opts.lossless = + gtk_toggle_button_get_active(GTK_TOGGLE_BUTTON(toggle)); + + g_clear_signal_handler(&self->handle_entry_distance, self->entry_distance); + g_clear_signal_handler(&self->handle_entry_quality, self->entry_quality); + g_clear_signal_handler(&self->handle_toggle_lossless, self->toggle_lossless); + + if (jxl_save_opts.lossless) { + gtk_adjustment_set_value(adj_quality, 100.0); + gtk_adjustment_set_value(adj_distance, 0.0); + jxl_save_opts.distance = 0; + jxl_save_opts.UpdateQuality(); + gtk_adjustment_set_value(adj_effort, 7); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(self->toggle_no_xyb), true); + } else { + gtk_adjustment_set_value(adj_quality, 90.0); + gtk_adjustment_set_value(adj_distance, 1.0); + jxl_save_opts.distance = 1.0; + jxl_save_opts.UpdateQuality(); + gtk_adjustment_set_value(adj_effort, 3); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(self->toggle_no_xyb), false); + } + self->handle_toggle_lossless = g_signal_connect( + self->toggle_lossless, "toggled", G_CALLBACK(GuiOnChangeLossless), self); + self->handle_entry_distance = + g_signal_connect(self->entry_distance, "value-changed", + G_CALLBACK(GuiOnChangeDistance), self); + self->handle_entry_quality = + g_signal_connect(self->entry_quality, "value-changed", + G_CALLBACK(GuiOnChangeQuality), self); + return true; +} + +bool JpegXlSaveGui::GuiOnChangeCodestream(GtkWidget* toggle) { + jxl_save_opts.use_container = + !gtk_toggle_button_get_active(GTK_TOGGLE_BUTTON(toggle)); + return true; +} + +bool JpegXlSaveGui::GuiOnChangeNoXYB(GtkWidget* toggle) { + jxl_save_opts.basic_info.uses_original_profile = + gtk_toggle_button_get_active(GTK_TOGGLE_BUTTON(toggle)); + return true; +} + +bool JpegXlSaveGui::GuiOnChangeAdvancedMode(GtkWidget* toggle, + void* this_pointer) { + JpegXlSaveGui* self = static_cast(this_pointer); + jxl_save_opts.advanced_mode = + gtk_toggle_button_get_active(GTK_TOGGLE_BUTTON(toggle)); + + gtk_widget_set_sensitive(self->frame_advanced, jxl_save_opts.advanced_mode); + + if (!jxl_save_opts.advanced_mode) { + jxl_save_opts.basic_info.uses_original_profile = false; + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(self->toggle_no_xyb), false); + + jxl_save_opts.use_container = true; + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(self->toggle_raw), false); + + jxl_save_opts.faster_decoding = 0; + gtk_adjustment_set_value(GTK_ADJUSTMENT(self->entry_faster), 0); + } + return true; +} + +bool JpegXlSaveGui::SaveDialog() { + gboolean run; + GtkWidget* dialog; + GtkWidget* content_area; + GtkWidget* main_vbox; + GtkWidget* frame; + GtkWidget* toggle; + GtkWidget* table; + GtkWidget* vbox; + GtkWidget* separator; + + // initialize export dialog + gimp_ui_init(PLUG_IN_BINARY, true); + dialog = gimp_export_dialog_new("JPEG XL", PLUG_IN_BINARY, SAVE_PROC); + + gtk_window_set_resizable(GTK_WINDOW(dialog), false); + content_area = gimp_export_dialog_get_content_area(dialog); + + main_vbox = gtk_vbox_new(false, 6); + gtk_container_set_border_width(GTK_CONTAINER(main_vbox), 6); + gtk_box_pack_start(GTK_BOX(content_area), main_vbox, true, true, 0); + gtk_widget_show(main_vbox); + + // Standard Settings Frame + frame = gtk_frame_new(nullptr); + gtk_frame_set_shadow_type(GTK_FRAME(frame), GTK_SHADOW_ETCHED_IN); + gtk_box_pack_start(GTK_BOX(main_vbox), frame, false, false, 0); + gtk_widget_show(frame); + + vbox = gtk_vbox_new(false, 6); + gtk_container_set_border_width(GTK_CONTAINER(vbox), 6); + gtk_container_add(GTK_CONTAINER(frame), vbox); + gtk_widget_show(vbox); + + // Layout Table + table = gtk_table_new(20, 3, false); + gtk_table_set_col_spacings(GTK_TABLE(table), 6); + gtk_box_pack_start(GTK_BOX(vbox), table, false, false, 0); + gtk_widget_show(table); + + // Distance Slider + static gchar distance_help[] = + "Butteraugli distance target. Suggested values:" + "\n\td\u00A0=\u00A00.3\tExcellent" + "\n\td\u00A0=\u00A01\tVery Good" + "\n\td\u00A0=\u00A02\tGood" + "\n\td\u00A0=\u00A03\tFair" + "\n\td\u00A0=\u00A06\tPoor"; + + entry_distance = (GtkAdjustment*)gimp_scale_entry_new( + GTK_TABLE(table), 0, 0, "Distance", SCALE_WIDTH, 0, + jxl_save_opts.distance, 0.0, 15.0, 0.001, 1.0, 3, true, 0.0, 0.0, + distance_help, SAVE_PROC); + gimp_scale_entry_set_logarithmic((GtkObject*)entry_distance, true); + + // Quality Slider + static gchar quality_help[] = + "JPEG-style Quality is remapped to distance. " + "Values roughly match libjpeg quality settings."; + entry_quality = (GtkAdjustment*)gimp_scale_entry_new( + GTK_TABLE(table), 0, 1, "Quality", SCALE_WIDTH, 0, jxl_save_opts.quality, + 8.26, 100.0, 1.0, 10.0, 2, true, 0.0, 0.0, quality_help, SAVE_PROC); + + // Distance and Quality Signals + handle_entry_distance = g_signal_connect( + entry_distance, "value-changed", G_CALLBACK(GuiOnChangeDistance), this); + handle_entry_quality = g_signal_connect(entry_quality, "value-changed", + G_CALLBACK(GuiOnChangeQuality), this); + + // ---------- + separator = gtk_vseparator_new(); + gtk_table_attach(GTK_TABLE(table), separator, 0, 2, 2, 3, GTK_EXPAND, + GTK_EXPAND, 9, 9); + gtk_widget_show(separator); + + // Encoding Effort / Speed + static gchar effort_help[] = + "Adjust encoding speed. Higher values are faster because " + "the encoder uses less effort to hit distance targets. " + "As\u00A0a\u00A0result, image quality may be decreased. " + "Default\u00A0=\u00A03."; + entry_effort = (GtkAdjustment*)gimp_scale_entry_new( + GTK_TABLE(table), 0, 3, "Speed", SCALE_WIDTH, 0, + 10 - jxl_save_opts.encoding_effort, 1, 9, 1, 2, 0, true, 0.0, 0.0, + effort_help, SAVE_PROC); + + // effort signal + g_signal_connect(entry_effort, "value-changed", G_CALLBACK(GuiOnChangeEffort), + nullptr); + + // ---------- + separator = gtk_vseparator_new(); + gtk_table_attach(GTK_TABLE(table), separator, 0, 2, 4, 5, GTK_EXPAND, + GTK_EXPAND, 9, 9); + gtk_widget_show(separator); + + // Lossless Mode Convenience Checkbox + static gchar lossless_help[] = + "Compress using modular lossless mode. " + "Speed\u00A0is adjusted to improve performance."; + toggle_lossless = gtk_check_button_new_with_label("Lossless Mode"); + gimp_help_set_help_data(toggle_lossless, lossless_help, nullptr); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(toggle_lossless), + jxl_save_opts.lossless); + gtk_table_attach_defaults(GTK_TABLE(table), toggle_lossless, 0, 2, 5, 6); + gtk_widget_show(toggle_lossless); + + // lossless signal + handle_toggle_lossless = g_signal_connect( + toggle_lossless, "toggled", G_CALLBACK(GuiOnChangeLossless), this); + + // ---------- + separator = gtk_vseparator_new(); + gtk_box_pack_start(GTK_BOX(main_vbox), separator, false, false, 1); + gtk_widget_show(separator); + + // Advanced Settings Frame + std::vector advanced_opts; + + frame_advanced = gtk_frame_new("Advanced Settings"); + gimp_help_set_help_data(frame_advanced, + "Some advanced settings may produce malformed files.", + nullptr); + gtk_frame_set_shadow_type(GTK_FRAME(frame_advanced), GTK_SHADOW_ETCHED_IN); + gtk_box_pack_start(GTK_BOX(main_vbox), frame_advanced, true, true, 0); + gtk_widget_show(frame_advanced); + + gtk_widget_set_sensitive(frame_advanced, false); + + vbox = gtk_vbox_new(false, 6); + gtk_container_set_border_width(GTK_CONTAINER(vbox), 6); + gtk_container_add(GTK_CONTAINER(frame_advanced), vbox); + gtk_widget_show(vbox); + + // uses_original_profile + static gchar uses_original_profile_help[] = + "Prevents conversion to the XYB colorspace. " + "File sizes are approximately doubled."; + toggle_no_xyb = gtk_check_button_new_with_label("Do not use XYB colorspace"); + gimp_help_set_help_data(toggle_no_xyb, uses_original_profile_help, nullptr); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(toggle_no_xyb), + jxl_save_opts.basic_info.uses_original_profile); + gtk_box_pack_start(GTK_BOX(vbox), toggle_no_xyb, false, false, 0); + gtk_widget_show(toggle_no_xyb); + + g_signal_connect(toggle_no_xyb, "toggled", G_CALLBACK(GuiOnChangeNoXYB), + nullptr); + + // save raw codestream + static gchar codestream_help[] = + "Save the raw codestream, without a container. " + "The container is required for metadata and some other features."; + toggle_raw = gtk_check_button_new_with_label("Save Raw Codestream"); + gimp_help_set_help_data(toggle_raw, codestream_help, nullptr); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(toggle_raw), + !jxl_save_opts.use_container); + gtk_box_pack_start(GTK_BOX(vbox), toggle_raw, false, false, 0); + gtk_widget_show(toggle_raw); + + g_signal_connect(toggle_raw, "toggled", G_CALLBACK(GuiOnChangeCodestream), + nullptr); + + // ---------- + separator = gtk_vseparator_new(); + gtk_box_pack_start(GTK_BOX(vbox), separator, false, false, 1); + gtk_widget_show(separator); + + // Faster Decoding / Decoding Speed + static gchar faster_help[] = + "Improve decoding speed at the expense of quality. " + "Default\u00A0=\u00A00."; + table = gtk_table_new(1, 3, false); + gtk_table_set_col_spacings(GTK_TABLE(table), 6); + gtk_container_add(GTK_CONTAINER(vbox), table); + gtk_widget_show(table); + + entry_faster = (GtkAdjustment*)gimp_scale_entry_new( + GTK_TABLE(table), 0, 0, "Faster Decoding", SCALE_WIDTH, 0, + jxl_save_opts.faster_decoding, 0, 4, 1, 1, 0, true, 0.0, 0.0, faster_help, + SAVE_PROC); + + // Faster Decoding Signals + g_signal_connect(entry_faster, "value-changed", + G_CALLBACK(gimp_int_adjustment_update), + &jxl_save_opts.faster_decoding); + + // Enable Advanced Settings + frame = gtk_frame_new(nullptr); + gtk_frame_set_shadow_type(GTK_FRAME(frame), GTK_SHADOW_NONE); + gtk_box_pack_start(GTK_BOX(main_vbox), frame, true, true, 0); + gtk_widget_show(frame); + + vbox = gtk_vbox_new(false, 6); + gtk_container_set_border_width(GTK_CONTAINER(vbox), 6); + gtk_container_add(GTK_CONTAINER(frame), vbox); + gtk_widget_show(vbox); + + static gchar advanced_help[] = + "Some advanced settings may produce malformed files."; + toggle = gtk_check_button_new_with_label("Enable Advanced Settings"); + gimp_help_set_help_data(toggle, advanced_help, nullptr); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(toggle), + jxl_save_opts.advanced_mode); + gtk_box_pack_start(GTK_BOX(vbox), toggle, false, false, 0); + gtk_widget_show(toggle); + + g_signal_connect(toggle, "toggled", G_CALLBACK(GuiOnChangeAdvancedMode), + this); + + // show dialog + gtk_widget_show(dialog); + + GtkAllocation allocation; + gtk_widget_get_allocation(dialog, &allocation); + + int height = allocation.height; + gtk_widget_set_size_request(dialog, height * 1.5, height); + + run = (gimp_dialog_run(GIMP_DIALOG(dialog)) == GTK_RESPONSE_OK); + gtk_widget_destroy(dialog); + + return run; +} // JpegXlSaveGui::SaveDialog + +JpegXlSaveOpts::JpegXlSaveOpts() { + SetDistance(1.0); + + pixel_format.num_channels = 4; + pixel_format.data_type = JXL_TYPE_FLOAT; + pixel_format.endianness = JXL_NATIVE_ENDIAN; + pixel_format.align = 0; + + JxlEncoderInitBasicInfo(&basic_info); + return; +} // JpegXlSaveOpts constructor + +bool JpegXlSaveOpts::SetModel(bool is_linear_) { + int channels; + std::string model; + + if (is_gray) { + channels = 1; + if (is_linear_) { + model = "Y"; + } else { + model = "Y'"; + } + } else { + channels = 3; + if (is_linear_) { + model = "RGB"; + } else { + model = "R'G'B'"; + } + } + if (has_alpha) { + SetBablModel(model + "A"); + SetNumChannels(channels + 1); + } else { + SetBablModel(model); + SetNumChannels(channels); + } + return true; +} // JpegXlSaveOpts::SetModel + +bool JpegXlSaveOpts::SetDistance(float dist) { + distance = dist; + return UpdateQuality(); +} + +bool JpegXlSaveOpts::SetQuality(float qual) { + quality = qual; + return UpdateDistance(); +} + +bool JpegXlSaveOpts::UpdateQuality() { + float qual; + + if (distance < 0.1) { + qual = 100; + } else if (distance > 6.56) { + qual = 30 - 5 * log(abs(6.25 * distance - 40)) / log(2.5); + lossless = false; + } else { + qual = 100 - (distance - 0.1) / 0.09; + lossless = false; + } + + if (qual < 0) { + quality = 0.0; + } else if (qual >= 100) { + quality = 100.0; + } else { + quality = qual; + } + + return true; +} + +bool JpegXlSaveOpts::UpdateDistance() { + float dist; + if (quality >= 30) { + dist = 0.1 + (100 - quality) * 0.09; + } else { + dist = 6.4 + pow(2.5, (30 - quality) / 5.0) / 6.25; + } + + if (dist > 15) { + distance = 15; + } else { + distance = dist; + } + return true; +} + +bool JpegXlSaveOpts::SetDimensions(int x, int y) { + basic_info.xsize = x; + basic_info.ysize = y; + return true; +} + +bool JpegXlSaveOpts::SetNumChannels(int channels) { + switch (channels) { + case 1: + pixel_format.num_channels = 1; + basic_info.num_color_channels = 1; + basic_info.num_extra_channels = 0; + basic_info.alpha_bits = 0; + basic_info.alpha_exponent_bits = 0; + break; + case 2: + pixel_format.num_channels = 2; + basic_info.num_color_channels = 1; + basic_info.num_extra_channels = 1; + basic_info.alpha_bits = int(std::fmin(16, basic_info.bits_per_sample)); + basic_info.alpha_exponent_bits = 0; + break; + case 3: + pixel_format.num_channels = 3; + basic_info.num_color_channels = 3; + basic_info.num_extra_channels = 0; + basic_info.alpha_bits = 0; + basic_info.alpha_exponent_bits = 0; + break; + case 4: + pixel_format.num_channels = 4; + basic_info.num_color_channels = 3; + basic_info.num_extra_channels = 1; + basic_info.alpha_bits = int(std::fmin(16, basic_info.bits_per_sample)); + basic_info.alpha_exponent_bits = 0; + break; + default: + SetNumChannels(3); + } // switch + return true; +} // JpegXlSaveOpts::SetNumChannels + +bool JpegXlSaveOpts::UpdateBablFormat() { + babl_format_str = babl_model_str + " " + babl_type_str; + return true; +} + +bool JpegXlSaveOpts::SetBablModel(std::string model) { + babl_model_str = model; + return UpdateBablFormat(); +} + +bool JpegXlSaveOpts::SetBablType(std::string type) { + babl_type_str = type; + return UpdateBablFormat(); +} + +bool JpegXlSaveOpts::SetPrecision(int gimp_precision) { + switch (gimp_precision) { + case GIMP_PRECISION_HALF_GAMMA: + case GIMP_PRECISION_HALF_LINEAR: + basic_info.bits_per_sample = 16; + basic_info.exponent_bits_per_sample = 5; + break; + + // UINT32 not supported by encoder; using FLOAT instead + case GIMP_PRECISION_U32_GAMMA: + case GIMP_PRECISION_U32_LINEAR: + case GIMP_PRECISION_FLOAT_GAMMA: + case GIMP_PRECISION_FLOAT_LINEAR: + basic_info.bits_per_sample = 32; + basic_info.exponent_bits_per_sample = 8; + break; + + case GIMP_PRECISION_U16_GAMMA: + case GIMP_PRECISION_U16_LINEAR: + basic_info.bits_per_sample = 16; + basic_info.exponent_bits_per_sample = 0; + break; + + default: + case GIMP_PRECISION_U8_LINEAR: + case GIMP_PRECISION_U8_GAMMA: + basic_info.bits_per_sample = 8; + basic_info.exponent_bits_per_sample = 0; + break; + } + return true; +} // JpegXlSaveOpts::SetPrecision + +} // namespace + +bool SaveJpegXlImage(const gint32 image_id, const gint32 drawable_id, + const gint32 orig_image_id, const gchar* const filename) { + if (!jxl_save_gui.SaveDialog()) { + return true; + } + + gint32 nlayers; + gint32* layers; + gint32 duplicate = gimp_image_duplicate(image_id); + + JpegXlGimpProgress gimp_save_progress( + ("Saving JPEG XL file:" + std::string(filename)).c_str()); + gimp_save_progress.update(); + + // try to get ICC color profile... + std::vector icc; + + GimpColorProfile* profile = gimp_image_get_effective_color_profile(image_id); + jxl_save_opts.is_gray = gimp_color_profile_is_gray(profile); + jxl_save_opts.is_linear = gimp_color_profile_is_linear(profile); + + profile = gimp_image_get_color_profile(image_id); + if (profile) { + g_printerr(SAVE_PROC " Info: Extracting ICC Profile...\n"); + gsize icc_size; + const guint8* const icc_bytes = + gimp_color_profile_get_icc_profile(profile, &icc_size); + + icc.assign(icc_bytes, icc_bytes + icc_size); + } else { + g_printerr(SAVE_PROC " Info: No ICC profile. Exporting image anyway.\n"); + } + + gimp_save_progress.update(); + + jxl_save_opts.SetDimensions(gimp_image_width(image_id), + gimp_image_height(image_id)); + + jxl_save_opts.SetPrecision(gimp_image_get_precision(image_id)); + layers = gimp_image_get_layers(duplicate, &nlayers); + + for (int i = 0; i < nlayers; i++) { + if (gimp_drawable_has_alpha(layers[i])) { + jxl_save_opts.has_alpha = true; + break; + } + } + + gimp_save_progress.update(); + + // layers need to match image size, for now + for (int i = 0; i < nlayers; i++) { + gimp_layer_resize_to_image_size(layers[i]); + } + + // treat layers as animation frames, for now + if (nlayers > 1) { + jxl_save_opts.basic_info.have_animation = true; + jxl_save_opts.basic_info.animation.tps_numerator = 100; + } + + gimp_save_progress.update(); + + // multi-threaded parallel runner. + auto runner = JxlResizableParallelRunnerMake(nullptr); + + JxlResizableParallelRunnerSetThreads( + runner.get(), + JxlResizableParallelRunnerSuggestThreads(jxl_save_opts.basic_info.xsize, + jxl_save_opts.basic_info.ysize)); + + auto enc = JxlEncoderMake(/*memory_manager=*/nullptr); + JxlEncoderUseContainer(enc.get(), jxl_save_opts.use_container); + + if (JXL_ENC_SUCCESS != JxlEncoderSetParallelRunner(enc.get(), + JxlResizableParallelRunner, + runner.get())) { + g_printerr(SAVE_PROC " Error: JxlEncoderSetParallelRunner failed\n"); + return false; + } + + // try to use ICC profile + if (!icc.empty() && !jxl_save_opts.is_gray) { + if (JXL_ENC_SUCCESS == + JxlEncoderSetICCProfile(enc.get(), icc.data(), icc.size())) { + jxl_save_opts.icc_attached = true; + } else { + g_printerr(SAVE_PROC " Warning: JxlEncoderSetICCProfile failed.\n"); + jxl_save_opts.basic_info.uses_original_profile = false; + jxl_save_opts.lossless = false; + } + } else { + g_printerr(SAVE_PROC " Warning: Using internal profile.\n"); + jxl_save_opts.basic_info.uses_original_profile = false; + jxl_save_opts.lossless = false; + } + + // set up internal color profile + JxlColorEncoding color_encoding = {}; + + if (jxl_save_opts.is_linear) { + JxlColorEncodingSetToLinearSRGB(&color_encoding, jxl_save_opts.is_gray); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, jxl_save_opts.is_gray); + } + + if (JXL_ENC_SUCCESS != + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)) { + g_printerr(SAVE_PROC " Warning: JxlEncoderSetColorEncoding failed\n"); + } + + // set encoder options + JxlEncoderFrameSettings* frame_settings; + frame_settings = JxlEncoderFrameSettingsCreate(enc.get(), nullptr); + + JxlEncoderFrameSettingsSetOption(frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, + jxl_save_opts.encoding_effort); + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_DECODING_SPEED, + jxl_save_opts.faster_decoding); + + // lossless mode + if (jxl_save_opts.lossless || jxl_save_opts.distance < 0.01) { + if (jxl_save_opts.basic_info.exponent_bits_per_sample > 0) { + // lossless mode doesn't work well with floating point + jxl_save_opts.distance = 0.01; + jxl_save_opts.lossless = false; + JxlEncoderSetFrameLossless(frame_settings, false); + JxlEncoderSetFrameDistance(frame_settings, 0.01); + } else { + JxlEncoderSetFrameDistance(frame_settings, 0); + JxlEncoderSetFrameLossless(frame_settings, true); + } + } else { + jxl_save_opts.lossless = false; + JxlEncoderSetFrameLossless(frame_settings, false); + JxlEncoderSetFrameDistance(frame_settings, jxl_save_opts.distance); + } + + // this sets some basic_info properties + jxl_save_opts.SetModel(jxl_save_opts.is_linear); + + if (JXL_ENC_SUCCESS != + JxlEncoderSetBasicInfo(enc.get(), &jxl_save_opts.basic_info)) { + g_printerr(SAVE_PROC " Error: JxlEncoderSetBasicInfo failed\n"); + return false; + } + + // convert precision and colorspace + if (jxl_save_opts.is_linear && + jxl_save_opts.basic_info.bits_per_sample < 32) { + gimp_image_convert_precision(duplicate, GIMP_PRECISION_FLOAT_LINEAR); + } else { + gimp_image_convert_precision(duplicate, GIMP_PRECISION_FLOAT_GAMMA); + } + + // process layers and compress into JXL + size_t buffer_size = + jxl_save_opts.basic_info.xsize * jxl_save_opts.basic_info.ysize * + jxl_save_opts.pixel_format.num_channels * 4; // bytes per sample + + for (int i = nlayers - 1; i >= 0; i--) { + gimp_save_progress.update(); + + // copy image into buffer... + gpointer pixels_buffer_1; + gpointer pixels_buffer_2; + pixels_buffer_1 = g_malloc(buffer_size); + pixels_buffer_2 = g_malloc(buffer_size); + + gimp_layer_resize_to_image_size(layers[i]); + + GeglBuffer* buffer = gimp_drawable_get_buffer(layers[i]); + + // using gegl_buffer_set_format to get the format because + // gegl_buffer_get_format doesn't always get the original format + const Babl* native_format = gegl_buffer_set_format(buffer, nullptr); + + gegl_buffer_get(buffer, + GEGL_RECTANGLE(0, 0, jxl_save_opts.basic_info.xsize, + jxl_save_opts.basic_info.ysize), + 1.0, native_format, pixels_buffer_1, GEGL_AUTO_ROWSTRIDE, + GEGL_ABYSS_NONE); + g_clear_object(&buffer); + + // use babl to fix gamma mismatch issues + if (jxl_save_opts.icc_attached) { + jxl_save_opts.SetModel(jxl_save_opts.is_linear); + } else { + jxl_save_opts.SetModel(!jxl_save_opts.is_linear); + } + jxl_save_opts.pixel_format.data_type = JXL_TYPE_FLOAT; + jxl_save_opts.SetBablType("float"); + const Babl* destination_format = + babl_format(jxl_save_opts.babl_format_str.c_str()); + + babl_process( + babl_fish(native_format, destination_format), pixels_buffer_1, + pixels_buffer_2, + jxl_save_opts.basic_info.xsize * jxl_save_opts.basic_info.ysize); + + gimp_save_progress.update(); + + // send layer to encoder + if (JXL_ENC_SUCCESS != + JxlEncoderAddImageFrame(frame_settings, &jxl_save_opts.pixel_format, + pixels_buffer_2, buffer_size)) { + g_printerr(SAVE_PROC " Error: JxlEncoderAddImageFrame failed\n"); + return false; + } + } + + JxlEncoderCloseInput(enc.get()); + + // get data from encoder + std::vector compressed; + compressed.resize(262144); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size(); + + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + gimp_save_progress.update(); + + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() + 262144); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + + if (JXL_ENC_SUCCESS != process_result) { + g_printerr(SAVE_PROC " Error: JxlEncoderProcessOutput failed\n"); + return false; + } + + // write file + std::ofstream outstream(filename, std::ios::out | std::ios::binary); + copy(compressed.begin(), compressed.end(), + std::ostream_iterator(outstream)); + + gimp_save_progress.finished(); + return true; +} // SaveJpegXlImage() + +} // namespace jxl diff --git a/media/libjxl/src/plugins/gimp/file-jxl-save.h b/media/libjxl/src/plugins/gimp/file-jxl-save.h new file mode 100644 index 000000000..9dfa45c59 --- /dev/null +++ b/media/libjxl/src/plugins/gimp/file-jxl-save.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef PLUGINS_GIMP_FILE_JXL_SAVE_H_ +#define PLUGINS_GIMP_FILE_JXL_SAVE_H_ + +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "plugins/gimp/common.h" + +namespace jxl { + +bool SaveJpegXlImage(gint32 image_id, gint32 drawable_id, gint32 orig_image_id, + const gchar* filename); + +} // namespace jxl + +#endif // PLUGINS_GIMP_FILE_JXL_SAVE_H_ diff --git a/media/libjxl/src/plugins/gimp/file-jxl.cc b/media/libjxl/src/plugins/gimp/file-jxl.cc new file mode 100644 index 000000000..743495a2e --- /dev/null +++ b/media/libjxl/src/plugins/gimp/file-jxl.cc @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include + +#include "plugins/gimp/common.h" +#include "plugins/gimp/file-jxl-load.h" +#include "plugins/gimp/file-jxl-save.h" + +namespace jxl { +namespace { + +constexpr char kLoadProc[] = "file-jxl-load"; +constexpr char kSaveProc[] = "file-jxl-save"; + +void Query() { + { + static char run_mode_name[] = "run-mode"; + static char run_mode_description[] = "Run mode"; + static char filename_name[] = "filename"; + static char filename_description[] = "The name of the file to load"; + static char raw_filename_name[] = "raw-filename"; + static char raw_filename_description[] = + "The name of the file, as entered by the user"; + static const GimpParamDef load_args[] = { + {GIMP_PDB_INT32, run_mode_name, run_mode_description}, + {GIMP_PDB_STRING, filename_name, filename_description}, + {GIMP_PDB_STRING, raw_filename_name, raw_filename_description}, + }; + static char image_name[] = "image"; + static char image_description[] = "Loaded image"; + static const GimpParamDef load_return_vals[] = { + {GIMP_PDB_IMAGE, image_name, image_description}, + }; + + gimp_install_procedure( + /*name=*/kLoadProc, /*blurb=*/"Loads JPEG XL image files", + /*help=*/"Loads JPEG XL image files", /*author=*/"JPEG XL Project", + /*copyright=*/"JPEG XL Project", /*date=*/"2019", + /*menu_label=*/"JPEG XL image", /*image_types=*/nullptr, + /*type=*/GIMP_PLUGIN, /*n_params=*/G_N_ELEMENTS(load_args), + /*n_return_vals=*/G_N_ELEMENTS(load_return_vals), /*params=*/load_args, + /*return_vals=*/load_return_vals); + gimp_register_file_handler_mime(kLoadProc, "image/jxl"); + gimp_register_magic_load_handler( + kLoadProc, "jxl", "", + "0,string,\xFF\x0A," + "0,string,\\000\\000\\000\x0CJXL\\040\\015\\012\x87\\012"); + } + + { + static char run_mode_name[] = "run-mode"; + static char run_mode_description[] = "Run mode"; + static char image_name[] = "image"; + static char image_description[] = "Input image"; + static char drawable_name[] = "drawable"; + static char drawable_description[] = "Drawable to save"; + static char filename_name[] = "filename"; + static char filename_description[] = "The name of the file to save"; + static char raw_filename_name[] = "raw-filename"; + static char raw_filename_description[] = "The name of the file to save"; + static const GimpParamDef save_args[] = { + {GIMP_PDB_INT32, run_mode_name, run_mode_description}, + {GIMP_PDB_IMAGE, image_name, image_description}, + {GIMP_PDB_DRAWABLE, drawable_name, drawable_description}, + {GIMP_PDB_STRING, filename_name, filename_description}, + {GIMP_PDB_STRING, raw_filename_name, raw_filename_description}, + }; + + gimp_install_procedure( + /*name=*/kSaveProc, /*blurb=*/"Saves JPEG XL image files", + /*help=*/"Saves JPEG XL image files", /*author=*/"JPEG XL Project", + /*copyright=*/"JPEG XL Project", /*date=*/"2019", + /*menu_label=*/"JPEG XL image", /*image_types=*/"RGB*, GRAY*", + /*type=*/GIMP_PLUGIN, /*n_params=*/G_N_ELEMENTS(save_args), + /*n_return_vals=*/0, /*params=*/save_args, + /*return_vals=*/nullptr); + gimp_register_file_handler_mime(kSaveProc, "image/jxl"); + gimp_register_save_handler(kSaveProc, "jxl", ""); + } +} + +void Run(const gchar* const name, const gint nparams, + const GimpParam* const params, gint* const nreturn_vals, + GimpParam** const return_vals) { + gegl_init(nullptr, nullptr); + + static GimpParam values[2]; + + *nreturn_vals = 1; + *return_vals = values; + + values[0].type = GIMP_PDB_STATUS; + values[0].data.d_status = GIMP_PDB_EXECUTION_ERROR; + + if (strcmp(name, kLoadProc) == 0) { + if (nparams != 3) { + values[0].data.d_status = GIMP_PDB_CALLING_ERROR; + return; + } + + const gchar* const filename = params[1].data.d_string; + gint32 image_id; + if (!LoadJpegXlImage(filename, &image_id)) { + values[0].data.d_status = GIMP_PDB_EXECUTION_ERROR; + return; + } + + *nreturn_vals = 2; + values[0].data.d_status = GIMP_PDB_SUCCESS; + values[1].type = GIMP_PDB_IMAGE; + values[1].data.d_image = image_id; + } else if (strcmp(name, kSaveProc) == 0) { + if (nparams != 5) { + values[0].data.d_status = GIMP_PDB_CALLING_ERROR; + return; + } + + gint32 image_id = params[1].data.d_image; + gint32 drawable_id = params[2].data.d_drawable; + const gchar* const filename = params[3].data.d_string; + const gint32 orig_image_id = image_id; + const GimpExportReturn export_result = gimp_export_image( + &image_id, &drawable_id, "JPEG XL", + static_cast(GIMP_EXPORT_CAN_HANDLE_RGB | + GIMP_EXPORT_CAN_HANDLE_GRAY | + GIMP_EXPORT_CAN_HANDLE_ALPHA)); + switch (export_result) { + case GIMP_EXPORT_CANCEL: + values[0].data.d_status = GIMP_PDB_CANCEL; + return; + case GIMP_EXPORT_IGNORE: + break; + case GIMP_EXPORT_EXPORT: + break; + } + if (!SaveJpegXlImage(image_id, drawable_id, orig_image_id, filename)) { + return; + } + if (image_id != orig_image_id) { + gimp_image_delete(image_id); + } + values[0].data.d_status = GIMP_PDB_SUCCESS; + } +} + +} // namespace +} // namespace jxl + +static const GimpPlugInInfo PLUG_IN_INFO = {nullptr, nullptr, &jxl::Query, + &jxl::Run}; + +MAIN() diff --git a/media/libjxl/src/plugins/mime/CMakeLists.txt b/media/libjxl/src/plugins/mime/CMakeLists.txt new file mode 100644 index 000000000..6f2a0f919 --- /dev/null +++ b/media/libjxl/src/plugins/mime/CMakeLists.txt @@ -0,0 +1,6 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +install(FILES image-jxl.xml DESTINATION share/mime/packages/) diff --git a/media/libjxl/src/plugins/mime/README.md b/media/libjxl/src/plugins/mime/README.md new file mode 100644 index 000000000..4b5373ce5 --- /dev/null +++ b/media/libjxl/src/plugins/mime/README.md @@ -0,0 +1,20 @@ +## JPEG XL MIME type + +If not already installed by the [Installing section of README.md](../../README.md#installing), then it can be done manually: + +### Install +```bash +sudo xdg-mime install --novendor image-jxl.xml +``` + +Then run: +``` +update-mime --local +``` + + +### Uninstall +```bash +sudo xdg-mime uninstall image-jxl.xml +``` + diff --git a/media/libjxl/src/plugins/mime/image-jxl.xml b/media/libjxl/src/plugins/mime/image-jxl.xml new file mode 100644 index 000000000..cab9018c7 --- /dev/null +++ b/media/libjxl/src/plugins/mime/image-jxl.xml @@ -0,0 +1,13 @@ + + + + JPEG XL image + image JPEG XL + JPEG XL afbeelding + + + + + + + diff --git a/media/libjxl/src/third_party/CMakeLists.txt b/media/libjxl/src/third_party/CMakeLists.txt new file mode 100644 index 000000000..50cc72c92 --- /dev/null +++ b/media/libjxl/src/third_party/CMakeLists.txt @@ -0,0 +1,157 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if((SANITIZER STREQUAL "asan") OR (SANITIZER STREQUAL "msan")) + set(BUILD_TESTING OFF) +endif() + +# Highway +set(HWY_SYSTEM_GTEST ON CACHE INTERNAL "") +set(HWY_FORCE_STATIC_LIBS ON CACHE INTERNAL "") +set(HWY_ENABLE_CONTRIB OFF CACHE INTERNAL "") +set(HWY_ENABLE_EXAMPLES OFF CACHE INTERNAL "") +if((SANITIZER STREQUAL "asan") OR (SANITIZER STREQUAL "msan")) + set(HWY_ENABLE_INSTALL OFF CACHE INTERNAL "") +endif() +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/highway/CMakeLists.txt" AND + NOT JPEGXL_FORCE_SYSTEM_HWY) + add_subdirectory(highway) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/highway/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.highway COPYONLY) +else() + find_package(HWY 0.15.0) + if (NOT HWY_FOUND) + message(FATAL_ERROR + "Highway library (hwy) not found. Install libhwy-dev or download it " + "to third_party/highway from https://github.com/google/highway . " + "Highway is required to build JPEG XL. You can run " + "${PROJECT_SOURCE_DIR}/deps.sh to download this dependency.") + endif() + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libhwy-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.highway COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +# brotli +if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/brotli/c/include/brotli/decode.h" OR + JPEGXL_FORCE_SYSTEM_BROTLI) + find_package(Brotli) + if (NOT Brotli_FOUND) + message(FATAL_ERROR + "Brotli not found, install brotli-dev or download brotli source code to" + " third_party/brotli from https://github.com/google/brotli. You can use" + " ${PROJECT_SOURCE_DIR}/deps.sh to download this dependency.") + endif () + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libbrotli-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.brotli COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +else() + # Compile brotli from sources. + set(BROTLI_DISABLE_TESTS ON CACHE STRING "Disable Brotli tests") + # Override default "no-install" policy. + if((NOT SANITIZER STREQUAL "asan") AND (NOT SANITIZER STREQUAL "msan")) + set(BROTLI_BUNDLED_MODE OFF CACHE INTERNAL "") + endif() + add_subdirectory(brotli) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/brotli/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.brotli COPYONLY) + if(BROTLI_EMSCRIPTEN) + # Brotli only defines the -static targets when using emscripten. + foreach(brlib IN ITEMS brotlienc brotlidec brotlicommon) + add_library(${brlib} ALIAS ${brlib}-static) + endforeach() + endif() # BROTLI_EMSCRIPTEN +endif() + +# *cms +if (JPEGXL_ENABLE_SKCMS OR JPEGXL_ENABLE_PLUGINS) + if( NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/skcms/skcms.h" ) + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") + endif() + include(skcms.cmake) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/skcms/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.skcms COPYONLY) +endif () +if (JPEGXL_ENABLE_VIEWERS OR NOT JPEGXL_ENABLE_SKCMS) + if( NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/lcms/.git" OR JPEGXL_FORCE_SYSTEM_LCMS2 ) + find_package(LCMS2 2.13) + if ( NOT LCMS2_FOUND ) + message(FATAL_ERROR "Please install lcms2 or run git submodule update --init") + endif () + else() + include(lcms2.cmake) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/lcms/COPYING" + ${PROJECT_BINARY_DIR}/LICENSE.lcms COPYONLY) + endif() +endif() + +# libpng +if (JPEGXL_EMSCRIPTEN) + if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/libpng/CMakeLists.txt") + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") + endif() + file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/libpng/scripts/pnglibconf.h.prebuilt" DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/libpng") + file(RENAME "${CMAKE_CURRENT_SOURCE_DIR}/libpng/pnglibconf.h.prebuilt" "${CMAKE_CURRENT_SOURCE_DIR}/libpng/pnglibconf.h") + set(ZLIB_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/zlib/") + set(ZLIB_LIBRARY "") + set(PNG_FOUND YES PARENT_SCOPE) + set(PNG_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/libpng/" PARENT_SCOPE) + set(PNG_LIBRARIES "" PARENT_SCOPE) +elseif (JPEGXL_BUNDLE_LIBPNG) + if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/libpng/CMakeLists.txt") + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") + endif() + add_subdirectory(zlib) + set(PNG_STATIC ON CACHE BOOL "") + set(PNG_EXECUTABLES OFF CACHE BOOL "") + set(PNG_BUILD_ZLIB ON CACHE BOOL "") + set(PNG_TESTS OFF CACHE BOOL "") + set(SKIP_INSTALL_ALL ON CACHE BOOL "") + set(ZLIB_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/zlib/") + set(ZLIB_LIBRARY zlibstatic) + add_subdirectory(libpng EXCLUDE_FROM_ALL) + set(PNG_FOUND YES PARENT_SCOPE) + set(PNG_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/libpng/" PARENT_SCOPE) + set(PNG_LIBRARIES png_static PARENT_SCOPE) + set_property(TARGET png_static PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET zlibstatic PROPERTY POSITION_INDEPENDENT_CODE ON) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/libpng/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.libpng COPYONLY) + endif() +else() + find_package(PNG) + if(PNG_FOUND AND JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/zlib1g-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.zlib COPYONLY) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libpng-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libpng COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +# sjpeg +if (JPEGXL_ENABLE_SJPEG) + if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/sjpeg/CMakeLists.txt") + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") + endif() + include(sjpeg.cmake) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/sjpeg/COPYING" + ${PROJECT_BINARY_DIR}/LICENSE.sjpeg COPYONLY) +endif () diff --git a/media/libjxl/src/third_party/HEVCSoftware/README.md b/media/libjxl/src/third_party/HEVCSoftware/README.md new file mode 100644 index 000000000..70ebaeba3 --- /dev/null +++ b/media/libjxl/src/third_party/HEVCSoftware/README.md @@ -0,0 +1,2 @@ +This directory contains modified configuration files from the reference HEVC +encoder, the source code of which can be found at: https://hevc.hhi.fraunhofer.de/svn/svn_HEVCSoftware/ diff --git a/media/libjxl/src/third_party/HEVCSoftware/cfg/LICENSE b/media/libjxl/src/third_party/HEVCSoftware/cfg/LICENSE new file mode 100644 index 000000000..a9d8844e4 --- /dev/null +++ b/media/libjxl/src/third_party/HEVCSoftware/cfg/LICENSE @@ -0,0 +1,31 @@ +The copyright in this software is being made available under the BSD +License, included below. This software may be subject to other third party +and contributor rights, including patent rights, and no such rights are +granted under this license.   + +Copyright (c) 2010-2017, ITU/ISO/IEC +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the ITU/ISO/IEC nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. diff --git a/media/libjxl/src/third_party/HEVCSoftware/cfg/encoder_intra_main_scc_10.cfg b/media/libjxl/src/third_party/HEVCSoftware/cfg/encoder_intra_main_scc_10.cfg new file mode 100644 index 000000000..293cc5af5 --- /dev/null +++ b/media/libjxl/src/third_party/HEVCSoftware/cfg/encoder_intra_main_scc_10.cfg @@ -0,0 +1,136 @@ +#======== File I/O ===================== +BitstreamFile : str.bin +ReconFile : rec.yuv + +#======== Profile definition ============== +Profile : main-SCC # Profile name to use for encoding. Use main (for FDIS main), main10 (for FDIS main10), main-still-picture, main-RExt, high-throughput-RExt, main-SCC +Tier : main # Tier to use for interpretation of --Level (main or high only)" + +#======== Unit definition ================ +MaxCUWidth : 64 # Maximum coding unit width in pixel +MaxCUHeight : 64 # Maximum coding unit height in pixel +MaxPartitionDepth : 4 # Maximum coding unit depth +QuadtreeTULog2MaxSize : 5 # Log2 of maximum transform size for + # quadtree-based TU coding (2...6) +QuadtreeTULog2MinSize : 2 # Log2 of minimum transform size for + # quadtree-based TU coding (2...6) +QuadtreeTUMaxDepthInter : 3 +QuadtreeTUMaxDepthIntra : 3 + +#======== Coding Structure ============= +IntraPeriod : 1 # Period of I-Frame ( -1 = only first) +DecodingRefreshType : 1 # Random Access 0:none, 1:CRA, 2:IDR, 3:Recovery Point SEI +GOPSize : 1 # GOP Size (number of B slice = GOPSize-1) +ReWriteParamSetsFlag : 1 # Write parameter sets with every IRAP +# Type POC QPoffset QPfactor tcOffsetDiv2 betaOffsetDiv2 temporal_id #ref_pics_active #ref_pics reference pictures + +#=========== Motion Search ============= +FastSearch : 1 # 0:Full search 1:TZ search +SearchRange : 64 # (0: Search range is a Full frame) +HadamardME : 1 # Use of hadamard measure for fractional ME +FEN : 1 # Fast encoder decision +FDM : 1 # Fast Decision for Merge RD cost + +#======== Quantization ============= +QP : 32 # Quantization parameter(0-51) +MaxDeltaQP : 0 # CU-based multi-QP optimization +MaxCuDQPDepth : 0 # Max depth of a minimum CuDQP for sub-LCU-level delta QP +DeltaQpRD : 0 # Slice-based multi-QP optimization +RDOQ : 1 # RDOQ +RDOQTS : 1 # RDOQ for transform skip +CbQpOffset : 6 +CrQpOffset : 6 + +#=========== Deblock Filter ============ +LoopFilterOffsetInPPS : 1 # Dbl params: 0=varying params in SliceHeader, param = base_param + GOP_offset_param; 1 (default) =constant params in PPS, param = base_param) +LoopFilterDisable : 0 # Disable deblocking filter (0=Filter, 1=No Filter) +LoopFilterBetaOffset_div2 : 0 # base_param: -6 ~ 6 +LoopFilterTcOffset_div2 : 0 # base_param: -6 ~ 6 +DeblockingFilterMetric : 0 # blockiness metric (automatically configures deblocking parameters in bitstream). Applies slice-level loop filter offsets (LoopFilterOffsetInPPS and LoopFilterDisable must be 0) + +#=========== Misc. ============ +InternalBitDepth : 10 # codec operating bit-depth + +#=========== Coding Tools ================= +SAO : 1 # Sample adaptive offset (0: OFF, 1: ON) +AMP : 1 # Asymmetric motion partitions (0: OFF, 1: ON) +TransformSkip : 1 # Transform skipping (0: OFF, 1: ON) +TransformSkipFast : 1 # Fast Transform skipping (0: OFF, 1: ON) +SAOLcuBoundary : 0 # SAOLcuBoundary using non-deblocked pixels (0: OFF, 1: ON) + +#============ Slices ================ +SliceMode : 0 # 0: Disable all slice options. + # 1: Enforce maximum number of LCU in an slice, + # 2: Enforce maximum number of bytes in an 'slice' + # 3: Enforce maximum number of tiles in a slice +SliceArgument : 1500 # Argument for 'SliceMode'. + # If SliceMode==1 it represents max. SliceGranularity-sized blocks per slice. + # If SliceMode==2 it represents max. bytes per slice. + # If SliceMode==3 it represents max. tiles per slice. + +LFCrossSliceBoundaryFlag : 1 # In-loop filtering, including ALF and DB, is across or not across slice boundary. + # 0:not across, 1: across + +#============ PCM ================ +PCMEnabledFlag : 0 # 0: No PCM mode +PCMLog2MaxSize : 5 # Log2 of maximum PCM block size. +PCMLog2MinSize : 3 # Log2 of minimum PCM block size. +PCMInputBitDepthFlag : 1 # 0: PCM bit-depth is internal bit-depth. 1: PCM bit-depth is input bit-depth. +PCMFilterDisableFlag : 0 # 0: Enable loop filtering on I_PCM samples. 1: Disable loop filtering on I_PCM samples. + +#============ Tiles ================ +TileUniformSpacing : 0 # 0: the column boundaries are indicated by TileColumnWidth array, the row boundaries are indicated by TileRowHeight array + # 1: the column and row boundaries are distributed uniformly +NumTileColumnsMinus1 : 0 # Number of tile columns in a picture minus 1 +TileColumnWidthArray : 2 3 # Array containing tile column width values in units of CTU (from left to right in picture) +NumTileRowsMinus1 : 0 # Number of tile rows in a picture minus 1 +TileRowHeightArray : 2 # Array containing tile row height values in units of CTU (from top to bottom in picture) + +LFCrossTileBoundaryFlag : 1 # In-loop filtering is across or not across tile boundary. + # 0:not across, 1: across + +#============ WaveFront ================ +WaveFrontSynchro : 0 # 0: No WaveFront synchronisation (WaveFrontSubstreams must be 1 in this case). + # >0: WaveFront synchronises with the LCU above and to the right by this many LCUs. + +#=========== Quantization Matrix ================= +ScalingList : 0 # ScalingList 0 : off, 1 : default, 2 : file read +ScalingListFile : scaling_list.txt # Scaling List file name. If file is not exist, use Default Matrix. + +#============ Lossless ================ +TransquantBypassEnable : 0 # Value of PPS flag. +CUTransquantBypassFlagForce: 0 # Force transquant bypass mode, when transquant_bypass_enable_flag is enabled + +#=========== RExt ============ +ExtendedPrecision : 0 # Increased internal accuracies to support high bit depths (not valid in V1 profiles) +TransformSkipLog2MaxSize : 2 # Specify transform-skip maximum size. Minimum 2. (not valid in V1 profiles) +ImplicitResidualDPCM : 1 # Enable implicitly signalled residual DPCM for intra (also known as sample-adaptive intra predict) (not valid in V1 profiles) +ExplicitResidualDPCM : 1 # Enable explicitly signalled residual DPCM for inter and intra-block-copy (not valid in V1 profiles) +ResidualRotation : 1 # Enable rotation of transform-skipped and transquant-bypassed TUs through 180 degrees prior to entropy coding (not valid in V1 profiles) +SingleSignificanceMapContext : 1 # Enable, for transform-skipped and transquant-bypassed TUs, the selection of a single significance map context variable for all coefficients (not valid in V1 profiles) +IntraReferenceSmoothing : 1 # 0: Disable use of intra reference smoothing (not valid in V1 profiles). 1: Enable use of intra reference smoothing (same as V1) +GolombRiceParameterAdaptation : 1 # Enable the partial retention of the Golomb-Rice parameter value from one coefficient group to the next +HighPrecisionPredictionWeighting : 1 # Use high precision option for weighted prediction (not valid in V1 profiles) +CrossComponentPrediction : 1 # Enable the use of cross-component prediction (not valid in V1 profiles) + +#=========== SCC ============ +IntraBlockCopyEnabled : 1 # Enable the use of intra block copying +HashBasedIntraBlockCopySearchEnabled : 1 # Use hash based search for intra block copying on 8x8 blocks +IntraBlockCopySearchWidthInCTUs : -1 # Search range for IBC (-1: full frame search) +IntraBlockCopyNonHashSearchWidthInCTUs : 3 # Search range for IBC non-hash search method (i.e., fast/full search) +MSEBasedSequencePSNR : 1 # 0:Emit sequence PSNR only as a linear average of the frame PSNRs, 1: also emit a sequence PSNR based on an average of the frame MSEs +PrintClippedPSNR : 1 # 0:Print lossless PSNR values as 999.99 dB, 1: clip lossless PSNR according to resolution +PrintFrameMSE : 1 # 0:emit only bit count and PSNRs for each frame, 1: also emit MSE values +PrintSequenceMSE : 1 # 0:emit only bit rate and PSNRs for the whole sequence, 1 = also emit MSE values +ColourTransform : 1 # Enable the use of color transform(not valid in V1 profiles) +PaletteMode : 1 # Enable the use of palette mode(not valid in V1 profiles) +PaletteMaxSize : 63 # Supported maximum palette size (not valid in V1 profiles) +PaletteMaxPredSize : 128 # Supported maximum palette predictor size (not valid in V1 profiles) +IntraBoundaryFilterDisabled : 1 # Disable the use of intra boundary filtering (not valid in V1 profiles) +TransquantBypassInferTUSplit : 1 # Infer TU splitting for transquant bypass CUs +PalettePredInSPSEnabled : 0 # Transmit palette predictor initializer in SPS (not valid in V1 profiles) +PalettePredInPPSEnabled : 0 # Transmit palette predictor initializer in PPS (not valid in V1 profiles) +SelectiveRDOQ : 1 # Selective RDOQ + +### DO NOT ADD ANYTHING BELOW THIS LINE ### +### DO NOT DELETE THE EMPTY LINE BELOW ### diff --git a/media/libjxl/src/third_party/dirent.cc b/media/libjxl/src/third_party/dirent.cc new file mode 100644 index 000000000..81015ed0f --- /dev/null +++ b/media/libjxl/src/third_party/dirent.cc @@ -0,0 +1,142 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(_WIN32) || defined(_WIN64) +#include "third_party/dirent.h" + +#include "lib/jxl/base/status.h" + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#include + +#include +#include + +int mkdir(const char* path, mode_t /*mode*/) { + const LPSECURITY_ATTRIBUTES sec = nullptr; + if (!CreateDirectory(path, sec)) { + JXL_NOTIFY_ERROR("Failed to create directory %s", path); + return -1; + } + return 0; +} + +// Modified from code bearing the following notice: +// https://trac.wildfiregames.com/browser/ps/trunk/source/lib/sysdep/os/ +/* Copyright (C) 2010 Wildfire Games. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +struct DIR { + HANDLE hFind; + + WIN32_FIND_DATA findData; // indeterminate if hFind == INVALID_HANDLE_VALUE + + // readdir will return the address of this member. + // (must be stored in DIR to allow multiple independent + // opendir/readdir sequences). + dirent ent; + + // used by readdir to skip the first FindNextFile. + size_t numCalls = 0; +}; + +static bool IsValidDirectory(const char* path) { + const DWORD fileAttributes = GetFileAttributes(path); + + // path not found + if (fileAttributes == INVALID_FILE_ATTRIBUTES) return false; + + // not a directory + if ((fileAttributes & FILE_ATTRIBUTE_DIRECTORY) == 0) return false; + + return true; +} + +DIR* opendir(const char* path) { + if (!IsValidDirectory(path)) { + errno = ENOENT; + return nullptr; + } + + std::unique_ptr d(new DIR); + + // NB: "c:\\path" only returns information about that directory; + // trailing slashes aren't allowed. append "\\*" to retrieve its entries. + std::string searchPath(path); + if (searchPath.back() != '/' && searchPath.back() != '\\') { + searchPath += '\\'; + } + searchPath += '*'; + + // (we don't defer FindFirstFile until readdir because callers + // expect us to return 0 if directory reading will/did fail.) + d->hFind = FindFirstFile(searchPath.c_str(), &d->findData); + if (d->hFind != INVALID_HANDLE_VALUE) return d.release(); + if (GetLastError() == ERROR_NO_MORE_FILES) return d.release(); // empty + + JXL_NOTIFY_ERROR("Failed to open directory %s", searchPath.c_str()); + return nullptr; +} + +int closedir(DIR* dir) { + delete dir; + return 0; +} + +dirent* readdir(DIR* d) { + // "empty" case from opendir + if (d->hFind == INVALID_HANDLE_VALUE) return nullptr; + + // until end of directory or a valid entry was found: + for (;;) { + if (d->numCalls++ != 0) // (skip first call to FindNextFile - see opendir) + { + if (!FindNextFile(d->hFind, &d->findData)) { + JXL_ASSERT(GetLastError() == ERROR_NO_MORE_FILES); + SetLastError(0); + return nullptr; // end of directory or error + } + } + + // only return non-hidden and non-system entries + if ((d->findData.dwFileAttributes & + (FILE_ATTRIBUTE_HIDDEN | FILE_ATTRIBUTE_SYSTEM)) == 0) { + d->ent.d_name = d->findData.cFileName; + return &d->ent; + } + } +} + +#endif // #if defined(_WIN32) || defined(_WIN64) diff --git a/media/libjxl/src/third_party/dirent.h b/media/libjxl/src/third_party/dirent.h new file mode 100644 index 000000000..37a08f425 --- /dev/null +++ b/media/libjxl/src/third_party/dirent.h @@ -0,0 +1,49 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_THIRD_PARTY_DIRENT_H_ +#define LIB_JXL_THIRD_PARTY_DIRENT_H_ + +// Emulates POSIX readdir for Windows + +#if defined(_WIN32) || defined(_WIN64) + +#include // S_IFREG + +#ifndef _MODE_T_ +typedef unsigned int mode_t; +#endif // _MODE_T_ +int mkdir(const char* path, mode_t mode); + +struct dirent { + char* d_name; // no path +}; + +#define stat _stat64 + +#ifndef S_ISDIR +#define S_ISDIR(m) (m & S_IFDIR) +#endif // S_ISDIR + +#ifndef S_ISREG +#define S_ISREG(m) (m & S_IFREG) +#endif // S_ISREG + +struct DIR; +DIR* opendir(const char* path); +int closedir(DIR* dir); +dirent* readdir(DIR* d); + +#endif // #if defined(_WIN32) || defined(_WIN64) +#endif // LIB_JXL_THIRD_PARTY_DIRENT_H_ diff --git a/media/libjxl/src/third_party/lcms2.cmake b/media/libjxl/src/third_party/lcms2.cmake new file mode 100644 index 000000000..c4551de86 --- /dev/null +++ b/media/libjxl/src/third_party/lcms2.cmake @@ -0,0 +1,77 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(lcms2 STATIC EXCLUDE_FROM_ALL + lcms/src/cmsalpha.c + lcms/src/cmscam02.c + lcms/src/cmscgats.c + lcms/src/cmscnvrt.c + lcms/src/cmserr.c + lcms/src/cmsgamma.c + lcms/src/cmsgmt.c + lcms/src/cmshalf.c + lcms/src/cmsintrp.c + lcms/src/cmsio0.c + lcms/src/cmsio1.c + lcms/src/cmslut.c + lcms/src/cmsmd5.c + lcms/src/cmsmtrx.c + lcms/src/cmsnamed.c + lcms/src/cmsopt.c + lcms/src/cmspack.c + lcms/src/cmspcs.c + lcms/src/cmsplugin.c + lcms/src/cmsps2.c + lcms/src/cmssamp.c + lcms/src/cmssm.c + lcms/src/cmstypes.c + lcms/src/cmsvirt.c + lcms/src/cmswtpnt.c + lcms/src/cmsxform.c + lcms/src/lcms2_internal.h +) +target_include_directories(lcms2 + PUBLIC "${CMAKE_CURRENT_LIST_DIR}/lcms/include") +# This warning triggers with gcc-8. +if (CMAKE_C_COMPILER_ID MATCHES "GNU") +target_compile_options(lcms2 + PRIVATE + # gcc-only flags. + -Wno-stringop-truncation + -Wno-strict-aliasing +) +endif() +# By default LCMS uses sizeof(void*) for memory alignment, but in arm 32-bits we +# can't access doubles not aligned to 8 bytes. This forces the alignment to 8 +# bytes. +target_compile_definitions(lcms2 + PRIVATE "-DCMS_PTR_ALIGNMENT=8") +target_compile_definitions(lcms2 + PUBLIC "-DCMS_NO_REGISTER_KEYWORD=1") + +# Ensure that a thread safe alternative of gmtime is used in LCMS +include(CheckSymbolExists) +check_symbol_exists(gmtime_r "time.h" HAVE_GMTIME_R) +if (HAVE_GMTIME_R) + target_compile_definitions(lcms2 + PUBLIC "-DHAVE_GMTIME_R=1") +else() + check_symbol_exists(gmtime_s "time.h" HAVE_GMTIME_S) + if (HAVE_GMTIME_S) + target_compile_definitions(lcms2 + PUBLIC "-DHAVE_GMTIME_S=1") + endif() +endif() + +set_property(TARGET lcms2 PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/media/libjxl/src/third_party/sjpeg.cmake b/media/libjxl/src/third_party/sjpeg.cmake new file mode 100644 index 000000000..f1a69252b --- /dev/null +++ b/media/libjxl/src/third_party/sjpeg.cmake @@ -0,0 +1,27 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# We need to CACHE the SJPEG_BUILD_EXAMPLES to not be removed by the option() +# inside SJPEG. +set(SJPEG_BUILD_EXAMPLES NO CACHE BOOL "Examples") +# SJPEG uses OpenGL which throws a warning if multiple options are installed. +# This setting makes it prefer the new version. +set(OpenGL_GL_PREFERENCE GLVND) + +# Build SJPEG as a static library. +set(BUILD_SHARED_LIBS_BACKUP ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF) +add_subdirectory(sjpeg EXCLUDE_FROM_ALL) +target_include_directories(sjpeg PUBLIC "${CMAKE_CURRENT_LIST_DIR}/sjpeg/src/") +set(BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS_BACKUP}) diff --git a/media/libjxl/src/third_party/skcms.cmake b/media/libjxl/src/third_party/skcms.cmake new file mode 100644 index 000000000..4d2a79cdb --- /dev/null +++ b/media/libjxl/src/third_party/skcms.cmake @@ -0,0 +1,51 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(skcms-obj OBJECT EXCLUDE_FROM_ALL skcms/skcms.cc) +target_include_directories(skcms-obj PUBLIC "${CMAKE_CURRENT_LIST_DIR}/skcms/") + +# This library is meant to be compiled/used by external libs (such as plugins) +# that need to use skcms. We use a wrapper for libjxl. +add_library(skcms-interface INTERFACE) +target_sources(skcms-interface INTERFACE ${CMAKE_CURRENT_LIST_DIR}/skcms/skcms.cc) +target_include_directories(skcms-interface INTERFACE ${CMAKE_CURRENT_LIST_DIR}/skcms) + +include(CheckCXXCompilerFlag) +check_cxx_compiler_flag("-Wno-psabi" CXX_WPSABI_SUPPORTED) +if(CXX_WPSABI_SUPPORTED) + target_compile_options(skcms-obj PRIVATE -Wno-psabi) + target_compile_options(skcms-interface INTERFACE -Wno-psabi) +endif() + +if(JPEGXL_BUNDLE_SKCMS) + target_compile_options(skcms-obj PRIVATE -DJPEGXL_BUNDLE_SKCMS=1) + if(MSVC) + target_compile_options(skcms-obj + PRIVATE /FI${CMAKE_CURRENT_SOURCE_DIR}/../lib/jxl/enc_jxl_skcms.h) + else() + target_compile_options(skcms-obj + PRIVATE -include ${CMAKE_CURRENT_SOURCE_DIR}/../lib/jxl/enc_jxl_skcms.h) + endif() +endif() + +set_target_properties(skcms-obj PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 +) + +add_library(skcms STATIC EXCLUDE_FROM_ALL $) +target_include_directories(skcms + PUBLIC $) + diff --git a/media/libjxl/src/third_party/testing.cmake b/media/libjxl/src/third_party/testing.cmake new file mode 100644 index 000000000..9f49737e5 --- /dev/null +++ b/media/libjxl/src/third_party/testing.cmake @@ -0,0 +1,83 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Enable tests in third_party/ as well. +enable_testing() +include(CTest) + +set(SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party") + +if(BUILD_TESTING) +# Add GTest from source and alias it to what the find_package(GTest) workflow +# defines. Omitting googletest/ directory would require it to be available in +# the base system instead, but it would work just fine. This makes packages +# using GTest and calling find_package(GTest) actually work. +if (EXISTS "${SOURCE_DIR}/googletest/CMakeLists.txt" AND + NOT JPEGXL_FORCE_SYSTEM_GTEST) + add_subdirectory(third_party/googletest EXCLUDE_FROM_ALL) + + set(GTEST_ROOT "${SOURCE_DIR}/googletest/googletest") + set(GTEST_INCLUDE_DIR "$" + CACHE STRING "") + set(GMOCK_INCLUDE_DIR "$") + set(GTEST_LIBRARY "$") + set(GTEST_MAIN_LIBRARY "$") + add_library(GTest::GTest ALIAS gtest) + add_library(GTest::Main ALIAS gtest_main) + + set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + + # googletest doesn't compile clean with clang-cl (-Wundef) + if (WIN32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set_target_properties(gtest PROPERTIES COMPILE_FLAGS "-Wno-error") + set_target_properties(gmock PROPERTIES COMPILE_FLAGS "-Wno-error") + set_target_properties(gtest_main PROPERTIES COMPILE_FLAGS "-Wno-error") + set_target_properties(gmock_main PROPERTIES COMPILE_FLAGS "-Wno-error") + endif () + configure_file("${SOURCE_DIR}/googletest/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.googletest COPYONLY) +else() + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/googletest/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.googletest COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() +find_package(GTest) +if (NOT GTEST_FOUND) + set(BUILD_TESTING OFF CACHE BOOL "Build tests" FORCE) + message(SEND_ERROR "GTest not found. Install googletest package " + "(libgtest-dev) in the system or download googletest to " + "third_party/googletest from https://github.com/google/googletest ." + "To disable tests instead re-run cmake with -DBUILD_TESTING=OFF.") +endif() # NOT GTEST_FOUND + +# Look for gmock in the system too. +if (NOT DEFINED GMOCK_INCLUDE_DIR) + find_path( + GMOCK_INCLUDE_DIR "gmock/gmock.h" + HINTS ${GTEST_INCLUDE_DIRS}) + if (NOT GMOCK_INCLUDE_DIR) + set(BUILD_TESTING OFF CACHE BOOL "Build tests" FORCE) + message(SEND_ERROR "GMock not found. Install googletest package " + "(libgmock-dev) in the system or download googletest to " + "third_party/googletest from https://github.com/google/googletest ." + "To disable tests instead re-run cmake with -DBUILD_TESTING=OFF.") + else() + message(STATUS "Found GMock: ${GMOCK_INCLUDE_DIR}") + endif() # NOT GMOCK_INCLUDE_DIR +endif() # NOT DEFINED GMOCK_INCLUDE_DIR +endif() # BUILD_TESTING diff --git a/media/libjxl/src/tools/CMakeLists.txt b/media/libjxl/src/tools/CMakeLists.txt new file mode 100644 index 000000000..934ed895c --- /dev/null +++ b/media/libjxl/src/tools/CMakeLists.txt @@ -0,0 +1,469 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# ICC detection library used by the comparison and viewer tools. +if(JPEGXL_ENABLE_VIEWERS) +if(WIN32) + find_package(Qt5 QUIET COMPONENTS Widgets) + if (NOT Qt5_FOUND) + message(WARNING "Qt5 was not found.") + else() + add_library(icc_detect STATIC EXCLUDE_FROM_ALL + icc_detect/icc_detect_win32.cc + icc_detect/icc_detect.h + ) + target_include_directories(icc_detect PRIVATE "${PROJECT_SOURCE_DIR}") + target_link_libraries(icc_detect PUBLIC Qt5::Widgets) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libqt5widgets5/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libqt5widgets5 COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR + endif() +elseif(APPLE) + find_package(Qt5 QUIET COMPONENTS Widgets) + if (Qt5_FOUND) + add_library(icc_detect STATIC EXCLUDE_FROM_ALL + icc_detect/icc_detect_empty.cc + icc_detect/icc_detect.h + ) + target_include_directories(icc_detect PRIVATE "${PROJECT_SOURCE_DIR}") + target_link_libraries(icc_detect PUBLIC Qt5::Widgets) + else() + message(WARNING "APPLE: Qt5 was not found.") + endif() +else() + find_package(Qt5 QUIET COMPONENTS Widgets X11Extras) + find_package(ECM QUIET NO_MODULE) + if (NOT Qt5_FOUND OR NOT ECM_FOUND) + if (NOT Qt5_FOUND) + message(WARNING "Qt5 was not found.") + else() + message(WARNING "extra-cmake-modules were not found.") + endif() + else() + set(CMAKE_MODULE_PATH ${ECM_FIND_MODULE_DIR}) + find_package(XCB COMPONENTS XCB) + if (XCB_FOUND) + add_library(icc_detect STATIC EXCLUDE_FROM_ALL + icc_detect/icc_detect_x11.cc + icc_detect/icc_detect.h + ) + target_link_libraries(icc_detect PUBLIC jxl-static Qt5::Widgets Qt5::X11Extras XCB::XCB) + endif () + endif() +endif() +endif() # JPEGXL_ENABLE_VIEWERS + +# Tools are added conditionally below. +set(TOOL_BINARIES) +# Tools that depend on jxl internal functions. +set(INTERNAL_TOOL_BINARIES) + +add_library(jxl_tool STATIC EXCLUDE_FROM_ALL + cmdline.cc + codec_config.cc + speed_stats.cc + file_io.cc + tool_version.cc +) +target_compile_options(jxl_tool PUBLIC "${JPEGXL_INTERNAL_FLAGS}") +target_include_directories(jxl_tool PUBLIC "${PROJECT_SOURCE_DIR}") +target_link_libraries(jxl_tool hwy) + +# The JPEGXL_VERSION is set from the builders. +if(NOT DEFINED JPEGXL_VERSION OR JPEGXL_VERSION STREQUAL "") + find_package(Git QUIET) + execute_process( + COMMAND "${GIT_EXECUTABLE}" rev-parse --short HEAD + OUTPUT_VARIABLE GIT_REV + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" + ERROR_QUIET) + string(STRIP "${GIT_REV}" GIT_REV) + if(GIT_REV STREQUAL "") + set(JPEGXL_VERSION "(unknown)") + endif() +endif() + +if(NOT DEFINED JPEGXL_VERSION OR JPEGXL_VERSION STREQUAL "") + # We are building from a git environment and the user didn't set + # JPEGXL_VERSION. Make a target that computes the GIT_REV at build-time always + # but only updates the file if it changed. This allows rebuilds without + # modifying cmake files to update the JPEGXL_VERSION. + message(STATUS "Building with JPEGXL_VERSION=${GIT_REV} (auto-updated)") + add_custom_target( + tool_version_git + ${CMAKE_COMMAND} + -D JPEGXL_ROOT_DIR=${CMAKE_SOURCE_DIR} + -D DST=${CMAKE_CURRENT_BINARY_DIR}/tool_version_git.h + -P ${CMAKE_CURRENT_SOURCE_DIR}/git_version.cmake + BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/tool_version_git.h" + ) + add_dependencies(jxl_tool tool_version_git) + + set_source_files_properties(tool_version.cc PROPERTIES + COMPILE_DEFINITIONS JPEGXL_VERSION_FROM_GIT=1) + target_include_directories(jxl_tool PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") + # Note: Ninja looks for dependencies on the jxl_tool target before running + # the tool_version_git targets, so when updating the tool_version_git.h the + # jxl_tool target is not rebuilt. This forces to generate it at configure time + # if needed. + execute_process( + COMMAND ${CMAKE_COMMAND} + -D JPEGXL_ROOT_DIR=${CMAKE_SOURCE_DIR} + -D DST=${CMAKE_CURRENT_BINARY_DIR}/tool_version_git.h + -P ${CMAKE_CURRENT_SOURCE_DIR}/git_version.cmake) +else() + message(STATUS "Building with JPEGXL_VERSION=${JPEGXL_VERSION}") + set_source_files_properties(tool_version.cc PROPERTIES + COMPILE_DEFINITIONS JPEGXL_VERSION=\"${JPEGXL_VERSION}\") +endif() + +if(JPEGXL_ENABLE_TOOLS) + # Main compressor. + add_executable(cjxl cjxl_main.cc) + target_link_libraries(cjxl + jxl + jxl_extras_codec-static + jxl_threads + jxl_tool + ) + list(APPEND TOOL_BINARIES cjxl) + + # Main decompressor. + add_executable(djxl djxl_main.cc) + target_link_libraries(djxl + jxl + jxl_extras_codec-static + jxl_threads + jxl_tool + ) + list(APPEND TOOL_BINARIES djxl) + + add_executable(cjpeg_hdr cjpeg_hdr.cc) + list(APPEND INTERNAL_TOOL_BINARIES cjpeg_hdr) + + add_executable(jxlinfo jxlinfo.c) + target_link_libraries(jxlinfo jxl) + list(APPEND TOOL_BINARIES jxlinfo) + + if(NOT SANITIZER STREQUAL "none") + # Linking a C test binary with the C++ JPEG XL implementation when using + # address sanitizer is not well supported by clang 9, so force using clang++ + # for linking this test if a sanitizer is used. + set_target_properties(jxlinfo PROPERTIES LINKER_LANGUAGE CXX) + endif() # SANITIZER != "none" + +endif() # JPEGXL_ENABLE_TOOLS + +# Other developer tools. +if(JPEGXL_ENABLE_DEVTOOLS) + list(APPEND INTERNAL_TOOL_BINARIES + fuzzer_corpus + butteraugli_main + decode_and_encode + display_to_hlg + pq_to_hlg + render_hlg + tone_map + texture_to_cube + generate_lut_template + ssimulacra_main + xyb_range + jxl_from_tree + ) + + add_executable(fuzzer_corpus fuzzer_corpus.cc) + + add_executable(ssimulacra_main ssimulacra_main.cc ssimulacra.cc) + add_executable(butteraugli_main butteraugli_main.cc) + add_executable(decode_and_encode decode_and_encode.cc) + add_executable(display_to_hlg hdr/display_to_hlg.cc) + add_executable(pq_to_hlg hdr/pq_to_hlg.cc) + add_executable(render_hlg hdr/render_hlg.cc) + add_executable(tone_map hdr/tone_map.cc) + add_executable(texture_to_cube hdr/texture_to_cube.cc) + add_executable(generate_lut_template hdr/generate_lut_template.cc) + add_executable(xyb_range xyb_range.cc) + add_executable(jxl_from_tree jxl_from_tree.cc) +endif() # JPEGXL_ENABLE_DEVTOOLS + +# Benchmark tools. +if(JPEGXL_ENABLE_BENCHMARK AND JPEGXL_ENABLE_TOOLS) + list(APPEND INTERNAL_TOOL_BINARIES + benchmark_xl + ) + + add_executable(benchmark_xl + benchmark/benchmark_xl.cc + benchmark/benchmark_args.cc + benchmark/benchmark_codec.cc + benchmark/benchmark_file_io.cc + benchmark/benchmark_stats.cc + benchmark/benchmark_utils.cc + benchmark/benchmark_utils.h + benchmark/benchmark_codec_custom.cc + benchmark/benchmark_codec_custom.h + benchmark/benchmark_codec_jxl.cc + benchmark/benchmark_codec_jxl.h + ../third_party/dirent.cc + ) + target_link_libraries(benchmark_xl Threads::Threads) + if(MINGW) + # MINGW doesn't support glob.h. + target_compile_definitions(benchmark_xl PRIVATE "-DHAS_GLOB=0") + endif() # MINGW + + find_package(JPEG) + if(JPEG_FOUND) + target_sources(benchmark_xl PRIVATE + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_jpeg.cc" + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_jpeg.h" + ) + endif () + + if(NOT JPEGXL_BUNDLE_LIBPNG) + find_package(PNG) + endif() + if(PNG_FOUND) + target_sources(benchmark_xl PRIVATE + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_png.cc" + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_png.h" + ) + endif() + + find_package(PkgConfig) + pkg_check_modules(WebP IMPORTED_TARGET libwebp) + if(WebP_FOUND) + target_sources(benchmark_xl PRIVATE + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_webp.cc" + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_webp.h" + ) + target_compile_definitions(benchmark_xl PRIVATE -DBENCHMARK_WEBP) + + # Use the static version of webp if available. + find_library(WebP_STATIC_LINK_LIBRARY NAMES libwebp.a + PATHS "${WebP_LIBDIR}") + if(NOT WebP_STATIC_LINK_LIBRARY) + message(WARNING "Using dynamic libwebp") + target_link_libraries(benchmark_xl PkgConfig::WebP) + else() + target_link_libraries(benchmark_xl "${WebP_STATIC_LINK_LIBRARY}") + target_include_directories(benchmark_xl + PRIVATE ${WebP_STATIC_INCLUDE_DIRS}) + target_compile_options(benchmark_xl PRIVATE ${WebP_STATIC_CFLAGS_OTHER}) + endif() # NOT WebP_STATIC_LINK_LIBRARY + endif() + + pkg_check_modules(AVIF IMPORTED_TARGET libavif) + if(AVIF_FOUND) + target_sources(benchmark_xl PRIVATE + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_avif.cc" + "${CMAKE_CURRENT_LIST_DIR}/benchmark/benchmark_codec_avif.h" + ) + target_compile_definitions(benchmark_xl PRIVATE -DBENCHMARK_AVIF) + target_link_libraries(benchmark_xl PkgConfig::AVIF) + endif() +endif() # JPEGXL_ENABLE_BENCHMARK + +# All tool binaries depend on "jxl" library and the tool helpers. +foreach(BINARY IN LISTS INTERNAL_TOOL_BINARIES) + target_link_libraries("${BINARY}" + jxl_extras-static + jxl_tool + ) +endforeach() + +list(APPEND TOOL_BINARIES ${INTERNAL_TOOL_BINARIES}) + +foreach(BINARY IN LISTS TOOL_BINARIES) + if(JPEGXL_EMSCRIPTEN) + set_target_properties(${BINARY} PROPERTIES LINK_FLAGS "-s USE_LIBPNG=1") + endif() +endforeach() + +install(TARGETS ${TOOL_BINARIES} RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +message(STATUS "Building tools: ${TOOL_BINARIES}") + +set(FUZZER_BINARIES + color_encoding_fuzzer + decode_basic_info_fuzzer + cjxl_fuzzer + djxl_fuzzer + icc_codec_fuzzer + fields_fuzzer + rans_fuzzer + set_from_bytes_fuzzer + transforms_fuzzer +) + +# Fuzzers. +foreach(FUZZER IN LISTS FUZZER_BINARIES) + if(JPEGXL_ENABLE_FUZZERS) + set(BINARY "${FUZZER}") + add_executable("${BINARY}" "${BINARY}.cc") + target_link_libraries("${BINARY}" ${JPEGXL_FUZZER_LINK_FLAGS}) + else() + # When not enabled we want a lightweight alternative for regular fuzzers + # that just run the target. + set(BINARY "${FUZZER}_runner") + add_executable("${BINARY}" EXCLUDE_FROM_ALL + "fuzzer_stub.cc" "${FUZZER}.cc") + endif() # JPEGXL_ENABLE_FUZZERS + target_include_directories("${BINARY}" PRIVATE "${CMAKE_SOURCE_DIR}") + if(FUZZER STREQUAL djxl_fuzzer) + target_link_libraries("${BINARY}" + jxl_dec-static + jxl_threads-static + ) + else() + target_link_libraries("${BINARY}" + jxl_extras-static + jxl_tool + ) + endif() +endforeach() + +# EMSCRIPTEN doesn't support dynamic libraries so testing for linkage there +# doesn't make much sense. +if(BUILD_TESTING AND TARGET jxl AND NOT JPEGXL_EMSCRIPTEN) +# Library API test. This test is only to check that we can link against the +# shared library from C99 file and don't need to use internal symbols. +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests) +add_executable(libjxl_test libjxl_test.c) +set_property(TARGET libjxl_test PROPERTY C_STANDARD 99) +if(NOT SANITIZER STREQUAL "none") + # Linking a C test binary with the C++ JPEG XL implementation when using + # address sanitizer is not well supported by clang 9, so force using clang++ + # for linking this test if a sanitizer is used. + set_target_properties(libjxl_test PROPERTIES LINKER_LANGUAGE CXX) +endif() # SANITIZER != "none" +set_target_properties(libjxl_test PROPERTIES PREFIX "tests/") +target_link_libraries(libjxl_test jxl) +if (NOT MSVC) +target_compile_options(libjxl_test PRIVATE -Wall -Wextra -Werror) +if(NOT WIN32) + target_compile_options(libjxl_test PRIVATE -pedantic) +endif() # NOT WIN32 +endif() # NOT MSVC + +add_test( + NAME LibraryCLinkageTest + COMMAND libjxl_test + WORKING_DIRECTORY $ +) +# if user decide to set CMAKE_SKIP_RPATH:BOOL=ON make sure libjxl.so.0.7 can +# still be found: +if(UNIX AND CMAKE_SKIP_RPATH) + set_property(TEST LibraryCLinkageTest PROPERTY ENVIRONMENT + LD_LIBRARY_PATH=${CMAKE_CURRENT_BINARY_DIR}/.. + ) +endif() + +endif() # BUILD_TESTING AND TARGET jxl AND NOT JPEGXL_EMSCRIPTEN + +# Tools defined in subdirectories. +if(JPEGXL_ENABLE_VIEWERS) +add_subdirectory(viewer) +add_subdirectory(comparison_viewer) +add_subdirectory(flicker_test) +endif() + +add_subdirectory(box) +add_subdirectory(conformance) + + +if (JPEGXL_ENABLE_TOOLS AND JPEGXL_EMSCRIPTEN) +# WASM API facade. +add_executable(jxl_emcc jxl_emcc.cc) +target_link_libraries(jxl_emcc + jxl_extras-static +) +set_target_properties(jxl_emcc PROPERTIES LINK_FLAGS "\ + -O3\ + --closure 1 \ + -s TOTAL_MEMORY=75mb \ + -s USE_LIBPNG=1 \ + -s DISABLE_EXCEPTION_CATCHING=1 \ + -s MODULARIZE=1 \ + -s FILESYSTEM=0 \ + -s USE_PTHREADS=1 \ + -s PTHREAD_POOL_SIZE=4 \ + -s EXPORT_NAME=\"JxlCodecModule\"\ + -s \"EXPORTED_FUNCTIONS=[\ + _malloc,\ + _free,\ + _jxlCreateInstance,\ + _jxlDestroyInstance,\ + _jxlFlush,\ + _jxlProcessInput\ + ]\"\ +") +endif () # JPEGXL_ENABLE_TOOLS AND JPEGXL_EMSCRIPTEN + +if(JPEGXL_ENABLE_JNI) +find_package(JNI QUIET) +find_package(Java QUIET) + +if (JNI_FOUND AND Java_FOUND) + include(UseJava) + + # decoder_jni_onload.cc might be necessary for Android; not used yet. + add_library(jxl_jni SHARED jni/org/jpeg/jpegxl/wrapper/decoder_jni.cc) + target_include_directories(jxl_jni PRIVATE "${JNI_INCLUDE_DIRS}" "${PROJECT_SOURCE_DIR}") + target_link_libraries(jxl_jni PUBLIC jxl_dec-static jxl_threads-static) + if(NOT DEFINED JPEGXL_INSTALL_JNIDIR) + set(JPEGXL_INSTALL_JNIDIR ${CMAKE_INSTALL_LIBDIR}) + endif() + install(TARGETS jxl_jni DESTINATION ${JPEGXL_INSTALL_JNIDIR}) + + add_jar(jxl_jni_wrapper SOURCES + jni/org/jpeg/jpegxl/wrapper/Decoder.java + jni/org/jpeg/jpegxl/wrapper/DecoderJni.java + jni/org/jpeg/jpegxl/wrapper/ImageData.java + jni/org/jpeg/jpegxl/wrapper/PixelFormat.java + jni/org/jpeg/jpegxl/wrapper/Status.java + jni/org/jpeg/jpegxl/wrapper/StreamInfo.java + OUTPUT_NAME org.jpeg.jpegxl + ) + get_target_property(JXL_JNI_WRAPPER_JAR jxl_jni_wrapper JAR_FILE) + if(NOT DEFINED JPEGXL_INSTALL_JARDIR) + set(JPEGXL_INSTALL_JARDIR ${CMAKE_INSTALL_LIBDIR}) + endif() + install_jar(jxl_jni_wrapper DESTINATION ${JPEGXL_INSTALL_JARDIR}) + + add_jar(jxl_jni_wrapper_test + SOURCES jni/org/jpeg/jpegxl/wrapper/DecoderTest.java + INCLUDE_JARS jxl_jni_wrapper + ) + get_target_property(JXL_JNI_WRAPPER_TEST_JAR jxl_jni_wrapper_test JAR_FILE) + + if(NOT SANITIZER MATCHES ".san") + # NB: Vanilla OpenJDK 8 / 11 are known to work well (i.e. either + # "which java" or JAVA_HOME environment variable point to the path like + # "/usr/lib/jvm/java-xx-openjdk-yyy" on Debian Linux). + add_test( + NAME test_jxl_jni_wrapper + COMMAND ${Java_JAVA_EXECUTABLE} + -cp "${JXL_JNI_WRAPPER_JAR}:${JXL_JNI_WRAPPER_TEST_JAR}" + -Dorg.jpeg.jpegxl.wrapper.lib=$ + org.jpeg.jpegxl.wrapper.DecoderTest + ) + endif() # JPEGXL_ENABLE_FUZZERS +endif() # JNI_FOUND & Java_FOUND +endif() # JPEGXL_ENABLE_JNI + +# End-to-end tests for the tools +if(BUILD_TESTING AND JPEGXL_ENABLE_TOOLS AND JPEGXL_ENABLE_DEVTOOLS AND JPEGXL_ENABLE_TRANSCODE_JPEG AND (NOT JPEGXL_ENABLE_JNI)) +find_program (BASH_PROGRAM bash) +if(BASH_PROGRAM AND $ AND $ AND $) + add_test( + NAME roundtrip_test + COMMAND ${BASH_PROGRAM} ${CMAKE_CURRENT_SOURCE_DIR}/roundtrip_test.sh + ${CMAKE_BINARY_DIR}) + if (CMAKE_CROSSCOMPILING_EMULATOR) + set_tests_properties(roundtrip_test PROPERTIES ENVIRONMENT "EMULATOR=${CMAKE_CROSSCOMPILING_EMULATOR}") + endif() +endif() +endif() # BUILD_TESTING diff --git a/media/libjxl/src/tools/README.cjpeg_hdr.md b/media/libjxl/src/tools/README.cjpeg_hdr.md new file mode 100644 index 000000000..bd7c793bd --- /dev/null +++ b/media/libjxl/src/tools/README.cjpeg_hdr.md @@ -0,0 +1,73 @@ +# High bit depth JPEG encoder +`cjpeg_hdr` is an (experimental) JPEG encoder that can preserve a higher bit +depth than a traditional JPEG encoder. In particular, it may be used to produce +HDR JPEGs that do not show obvious signs of banding. + +Note that at this point in time `cjpeg_hdr` does not attempt to actually +*compress* the image - it behaves in the same way as a "quality 100" JPEG +encoder would normally do, i.e. no quantization, to achieve the maximum +possible visual quality. Moreover, no Huffman optimization is performed. + +## Generating HBD JPEGs +Note: this and the following sections assume that `libjxl` has been built in +the `build/` directory, either by using CMake or by running `./ci.sh opt`. + +It should be sufficient to run `build/tools/cjpeg_hdr input_image output.jpg`. +Various input formats are supported, including NetBPM and (8- or 16-bit) PNG. + +If the PNG image includes a colour profile, it will be copied in the resulting +JPEG image. If this colour profile approximates the PQ or HLG transfer curves, +some applications will consider the resulting image to be HDR. + +To attach a PQ profile to an image without a colour profile (or with a +different colour profile), the following command can be used: + +``` + build/tools/decode_and_encode input RGB_D65_202_Rel_PeQ output_with_pq.png 16 +``` + +Similarly, to attach an HLG profile, the following command can be used + +``` + build/tools/decode_and_encode input RGB_D65_202_Rel_HLG output_with_pq.png 16 +``` + +## Decoding HBD JPEGs +HBD JPEGs are fully retrocompatible with libjpeg, and any JPEG viewer ought to +be able to visualize them. Nonetheless, to achieve the best visual quality, a +high bit depth decoder should be used. + +Such a decoder does not exist today. As a workaround, it is possible to do a +lossless conversion to JPEG XL and then view the resulting image: + +``` + build/tools/cjxl --jpeg_transcode_disable_cfl hbd.jpeg hbd.jxl +``` + +The resulting JPEG XL file can be visualized, for example, in a browser, +assuming that the corresponding flag is enabled in the settings. + +In particular, if the HBD JPEG has a PQ or HLG profile attached and the current +display is an HDR display, Chrome ought to visualize the image as HDR content. + +It is also possible to convert the JPEG XL file back to a 16-bit PNG: + +``` + build/tools/djxl hbd.jxl --bits_per_sample=16 output.png +``` + +Note however that as of today (2 Nov 2021) Chrome does not interpret such a PNG +as an HDR image, even if a PQ or HLG profile is attached. Thus, to display the +HDR image correctly it is recommended to either display the JPEG XL image +directly or to convert the PNG to a format that Chrome interprets as HDR, such +as AVIF. This can be done with the following command for a PQ image: + +``` + avifenc -l -y 444 --depth 10 --cicp 9/16/9 image.png output.avif +``` + +and the following one for an HLG image: + +``` + avifenc -l -y 444 --depth 10 --cicp 9/18/9 image.png output.avif +``` diff --git a/media/libjxl/src/tools/args.h b/media/libjxl/src/tools/args.h new file mode 100644 index 000000000..7d04ce3a7 --- /dev/null +++ b/media/libjxl/src/tools/args.h @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_ARGS_H_ +#define TOOLS_ARGS_H_ + +// Helpers for parsing command line arguments. No include guard needed. + +#include +#include +#include +#include + +#include +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" // DecoderHints +#include "lib/jxl/gaborish.h" +#include "lib/jxl/modular/options.h" + +namespace jpegxl { +namespace tools { + +static inline bool ParseOverride(const char* arg, jxl::Override* out) { + const std::string s_arg(arg); + if (s_arg == "1") { + *out = jxl::Override::kOn; + return true; + } + if (s_arg == "0") { + *out = jxl::Override::kOff; + return true; + } + fprintf(stderr, "Invalid flag, %s must be 0 or 1\n", arg); + return JXL_FAILURE("Args"); +} + +static inline bool ParseFloatPair(const char* arg, + std::pair* out) { + int parsed = sscanf(arg, "%f,%f", &out->first, &out->second); + if (parsed == 1) { + out->second = out->first; + } else if (parsed != 2) { + fprintf(stderr, + "Unable to interpret as float pair separated by a comma: %s.\n", + arg); + return JXL_FAILURE("Args"); + } + return true; +} + +static inline bool ParseAndAppendKeyValue(const char* arg, + jxl::extras::ColorHints* out) { + const char* eq = strchr(arg, '='); + if (!eq) { + fprintf(stderr, "Expected argument as 'key=value' but received '%s'\n", + arg); + return false; + } + std::string key(arg, eq); + out->Add(key, std::string(eq + 1)); + return true; +} + +static inline bool ParsePredictor(const char* arg, jxl::Predictor* out) { + char* end; + uint64_t p = static_cast(strtoull(arg, &end, 0)); + if (end[0] != '\0') { + fprintf(stderr, "Invalid predictor: %s.\n", arg); + return JXL_FAILURE("Args"); + } + if (p >= jxl::kNumModularEncoderPredictors) { + fprintf(stderr, + "Invalid predictor value %" PRIu64 ", must be less than %" PRIu64 + ".\n", + p, static_cast(jxl::kNumModularEncoderPredictors)); + return JXL_FAILURE("Args"); + } + *out = static_cast(p); + return true; +} + +static inline bool ParseCString(const char* arg, const char** out) { + *out = arg; + return true; +} + +static inline bool IncrementUnsigned(size_t* out) { + (*out)++; + return true; +} + +} // namespace tools +} // namespace jpegxl + +#endif // TOOLS_ARGS_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_args.cc b/media/libjxl/src/tools/benchmark/benchmark_args.cc new file mode 100644 index 000000000..2bd3eb893 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_args.cc @@ -0,0 +1,281 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/benchmark/benchmark_args.h" + +#include +#include + +#include +#include +#include + +#include "lib/extras/codec.h" +#include "lib/extras/dec/color_description.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "tools/benchmark/benchmark_codec_jpeg.h" // for AddCommand.. +#include "tools/benchmark/benchmark_codec_jxl.h" +#if JPEGXL_ENABLE_APNG +#include "tools/benchmark/benchmark_codec_png.h" +#endif + +#ifdef BENCHMARK_WEBP +#include "tools/benchmark/benchmark_codec_webp.h" +#endif // BENCHMARK_WEBP + +#ifdef BENCHMARK_AVIF +#include "tools/benchmark/benchmark_codec_avif.h" +#endif // BENCHMARK_AVIF + +namespace jxl { + +std::vector SplitString(const std::string& s, char c) { + std::vector result; + size_t pos = 0; + for (size_t i = 0; i <= s.size(); i++) { + if (i == s.size() || s[i] == c) { + result.push_back(s.substr(pos, i - pos)); + pos = i + 1; + } + } + return result; +} + +int ParseIntParam(const std::string& param, int lower_bound, int upper_bound) { + int val = strtol(param.substr(1).c_str(), nullptr, 10); + JXL_CHECK(val >= lower_bound && val <= upper_bound); + return val; +} + +BenchmarkArgs* Args() { + static BenchmarkArgs args; + return &args; +} + +Status BenchmarkArgs::AddCommandLineOptions() { + AddString(&input, "input", "File or file pattern matching input files."); + AddString(&codec, "codec", + "Comma separated list of image codec descriptions to benchmark.", + "jxl"); + AddFlag(&print_details, "print_details", + "Prints size and distortion for each image. Not safe for " + "concurrent benchmark runs.", + false); + AddFlag(&print_details_csv, "print_details_csv", + "When print_details is used, print as CSV.", false); + AddString(&extra_metrics, "extra_metrics", + "Extra metrics to be computed. Only displayed with --print_details " + "or --print_details_csv. Comma-separated list of NAME:COMMAND " + "pairs; COMMAND is invoked with the original image as the first " + "argument, the decompressed image as a second argument, and the " + "name of the file where to write the metric value (as a single " + "floating point number) as the third argument.", + ""); + AddFlag( + &print_more_stats, "print_more_stats", + "Prints codec-specific stats. Not safe for concurrent benchmark runs.", + false); + AddFlag(&print_distance_percentiles, "print_distance_percentiles", + "Prints distance percentiles for the corpus. Not safe for " + "concurrent benchmark runs.", + false); + AddFlag(&silent_errors, "silent_errors", + "If true, doesn't print error messages on compression or" + " decompression errors. Errors counts are still visible in the" + " 'Errors' column of the result table. Please note that depending" + " depending on the JXL build settings, error messages and asserts" + " from within the codec may be printed irrespective of this flag" + " anyway, use release build to ensure no messages.", + false); + AddFlag(&save_compressed, "save_compressed", + "Saves the compressed files for each input image and each codec.", + false); + AddFlag(&save_decompressed, "save_decompressed", + "Saves the decompressed files as PNG for each input image " + "and each codec.", + false); + AddString(&output_extension, "output_extension", + "Extension (starting with dot) to use for saving output images.", + ".png"); + AddString(&output_description, "output_description", + "Color encoding (see ParseDescription; e.g. RGB_D65_SRG_Rel_709) " + "for saving output images, " + " defaults to sRGB."); + + AddFloat(&intensity_target, "intensity_target", + "Intended viewing intensity target in nits. Defaults to 255 for " + "SDR images, 4000 for HDR images (when the input image uses PQ or " + "HLG transfer function)", + 0); + + AddString(&color_hints_string, "dec-hints", + "Color encoding hints for the input images to encoder. Comma " + "separated key=value pairs. The key color_space indicates " + "ColorEncoding (see ParseDescription; e.g. RGB_D65_SRG_Rel_709) " + "for input images without color encoding (such as PNM)"); + + AddUnsigned( + &override_bitdepth, "override_bitdepth", + "If nonzero, store the given bit depth in the JPEG XL file metadata" + " (1-32), instead of using the bit depth from the original input" + " image.", + 0); + + AddDouble(&mul_output, "mul_output", + "If nonzero, multiplies linear sRGB by this and clamps to 255", + 0.0); + AddDouble(&heatmap_good, "heatmap_good", + "If greater than zero, use this as the good " + "threshold for creating heatmap images.", + 0.0); + AddDouble(&heatmap_bad, "heatmap_bad", + "If greater than zero, use this as the bad " + "threshold for creating heatmap images.", + 0.0); + + AddFlag(&write_html_report, "write_html_report", + "Creates an html report with original and compressed images.", false); + AddFlag(&html_report_self_contained, "html_report_self_contained", + "Base64-encode the images in the HTML report rather than use " + "external file names. May cause very large HTML data size.", + false); + + AddFlag( + &markdown, "markdown", + "Adds formatting around ASCII table to render correctly in Markdown based" + " interfaces", + true); + + AddFlag(&more_columns, "more_columns", "Print extra columns in the table", + false); + + AddString(&originals_url, "originals_url", + "Url prefix to serve original images from in the html report."); + AddString(&output_dir, "output_dir", + "If not empty, save compressed and decompressed " + "images here."); + + AddSigned(&num_threads, "num_threads", + "The number of threads for concurrent benchmarking. Defaults to " + "1 thread per CPU core (if negative).", + -1); + AddSigned(&inner_threads, "inner_threads", + "The number of extra threads per task. " + "Defaults to occupy cores (if negative).", + -1); + AddUnsigned(&encode_reps, "encode_reps", + "How many times to encode (>1 for more precise measurements). " + "Defaults to 1.", + 1); + AddUnsigned(&decode_reps, "decode_reps", + "How many times to decode (>1 for more precise measurements). " + "Defaults to 1.", + 1); + + AddString(&sample_tmp_dir, "sample_tmp_dir", + "Directory to put samples from input images."); + + AddSigned(&num_samples, "num_samples", "How many sample areas to take.", 0); + AddSigned(&sample_dimensions, "sample_dimensions", + "How big areas to sample from the input.", 64); + + AddDouble(&error_pnorm, "error_pnorm", + "smallest p norm for pooling butteraugli values", 3.0); + + AddFloat(&ba_params.hf_asymmetry, "hf_asymmetry", + "Multiplier for weighting HF artefacts more than features " + "being smoothed out. 1.0 means no HF asymmetry. 0.3 is " + "a good value to start exploring for asymmetry.", + 0.8f); + AddFlag(&profiler, "profiler", "If true, print profiler results.", false); + + AddFlag(&show_progress, "show_progress", + "Show activity dots per completed file during benchmark.", false); + + AddFlag(&skip_butteraugli, "skip_butteraugli", + "If true, doesn't compute distance metrics, only compression and" + " decompression speed and size. Distance numbers shown in the" + " table are invalid.", + false); + + AddFlag( + &decode_only, "decode_only", + "If true, only decodes, and the input files must be compressed with a " + "compatible format for the given codec(s). Only measures decompression " + "speed and sizes, and can only use a single set of compatible decoders. " + "Distance numbers and compression speeds shown in the table are invalid.", + false); + + if (!AddCommandLineOptionsJxlCodec(this)) return false; +#ifdef BENCHMARK_JPEG + if (!AddCommandLineOptionsJPEGCodec(this)) return false; +#endif // BENCHMARK_JPEG +#if JPEGXL_ENABLE_APNG + if (!AddCommandLineOptionsPNGCodec(this)) return false; +#endif +#ifdef BENCHMARK_WEBP + if (!AddCommandLineOptionsWebPCodec(this)) return false; +#endif // BENCHMARK_WEBP +#ifdef BENCHMARK_AVIF + if (!AddCommandLineOptionsAvifCodec(this)) return false; +#endif // BENCHMARK_AVIF + + return true; +} + +Status BenchmarkArgs::ValidateArgs() { + size_t bits_per_sample = 0; // unused + if (input.empty()) { + fprintf(stderr, "Missing --input filename(s).\n"); + return false; + } + if (extras::CodecFromExtension(output_extension, &bits_per_sample) == + extras::Codec::kUnknown) { + JXL_WARNING("Unrecognized output_extension %s, try .png", + output_extension.c_str()); + return false; // already warned + } + + // If empty, don't do anything; callers must only use output_encoding if + // output_description is not empty. + if (!output_description.empty()) { + // Validate, but also create the profile (only needs to happen once). + JxlColorEncoding output_encoding_external; + if (!ParseDescription(output_description, &output_encoding_external)) { + JXL_WARNING("Unrecognized output_description %s, try RGB_D65_SRG_Rel_Lin", + output_description.c_str()); + return false; // already warned + } + JXL_RETURN_IF_ERROR(jxl::ConvertExternalToInternalColorEncoding( + output_encoding_external, &output_encoding)); + JXL_RETURN_IF_ERROR(output_encoding.CreateICC()); + } + + JXL_RETURN_IF_ERROR(ValidateArgsJxlCodec(this)); + + if (print_details_csv) print_details = true; + + if (override_bitdepth > 32) { + return JXL_FAILURE("override_bitdepth must be <= 32"); + } + + if (!color_hints_string.empty()) { + std::vector hints = SplitString(color_hints_string, ','); + for (const auto& hint : hints) { + std::vector kv = SplitString(hint, '='); + if (kv.size() != 2) { + return JXL_FAILURE( + "dec-hints key value pairs must have the form 'key=value'"); + } + color_hints.Add(kv[0], kv[1]); + } + } + + return true; +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_args.h b/media/libjxl/src/tools/benchmark/benchmark_args.h new file mode 100644 index 000000000..bebc0ac49 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_args.h @@ -0,0 +1,174 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_ARGS_H_ +#define TOOLS_BENCHMARK_BENCHMARK_ARGS_H_ + +// Command line parsing and arguments for benchmark_xl + +#include + +#include +#include +#include +#include + +#include "lib/extras/dec/color_hints.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/color_encoding_internal.h" +#include "tools/args.h" +#include "tools/cmdline.h" + +namespace jxl { + +std::vector SplitString(const std::string& s, char c); + +int ParseIntParam(const std::string& param, int lower_bound, int upper_bound); + +struct BenchmarkArgs { + using OptionId = jpegxl::tools::CommandLineParser::OptionId; + + void AddFlag(bool* field, const char* longName, const char* help, + bool defaultValue) { + const char* noName = RememberString_(std::string("no") + longName); + cmdline.AddOptionFlag('\0', longName, nullptr, field, + &jpegxl::tools::SetBooleanTrue); + cmdline.AddOptionFlag('\0', noName, help, field, + &jpegxl::tools::SetBooleanFalse); + *field = defaultValue; + } + + OptionId AddOverride(Override* field, const char* longName, + const char* help) { + OptionId result = cmdline.AddOptionValue('\0', longName, "0|1", help, field, + &jpegxl::tools::ParseOverride); + *field = Override::kDefault; + return result; + } + + OptionId AddString(std::string* field, const char* longName, const char* help, + const std::string& defaultValue = "") { + OptionId result = cmdline.AddOptionValue( + '\0', longName, "", help, field, &jpegxl::tools::ParseString); + *field = defaultValue; + return result; + } + + OptionId AddFloat(float* field, const char* longName, const char* help, + float defaultValue) { + OptionId result = cmdline.AddOptionValue('\0', longName, "", help, + field, &jpegxl::tools::ParseFloat); + *field = defaultValue; + return result; + } + + OptionId AddDouble(double* field, const char* longName, const char* help, + double defaultValue) { + OptionId result = cmdline.AddOptionValue( + '\0', longName, "", help, field, &jpegxl::tools::ParseDouble); + *field = defaultValue; + return result; + } + + OptionId AddSigned(int* field, const char* longName, const char* help, + int defaultValue) { + OptionId result = cmdline.AddOptionValue( + '\0', longName, "", help, field, &jpegxl::tools::ParseSigned); + *field = defaultValue; + return result; + } + + OptionId AddUnsigned(size_t* field, const char* longName, const char* help, + size_t defaultValue) { + OptionId result = + cmdline.AddOptionValue('\0', longName, "", help, field, + &jpegxl::tools::ParseUnsigned); + *field = defaultValue; + return result; + } + + Status AddCommandLineOptions(); + + Status ValidateArgs(); + + bool Parse(int argc, const char** argv) { return cmdline.Parse(argc, argv); } + + void PrintHelp() const { cmdline.PrintHelp(); } + + std::string input; + std::string codec; + bool print_details; + bool print_details_csv; + bool print_more_stats; + bool print_distance_percentiles; + bool silent_errors; + bool save_compressed; + bool save_decompressed; + std::string output_extension; // see CodecFromExtension + std::string output_description; // see ParseDescription + ColorEncoding output_encoding; // determined by output_description + + bool decode_only; + bool skip_butteraugli; + + float intensity_target; + + std::string color_hints_string; + jxl::extras::ColorHints color_hints; + + size_t override_bitdepth; + + double mul_output; + double heatmap_good; + double heatmap_bad; + + bool write_html_report; + bool html_report_self_contained; + bool markdown; + bool more_columns; + + std::string originals_url; + std::string output_dir; + + int num_threads; + int inner_threads; + size_t decode_reps; + size_t encode_reps; + + std::string sample_tmp_dir; + + int num_samples; + int sample_dimensions; + ButteraugliParams ba_params; + + bool profiler; + double error_pnorm; + bool show_progress; + + std::string extra_metrics; + + jpegxl::tools::CommandLineParser cmdline; + + private: + const char* RememberString_(const std::string& text) { + const char* data = text.c_str(); + std::vector copy(data, data + text.size() + 1); + string_pool_.push_back(copy); + return string_pool_.back().data(); + } + + // A memory pool with stable addresses for strings to provide stable + // const char pointers to cmdline.h for dynamic help/name strings. + std::deque> string_pool_; +}; + +// Returns singleton +BenchmarkArgs* Args(); + +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_ARGS_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec.cc b/media/libjxl/src/tools/benchmark/benchmark_codec.cc new file mode 100644 index 000000000..230665bba --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec.cc @@ -0,0 +1,191 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/benchmark/benchmark_codec.h" + +#include +#include +#include + +#include +#include +#include + +#include "lib/extras/time.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec_custom.h" +#ifdef JPEGXL_ENABLE_JPEG +#include "tools/benchmark/benchmark_codec_jpeg.h" +#endif // JPEG_ENABLE_JPEG +#include "tools/benchmark/benchmark_codec_jxl.h" +#include "tools/benchmark/benchmark_codec_png.h" +#include "tools/benchmark/benchmark_stats.h" + +#ifdef BENCHMARK_WEBP +#include "tools/benchmark/benchmark_codec_webp.h" +#endif // BENCHMARK_WEBP + +#ifdef BENCHMARK_AVIF +#include "tools/benchmark/benchmark_codec_avif.h" +#endif // BENCHMARK_AVIF + +namespace jxl { + +void ImageCodec::ParseParameters(const std::string& parameters) { + params_ = parameters; + std::vector parts = SplitString(parameters, ':'); + for (size_t i = 0; i < parts.size(); ++i) { + if (!ParseParam(parts[i])) { + JXL_ABORT("Invalid parameter %s", parts[i].c_str()); + } + } +} + +Status ImageCodec::ParseParam(const std::string& param) { + if (param[0] == 'q') { // libjpeg-style quality, [0,100] + const std::string quality_param = param.substr(1); + char* end; + const float q_target = strtof(quality_param.c_str(), &end); + if (end == quality_param.c_str() || + end != quality_param.c_str() + quality_param.size()) { + return false; + } + q_target_ = q_target; + return true; + } + if (param[0] == 'd') { // butteraugli distance + const std::string distance_param = param.substr(1); + char* end; + const float butteraugli_target = strtof(distance_param.c_str(), &end); + if (end == distance_param.c_str() || + end != distance_param.c_str() + distance_param.size()) { + return false; + } + butteraugli_target_ = butteraugli_target; + + // full hf asymmetry at high distance + static const double kHighDistance = 2.5; + + // no hf asymmetry at low distance + static const double kLowDistance = 0.6; + + if (butteraugli_target_ >= kHighDistance) { + ba_params_.hf_asymmetry = args_.ba_params.hf_asymmetry; + } else if (butteraugli_target_ >= kLowDistance) { + float w = + (butteraugli_target_ - kLowDistance) / (kHighDistance - kLowDistance); + ba_params_.hf_asymmetry = + args_.ba_params.hf_asymmetry * w + 1.0f * (1.0f - w); + } else { + ba_params_.hf_asymmetry = 1.0f; + } + return true; + } else if (param[0] == 'r') { + ba_params_.hf_asymmetry = args_.ba_params.hf_asymmetry; + bitrate_target_ = strtof(param.substr(1).c_str(), nullptr); + return true; + } + return false; +} + +// Low-overhead "codec" for measuring benchmark overhead. +class NoneCodec : public ImageCodec { + public: + explicit NoneCodec(const BenchmarkArgs& args) : ImageCodec(args) {} + Status ParseParam(const std::string& param) override { return true; } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + PROFILER_ZONE("NoneCompress"); + const double start = Now(); + // Encode image size so we "decompress" something of the same size, as + // required by butteraugli. + const uint32_t xsize = io->xsize(); + const uint32_t ysize = io->ysize(); + compressed->resize(8); + memcpy(compressed->data(), &xsize, 4); + memcpy(compressed->data() + 4, &ysize, 4); + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + return true; + } + + Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + PROFILER_ZONE("NoneDecompress"); + const double start = Now(); + JXL_ASSERT(compressed.size() == 8); + uint32_t xsize, ysize; + memcpy(&xsize, compressed.data(), 4); + memcpy(&ysize, compressed.data() + 4, 4); + Image3F image(xsize, ysize); + ZeroFillImage(&image); + io->metadata.m.SetFloat32Samples(); + io->metadata.m.color_encoding = ColorEncoding::SRGB(); + io->SetFromImage(std::move(image), io->metadata.m.color_encoding); + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + return true; + } + + void GetMoreStats(BenchmarkStats* stats) override {} +}; + +ImageCodecPtr CreateImageCodec(const std::string& description) { + std::string name = description; + std::string parameters = ""; + size_t colon = description.find(':'); + if (colon < description.size()) { + name = description.substr(0, colon); + parameters = description.substr(colon + 1); + } + ImageCodecPtr result; + if (name == "jxl") { + result.reset(CreateNewJxlCodec(*Args())); +#if !defined(__wasm__) + } else if (name == "custom") { + result.reset(CreateNewCustomCodec(*Args())); +#endif +#ifdef JPEGXL_ENABLE_JPEG + } else if (name == "jpeg") { + result.reset(CreateNewJPEGCodec(*Args())); +#endif // BENCHMARK_JPEG +#if JPEGXL_ENABLE_APNG + } else if (name == "png") { + result.reset(CreateNewPNGCodec(*Args())); +#endif + } else if (name == "none") { + result.reset(new NoneCodec(*Args())); +#ifdef BENCHMARK_WEBP + } else if (name == "webp") { + result.reset(CreateNewWebPCodec(*Args())); +#endif // BENCHMARK_WEBP +#ifdef BENCHMARK_AVIF + } else if (name == "avif") { + result.reset(CreateNewAvifCodec(*Args())); +#endif // BENCHMARK_AVIF + } else { + JXL_ABORT("Unknown image codec: %s", name.c_str()); + } + result->set_description(description); + if (!parameters.empty()) result->ParseParameters(parameters); + return result; +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec.h b/media/libjxl/src/tools/benchmark/benchmark_codec.h new file mode 100644 index 000000000..e554fc280 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec.h @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_H_ + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/image.h" +#include "tools/args.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_stats.h" +#include "tools/cmdline.h" +#include "tools/speed_stats.h" + +namespace jxl { + +// Thread-compatible. +class ImageCodec { + public: + explicit ImageCodec(const BenchmarkArgs& args) + : args_(args), + butteraugli_target_(1.0f), + q_target_(100.0f), + bitrate_target_(0.0f) {} + + virtual ~ImageCodec() = default; + + void set_description(const std::string& desc) { description_ = desc; } + const std::string& description() const { return description_; } + + const ButteraugliParams& BaParams() const { return ba_params_; } + + virtual void ParseParameters(const std::string& parameters); + + virtual Status ParseParam(const std::string& param); + + // Returns true iff the codec instance (including parameters) can tolerate + // ImageBundle c_current() != metadata()->color_encoding, and the possibility + // of negative (out of gamut) pixel values. + virtual bool IsColorAware() const { return false; } + + // Returns true iff the codec instance (including parameters) will operate + // only with quantized DCT (JPEG) coefficients in input. + virtual bool IsJpegTranscoder() const { return false; } + + virtual Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, + std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) = 0; + + virtual Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) = 0; + + virtual void GetMoreStats(BenchmarkStats* stats) {} + + virtual Status CanRecompressJpeg() const { return false; } + virtual Status RecompressJpeg(const std::string& filename, + const std::string& data, + std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) { + return false; + } + + virtual std::string GetErrorMessage() const { return error_message_; } + + protected: + const BenchmarkArgs& args_; + std::string params_; + std::string description_; + float butteraugli_target_; + float q_target_; + float bitrate_target_; + ButteraugliParams ba_params_; + std::string error_message_; +}; + +using ImageCodecPtr = std::unique_ptr; + +// Creates an image codec by name, e.g. "jxl" to get a new instance of the +// jxl codec. Optionally, behind a colon, parameters can be specified, +// then ParseParameters of the codec gets called with the part behind the colon. +ImageCodecPtr CreateImageCodec(const std::string& description); + +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_avif.cc b/media/libjxl/src/tools/benchmark/benchmark_codec_avif.cc new file mode 100644 index 000000000..fbe36b5a0 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_avif.cc @@ -0,0 +1,358 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "tools/benchmark/benchmark_codec_avif.h" + +#include + +#include "lib/extras/time.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_external_image.h" +#include "tools/cmdline.h" + +#define JXL_RETURN_IF_AVIF_ERROR(result) \ + do { \ + avifResult jxl_return_if_avif_error_result = (result); \ + if (jxl_return_if_avif_error_result != AVIF_RESULT_OK) { \ + return JXL_FAILURE("libavif error: %s", \ + avifResultToString(jxl_return_if_avif_error_result)); \ + } \ + } while (false) + +namespace jxl { + +namespace { + +struct AvifArgs { + avifPixelFormat chroma_subsampling = AVIF_PIXEL_FORMAT_YUV444; +}; + +AvifArgs* const avifargs = new AvifArgs; + +bool ParseChromaSubsampling(const char* arg, avifPixelFormat* subsampling) { + if (strcmp(arg, "444") == 0) { + *subsampling = AVIF_PIXEL_FORMAT_YUV444; + return true; + } + if (strcmp(arg, "422") == 0) { + *subsampling = AVIF_PIXEL_FORMAT_YUV422; + return true; + } + if (strcmp(arg, "420") == 0) { + *subsampling = AVIF_PIXEL_FORMAT_YUV420; + return true; + } + if (strcmp(arg, "400") == 0) { + *subsampling = AVIF_PIXEL_FORMAT_YUV400; + return true; + } + return false; +} + +void SetUpAvifColor(const ColorEncoding& color, avifImage* const image) { + bool need_icc = color.white_point != WhitePoint::kD65; + + image->matrixCoefficients = AVIF_MATRIX_COEFFICIENTS_BT709; + if (!color.HasPrimaries()) { + need_icc = true; + } else { + switch (color.primaries) { + case Primaries::kSRGB: + image->colorPrimaries = AVIF_COLOR_PRIMARIES_BT709; + break; + case Primaries::k2100: + image->colorPrimaries = AVIF_COLOR_PRIMARIES_BT2020; + image->matrixCoefficients = AVIF_MATRIX_COEFFICIENTS_BT2020_NCL; + break; + default: + need_icc = true; + image->colorPrimaries = AVIF_COLOR_PRIMARIES_UNKNOWN; + break; + } + } + + switch (color.tf.GetTransferFunction()) { + case TransferFunction::kSRGB: + image->transferCharacteristics = AVIF_TRANSFER_CHARACTERISTICS_SRGB; + break; + case TransferFunction::kLinear: + image->transferCharacteristics = AVIF_TRANSFER_CHARACTERISTICS_LINEAR; + break; + case TransferFunction::kPQ: + image->transferCharacteristics = AVIF_TRANSFER_CHARACTERISTICS_SMPTE2084; + break; + case TransferFunction::kHLG: + image->transferCharacteristics = AVIF_TRANSFER_CHARACTERISTICS_HLG; + break; + default: + need_icc = true; + image->transferCharacteristics = AVIF_TRANSFER_CHARACTERISTICS_UNKNOWN; + break; + } + + if (need_icc) { + avifImageSetProfileICC(image, color.ICC().data(), color.ICC().size()); + } +} + +Status ReadAvifColor(const avifImage* const image, ColorEncoding* const color) { + if (image->icc.size != 0) { + PaddedBytes icc; + icc.assign(image->icc.data, image->icc.data + image->icc.size); + return color->SetICC(std::move(icc)); + } + + color->white_point = WhitePoint::kD65; + switch (image->colorPrimaries) { + case AVIF_COLOR_PRIMARIES_BT709: + color->primaries = Primaries::kSRGB; + break; + case AVIF_COLOR_PRIMARIES_BT2020: + color->primaries = Primaries::k2100; + break; + default: + return JXL_FAILURE("unsupported avif primaries"); + } + switch (image->transferCharacteristics) { + case AVIF_TRANSFER_CHARACTERISTICS_BT470M: + JXL_RETURN_IF_ERROR(color->tf.SetGamma(2.2)); + break; + case AVIF_TRANSFER_CHARACTERISTICS_BT470BG: + JXL_RETURN_IF_ERROR(color->tf.SetGamma(2.8)); + break; + case AVIF_TRANSFER_CHARACTERISTICS_LINEAR: + color->tf.SetTransferFunction(TransferFunction::kLinear); + break; + case AVIF_TRANSFER_CHARACTERISTICS_SRGB: + color->tf.SetTransferFunction(TransferFunction::kSRGB); + break; + case AVIF_TRANSFER_CHARACTERISTICS_SMPTE2084: + color->tf.SetTransferFunction(TransferFunction::kPQ); + break; + case AVIF_TRANSFER_CHARACTERISTICS_HLG: + color->tf.SetTransferFunction(TransferFunction::kHLG); + break; + default: + return JXL_FAILURE("unsupported avif TRC"); + } + return color->CreateICC(); +} + +} // namespace + +Status AddCommandLineOptionsAvifCodec(BenchmarkArgs* args) { + args->cmdline.AddOptionValue( + '\0', "avif_chroma_subsampling", "444/422/420/400", + "default AVIF chroma subsampling (default: 444).", + &avifargs->chroma_subsampling, &ParseChromaSubsampling); + return true; +} + +class AvifCodec : public ImageCodec { + public: + explicit AvifCodec(const BenchmarkArgs& args) : ImageCodec(args) { + chroma_subsampling_ = avifargs->chroma_subsampling; + } + + Status ParseParam(const std::string& param) override { + if (param.compare(0, 3, "yuv") == 0) { + if (param.size() != 6) return false; + return ParseChromaSubsampling(param.c_str() + 3, &chroma_subsampling_); + } + if (param.compare(0, 10, "log2_cols=") == 0) { + log2_cols = strtol(param.c_str() + 10, nullptr, 10); + return true; + } + if (param.compare(0, 10, "log2_rows=") == 0) { + log2_rows = strtol(param.c_str() + 10, nullptr, 10); + return true; + } + if (param[0] == 's') { + speed_ = strtol(param.c_str() + 1, nullptr, 10); + return true; + } + if (param == "aomenc") { + encoder_ = AVIF_CODEC_CHOICE_AOM; + return true; + } + if (param == "aomdec") { + decoder_ = AVIF_CODEC_CHOICE_AOM; + return true; + } + if (param == "aom") { + encoder_ = AVIF_CODEC_CHOICE_AOM; + decoder_ = AVIF_CODEC_CHOICE_AOM; + return true; + } + if (param == "rav1e") { + encoder_ = AVIF_CODEC_CHOICE_RAV1E; + return true; + } + if (param == "dav1d") { + decoder_ = AVIF_CODEC_CHOICE_DAV1D; + return true; + } + if (param.compare(0, 2, "a=") == 0) { + std::string subparam = param.substr(2); + size_t pos = subparam.find('='); + if (pos == std::string::npos) { + codec_specific_options_.emplace_back(subparam, ""); + } else { + std::string key = subparam.substr(0, pos); + std::string value = subparam.substr(pos + 1); + codec_specific_options_.emplace_back(key, value); + } + return true; + } + return ImageCodec::ParseParam(param); + } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + double elapsed_convert_image = 0; + const double start = Now(); + { + const auto depth = + std::min(16, io->metadata.m.bit_depth.bits_per_sample); + std::unique_ptr encoder( + avifEncoderCreate(), &avifEncoderDestroy); + encoder->codecChoice = encoder_; + // TODO(sboukortt): configure this separately. + encoder->minQuantizer = 0; + encoder->maxQuantizer = 63; + encoder->tileColsLog2 = log2_cols; + encoder->tileRowsLog2 = log2_rows; + encoder->speed = speed_; + encoder->maxThreads = pool->NumThreads(); + for (const auto& opts : codec_specific_options_) { + avifEncoderSetCodecSpecificOption(encoder.get(), opts.first.c_str(), + opts.second.c_str()); + } + avifAddImageFlags add_image_flags = AVIF_ADD_IMAGE_FLAG_SINGLE; + if (io->metadata.m.have_animation) { + encoder->timescale = std::lround( + static_cast(io->metadata.m.animation.tps_numerator) / + io->metadata.m.animation.tps_denominator); + add_image_flags = AVIF_ADD_IMAGE_FLAG_NONE; + } + for (const ImageBundle& ib : io->frames) { + std::unique_ptr image( + avifImageCreate(ib.xsize(), ib.ysize(), depth, chroma_subsampling_), + &avifImageDestroy); + image->width = ib.xsize(); + image->height = ib.ysize(); + image->depth = depth; + SetUpAvifColor(ib.c_current(), image.get()); + std::unique_ptr icc_freer( + &image->icc, &avifRWDataFree); + avifRGBImage rgb_image; + avifRGBImageSetDefaults(&rgb_image, image.get()); + rgb_image.format = + ib.HasAlpha() ? AVIF_RGB_FORMAT_RGBA : AVIF_RGB_FORMAT_RGB; + avifRGBImageAllocatePixels(&rgb_image); + std::unique_ptr pixels_freer( + &rgb_image, &avifRGBImageFreePixels); + const double start_convert_image = Now(); + JXL_RETURN_IF_ERROR(ConvertToExternal( + ib, depth, /*float_out=*/false, + /*num_channels=*/ib.HasAlpha() ? 4 : 3, JXL_NATIVE_ENDIAN, + /*stride=*/rgb_image.rowBytes, pool, rgb_image.pixels, + rgb_image.rowBytes * rgb_image.height, + /*out_callback=*/{}, jxl::Orientation::kIdentity)); + const double end_convert_image = Now(); + elapsed_convert_image += end_convert_image - start_convert_image; + JXL_RETURN_IF_AVIF_ERROR(avifImageRGBToYUV(image.get(), &rgb_image)); + JXL_RETURN_IF_AVIF_ERROR(avifEncoderAddImage( + encoder.get(), image.get(), ib.duration, add_image_flags)); + } + avifRWData buffer = AVIF_DATA_EMPTY; + JXL_RETURN_IF_AVIF_ERROR(avifEncoderFinish(encoder.get(), &buffer)); + compressed->assign(buffer.data, buffer.data + buffer.size); + avifRWDataFree(&buffer); + } + const double end = Now(); + speed_stats->NotifyElapsed(end - start - elapsed_convert_image); + return true; + } + + Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + io->frames.clear(); + io->dec_pixels = 0; + double elapsed_convert_image = 0; + const double start = Now(); + { + std::unique_ptr decoder( + avifDecoderCreate(), &avifDecoderDestroy); + decoder->codecChoice = decoder_; + decoder->maxThreads = pool->NumThreads(); + JXL_RETURN_IF_AVIF_ERROR(avifDecoderSetIOMemory( + decoder.get(), compressed.data(), compressed.size())); + JXL_RETURN_IF_AVIF_ERROR(avifDecoderParse(decoder.get())); + const bool has_alpha = decoder->alphaPresent; + io->metadata.m.have_animation = decoder->imageCount > 1; + io->metadata.m.animation.tps_numerator = decoder->timescale; + io->metadata.m.animation.tps_denominator = 1; + io->metadata.m.SetUintSamples(decoder->image->depth); + io->SetSize(decoder->image->width, decoder->image->height); + avifResult next_image; + while ((next_image = avifDecoderNextImage(decoder.get())) == + AVIF_RESULT_OK) { + ColorEncoding color; + JXL_RETURN_IF_ERROR(ReadAvifColor(decoder->image, &color)); + avifRGBImage rgb_image; + avifRGBImageSetDefaults(&rgb_image, decoder->image); + rgb_image.format = + has_alpha ? AVIF_RGB_FORMAT_RGBA : AVIF_RGB_FORMAT_RGB; + avifRGBImageAllocatePixels(&rgb_image); + std::unique_ptr pixels_freer( + &rgb_image, &avifRGBImageFreePixels); + JXL_RETURN_IF_AVIF_ERROR(avifImageYUVToRGB(decoder->image, &rgb_image)); + const double start_convert_image = Now(); + { + ImageBundle ib(&io->metadata.m); + JXL_RETURN_IF_ERROR(ConvertFromExternal( + Span(rgb_image.pixels, + rgb_image.height * rgb_image.rowBytes), + rgb_image.width, rgb_image.height, color, (has_alpha ? 4 : 3), + /*alpha_is_premultiplied=*/false, rgb_image.depth, + JXL_NATIVE_ENDIAN, pool, &ib, + /*float_in=*/false, /*align=*/0)); + io->frames.push_back(std::move(ib)); + io->dec_pixels += rgb_image.width * rgb_image.height; + } + const double end_convert_image = Now(); + elapsed_convert_image += end_convert_image - start_convert_image; + } + if (next_image != AVIF_RESULT_NO_IMAGES_REMAINING) { + JXL_RETURN_IF_AVIF_ERROR(next_image); + } + } + const double end = Now(); + speed_stats->NotifyElapsed(end - start - elapsed_convert_image); + return true; + } + + protected: + avifPixelFormat chroma_subsampling_; + avifCodecChoice encoder_ = AVIF_CODEC_CHOICE_AUTO; + avifCodecChoice decoder_ = AVIF_CODEC_CHOICE_AUTO; + int speed_ = AVIF_SPEED_DEFAULT; + int log2_cols = 0; + int log2_rows = 0; + std::vector> codec_specific_options_; +}; + +ImageCodec* CreateNewAvifCodec(const BenchmarkArgs& args) { + return new AvifCodec(args); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_avif.h b/media/libjxl/src/tools/benchmark/benchmark_codec_avif.h new file mode 100644 index 000000000..b3dc38e97 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_avif.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_AVIF_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_AVIF_H_ + +#include "lib/jxl/base/status.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" + +namespace jxl { +ImageCodec* CreateNewAvifCodec(const BenchmarkArgs& args); + +// Registers the avif-specific command line options. +Status AddCommandLineOptionsAvifCodec(BenchmarkArgs* args); +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_AVIF_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_custom.cc b/media/libjxl/src/tools/benchmark/benchmark_codec_custom.cc new file mode 100644 index 000000000..eefae6e65 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_custom.cc @@ -0,0 +1,161 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/benchmark/benchmark_codec_custom.h" + +// Not supported on Windows due to Linux-specific functions. +#ifndef _WIN32 + +#include + +#include + +#include "lib/extras/codec.h" +#include "lib/extras/enc/apng.h" +#include "lib/extras/time.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/image_bundle.h" +#include "tools/benchmark/benchmark_utils.h" + +namespace jxl { +namespace { + +std::string GetBaseName(std::string filename) { + std::string result = std::move(filename); + result = basename(&result[0]); + const size_t dot = result.rfind('.'); + if (dot != std::string::npos) { + result.resize(dot); + } + return result; +} + +// This uses `output_filename` to determine the name of the corresponding +// `.time` file. +template +Status ReportCodecRunningTime(F&& function, std::string output_filename, + jpegxl::tools::SpeedStats* const speed_stats) { + const double start = Now(); + JXL_RETURN_IF_ERROR(function()); + const double end = Now(); + const std::string time_filename = + GetBaseName(std::move(output_filename)) + ".time"; + std::ifstream time_stream(time_filename); + double time; + if (time_stream >> time) { + // Report the time measured by the external codec itself. + speed_stats->NotifyElapsed(time); + } else { + // Fall back to the less accurate time that we measured. + speed_stats->NotifyElapsed(end - start); + } + if (time_stream.is_open()) { + remove(time_filename.c_str()); + } + return true; +} + +class CustomCodec : public ImageCodec { + public: + explicit CustomCodec(const BenchmarkArgs& args) : ImageCodec(args) {} + + Status ParseParam(const std::string& param) override { + switch (param_index_) { + case 0: + extension_ = param; + break; + + case 1: + compress_command_ = param; + break; + + case 2: + decompress_command_ = param; + break; + + default: + compress_args_.push_back(param); + break; + } + ++param_index_; + return true; + } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + JXL_RETURN_IF_ERROR(param_index_ > 2); + + const std::string basename = GetBaseName(filename); + TemporaryFile png_file(basename, "png"), encoded_file(basename, extension_); + std::string png_filename, encoded_filename; + JXL_RETURN_IF_ERROR(png_file.GetFileName(&png_filename)); + JXL_RETURN_IF_ERROR(encoded_file.GetFileName(&encoded_filename)); + saved_intensity_target_ = io->metadata.m.IntensityTarget(); + + const size_t bits = io->metadata.m.bit_depth.bits_per_sample; + JXL_RETURN_IF_ERROR( + EncodeToFile(*io, io->Main().c_current(), bits, png_filename, pool)); + std::vector arguments = compress_args_; + arguments.push_back(png_filename); + arguments.push_back(encoded_filename); + JXL_RETURN_IF_ERROR(ReportCodecRunningTime( + [&, this] { return RunCommand(compress_command_, arguments); }, + encoded_filename, speed_stats)); + return ReadFile(encoded_filename, compressed); + } + + Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + const std::string basename = GetBaseName(filename); + TemporaryFile encoded_file(basename, extension_), png_file(basename, "png"); + std::string encoded_filename, png_filename; + JXL_RETURN_IF_ERROR(encoded_file.GetFileName(&encoded_filename)); + JXL_RETURN_IF_ERROR(png_file.GetFileName(&png_filename)); + + JXL_RETURN_IF_ERROR(WriteFile(compressed, encoded_filename)); + JXL_RETURN_IF_ERROR(ReportCodecRunningTime( + [&, this] { + return RunCommand( + decompress_command_, + std::vector{encoded_filename, png_filename}); + }, + png_filename, speed_stats)); + JXL_RETURN_IF_ERROR( + SetFromFile(png_filename, extras::ColorHints(), io, pool)); + io->metadata.m.SetIntensityTarget(saved_intensity_target_); + return true; + } + + private: + std::string extension_; + std::string compress_command_; + std::string decompress_command_; + std::vector compress_args_; + int param_index_ = 0; + int saved_intensity_target_ = 255; +}; + +} // namespace + +ImageCodec* CreateNewCustomCodec(const BenchmarkArgs& args) { + return new CustomCodec(args); +} + +} // namespace jxl + +#else + +namespace jxl { + +ImageCodec* CreateNewCustomCodec(const BenchmarkArgs& args) { return nullptr; } + +} // namespace jxl + +#endif // _MSC_VER diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_custom.h b/media/libjxl/src/tools/benchmark/benchmark_codec_custom.h new file mode 100644 index 000000000..b2711cd5c --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_custom.h @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_CUSTOM_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_CUSTOM_H_ + +// This is a benchmark codec that can be used with any command-line +// encoder/decoder that satisfies the following conditions: +// +// - the encoder can read from a PNG file `$input.png` and write the encoded +// image to `$encoded.$ext` if it is called as: +// +// $encoder [OPTIONS] $input.png $encoded.$ext +// +// - the decoder can read from an encoded file `$encoded.$ext` and write to a +// PNG file `$decoded.png` if it is called as: +// +// $decoder $encoded.$ext $decoded.png +// +// On the benchmark command line, the codec must be specified as: +// +// custom:$ext:$encoder:$decoder:$options +// +// Where the options are also separated by colons. +// +// An example with JPEG XL itself would be: +// +// custom:jxl:cjxl:djxl:--distance:3 +// +// Optionally, to have encoding and decoding speed reported, the codec may write +// the number of seconds (as a floating point number) elapsed during actual +// encoding/decoding to $encoded.time and $decoded.time, respectively (replacing +// the .$ext and .png extensions). + +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" + +namespace jxl { + +ImageCodec* CreateNewCustomCodec(const BenchmarkArgs& args); + +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_CUSTOM_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_jpeg.cc b/media/libjxl/src/tools/benchmark/benchmark_codec_jpeg.cc new file mode 100644 index 000000000..ae3215a32 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_jpeg.cc @@ -0,0 +1,116 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "tools/benchmark/benchmark_codec_jpeg.h" + +#include +#include +// After stddef/stdio +#include +#include + +#include // partial_sum +#include + +#include "lib/extras/dec/jpg.h" +#include "lib/extras/enc/jpg.h" +#include "lib/extras/packed_image.h" +#include "lib/extras/packed_image_convert.h" +#include "lib/extras/time.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "tools/cmdline.h" + +namespace jxl { + +namespace { + +struct JPEGArgs { + std::string jpeg_encoder = "libjpeg"; + std::string chroma_subsampling = "444"; +}; + +JPEGArgs* const jpegargs = new JPEGArgs; + +} // namespace + +Status AddCommandLineOptionsJPEGCodec(BenchmarkArgs* args) { + args->cmdline.AddOptionValue( + '\0', "chroma_subsampling", "444/422/420/411", + "default JPEG chroma subsampling (default: 444).", + &jpegargs->chroma_subsampling, &jpegxl::tools::ParseString); + return true; +} + +class JPEGCodec : public ImageCodec { + public: + explicit JPEGCodec(const BenchmarkArgs& args) : ImageCodec(args) { + jpeg_encoder_ = jpegargs->jpeg_encoder; + chroma_subsampling_ = jpegargs->chroma_subsampling; + } + + Status ParseParam(const std::string& param) override { + if (ImageCodec::ParseParam(param)) { + return true; + } + if (param == "sjpeg") { + jpeg_encoder_ = param; + return true; + } + if (param.compare(0, 3, "yuv") == 0) { + if (param.size() != 6) return false; + chroma_subsampling_ = param.substr(3); + return true; + } + return false; + } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + extras::PackedPixelFile ppf; + JxlPixelFormat format = {0, JXL_TYPE_UINT8, JXL_BIG_ENDIAN, 0}; + JXL_RETURN_IF_ERROR(ConvertCodecInOutToPackedPixelFile( + *io, format, io->metadata.m.color_encoding, pool, &ppf)); + extras::EncodedImage encoded; + std::unique_ptr encoder = extras::GetJPEGEncoder(); + std::ostringstream os; + os << static_cast(std::round(q_target_)); + encoder->SetOption("q", os.str()); + encoder->SetOption("jpeg_encoder", jpeg_encoder_); + encoder->SetOption("chroma_subsampling", chroma_subsampling_); + const double start = Now(); + JXL_RETURN_IF_ERROR(encoder->Encode(ppf, &encoded, pool)); + const double end = Now(); + *compressed = encoded.bitstreams.back(); + speed_stats->NotifyElapsed(end - start); + return true; + } + + Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + extras::PackedPixelFile ppf; + const double start = Now(); + JXL_RETURN_IF_ERROR(DecodeImageJPG(compressed, extras::ColorHints(), + SizeConstraints(), &ppf)); + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + JXL_RETURN_IF_ERROR(ConvertPackedPixelFileToCodecInOut(ppf, pool, io)); + return true; + } + + protected: + std::string jpeg_encoder_; + std::string chroma_subsampling_; +}; + +ImageCodec* CreateNewJPEGCodec(const BenchmarkArgs& args) { + return new JPEGCodec(args); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_jpeg.h b/media/libjxl/src/tools/benchmark/benchmark_codec_jpeg.h new file mode 100644 index 000000000..cd4b009a7 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_jpeg.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_JPEG_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_JPEG_H_ + +#include "lib/jxl/base/status.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" + +namespace jxl { +ImageCodec* CreateNewJPEGCodec(const BenchmarkArgs& args); + +// Registers the jpeg-specific command line options. +Status AddCommandLineOptionsJPEGCodec(BenchmarkArgs* args); +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_JPEG_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_jxl.cc b/media/libjxl/src/tools/benchmark/benchmark_codec_jxl.cc new file mode 100644 index 000000000..655785832 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_jxl.cc @@ -0,0 +1,338 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "tools/benchmark/benchmark_codec_jxl.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "jxl/thread_parallel_runner_cxx.h" +#include "lib/extras/codec.h" +#include "lib/extras/dec/jxl.h" +#if JPEGXL_ENABLE_JPEG +#include "lib/extras/enc/jpg.h" +#endif +#include "lib/extras/packed_image_convert.h" +#include "lib/extras/time.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "tools/benchmark/benchmark_file_io.h" +#include "tools/benchmark/benchmark_stats.h" +#include "tools/cmdline.h" + +namespace jxl { + +// Output function for EncodeBrunsli. +size_t OutputToBytes(void* data, const uint8_t* buf, size_t count) { + PaddedBytes* output = reinterpret_cast(data); + output->append(buf, buf + count); + return count; +} + +struct JxlArgs { + double xmul; + double quant_bias; + + bool use_ac_strategy; + bool qprogressive; // progressive with shift-quantization. + bool progressive; + int progressive_dc; + + Override noise; + Override dots; + Override patches; + + bool log_search_state; + std::string debug_image_dir; +}; + +static JxlArgs* const jxlargs = new JxlArgs; + +Status AddCommandLineOptionsJxlCodec(BenchmarkArgs* args) { + args->AddDouble(&jxlargs->xmul, "xmul", + "Multiplier for the difference in X channel in Butteraugli.", + 1.0); + args->AddDouble(&jxlargs->quant_bias, "quant_bias", + "Bias border pixels during quantization by this ratio.", 0.0); + args->AddFlag(&jxlargs->use_ac_strategy, "use_ac_strategy", + "If true, AC strategy will be used.", false); + args->AddFlag(&jxlargs->qprogressive, "qprogressive", + "Enable quantized progressive mode for AC.", false); + args->AddFlag(&jxlargs->progressive, "progressive", + "Enable progressive mode for AC.", false); + args->AddSigned(&jxlargs->progressive_dc, "progressive_dc", + "Enable progressive mode for DC.", -1); + + args->AddOverride(&jxlargs->noise, "noise", + "Enable(1)/disable(0) noise generation."); + args->AddOverride(&jxlargs->dots, "dots", + "Enable(1)/disable(0) dots generation."); + args->AddOverride(&jxlargs->patches, "patches", + "Enable(1)/disable(0) patch dictionary."); + + args->AddFlag(&jxlargs->log_search_state, "log_search_state", + "Print out debug info for tortoise mode AQ loop.", false); + + args->AddString( + &jxlargs->debug_image_dir, "debug_image_dir", + "If not empty, saves debug images for each " + "input image and each codec that provides it to this directory."); + + return true; +} + +Status ValidateArgsJxlCodec(BenchmarkArgs* args) { return true; } + +class JxlCodec : public ImageCodec { + public: + explicit JxlCodec(const BenchmarkArgs& args) : ImageCodec(args) {} + + Status ParseParam(const std::string& param) override { + const std::string kMaxPassesPrefix = "max_passes="; + const std::string kDownsamplingPrefix = "downsampling="; + const std::string kResamplingPrefix = "resampling="; + const std::string kEcResamplingPrefix = "ec_resampling="; + + if (param.substr(0, kResamplingPrefix.size()) == kResamplingPrefix) { + std::istringstream parser(param.substr(kResamplingPrefix.size())); + parser >> cparams_.resampling; + } else if (param.substr(0, kEcResamplingPrefix.size()) == + kEcResamplingPrefix) { + std::istringstream parser(param.substr(kEcResamplingPrefix.size())); + parser >> cparams_.ec_resampling; + } else if (ImageCodec::ParseParam(param)) { + if (param[0] == 'd' && butteraugli_target_ == 0.0) { + cparams_.SetLossless(); + } + } else if (param == "uint8") { + uint8_ = true; + } else if (param[0] == 'u') { + char* end; + cparams_.uniform_quant = strtof(param.c_str() + 1, &end); + if (end == param.c_str() + 1 || *end != '\0') { + return JXL_FAILURE("failed to parse uniform quant parameter %s", + param.c_str()); + } + ba_params_.hf_asymmetry = args_.ba_params.hf_asymmetry; + } else if (param.substr(0, kMaxPassesPrefix.size()) == kMaxPassesPrefix) { + std::istringstream parser(param.substr(kMaxPassesPrefix.size())); + parser >> dparams_.max_passes; + } else if (param.substr(0, kDownsamplingPrefix.size()) == + kDownsamplingPrefix) { + std::istringstream parser(param.substr(kDownsamplingPrefix.size())); + parser >> dparams_.max_downsampling; + } else if (ParseSpeedTier(param, &cparams_.speed_tier)) { + // Nothing to do. + } else if (param[0] == 'X') { + cparams_.channel_colors_pre_transform_percent = + strtol(param.substr(1).c_str(), nullptr, 10); + } else if (param[0] == 'Y') { + cparams_.channel_colors_percent = + strtol(param.substr(1).c_str(), nullptr, 10); + } else if (param[0] == 'p') { + cparams_.palette_colors = strtol(param.substr(1).c_str(), nullptr, 10); + } else if (param == "lp") { + cparams_.lossy_palette = true; + } else if (param[0] == 'C') { + cparams_.colorspace = strtol(param.substr(1).c_str(), nullptr, 10); + } else if (param[0] == 'c') { + cparams_.color_transform = + (jxl::ColorTransform)strtol(param.substr(1).c_str(), nullptr, 10); + has_ctransform_ = true; + } else if (param[0] == 'I') { + cparams_.options.nb_repeats = strtof(param.substr(1).c_str(), nullptr); + } else if (param[0] == 'E') { + cparams_.options.max_properties = + strtof(param.substr(1).c_str(), nullptr); + } else if (param[0] == 'P') { + cparams_.options.predictor = + static_cast(strtof(param.substr(1).c_str(), nullptr)); + } else if (param == "slow") { + cparams_.options.nb_repeats = 2; + } else if (param == "R") { + cparams_.responsive = 1; + } else if (param[0] == 'R') { + cparams_.responsive = strtol(param.substr(1).c_str(), nullptr, 10); + } else if (param == "m") { + cparams_.modular_mode = true; + cparams_.color_transform = jxl::ColorTransform::kNone; + } else if (param.substr(0, 3) == "gab") { + long gab = strtol(param.substr(3).c_str(), nullptr, 10); + if (gab != 0 && gab != 1) { + return JXL_FAILURE("Invalid gab value"); + } + cparams_.gaborish = static_cast(gab); + } else if (param[0] == 'g') { + long gsize = strtol(param.substr(1).c_str(), nullptr, 10); + if (gsize < 0 || gsize > 3) { + return JXL_FAILURE("Invalid group size shift value"); + } + cparams_.modular_group_size_shift = gsize; + } else if (param == "plt") { + cparams_.options.max_properties = 0; + cparams_.options.nb_repeats = 0; + cparams_.options.predictor = Predictor::Zero; + cparams_.responsive = 0; + cparams_.colorspace = 0; + cparams_.channel_colors_pre_transform_percent = 0; + cparams_.channel_colors_percent = 0; + } else if (param.substr(0, 3) == "epf") { + cparams_.epf = strtol(param.substr(3).c_str(), nullptr, 10); + if (cparams_.epf > 3) { + return JXL_FAILURE("Invalid epf value"); + } + } else if (param.substr(0, 2) == "nr") { + normalize_bitrate_ = true; + } else if (param.substr(0, 16) == "faster_decoding=") { + cparams_.decoding_speed_tier = + strtol(param.substr(16).c_str(), nullptr, 10); + } else { + return JXL_FAILURE("Unrecognized param"); + } + return true; + } + + bool IsColorAware() const override { + // Can't deal with negative values from color space conversion. + if (cparams_.modular_mode) return false; + if (normalize_bitrate_) return false; + // Otherwise, input may be in any color space. + return true; + } + + bool IsJpegTranscoder() const override { + // TODO(veluca): figure out when to turn this on. + return false; + } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + if (!jxlargs->debug_image_dir.empty()) { + cinfo_.dump_image = [](const CodecInOut& io, const std::string& path) { + return EncodeToFile(io, path); + }; + cinfo_.debug_prefix = + JoinPath(jxlargs->debug_image_dir, FileBaseName(filename)) + + ".jxl:" + params_ + ".dbg/"; + JXL_RETURN_IF_ERROR(MakeDir(cinfo_.debug_prefix)); + } + cparams_.butteraugli_distance = butteraugli_target_; + cparams_.target_bitrate = bitrate_target_; + + cparams_.dots = jxlargs->dots; + cparams_.patches = jxlargs->patches; + + cparams_.progressive_mode = jxlargs->progressive; + cparams_.qprogressive_mode = jxlargs->qprogressive; + cparams_.progressive_dc = jxlargs->progressive_dc; + + cparams_.noise = jxlargs->noise; + + cparams_.quant_border_bias = static_cast(jxlargs->quant_bias); + cparams_.ba_params.hf_asymmetry = ba_params_.hf_asymmetry; + cparams_.ba_params.xmul = static_cast(jxlargs->xmul); + + if (cparams_.butteraugli_distance > 0.f && + cparams_.color_transform == ColorTransform::kNone && + cparams_.modular_mode && !has_ctransform_) { + cparams_.color_transform = ColorTransform::kXYB; + } + + cparams_.log_search_state = jxlargs->log_search_state; + +#if JPEGXL_ENABLE_JPEG + if (normalize_bitrate_ && cparams_.butteraugli_distance > 0.0f) { + extras::PackedPixelFile ppf; + JxlPixelFormat format = {0, JXL_TYPE_UINT8, JXL_BIG_ENDIAN, 0}; + JXL_RETURN_IF_ERROR(ConvertCodecInOutToPackedPixelFile( + *io, format, io->metadata.m.color_encoding, pool, &ppf)); + extras::EncodedImage encoded; + std::unique_ptr encoder = extras::GetJPEGEncoder(); + encoder->SetOption("q", "95"); + JXL_RETURN_IF_ERROR(encoder->Encode(ppf, &encoded, pool)); + float jpeg_bits = encoded.bitstreams.back().size() * kBitsPerByte; + float jpeg_bitrate = jpeg_bits / (io->xsize() * io->ysize()); + // Formula fitted on jyrki31 corpus for distances between 1.0 and 8.0. + cparams_.target_bitrate = (jpeg_bitrate * 0.36f / + (0.6f * cparams_.butteraugli_distance + 0.4f)); + } +#endif + + const double start = Now(); + PassesEncoderState passes_encoder_state; + PaddedBytes compressed_padded; + JXL_RETURN_IF_ERROR(EncodeFile(cparams_, io, &passes_encoder_state, + &compressed_padded, GetJxlCms(), &cinfo_, + pool)); + const double end = Now(); + compressed->assign(compressed_padded.begin(), compressed_padded.end()); + speed_stats->NotifyElapsed(end - start); + return true; + } + + Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + dparams_.runner = pool->runner(); + dparams_.runner_opaque = pool->runner_opaque(); + JxlDataType data_type = uint8_ ? JXL_TYPE_UINT8 : JXL_TYPE_FLOAT; + dparams_.accepted_formats = {{3, data_type, JXL_NATIVE_ENDIAN, 0}, + {4, data_type, JXL_NATIVE_ENDIAN, 0}}; + // By default, the decoder will undo exif orientation, giving an image + // with identity exif rotation as result. However, the benchmark does + // not undo exif orientation of the originals, and compares against the + // originals, so we must set the option to keep the original orientation + // instead. + dparams_.keep_orientation = true; + extras::PackedPixelFile ppf; + size_t decoded_bytes; + const double start = Now(); + JXL_RETURN_IF_ERROR(DecodeImageJXL(compressed.data(), compressed.size(), + dparams_, &decoded_bytes, &ppf)); + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + JXL_RETURN_IF_ERROR(ConvertPackedPixelFileToCodecInOut(ppf, pool, io)); + return true; + } + + void GetMoreStats(BenchmarkStats* stats) override { + JxlStats jxl_stats; + jxl_stats.num_inputs = 1; + jxl_stats.aux_out = cinfo_; + stats->jxl_stats.Assimilate(jxl_stats); + } + + protected: + AuxOut cinfo_; + CompressParams cparams_; + bool has_ctransform_ = false; + extras::JXLDecompressParams dparams_; + bool uint8_ = false; + bool normalize_bitrate_ = false; +}; + +ImageCodec* CreateNewJxlCodec(const BenchmarkArgs& args) { + return new JxlCodec(args); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_jxl.h b/media/libjxl/src/tools/benchmark/benchmark_codec_jxl.h new file mode 100644 index 000000000..12e9fef79 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_jxl.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_JXL_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_JXL_H_ + +#include + +#include "lib/jxl/base/status.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" + +namespace jxl { +ImageCodec* CreateNewJxlCodec(const BenchmarkArgs& args); + +// Registers the jxl-specific command line options. +Status AddCommandLineOptionsJxlCodec(BenchmarkArgs* args); +Status ValidateArgsJxlCodec(BenchmarkArgs* args); +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_JXL_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_png.cc b/media/libjxl/src/tools/benchmark/benchmark_codec_png.cc new file mode 100644 index 000000000..b310b11bd --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_png.cc @@ -0,0 +1,76 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if JPEGXL_ENABLE_APNG + +#include "tools/benchmark/benchmark_codec_png.h" + +#include +#include + +#include + +#include "lib/extras/codec.h" +#include "lib/extras/dec/apng.h" +#include "lib/extras/enc/apng.h" +#include "lib/extras/packed_image.h" +#include "lib/extras/packed_image_convert.h" +#include "lib/extras/time.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +struct PNGArgs { + // Empty, no PNG-specific args currently. +}; + +static PNGArgs* const pngargs = new PNGArgs; + +Status AddCommandLineOptionsPNGCodec(BenchmarkArgs* args) { return true; } + +// Lossless. +class PNGCodec : public ImageCodec { + public: + explicit PNGCodec(const BenchmarkArgs& args) : ImageCodec(args) {} + + Status ParseParam(const std::string& param) override { return true; } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + const size_t bits = io->metadata.m.bit_depth.bits_per_sample; + const double start = Now(); + JXL_RETURN_IF_ERROR(Encode(*io, extras::Codec::kPNG, io->Main().c_current(), + bits, compressed, pool)); + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + return true; + } + + Status Decompress(const std::string& /*filename*/, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + extras::PackedPixelFile ppf; + const double start = Now(); + JXL_RETURN_IF_ERROR(extras::DecodeImageAPNG( + compressed, extras::ColorHints(), SizeConstraints(), &ppf)); + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + JXL_RETURN_IF_ERROR(ConvertPackedPixelFileToCodecInOut(ppf, pool, io)); + return true; + } +}; + +ImageCodec* CreateNewPNGCodec(const BenchmarkArgs& args) { + return new PNGCodec(args); +} + +} // namespace jxl + +#endif diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_png.h b/media/libjxl/src/tools/benchmark/benchmark_codec_png.h new file mode 100644 index 000000000..23d982e17 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_png.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_PNG_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_PNG_H_ + +#if JPEGXL_ENABLE_APNG + +#include + +#include "lib/jxl/base/status.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" + +namespace jxl { +ImageCodec* CreateNewPNGCodec(const BenchmarkArgs& args); + +// Registers the png-specific command line options. +Status AddCommandLineOptionsPNGCodec(BenchmarkArgs* args); +} // namespace jxl + +#endif + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_PNG_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_webp.cc b/media/libjxl/src/tools/benchmark/benchmark_codec_webp.cc new file mode 100644 index 000000000..3b1bb264d --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_webp.cc @@ -0,0 +1,280 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "tools/benchmark/benchmark_codec_webp.h" + +#include +#include +#include +#include + +#include +#include + +#include "lib/extras/time.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_image_bundle.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +// Sets image data from 8-bit sRGB pixel array in bytes. +// Amount of input bytes per pixel must be: +// (is_gray ? 1 : 3) + (has_alpha ? 1 : 0) +Status FromSRGB(const size_t xsize, const size_t ysize, const bool is_gray, + const bool has_alpha, const bool alpha_is_premultiplied, + const bool is_16bit, const JxlEndianness endianness, + const uint8_t* pixels, const uint8_t* end, ThreadPool* pool, + ImageBundle* ib) { + const ColorEncoding& c = ColorEncoding::SRGB(is_gray); + const size_t bits_per_sample = (is_16bit ? 2 : 1) * kBitsPerByte; + const Span span(pixels, end - pixels); + return ConvertFromExternal( + span, xsize, ysize, c, (is_gray ? 1 : 3) + (has_alpha ? 1 : 0), + alpha_is_premultiplied, bits_per_sample, endianness, pool, ib, + /*float_in=*/false, /*align=*/0); +} + +struct WebPArgs { + // Empty, no WebP-specific args currently. +}; + +static WebPArgs* const webpargs = new WebPArgs; + +Status AddCommandLineOptionsWebPCodec(BenchmarkArgs* args) { return true; } + +class WebPCodec : public ImageCodec { + public: + explicit WebPCodec(const BenchmarkArgs& args) : ImageCodec(args) {} + + Status ParseParam(const std::string& param) override { + // Ensure that the 'q' parameter is not used up by ImageCodec. + if (param[0] == 'q') { + if (near_lossless_) { + near_lossless_quality_ = ParseIntParam(param, 0, 99); + } else { + quality_ = ParseIntParam(param, 1, 100); + } + return true; + } else if (ImageCodec::ParseParam(param)) { + return true; + } else if (param == "ll") { + lossless_ = true; + JXL_CHECK(!near_lossless_); + return true; + } else if (param == "nl") { + near_lossless_ = true; + JXL_CHECK(!lossless_); + return true; + } else if (param[0] == 'm') { + method_ = ParseIntParam(param, 1, 6); + return true; + } + return false; + } + + Status Compress(const std::string& filename, const CodecInOut* io, + ThreadPoolInternal* pool, std::vector* compressed, + jpegxl::tools::SpeedStats* speed_stats) override { + const double start = Now(); + const ImageBundle& ib = io->Main(); + + if (ib.HasAlpha() && ib.metadata()->GetAlphaBits() > 8) { + return JXL_FAILURE("WebP alpha must be 8-bit"); + } + + size_t num_chans = (ib.HasAlpha() ? 4 : 3); + ImageMetadata metadata = io->metadata.m; + ImageBundle store(&metadata); + const ImageBundle* transformed; + const ColorEncoding& c_desired = ColorEncoding::SRGB(false); + JXL_RETURN_IF_ERROR(TransformIfNeeded(ib, c_desired, GetJxlCms(), pool, + &store, &transformed)); + size_t xsize = ib.oriented_xsize(); + size_t ysize = ib.oriented_ysize(); + size_t stride = xsize * num_chans; + std::vector srgb(stride * ysize); + JXL_RETURN_IF_ERROR(ConvertToExternal( + *transformed, 8, /*float_out=*/false, num_chans, JXL_BIG_ENDIAN, stride, + pool, srgb.data(), srgb.size(), + /*out_callback=*/{}, metadata.GetOrientation())); + + if (lossless_ || near_lossless_) { + // The lossless codec does not support 16-bit channels. + // Color models are currently not supported here and the sRGB 8-bit + // conversion causes loss due to clipping. + if (!ib.IsSRGB() || ib.metadata()->bit_depth.bits_per_sample > 8 || + ib.metadata()->bit_depth.exponent_bits_per_sample > 0) { + return JXL_FAILURE("%s: webp:ll/nl requires 8-bit sRGB", + filename.c_str()); + } + JXL_RETURN_IF_ERROR( + CompressInternal(srgb, xsize, ysize, num_chans, 100, compressed)); + } else if (bitrate_target_ > 0.0) { + int quality_bad = 100; + int quality_good = 92; + size_t target_size = xsize * ysize * bitrate_target_ / 8.0; + while (quality_good > 0 && + CompressInternal(srgb, xsize, ysize, num_chans, quality_good, + compressed) && + compressed->size() > target_size) { + quality_bad = quality_good; + quality_good -= 8; + } + if (quality_good <= 0) quality_good = 1; + while (quality_good + 1 < quality_bad) { + int quality = (quality_bad + quality_good) / 2; + if (!CompressInternal(srgb, xsize, ysize, num_chans, quality, + compressed)) { + break; + } + if (compressed->size() <= target_size) { + quality_good = quality; + } else { + quality_bad = quality; + } + } + JXL_RETURN_IF_ERROR(CompressInternal(srgb, xsize, ysize, num_chans, + quality_good, compressed)); + } else if (quality_ > 0) { + JXL_RETURN_IF_ERROR(CompressInternal(srgb, xsize, ysize, num_chans, + quality_, compressed)); + } else { + return false; + } + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + return true; + } + + Status Decompress(const std::string& filename, + const Span compressed, + ThreadPoolInternal* pool, CodecInOut* io, + jpegxl::tools::SpeedStats* speed_stats) override { + WebPDecoderConfig config; +#ifdef MEMORY_SANITIZER + // config is initialized by libwebp, which we are not instrumenting with + // msan, therefore we need to initialize it here. + memset(&config, 0, sizeof(config)); +#endif + JXL_RETURN_IF_ERROR(WebPInitDecoderConfig(&config) == 1); + config.options.use_threads = 0; + config.options.dithering_strength = 0; + config.options.bypass_filtering = 0; + config.options.no_fancy_upsampling = 0; + WebPDecBuffer* const buf = &config.output; + buf->colorspace = MODE_RGBA; + const uint8_t* webp_data = compressed.data(); + const int webp_size = compressed.size(); + const double start = Now(); + if (WebPDecode(webp_data, webp_size, &config) != VP8_STATUS_OK) { + return JXL_FAILURE("WebPDecode failed"); + } + const double end = Now(); + speed_stats->NotifyElapsed(end - start); + JXL_CHECK(buf->u.RGBA.stride == buf->width * 4); + + const bool is_gray = false; + const bool has_alpha = true; + const uint8_t* data_begin = &buf->u.RGBA.rgba[0]; + const uint8_t* data_end = data_begin + buf->width * buf->height * 4; + // The image data is initialized by libwebp, which we are not instrumenting + // with msan. + msan::UnpoisonMemory(data_begin, data_end - data_begin); + if (io->metadata.m.color_encoding.IsGray() != is_gray) { + // TODO(lode): either ensure is_gray matches what the color profile says, + // or set a correct color profile, e.g. + // io->metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + // Return a standard failure because SetFromSRGB triggers a fatal assert + // for this instead. + return JXL_FAILURE("Color profile is-gray mismatch"); + } + io->metadata.m.SetAlphaBits(8); + const Status ok = + FromSRGB(buf->width, buf->height, is_gray, has_alpha, + /*alpha_is_premultiplied=*/false, /*is_16bit=*/false, + JXL_LITTLE_ENDIAN, data_begin, data_end, pool, &io->Main()); + WebPFreeDecBuffer(buf); + JXL_RETURN_IF_ERROR(ok); + io->dec_pixels = buf->width * buf->height; + return true; + } + + private: + static int WebPStringWrite(const uint8_t* data, size_t data_size, + const WebPPicture* const picture) { + if (data_size) { + std::vector* const out = + static_cast*>(picture->custom_ptr); + const size_t pos = out->size(); + out->resize(pos + data_size); + memcpy(out->data() + pos, data, data_size); + } + return 1; + } + Status CompressInternal(const std::vector& srgb, size_t xsize, + size_t ysize, size_t num_chans, int quality, + std::vector* compressed) { + compressed->clear(); + WebPConfig config; + WebPConfigInit(&config); + JXL_ASSERT(!lossless_ || !near_lossless_); // can't have both + config.lossless = lossless_; + config.quality = quality; + config.method = method_; +#if WEBP_ENCODER_ABI_VERSION >= 0x020a + config.near_lossless = near_lossless_ ? near_lossless_quality_ : 100; +#else + if (near_lossless_) { + JXL_WARNING("Near lossless not supported by this WebP version"); + } +#endif + JXL_CHECK(WebPValidateConfig(&config)); + + WebPPicture pic; + WebPPictureInit(&pic); + pic.width = static_cast(xsize); + pic.height = static_cast(ysize); + pic.writer = &WebPStringWrite; + if (lossless_ || near_lossless_) pic.use_argb = 1; + pic.custom_ptr = compressed; + + if (num_chans == 3) { + WebPPictureImportRGB(&pic, srgb.data(), 3 * xsize); + } else { + WebPPictureImportRGBA(&pic, srgb.data(), 4 * xsize); + } + + // WebP encoding may fail, for example, if the image is more than 16384 + // pixels high or wide. + bool ok = WebPEncode(&config, &pic); + WebPPictureFree(&pic); + // Compressed image data is initialized by libwebp, which we are not + // instrumenting with msan. + msan::UnpoisonMemory(compressed->data(), compressed->size()); + return ok; + } + + int quality_ = 90; + bool lossless_ = false; + bool near_lossless_ = false; + bool near_lossless_quality_ = 40; // only used if near_lossless_ + int method_ = 6; // smallest, some speed cost +}; + +ImageCodec* CreateNewWebPCodec(const BenchmarkArgs& args) { + return new WebPCodec(args); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_codec_webp.h b/media/libjxl/src/tools/benchmark/benchmark_codec_webp.h new file mode 100644 index 000000000..cd4c60fb5 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_codec_webp.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#ifndef TOOLS_BENCHMARK_BENCHMARK_CODEC_WEBP_H_ +#define TOOLS_BENCHMARK_BENCHMARK_CODEC_WEBP_H_ + +// To support webp, install libwebp-dev and rerun cmake. + +#include + +#include "lib/jxl/base/status.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" + +namespace jxl { +ImageCodec* CreateNewWebPCodec(const BenchmarkArgs& args); + +// Registers the webp-specific command line options. +Status AddCommandLineOptionsWebPCodec(BenchmarkArgs* args); +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_CODEC_WEBP_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_file_io.cc b/media/libjxl/src/tools/benchmark/benchmark_file_io.cc new file mode 100644 index 000000000..c5db02b8f --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_file_io.cc @@ -0,0 +1,232 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "tools/benchmark/benchmark_file_io.h" + +#include +#include + +#include + +#if defined(_WIN32) || defined(_WIN64) +#include "third_party/dirent.h" +#else +#include +#include +#endif + +#ifndef HAS_GLOB +#define HAS_GLOB 0 +#if defined __has_include +// is included in previous APIs but glob() function is not defined +// until API 28. +#if __has_include() && \ + (!defined(__ANDROID_API__) || __ANDROID_API__ >= 28) +#undef HAS_GLOB +#define HAS_GLOB 1 +#endif // __has_include() +#endif // __has_include +#endif // HAS_GLOB + +#if HAS_GLOB +#include +#endif // HAS_GLOB + +// There is no "user" in embedded filesystems. +#ifndef GLOB_TILDE +#define GLOB_TILDE 0 +#endif + +namespace jxl { + +const char kPathSeparator = '/'; + +// RAII, ensures dir is closed even when returning early. +class DirWrapper { + public: + DirWrapper(const DirWrapper& other) = delete; + DirWrapper& operator=(const DirWrapper& other) = delete; + + explicit DirWrapper(const std::string& pathname) + : dir_(opendir(pathname.c_str())) {} + + ~DirWrapper() { + if (dir_ != nullptr) { + const int err = closedir(dir_); + JXL_CHECK(err == 0); + } + } + + // NOLINTNEXTLINE(google-explicit-constructor) + operator DIR*() const { return dir_; } + + private: + DIR* const dir_; +}; + +// Checks if the file exists, either as file or as directory +bool PathExists(const std::string& fname) { + struct stat s; + if (stat(fname.c_str(), &s) != 0) return false; + return true; +} + +// Checks if the file exists and is a regular file. +bool IsRegularFile(const std::string& fname) { + struct stat s; + if (stat(fname.c_str(), &s) != 0) return false; + return S_ISREG(s.st_mode); +} + +// Checks if the file exists and is a directory. +bool IsDirectory(const std::string& fname) { + struct stat s; + if (stat(fname.c_str(), &s) != 0) return false; + return S_ISDIR(s.st_mode); +} + +// Recursively makes dir, or successfully does nothing if it already exists. +Status MakeDir(const std::string& dirname) { + size_t pos = 0; + for (pos = dirname.size(); pos > 0; pos--) { + if (pos == dirname.size() || dirname[pos] == kPathSeparator) { + // Found existing dir or regular file, break and then start creating + // from here (in the latter case we'll get error below). + if (PathExists(dirname.substr(0, pos + 1))) { + pos += 1; // Skip past this existing path + break; + } + } + } + for (; pos <= dirname.size(); pos++) { + if (pos == dirname.size() || dirname[pos] == kPathSeparator) { + std::string subdir = dirname.substr(0, pos + 1); + if (mkdir(subdir.c_str(), 0777) && errno != EEXIST) { + return JXL_FAILURE("Failed to create directory"); + } + } + } + if (!IsDirectory(dirname)) return JXL_FAILURE("Failed to create directory"); + return true; // success +} + +Status DeleteFile(const std::string& fname) { + if (!IsRegularFile(fname)) { + return JXL_FAILURE("Trying to delete non-regular file"); + } + if (std::remove(fname.c_str())) return JXL_FAILURE("Failed to delete file"); + return true; +} + +std::string FileBaseName(const std::string& fname) { + size_t pos = fname.rfind('/'); + if (pos == std::string::npos) return fname; + return fname.substr(pos + 1); +} + +std::string FileDirName(const std::string& fname) { + size_t pos = fname.rfind('/'); + if (pos == std::string::npos) return ""; + return fname.substr(0, pos); +} + +std::string FileExtension(const std::string& fname) { + size_t pos = fname.rfind('.'); + if (pos == std::string::npos) return ""; + return fname.substr(pos); +} + +std::string JoinPath(const std::string& first, const std::string& second) { + JXL_CHECK(second.empty() || second[0] != kPathSeparator); + return (!first.empty() && first.back() == kPathSeparator) + ? (first + second) + : (first + kPathSeparator + second); +} + +// Can match a single file, or multiple files in a directory (non-recursive). +// With POSIX, supports glob(), otherwise supports a subset. +Status MatchFiles(const std::string& pattern, std::vector* list) { +#if HAS_GLOB + glob_t g; + memset(&g, 0, sizeof(g)); + int error = glob(pattern.c_str(), GLOB_TILDE, NULL, &g); + if (!error) { + for (size_t i = 0; i < g.gl_pathc; ++i) { + list->push_back(g.gl_pathv[i]); + } + } + globfree(&g); + if (error) return JXL_FAILURE("glob failed for %s", pattern.c_str()); + return true; +#else + std::string dirname = FileDirName(pattern); + std::string basename = FileBaseName(pattern); + size_t pos0 = basename.find('*'); + size_t pos1 = pos0 == std::string::npos ? pos0 : basename.find('*', pos0 + 1); + std::string prefix, middle, suffix; + if (pos0 != std::string::npos) { + prefix = basename.substr(0, pos0); + if (pos1 != std::string::npos) { + middle = basename.substr(pos0 + 1, pos1 - pos0 - 1); + suffix = basename.substr(pos1 + 1); + } else { + suffix = basename.substr(pos0 + 1); + } + } + + if (prefix.find_first_of("*?[") != std::string::npos || + middle.find_first_of("*?[") != std::string::npos || + suffix.find_first_of("*?[") != std::string::npos || + dirname.find_first_of("*?[") != std::string::npos) { + return JXL_FAILURE( + "Only glob patterns with max two '*' in the basename" + " are supported, e.g. directory/path/*.png or" + " /directory/path/*heatmap*"); + } + + if (pos0 != std::string::npos) { + DirWrapper dir(dirname); + if (!dir) return JXL_FAILURE("directory %s doesn't exist", dirname.c_str()); + for (;;) { + dirent* ent = readdir(dir); + if (!ent) break; + std::string name = ent->d_name; + // If there was a suffix, only add if it matches (e.g. ".png") + bool matches = + name.size() >= (prefix.size() + middle.size() + suffix.size()); + if (matches) { + if (!prefix.empty() && name.substr(0, prefix.size()) != prefix) { + matches = false; + } + if (!middle.empty()) { + size_t pos = name.find(middle, prefix.size()); + if (pos == std::string::npos || + pos + middle.size() > name.size() - suffix.size()) { + matches = false; + } + } + if (!suffix.empty() && + name.substr(name.size() - suffix.size()) != suffix) { + matches = false; + } + } + if (matches) { + std::string path = JoinPath(dirname, name); + + if (IsRegularFile(path)) { + list->push_back(path); + } + } + } + return true; + } + // No *, so a single regular file is intended + if (IsRegularFile(pattern)) { + list->push_back(pattern); + } + return true; +#endif // HAS_GLOB +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_file_io.h b/media/libjxl/src/tools/benchmark/benchmark_file_io.h new file mode 100644 index 000000000..ecb83590d --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_file_io.h @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// File utilities for benchmarking and testing, but which are not needed for +// main jxl itself. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_FILE_IO_H_ +#define TOOLS_BENCHMARK_BENCHMARK_FILE_IO_H_ + +#include +#include + +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Checks if the file exists, either as file or as directory +bool PathExists(const std::string& fname); + +// Checks if the file exists and is a regular file. +bool IsRegularFile(const std::string& fname); + +// Checks if the file exists and is a directory. +bool IsDirectory(const std::string& fname); + +// Recursively makes dir, or successfully does nothing if it already exists. +Status MakeDir(const std::string& dirname); + +// Deletes a single regular file. +Status DeleteFile(const std::string& fname); + +// Returns value similar to unix basename, except it returns empty string if +// fname ends in '/'. +std::string FileBaseName(const std::string& fname); +// Returns value similar to unix dirname, except returns up to before the last +// slash if fname ends in '/'. +std::string FileDirName(const std::string& fname); + +// Returns the part of the filename starting from the last dot, or empty +// string if there is no dot. +std::string FileExtension(const std::string& fname); + +// Matches one or more files given glob pattern. +Status MatchFiles(const std::string& pattern, std::vector* list); + +std::string JoinPath(const std::string& first, const std::string& second); + +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_FILE_IO_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_stats.cc b/media/libjxl/src/tools/benchmark/benchmark_stats.cc new file mode 100644 index 000000000..f22e89c84 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_stats.cc @@ -0,0 +1,376 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/benchmark/benchmark_stats.h" + +#include +#include +#include +#include + +#include +#include + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "tools/benchmark/benchmark_args.h" + +namespace jxl { +namespace { + +// Computes longest codec name from Args()->codec, for table alignment. +uint32_t ComputeLargestCodecName() { + std::vector methods = SplitString(Args()->codec, ','); + size_t max = strlen("Aggregate:"); // Include final row's name + for (const auto& method : methods) { + max = std::max(max, method.size()); + } + return max; +} + +// The benchmark result is a table of heterogeneous data, the column type +// specifies its data type. The type affects how it is printed as well as how +// aggregate values are computed. +enum ColumnType { + // Formatted string + TYPE_STRING, + // Positive size, prints 0 as "---" + TYPE_SIZE, + // Floating point value (double precision) which is interpreted as + // "not applicable" if <= 0, must be strictly positive to be valid but can be + // set to 0 or negative to be printed as "---", for example for a speed that + // is not measured. + TYPE_POSITIVE_FLOAT, + // Counts of some event + TYPE_COUNT, +}; + +struct ColumnDescriptor { + // Column name + std::string label; + // Total width to render the values of this column. If t his is a floating + // point value, make sure this is large enough to contain a space and the + // point, plus precision digits after the point, plus the max amount of + // integer digits you expect in front of the point. + uint32_t width; + // Amount of digits after the point, or 0 if not a floating point value. + uint32_t precision; + ColumnType type; + bool more; // Whether to print only if more_columns is enabled +}; + +static const ColumnDescriptor ExtraMetricDescriptor() { + ColumnDescriptor d{{"DO NOT USE"}, 12, 4, TYPE_POSITIVE_FLOAT, false}; + return d; +} + +// To add or change a column to the benchmark ASCII table output, add/change +// an entry here with table header line 1, table header line 2, width of the +// column, precision after the point in case of floating point, and the +// data type. Then add/change the corresponding formula or formatting in +// the function ComputeColumns. +std::vector GetColumnDescriptors(size_t num_extra_metrics) { + // clang-format off + std::vector result = { + {{"Encoding"}, ComputeLargestCodecName() + 1, 0, TYPE_STRING, false}, + {{"kPixels"}, 10, 0, TYPE_SIZE, false}, + {{"Bytes"}, 9, 0, TYPE_SIZE, false}, + {{"BPP"}, 13, 7, TYPE_POSITIVE_FLOAT, false}, + {{"E MP/s"}, 8, 3, TYPE_POSITIVE_FLOAT, false}, + {{"D MP/s"}, 8, 3, TYPE_POSITIVE_FLOAT, false}, + {{"Max norm"}, 13, 8, TYPE_POSITIVE_FLOAT, false}, + {{"pnorm"}, 13, 8, TYPE_POSITIVE_FLOAT, false}, + {{"PSNR"}, 7, 2, TYPE_POSITIVE_FLOAT, true}, + {{"QABPP"}, 8, 3, TYPE_POSITIVE_FLOAT, true}, + {{"SmallB"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"DCT4x8"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"AFV"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"DCT8x8"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"8x16"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"8x32"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"16"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"16x32"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"32"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"32x64"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"64"}, 8, 4, TYPE_POSITIVE_FLOAT, true}, + {{"BPP*pnorm"}, 16, 12, TYPE_POSITIVE_FLOAT, false}, + {{"Bugs"}, 7, 5, TYPE_COUNT, false}, + }; + // clang-format on + + for (size_t i = 0; i < num_extra_metrics; i++) { + result.push_back(ExtraMetricDescriptor()); + } + + return result; +} + +// Computes throughput [megapixels/s] as reported in the report table +static double ComputeSpeed(size_t pixels, double time_s) { + if (time_s == 0.0) return 0; + return pixels * 1E-6 / time_s; +} + +static std::string FormatFloat(const ColumnDescriptor& label, double value) { + std::string result = + StringPrintf("%*.*f", label.width - 1, label.precision, value); + + // Reduce precision if the value is too wide for the column. However, keep + // at least one digit to the right of the point, and especially the integer + // digits. + if (result.size() >= label.width) { + size_t point = result.rfind('.'); + if (point != std::string::npos) { + int end = std::max(point + 2, label.width - 1); + result = result.substr(0, end); + } + } + return result; +} + +} // namespace + +std::string StringPrintf(const char* format, ...) { + char buf[2000]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + return std::string(buf); +} + +void BenchmarkStats::Assimilate(const BenchmarkStats& victim) { + total_input_files += victim.total_input_files; + total_input_pixels += victim.total_input_pixels; + total_compressed_size += victim.total_compressed_size; + total_adj_compressed_size += victim.total_adj_compressed_size; + total_time_encode += victim.total_time_encode; + total_time_decode += victim.total_time_decode; + max_distance = std::max(max_distance, victim.max_distance); + distance_p_norm += victim.distance_p_norm; + distance_2 += victim.distance_2; + distances.insert(distances.end(), victim.distances.begin(), + victim.distances.end()); + total_errors += victim.total_errors; + jxl_stats.Assimilate(victim.jxl_stats); + if (extra_metrics.size() < victim.extra_metrics.size()) { + extra_metrics.resize(victim.extra_metrics.size()); + } + for (size_t i = 0; i < victim.extra_metrics.size(); i++) { + extra_metrics[i] += victim.extra_metrics[i]; + } +} + +void BenchmarkStats::PrintMoreStats() const { + if (Args()->print_more_stats) { + jxl_stats.Print(); + size_t total_bits = jxl_stats.aux_out.TotalBits(); + size_t compressed_bits = total_compressed_size * kBitsPerByte; + if (total_bits != compressed_bits) { + printf("Total layer bits: %" PRIuS " vs total compressed bits: %" PRIuS + " (%.2f%% accounted for)\n", + total_bits, compressed_bits, total_bits * 100.0 / compressed_bits); + } + } + if (Args()->print_distance_percentiles) { + std::vector sorted = distances; + std::sort(sorted.begin(), sorted.end()); + int p50idx = 0.5 * distances.size(); + int p90idx = 0.9 * distances.size(); + printf("50th/90th percentile distance: %.8f %.8f\n", sorted[p50idx], + sorted[p90idx]); + } +} + +std::vector BenchmarkStats::ComputeColumns( + const std::string& codec_desc, size_t corpus_size) const { + JXL_CHECK(total_input_files == corpus_size); + const double comp_bpp = total_compressed_size * 8.0 / total_input_pixels; + const double adj_comp_bpp = + total_adj_compressed_size * 8.0 / total_input_pixels; + // Note: this is not affected by alpha nor bit depth. + const double compression_speed = + ComputeSpeed(total_input_pixels, total_time_encode); + const double decompression_speed = + ComputeSpeed(total_input_pixels, total_time_decode); + // Already weighted, no need to divide by #channels. + const double rmse = std::sqrt(distance_2 / total_input_pixels); + const double psnr = total_compressed_size == 0 ? 0.0 + : (distance_2 == 0) ? 99.99 + : (20 * std::log10(1 / rmse)); + const double p_norm = distance_p_norm / total_input_pixels; + const double bpp_p_norm = p_norm * comp_bpp; + + std::vector values( + GetColumnDescriptors(extra_metrics.size()).size()); + + values[0].s = codec_desc; + values[1].i = total_input_pixels / 1000; + values[2].i = total_compressed_size; + values[3].f = comp_bpp; + values[4].f = compression_speed; + values[5].f = decompression_speed; + values[6].f = static_cast(max_distance); + values[7].f = p_norm; + values[8].f = psnr; + values[9].f = adj_comp_bpp; + // The DCT2, DCT4, AFV and DCT4X8 are applied to an 8x8 block by having 4x4 + // DCT2X2s, 2x2 DCT4x4s/AFVs, or 2x1 DCT4X8s, filling the whole 8x8 blocks. + // Thus we need to multiply the block count by 8.0 * 8.0 pixels for these + // transforms. + values[10].f = 100.f * jxl_stats.aux_out.num_small_blocks * 8.0 * 8.0 / + total_input_pixels; + values[11].f = 100.f * jxl_stats.aux_out.num_dct4x8_blocks * 8.0 * 8.0 / + total_input_pixels; + values[12].f = + 100.f * jxl_stats.aux_out.num_afv_blocks * 8.0 * 8.0 / total_input_pixels; + values[13].f = 100.f * jxl_stats.aux_out.num_dct8_blocks * 8.0 * 8.0 / + total_input_pixels; + values[14].f = 100.f * jxl_stats.aux_out.num_dct8x16_blocks * 8.0 * 16.0 / + total_input_pixels; + values[15].f = 100.f * jxl_stats.aux_out.num_dct8x32_blocks * 8.0 * 32.0 / + total_input_pixels; + values[16].f = 100.f * jxl_stats.aux_out.num_dct16_blocks * 16.0 * 16.0 / + total_input_pixels; + values[17].f = 100.f * jxl_stats.aux_out.num_dct16x32_blocks * 16.0 * 32.0 / + total_input_pixels; + values[18].f = 100.f * jxl_stats.aux_out.num_dct32_blocks * 32.0 * 32.0 / + total_input_pixels; + values[19].f = 100.f * jxl_stats.aux_out.num_dct32x64_blocks * 32.0 * 64.0 / + total_input_pixels; + values[20].f = 100.f * jxl_stats.aux_out.num_dct64_blocks * 64.0 * 64.0 / + total_input_pixels; + values[21].f = bpp_p_norm; + values[22].i = total_errors; + for (size_t i = 0; i < extra_metrics.size(); i++) { + values[23 + i].f = extra_metrics[i] / total_input_files; + } + return values; +} + +static std::string PrintFormattedEntries( + size_t num_extra_metrics, const std::vector& values) { + const auto& descriptors = GetColumnDescriptors(num_extra_metrics); + + std::string out; + for (size_t i = 0; i < descriptors.size(); i++) { + if (!Args()->more_columns && descriptors[i].more) continue; + std::string value; + if (descriptors[i].type == TYPE_STRING) { + value = values[i].s; + } else if (descriptors[i].type == TYPE_SIZE) { + value = values[i].i ? StringPrintf("%" PRIdS, values[i].i) : "---"; + } else if (descriptors[i].type == TYPE_POSITIVE_FLOAT) { + value = FormatFloat(descriptors[i], values[i].f); + value = FormatFloat(descriptors[i], values[i].f); + } else if (descriptors[i].type == TYPE_COUNT) { + value = StringPrintf("%" PRIdS, values[i].i); + } + + int numspaces = descriptors[i].width - value.size(); + if (numspaces < 1) { + numspaces = 1; + } + // All except the first one are right-aligned, the first one is the name, + // others are numbers with digits matching from the right. + if (i == 0) out += value.c_str(); + out += std::string(numspaces, ' '); + if (i != 0) out += value.c_str(); + } + return out + "\n"; +} + +std::string BenchmarkStats::PrintLine(const std::string& codec_desc, + size_t corpus_size) const { + std::vector values = ComputeColumns(codec_desc, corpus_size); + return PrintFormattedEntries(extra_metrics.size(), values); +} + +std::string PrintHeader(const std::vector& extra_metrics_names) { + std::string out; + // Extra metrics are handled separately. + const auto& descriptors = GetColumnDescriptors(0); + for (size_t i = 0; i < descriptors.size(); i++) { + if (!Args()->more_columns && descriptors[i].more) continue; + const std::string& label = descriptors[i].label; + int numspaces = descriptors[i].width - label.size(); + // All except the first one are right-aligned. + if (i == 0) out += label.c_str(); + out += std::string(numspaces, ' '); + if (i != 0) out += label.c_str(); + } + for (const std::string& em : extra_metrics_names) { + int numspaces = ExtraMetricDescriptor().width - em.size(); + JXL_CHECK(numspaces >= 1); + out += std::string(numspaces, ' '); + out += em; + } + out += '\n'; + for (const auto& descriptor : descriptors) { + if (!Args()->more_columns && descriptor.more) continue; + out += std::string(descriptor.width, '-'); + } + out += std::string(ExtraMetricDescriptor().width * extra_metrics_names.size(), + '-'); + return out + "\n"; +} + +std::string PrintAggregate( + size_t num_extra_metrics, + const std::vector>& aggregate) { + const auto& descriptors = GetColumnDescriptors(num_extra_metrics); + + for (size_t i = 0; i < aggregate.size(); i++) { + // Check when statistics has wrong amount of column entries + JXL_CHECK(aggregate[i].size() == descriptors.size()); + } + + std::vector result(descriptors.size()); + + // Statistics for the aggregate row are combined together with different + // formulas than Assimilate uses for combining the statistics of files. + for (size_t i = 0; i < descriptors.size(); i++) { + if (descriptors[i].type == TYPE_STRING) { + // "---" for the Iters column since this does not have meaning for + // the aggregate stats. + result[i].s = i == 0 ? "Aggregate:" : "---"; + continue; + } + if (descriptors[i].type == TYPE_COUNT) { + size_t sum = 0; + for (size_t j = 0; j < aggregate.size(); j++) { + sum += aggregate[j][i].i; + } + result[i].i = sum; + continue; + } + + ColumnType type = descriptors[i].type; + + double logsum = 0; + size_t numvalid = 0; + for (size_t j = 0; j < aggregate.size(); j++) { + double value = + (type == TYPE_SIZE) ? aggregate[j][i].i : aggregate[j][i].f; + if (value > 0) { + numvalid++; + logsum += std::log2(value); + } + } + double geomean = numvalid ? std::exp2(logsum / numvalid) : 0.0; + + if (type == TYPE_SIZE || type == TYPE_COUNT) { + result[i].i = static_cast(geomean + 0.5); + } else if (type == TYPE_POSITIVE_FLOAT) { + result[i].f = geomean; + } else { + JXL_ABORT("unknown entry type"); + } + } + + return PrintFormattedEntries(num_extra_metrics, result); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/benchmark/benchmark_stats.h b/media/libjxl/src/tools/benchmark/benchmark_stats.h new file mode 100644 index 000000000..a23c4a1ae --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_stats.h @@ -0,0 +1,81 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_STATS_H_ +#define TOOLS_BENCHMARK_BENCHMARK_STATS_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/aux_out.h" + +namespace jxl { + +std::string StringPrintf(const char* format, ...); + +struct JxlStats { + JxlStats() { + num_inputs = 0; + aux_out = AuxOut(); + } + void Assimilate(const JxlStats& victim) { + num_inputs += victim.num_inputs; + aux_out.Assimilate(victim.aux_out); + } + void Print() const { aux_out.Print(num_inputs); } + + size_t num_inputs; + AuxOut aux_out; +}; + +// The value of an entry in the table. Depending on the ColumnType, the string, +// size_t or double should be used. +struct ColumnValue { + std::string s; // for TYPE_STRING + size_t i; // for TYPE_SIZE and TYPE_COUNT + double f; // for TYPE_POSITIVE_FLOAT +}; + +struct BenchmarkStats { + void Assimilate(const BenchmarkStats& victim); + + std::vector ComputeColumns(const std::string& codec_desc, + size_t corpus_size) const; + + std::string PrintLine(const std::string& codec_desc, + size_t corpus_size) const; + + void PrintMoreStats() const; + + size_t total_input_files = 0; + size_t total_input_pixels = 0; + size_t total_compressed_size = 0; + size_t total_adj_compressed_size = 0; + double total_time_encode = 0.0; + double total_time_decode = 0.0; + float max_distance = -1.0; // Max butteraugli score + // sum of 8th powers of butteraugli distmap pixels. + double distance_p_norm = 0.0; + // sum of 2nd powers of differences between R, G, B. + double distance_2 = 0.0; + std::vector distances; + size_t total_errors = 0; + JxlStats jxl_stats; + std::vector extra_metrics; +}; + +std::string PrintHeader(const std::vector& extra_metrics_names); + +// Given the rows of all printed statistics, print an aggregate row. +std::string PrintAggregate( + size_t num_extra_metrics, + const std::vector>& aggregate); + +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_STATS_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_utils.cc b/media/libjxl/src/tools/benchmark/benchmark_utils.cc new file mode 100644 index 000000000..4b531317e --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_utils.cc @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#define _DEFAULT_SOURCE // for mkstemps(). + +#include "tools/benchmark/benchmark_utils.h" + +// Not supported on Windows due to Linux-specific functions. +// Not supported in Android NDK before API 28. +#if !defined(_WIN32) && !defined(__EMSCRIPTEN__) && \ + (!defined(__ANDROID_API__) || __ANDROID_API__ >= 28) + +#include +#include +#include +#include +#include +#include + +#include + +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/image_bundle.h" + +extern char** environ; + +namespace jxl { +TemporaryFile::TemporaryFile(std::string basename, std::string extension) { + const auto extension_size = 1 + extension.size(); + temp_filename_ = std::move(basename) + "_XXXXXX." + std::move(extension); + const int fd = mkstemps(&temp_filename_[0], extension_size); + if (fd == -1) { + ok_ = false; + return; + } + close(fd); +} +TemporaryFile::~TemporaryFile() { + if (ok_) { + unlink(temp_filename_.c_str()); + } +} + +Status TemporaryFile::GetFileName(std::string* const output) const { + JXL_RETURN_IF_ERROR(ok_); + *output = temp_filename_; + return true; +} + +Status RunCommand(const std::string& command, + const std::vector& arguments) { + std::vector args; + args.reserve(arguments.size() + 2); + args.push_back(const_cast(command.c_str())); + for (const std::string& argument : arguments) { + args.push_back(const_cast(argument.c_str())); + } + args.push_back(nullptr); + pid_t pid; + JXL_RETURN_IF_ERROR(posix_spawnp(&pid, command.c_str(), nullptr, nullptr, + args.data(), environ) == 0); + int wstatus; + waitpid(pid, &wstatus, 0); + return WIFEXITED(wstatus) && WEXITSTATUS(wstatus) == EXIT_SUCCESS; +} + +} // namespace jxl + +#else + +namespace jxl { + +TemporaryFile::TemporaryFile(std::string basename, std::string extension) {} +TemporaryFile::~TemporaryFile() {} +Status TemporaryFile::GetFileName(std::string* const output) const { + (void)ok_; + return JXL_FAILURE("Not supported on this build"); +} + +Status RunCommand(const std::string& command, + const std::vector& arguments) { + return JXL_FAILURE("Not supported on this build"); +} + +} // namespace jxl + +#endif // _MSC_VER diff --git a/media/libjxl/src/tools/benchmark/benchmark_utils.h b/media/libjxl/src/tools/benchmark/benchmark_utils.h new file mode 100644 index 000000000..027fa0868 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_utils.h @@ -0,0 +1,35 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_BENCHMARK_BENCHMARK_UTILS_H_ +#define TOOLS_BENCHMARK_BENCHMARK_UTILS_H_ + +#include +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +class TemporaryFile final { + public: + explicit TemporaryFile(std::string basename, std::string extension); + TemporaryFile(const TemporaryFile&) = delete; + TemporaryFile& operator=(const TemporaryFile&) = delete; + ~TemporaryFile(); + Status GetFileName(std::string* output) const; + + private: + bool ok_ = true; + + std::string temp_filename_; +}; + +Status RunCommand(const std::string& command, + const std::vector& arguments); + +} // namespace jxl + +#endif // TOOLS_BENCHMARK_BENCHMARK_UTILS_H_ diff --git a/media/libjxl/src/tools/benchmark/benchmark_xl.cc b/media/libjxl/src/tools/benchmark/benchmark_xl.cc new file mode 100644 index 000000000..fed5e9b1b --- /dev/null +++ b/media/libjxl/src/tools/benchmark/benchmark_xl.cc @@ -0,0 +1,1090 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "jxl/decode.h" +#include "lib/extras/codec.h" +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/time.h" +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "tools/benchmark/benchmark_args.h" +#include "tools/benchmark/benchmark_codec.h" +#include "tools/benchmark/benchmark_file_io.h" +#include "tools/benchmark/benchmark_stats.h" +#include "tools/benchmark/benchmark_utils.h" +#include "tools/codec_config.h" +#include "tools/speed_stats.h" + +namespace jxl { +namespace { + +Status WriteImage(Image3F&& image, ThreadPool* pool, + const std::string& filename) { + CodecInOut io; + io.metadata.m.SetUintSamples(8); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(std::move(image), io.metadata.m.color_encoding); + return EncodeToFile(io, filename, pool); +} + +Status ReadPNG(const std::string& filename, Image3F* image) { + CodecInOut io; + JXL_CHECK(SetFromFile(filename, extras::ColorHints(), &io)); + *image = CopyImage(*io.Main().color()); + return true; +} + +void DoCompress(const std::string& filename, const CodecInOut& io, + const std::vector& extra_metrics_commands, + ImageCodec* codec, ThreadPoolInternal* inner_pool, + std::vector* compressed, BenchmarkStats* s) { + PROFILER_FUNC; + ++s->total_input_files; + + if (io.frames.size() != 1) { + // Multiple frames not supported (io.xsize() will checkfail) + s->total_errors++; + if (!Args()->silent_errors) { + JXL_WARNING("multiframe input image not supported %s", filename.c_str()); + } + return; + } + const size_t xsize = io.xsize(); + const size_t ysize = io.ysize(); + const size_t input_pixels = xsize * ysize; + + jpegxl::tools::SpeedStats speed_stats; + jpegxl::tools::SpeedStats::Summary summary; + + bool valid = true; // false if roundtrip, encoding or decoding errors occur. + + if (!Args()->decode_only && (io.xsize() == 0 || io.ysize() == 0)) { + // This means the benchmark couldn't load the image, e.g. due to invalid + // ICC profile. Warning message about that was already printed. Continue + // this function to indicate it as error in the stats. + valid = false; + } + + std::string ext = FileExtension(filename); + if (valid && !Args()->decode_only) { + for (size_t i = 0; i < Args()->encode_reps; ++i) { + if (codec->CanRecompressJpeg() && (ext == ".jpg" || ext == ".jpeg")) { + std::string data_in; + JXL_CHECK(ReadFile(filename, &data_in)); + JXL_CHECK( + codec->RecompressJpeg(filename, data_in, compressed, &speed_stats)); + } else { + Status status = codec->Compress(filename, &io, inner_pool, compressed, + &speed_stats); + if (!status) { + valid = false; + if (!Args()->silent_errors) { + std::string message = codec->GetErrorMessage(); + if (!message.empty()) { + fprintf(stderr, "Error in %s codec: %s\n", + codec->description().c_str(), message.c_str()); + } else { + fprintf(stderr, "Error in %s codec\n", + codec->description().c_str()); + } + } + } + } + } + JXL_CHECK(speed_stats.GetSummary(&summary)); + s->total_time_encode += summary.central_tendency; + } + + if (valid && Args()->decode_only) { + std::vector data_in; + JXL_CHECK(ReadFile(filename, &data_in)); + compressed->insert(compressed->end(), data_in.begin(), data_in.end()); + } + + // Decompress + CodecInOut io2; + io2.metadata.m = io.metadata.m; + if (valid) { + speed_stats = jpegxl::tools::SpeedStats(); + for (size_t i = 0; i < Args()->decode_reps; ++i) { + if (!codec->Decompress(filename, Span(*compressed), + inner_pool, &io2, &speed_stats)) { + if (!Args()->silent_errors) { + fprintf(stderr, + "%s failed to decompress encoded image. Original source:" + " %s\n", + codec->description().c_str(), filename.c_str()); + } + valid = false; + } + + // io2.dec_pixels increases each time, but the total should be independent + // of decode_reps, so only take the value from the first iteration. + if (i == 0) s->total_input_pixels += io2.dec_pixels; + } + JXL_CHECK(speed_stats.GetSummary(&summary)); + s->total_time_decode += summary.central_tendency; + } + + std::string name = FileBaseName(filename); + std::string codec_name = codec->description(); + + if (!valid) { + s->total_errors++; + } + + if (io.frames.size() != io2.frames.size()) { + if (!Args()->silent_errors) { + // Animated gifs not supported yet? + fprintf(stderr, + "Frame sizes not equal, is this an animated gif? %s %s %" PRIuS + " %" PRIuS "\n", + codec_name.c_str(), name.c_str(), io.frames.size(), + io2.frames.size()); + } + valid = false; + } + + bool lossless = codec->IsJpegTranscoder(); + bool skip_butteraugli = + Args()->skip_butteraugli || Args()->decode_only || lossless; + ImageF distmap; + float max_distance = 1.0f; + + if (valid && !skip_butteraugli) { + JXL_ASSERT(io.frames.size() == io2.frames.size()); + for (size_t i = 0; i < io.frames.size(); i++) { + const ImageBundle& ib1 = io.frames[i]; + ImageBundle& ib2 = io2.frames[i]; + + // Verify output + PROFILER_ZONE("Benchmark stats"); + float distance; + if (SameSize(ib1, ib2)) { + ButteraugliParams params = codec->BaParams(); + if (ib1.metadata()->IntensityTarget() != + ib2.metadata()->IntensityTarget()) { + fprintf(stderr, + "WARNING: input and output images have different intensity " + "targets"); + } + params.intensity_target = ib1.metadata()->IntensityTarget(); + // Hack the default intensity target value to be 80.0, the intensity + // target of sRGB images and a more reasonable viewing default than + // JPEG XL file format's default. + if (fabs(params.intensity_target - 255.0f) < 1e-3) { + params.intensity_target = 80.0; + } + distance = ButteraugliDistance(ib1, ib2, params, GetJxlCms(), &distmap, + inner_pool); + // Ensure pixels in range 0-1 + s->distance_2 += ComputeDistance2(ib1, ib2, GetJxlCms()); + } else { + // TODO(veluca): re-upsample and compute proper distance. + distance = 1e+4f; + distmap = ImageF(1, 1); + distmap.Row(0)[0] = distance; + s->distance_2 += distance; + } + // Update stats + s->distance_p_norm += + ComputeDistanceP(distmap, Args()->ba_params, Args()->error_pnorm) * + input_pixels; + s->max_distance = std::max(s->max_distance, distance); + s->distances.push_back(distance); + max_distance = std::max(max_distance, distance); + } + } + + s->total_compressed_size += compressed->size(); + s->total_adj_compressed_size += compressed->size() * max_distance; + codec->GetMoreStats(s); + + if (io2.frames.size() == 1 && + (Args()->save_compressed || Args()->save_decompressed)) { + JXL_ASSERT(io2.frames.size() == 1); + ImageBundle& ib2 = io2.Main(); + + // By default the benchmark will save the image after roundtrip with the + // same color encoding as the image before roundtrip. Not all codecs + // necessarily preserve the amount of channels (1 for gray, 3 for RGB) + // though, since not all image formats necessarily allow a way to remember + // what amount of channels you happened to give the benchmark codec + // input (say, an RGB-only format) and that is fine since in the end what + // matters is that the pixels look the same on a 3-channel RGB monitor + // while using grayscale encoding is an internal compression optimization. + // If that is the case, output with the current color model instead, + // because CodecInOut does not automatically convert between 1 or 3 + // channels, and giving a ColorEncoding with a different amount of + // channels is not allowed. + const ColorEncoding* c_desired = + (ib2.metadata()->color_encoding.Channels() == + ib2.c_current().Channels()) + ? &ib2.metadata()->color_encoding + : &ib2.c_current(); + // Allow overriding via --output_encoding. + if (!Args()->output_description.empty()) { + c_desired = &Args()->output_encoding; + } + + std::string dir = FileDirName(filename); + std::string outdir = + Args()->output_dir.empty() ? dir + "/out" : Args()->output_dir; + std::string compressed_fn = outdir + "/" + name; + // Add in the parameters of the codec_name in reverse order, so that the + // name of the file format (e.g. jxl) is last. + int pos = static_cast(codec_name.size()) - 1; + while (pos > 0) { + int prev = codec_name.find_last_of(':', pos); + if (prev > pos) prev = -1; + compressed_fn += '.' + codec_name.substr(prev + 1, pos - prev); + pos = prev - 1; + } + std::string decompressed_fn = compressed_fn + Args()->output_extension; +#if JPEGXL_ENABLE_APNG + std::string heatmap_fn = compressed_fn + ".heatmap.png"; +#else + std::string heatmap_fn = compressed_fn + ".heatmap.ppm"; +#endif + JXL_CHECK(MakeDir(outdir)); + if (Args()->save_compressed) { + std::string compressed_str( + reinterpret_cast(compressed->data()), + compressed->size()); + JXL_CHECK(WriteFile(compressed_str, compressed_fn)); + } + if (Args()->save_decompressed && valid) { + // For verifying HDR: scale output. + if (Args()->mul_output != 0.0) { + fprintf(stderr, "WARNING: scaling outputs by %f\n", Args()->mul_output); + JXL_CHECK(ib2.TransformTo(ColorEncoding::LinearSRGB(ib2.IsGray()), + GetJxlCms(), inner_pool)); + ScaleImage(static_cast(Args()->mul_output), ib2.color()); + } + + JXL_CHECK(EncodeToFile(io2, *c_desired, + ib2.metadata()->bit_depth.bits_per_sample, + decompressed_fn)); + if (!skip_butteraugli) { + float good = Args()->heatmap_good > 0.0f ? Args()->heatmap_good + : ButteraugliFuzzyInverse(1.5); + float bad = Args()->heatmap_bad > 0.0f ? Args()->heatmap_bad + : ButteraugliFuzzyInverse(0.5); + JXL_CHECK(WriteImage(CreateHeatMapImage(distmap, good, bad), inner_pool, + heatmap_fn)); + } + } + } + if (!extra_metrics_commands.empty()) { + CodecInOut in_copy; + in_copy.SetFromImage(std::move(*io.Main().Copy().color()), + io.Main().c_current()); + TemporaryFile tmp_in("original", "pfm"); + TemporaryFile tmp_out("decoded", "pfm"); + TemporaryFile tmp_res("result", "txt"); + std::string tmp_in_fn, tmp_out_fn, tmp_res_fn; + JXL_CHECK(tmp_in.GetFileName(&tmp_in_fn)); + JXL_CHECK(tmp_out.GetFileName(&tmp_out_fn)); + JXL_CHECK(tmp_res.GetFileName(&tmp_res_fn)); + + // Convert everything to non-linear SRGB - this is what most metrics expect. + const ColorEncoding& c_desired = ColorEncoding::SRGB(io.Main().IsGray()); + JXL_CHECK(EncodeToFile(io, c_desired, + io.metadata.m.bit_depth.bits_per_sample, tmp_in_fn)); + JXL_CHECK(EncodeToFile( + io2, c_desired, io.metadata.m.bit_depth.bits_per_sample, tmp_out_fn)); + if (io.metadata.m.IntensityTarget() != io2.metadata.m.IntensityTarget()) { + fprintf(stderr, + "WARNING: original and decoded have different intensity targets " + "(%f vs. %f).\n", + io.metadata.m.IntensityTarget(), + io2.metadata.m.IntensityTarget()); + } + std::string intensity_target; + { + std::ostringstream intensity_target_oss; + intensity_target_oss << io.metadata.m.IntensityTarget(); + intensity_target = intensity_target_oss.str(); + } + for (size_t i = 0; i < extra_metrics_commands.size(); i++) { + float res = nanf(""); + bool error = false; + if (RunCommand(extra_metrics_commands[i], + {tmp_in_fn, tmp_out_fn, tmp_res_fn, intensity_target})) { + FILE* f = fopen(tmp_res_fn.c_str(), "r"); + if (fscanf(f, "%f", &res) != 1) { + error = true; + } + fclose(f); + } else { + error = true; + } + if (error) { + fprintf(stderr, + "WARNING: Computation of metric with command %s failed\n", + extra_metrics_commands[i].c_str()); + } + s->extra_metrics.push_back(res); + } + } + + if (Args()->show_progress) { + fprintf(stderr, "."); + fflush(stderr); + } +} + +// Makes a base64 data URI for embedded image in HTML +std::string Base64Image(const std::string& filename) { + PaddedBytes bytes; + if (!ReadFile(filename, &bytes)) { + return ""; + } + static const char* symbols = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string result; + for (size_t i = 0; i < bytes.size(); i += 3) { + uint8_t o0 = bytes[i + 0]; + uint8_t o1 = (i + 1 < bytes.size()) ? bytes[i + 1] : 0; + uint8_t o2 = (i + 2 < bytes.size()) ? bytes[i + 2] : 0; + uint32_t value = (o0 << 16) | (o1 << 8) | o2; + for (size_t j = 0; j < 4; j++) { + result += (i + j <= bytes.size()) ? symbols[(value >> (6 * (3 - j))) & 63] + : '='; + } + } + // NOTE: Chrome supports max 2MB of data this way for URLs, but appears to + // support larger images anyway as long as it's embedded in the HTML file + // itself. If more data is needed, use createObjectURL. + return "data:image;base64," + result; +} + +struct Task { + ImageCodecPtr codec; + size_t idx_image; + size_t idx_method; + const CodecInOut* image; + BenchmarkStats stats; +}; + +void WriteHtmlReport(const std::string& codec_desc, + const std::vector& fnames, + const std::vector& tasks, + const std::vector& images, + bool self_contained) { + std::string toggle_js = + " +)"; + std::string out_html; + std::string outdir; + out_html += "\n"; + out_html += "\n"; + std::string codec_name = codec_desc; + // Make compatible for filename + std::replace(codec_name.begin(), codec_name.end(), ':', '_'); + for (size_t i = 0; i < fnames.size(); ++i) { + std::string name = FileBaseName(fnames[i]); + std::string dir = FileDirName(fnames[i]); + outdir = Args()->output_dir.empty() ? dir + "/out" : Args()->output_dir; + std::string name_out = name + "." + codec_name + Args()->output_extension; + std::string heatmap_out = name + "." + codec_name + ".heatmap.png"; + + std::string fname_orig = fnames[i]; + std::string fname_out = outdir + "/" + name_out; + std::string fname_heatmap = outdir + "/" + heatmap_out; + std::string url_orig = Args()->originals_url.empty() + ? ("file://" + fnames[i]) + : (Args()->originals_url + "/" + name); + std::string url_out = name_out; + std::string url_heatmap = heatmap_out; + if (self_contained) { + url_orig = Base64Image(fname_orig); + url_out = Base64Image(fname_out); + url_heatmap = Base64Image(fname_heatmap); + } + std::string number = StringPrintf("%" PRIuS, i); + const CodecInOut& image = *images[i]; + size_t xsize = image.frames.size() == 1 ? image.xsize() : 0; + size_t ysize = image.frames.size() == 1 ? image.ysize() : 0; + std::string html_width = StringPrintf("%" PRIuS "px", xsize); + std::string html_height = StringPrintf("%" PRIuS "px", ysize); + double bpp = tasks[i]->stats.total_compressed_size * 8.0 / + tasks[i]->stats.total_input_pixels; + double pnorm = + tasks[i]->stats.distance_p_norm / tasks[i]->stats.total_input_pixels; + double max_dist = tasks[i]->stats.max_distance; + std::string compressed_title = StringPrintf( + "compressed. bpp: %f, pnorm: %f, max dist: %f", bpp, pnorm, max_dist); + out_html += "
\n" + " \n" + " \n" + " \n
\n"; + } + out_html += "\n"; + out_html += toggle_js; + JXL_CHECK(WriteFile(out_html, outdir + "/index." + codec_name + ".html")); +} + +// Prints the detailed and aggregate statistics, in the correct order but as +// soon as possible when multithreaded tasks are done. +struct StatPrinter { + StatPrinter(const std::vector& methods, + const std::vector& extra_metrics_names, + const std::vector& fnames, + const std::vector& tasks) + : methods_(&methods), + extra_metrics_names_(&extra_metrics_names), + fnames_(&fnames), + tasks_(&tasks), + tasks_done_(0), + stats_printed_(0), + details_printed_(0) { + stats_done_.resize(methods.size(), 0); + details_done_.resize(tasks.size(), 0); + max_fname_width_ = 0; + for (const auto& fname : fnames) { + max_fname_width_ = std::max(max_fname_width_, FileBaseName(fname).size()); + } + max_method_width_ = 0; + for (const auto& method : methods) { + max_method_width_ = + std::max(max_method_width_, FileBaseName(method).size()); + } + } + + void TaskDone(size_t task_index, const Task& t) { + PROFILER_FUNC; + std::lock_guard guard(mutex); + tasks_done_++; + if (Args()->print_details || Args()->show_progress) { + if (Args()->print_details) { + // Render individual results as soon as they are ready and all previous + // ones in task order are ready. + details_done_[task_index] = 1; + if (task_index == details_printed_) { + while (details_printed_ < tasks_->size() && + details_done_[details_printed_]) { + PrintDetails((*tasks_)[details_printed_]); + details_printed_++; + } + } + } + // When using "show_progress" or "print_details", the table must be + // rendered at the very end, else the details or progress would be + // rendered in-between the table rows. + if (tasks_done_ == tasks_->size()) { + PrintStatsHeader(); + for (size_t i = 0; i < methods_->size(); i++) { + PrintStats((*methods_)[i], i); + } + PrintStatsFooter(); + } + } else { + if (tasks_done_ == 1) { + PrintStatsHeader(); + } + // Render lines of the table as soon as it is ready and all previous + // lines have been printed. + stats_done_[t.idx_method]++; + if (stats_done_[t.idx_method] == fnames_->size() && + t.idx_method == stats_printed_) { + while (stats_printed_ < stats_done_.size() && + stats_done_[stats_printed_] == fnames_->size()) { + PrintStats((*methods_)[stats_printed_], stats_printed_); + stats_printed_++; + } + } + if (tasks_done_ == tasks_->size()) { + PrintStatsFooter(); + } + } + } + + void PrintDetails(const Task& t) { + double comp_bpp = + t.stats.total_compressed_size * 8.0 / t.stats.total_input_pixels; + double p_norm = t.stats.distance_p_norm / t.stats.total_input_pixels; + double bpp_p_norm = p_norm * comp_bpp; + + const double adj_comp_bpp = + t.stats.total_adj_compressed_size * 8.0 / t.stats.total_input_pixels; + + const double rmse = + std::sqrt(t.stats.distance_2 / t.stats.total_input_pixels); + const double psnr = t.stats.total_compressed_size == 0 ? 0.0 + : (t.stats.distance_2 == 0) + ? 99.99 + : (20 * std::log10(1 / rmse)); + size_t pixels = t.stats.total_input_pixels; + + const double enc_mps = + t.stats.total_input_pixels / (1000000.0 * t.stats.total_time_encode); + const double dec_mps = + t.stats.total_input_pixels / (1000000.0 * t.stats.total_time_decode); + if (Args()->print_details_csv) { + printf("%s,%s,%" PRIdS ",%" PRIdS ",%" PRIdS + ",%.8f,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f", + (*methods_)[t.idx_method].c_str(), + FileBaseName((*fnames_)[t.idx_image]).c_str(), + t.stats.total_errors, t.stats.total_compressed_size, pixels, + enc_mps, dec_mps, comp_bpp, t.stats.max_distance, psnr, p_norm, + bpp_p_norm, adj_comp_bpp); + for (float m : t.stats.extra_metrics) { + printf(",%.8f", m); + } + printf("\n"); + } else { + printf("%s", (*methods_)[t.idx_method].c_str()); + for (size_t i = (*methods_)[t.idx_method].size(); i <= max_method_width_; + i++) { + printf(" "); + } + printf("%s", FileBaseName((*fnames_)[t.idx_image]).c_str()); + for (size_t i = FileBaseName((*fnames_)[t.idx_image]).size(); + i <= max_fname_width_; i++) { + printf(" "); + } + printf( + "error:%" PRIdS " size:%8" PRIdS " pixels:%9" PRIdS + " enc_speed:%8.8f dec_speed:%8.8f bpp:%10.8f dist:%10.8f" + " psnr:%10.8f p:%10.8f bppp:%10.8f qabpp:%10.8f ", + t.stats.total_errors, t.stats.total_compressed_size, pixels, enc_mps, + dec_mps, comp_bpp, t.stats.max_distance, psnr, p_norm, bpp_p_norm, + adj_comp_bpp); + for (size_t i = 0; i < t.stats.extra_metrics.size(); i++) { + printf(" %s:%.8f", (*extra_metrics_names_)[i].c_str(), + t.stats.extra_metrics[i]); + } + printf("\n"); + } + fflush(stdout); + } + + void PrintStats(const std::string& method, size_t idx_method) { + PROFILER_FUNC; + // Assimilate all tasks with the same idx_method. + BenchmarkStats method_stats; + std::vector images; + std::vector tasks; + for (const Task& t : *tasks_) { + if (t.idx_method == idx_method) { + method_stats.Assimilate(t.stats); + images.push_back(t.image); + tasks.push_back(&t); + } + } + + std::string out; + + method_stats.PrintMoreStats(); // not concurrent + out += method_stats.PrintLine(method, fnames_->size()); + + if (Args()->write_html_report) { + WriteHtmlReport(method, *fnames_, tasks, images, + Args()->html_report_self_contained); + } + + stats_aggregate_.push_back( + method_stats.ComputeColumns(method, fnames_->size())); + + printf("%s", out.c_str()); + fflush(stdout); + } + + void PrintStatsHeader() { + if (Args()->markdown) { + if (Args()->show_progress) { + fprintf(stderr, "\n"); + fflush(stderr); + } + printf("```\n"); + } + if (fnames_->size() == 1) printf("%s\n", (*fnames_)[0].c_str()); + printf("%s", PrintHeader(*extra_metrics_names_).c_str()); + fflush(stdout); + } + + void PrintStatsFooter() { + printf( + "%s", + PrintAggregate(extra_metrics_names_->size(), stats_aggregate_).c_str()); + if (Args()->markdown) printf("```\n"); + printf("\n"); + fflush(stdout); + } + + const std::vector* methods_; + const std::vector* extra_metrics_names_; + const std::vector* fnames_; + const std::vector* tasks_; + + size_t tasks_done_; + + size_t stats_printed_; + std::vector stats_done_; + + size_t details_printed_; + std::vector details_done_; + + size_t max_fname_width_; + size_t max_method_width_; + + std::vector> stats_aggregate_; + + std::mutex mutex; +}; + +class Benchmark { + using StringVec = std::vector; + + public: + // Return the exit code of the program. + static int Run() { + int ret = EXIT_SUCCESS; + { + PROFILER_FUNC; + + const StringVec methods = GetMethods(); + const StringVec extra_metrics_names = GetExtraMetricsNames(); + const StringVec extra_metrics_commands = GetExtraMetricsCommands(); + const StringVec fnames = GetFilenames(); + bool all_color_aware; + bool jpeg_transcoding_requested; + // (non-const because Task.stats are updated) + std::vector tasks = CreateTasks(methods, fnames, &all_color_aware, + &jpeg_transcoding_requested); + + std::unique_ptr pool; + std::vector> inner_pools; + InitThreads(static_cast(tasks.size()), &pool, &inner_pools); + + const std::vector loaded_images = LoadImages( + fnames, all_color_aware, jpeg_transcoding_requested, pool.get()); + + if (RunTasks(methods, extra_metrics_names, extra_metrics_commands, fnames, + loaded_images, pool.get(), inner_pools, &tasks) != 0) { + ret = EXIT_FAILURE; + if (!Args()->silent_errors) { + fprintf(stderr, "There were error(s) in the benchmark.\n"); + } + } + } + + // Must have exited profiler zone above before calling. + if (Args()->profiler) { + PROFILER_PRINT_RESULTS(); + } + CacheAligned::PrintStats(); + return ret; + } + + private: + static int NumOuterThreads(const int num_hw_threads, const int num_tasks) { + int num_threads = Args()->num_threads; + // Default to #cores + if (num_threads < 0) num_threads = num_hw_threads; + + // As a safety precaution, limit the number of threads to 4x the number of + // available CPUs. + num_threads = + std::min(num_threads, 4 * std::thread::hardware_concurrency()); + + // Don't create more threads than there are tasks (pointless/wasteful). + num_threads = std::min(num_threads, num_tasks); + + // Just one thread is counterproductive. + if (num_threads == 1) num_threads = 0; + + return num_threads; + } + + static int NumInnerThreads(const int num_hw_threads, const int num_threads) { + int num_inner = Args()->inner_threads; + + // Default: distribute remaining cores among tasks. + if (num_inner < 0) { + const int cores_for_outer = num_hw_threads - num_threads; + num_inner = + num_threads == 0 ? num_hw_threads : cores_for_outer / num_threads; + } + + // Just one thread is counterproductive. + if (num_inner == 1) num_inner = 0; + + return num_inner; + } + + static void InitThreads( + const int num_tasks, std::unique_ptr* pool, + std::vector>* inner_pools) { + const int num_hw_threads = std::thread::hardware_concurrency(); + const int num_threads = NumOuterThreads(num_hw_threads, num_tasks); + const int num_inner = NumInnerThreads(num_hw_threads, num_threads); + + fprintf(stderr, + "%d total threads, %d tasks, %d threads, %d inner threads\n", + num_hw_threads, num_tasks, num_threads, num_inner); + + pool->reset(new ThreadPoolInternal(num_threads)); + // Main thread OR worker threads in pool each get a possibly empty nested + // pool (helps use all available cores when #tasks < #threads) + for (size_t i = 0; i < (*pool)->NumThreads(); ++i) { + inner_pools->emplace_back(new ThreadPoolInternal(num_inner)); + } + } + + static StringVec GetMethods() { + StringVec methods = SplitString(Args()->codec, ','); + for (auto it = methods.begin(); it != methods.end();) { + if (it->empty()) { + it = methods.erase(it); + } else { + ++it; + } + } + return methods; + } + + static StringVec GetExtraMetricsNames() { + StringVec metrics = SplitString(Args()->extra_metrics, ','); + for (auto it = metrics.begin(); it != metrics.end();) { + if (it->empty()) { + it = metrics.erase(it); + } else { + *it = SplitString(*it, ':')[0]; + ++it; + } + } + return metrics; + } + + static StringVec GetExtraMetricsCommands() { + StringVec metrics = SplitString(Args()->extra_metrics, ','); + for (auto it = metrics.begin(); it != metrics.end();) { + if (it->empty()) { + it = metrics.erase(it); + } else { + auto s = SplitString(*it, ':'); + JXL_CHECK(s.size() == 2); + *it = s[1]; + ++it; + } + } + return metrics; + } + + static StringVec SampleFromInput(const StringVec& fnames, + const std::string& sample_tmp_dir, + int num_samples, size_t size) { + JXL_CHECK(!sample_tmp_dir.empty()); + fprintf(stderr, "Creating samples of %" PRIuS "x%" PRIuS " tiles...\n", + size, size); + StringVec fnames_out; + std::vector images; + std::vector offsets; + size_t total_num_tiles = 0; + for (const auto& fname : fnames) { + Image3F img; + JXL_CHECK(ReadPNG(fname, &img)); + JXL_CHECK(img.xsize() >= size); + JXL_CHECK(img.ysize() >= size); + total_num_tiles += (img.xsize() - size + 1) * (img.ysize() - size + 1); + offsets.push_back(total_num_tiles); + images.emplace_back(std::move(img)); + } + JXL_CHECK(MakeDir(sample_tmp_dir)); + Rng rng(0); + for (int i = 0; i < num_samples; ++i) { + int val = rng.UniformI(0, offsets.back()); + size_t idx = (std::lower_bound(offsets.begin(), offsets.end(), val) - + offsets.begin()); + JXL_CHECK(idx < images.size()); + const Image3F& img = images[idx]; + int x0 = rng.UniformI(0, img.xsize() - size); + int y0 = rng.UniformI(0, img.ysize() - size); + Image3F sample(size, size); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < size; ++y) { + const float* JXL_RESTRICT row_in = img.PlaneRow(c, y0 + y); + float* JXL_RESTRICT row_out = sample.PlaneRow(c, y); + memcpy(row_out, &row_in[x0], size * sizeof(row_out[0])); + } + } + std::string fn_output = + StringPrintf("%s/%s.crop_%dx%d+%d+%d.png", sample_tmp_dir.c_str(), + FileBaseName(fnames[idx]).c_str(), size, size, x0, y0); + ThreadPool* null_pool = nullptr; + JXL_CHECK(WriteImage(std::move(sample), null_pool, fn_output)); + fnames_out.push_back(fn_output); + } + fprintf(stderr, "Created %d sample tiles\n", num_samples); + return fnames_out; + } + + static StringVec GetFilenames() { + StringVec fnames; + JXL_CHECK(MatchFiles(Args()->input, &fnames)); + if (fnames.empty()) { + JXL_ABORT("No input file matches pattern: '%s'", Args()->input.c_str()); + } + if (Args()->print_details) { + std::sort(fnames.begin(), fnames.end()); + } + + if (Args()->num_samples > 0) { + fnames = SampleFromInput(fnames, Args()->sample_tmp_dir, + Args()->num_samples, Args()->sample_dimensions); + } + return fnames; + } + + // (Load only once, not for every codec) + static std::vector LoadImages( + const StringVec& fnames, const bool all_color_aware, + const bool jpeg_transcoding_requested, ThreadPool* pool) { + PROFILER_FUNC; + std::vector loaded_images; + loaded_images.resize(fnames.size()); + JXL_CHECK(RunOnPool( + pool, 0, static_cast(fnames.size()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t i = static_cast(task); + Status ok = true; + + if (!Args()->decode_only) { + PaddedBytes encoded; + ok = ReadFile(fnames[i], &encoded) && + (jpeg_transcoding_requested + ? jpeg::DecodeImageJPG(Span(encoded), + &loaded_images[i]) + : SetFromBytes(Span(encoded), + Args()->color_hints, &loaded_images[i])); + if (ok && Args()->intensity_target != 0) { + loaded_images[i].metadata.m.SetIntensityTarget( + Args()->intensity_target); + } + } + if (!ok) { + if (!Args()->silent_errors) { + fprintf(stderr, "Failed to load image %s\n", fnames[i].c_str()); + } + return; + } + + if (!Args()->decode_only && all_color_aware) { + const bool is_gray = loaded_images[i].Main().IsGray(); + const ColorEncoding& c_desired = ColorEncoding::LinearSRGB(is_gray); + if (!loaded_images[i].TransformTo(c_desired, GetJxlCms(), + /*pool=*/nullptr)) { + JXL_ABORT("Failed to transform to lin. sRGB %s", + fnames[i].c_str()); + } + } + + if (!Args()->decode_only && Args()->override_bitdepth != 0) { + if (Args()->override_bitdepth == 32) { + loaded_images[i].metadata.m.SetFloat32Samples(); + } else { + loaded_images[i].metadata.m.SetUintSamples( + Args()->override_bitdepth); + } + } + }, + "Load images")); + return loaded_images; + } + + static std::vector CreateTasks(const StringVec& methods, + const StringVec& fnames, + bool* all_color_aware, + bool* jpeg_transcoding_requested) { + std::vector tasks; + tasks.reserve(methods.size() * fnames.size()); + *all_color_aware = true; + *jpeg_transcoding_requested = false; + for (size_t idx_image = 0; idx_image < fnames.size(); ++idx_image) { + for (size_t idx_method = 0; idx_method < methods.size(); ++idx_method) { + tasks.emplace_back(); + Task& t = tasks.back(); + t.codec = CreateImageCodec(methods[idx_method]); + *all_color_aware &= t.codec->IsColorAware(); + *jpeg_transcoding_requested |= t.codec->IsJpegTranscoder(); + t.idx_image = idx_image; + t.idx_method = idx_method; + // t.stats is default-initialized. + } + } + JXL_ASSERT(tasks.size() == tasks.capacity()); + return tasks; + } + + // Return the total number of errors. + static size_t RunTasks( + const StringVec& methods, const StringVec& extra_metrics_names, + const StringVec& extra_metrics_commands, const StringVec& fnames, + const std::vector& loaded_images, ThreadPoolInternal* pool, + const std::vector>& inner_pools, + std::vector* tasks) { + PROFILER_FUNC; + StatPrinter printer(methods, extra_metrics_names, fnames, *tasks); + if (Args()->print_details_csv) { + // Print CSV header + printf( + "method,image,error,size,pixels,enc_speed,dec_speed," + "bpp,dist,psnr,p,bppp,qabpp"); + for (const std::string& s : extra_metrics_names) { + printf(",%s", s.c_str()); + } + printf("\n"); + } + + std::vector errors_thread; + JXL_CHECK(RunOnPool( + pool, 0, tasks->size(), + [&](const size_t num_threads) { + // Reduce false sharing by only writing every 8th slot (64 bytes). + errors_thread.resize(8 * num_threads); + return true; + }, + [&](const uint32_t i, const size_t thread) { + Task& t = (*tasks)[i]; + const CodecInOut& image = loaded_images[t.idx_image]; + t.image = ℑ + std::vector compressed; + DoCompress(fnames[t.idx_image], image, extra_metrics_commands, + t.codec.get(), inner_pools[thread].get(), &compressed, + &t.stats); + printer.TaskDone(i, t); + errors_thread[8 * thread] += t.stats.total_errors; + }, + "Benchmark tasks")); + if (Args()->show_progress) fprintf(stderr, "\n"); + return std::accumulate(errors_thread.begin(), errors_thread.end(), 0); + } +}; + +int BenchmarkMain(int argc, const char** argv) { + fprintf(stderr, "benchmark_xl %s\n", + jpegxl::tools::CodecConfigString(JxlDecoderVersion()).c_str()); + + JXL_CHECK(Args()->AddCommandLineOptions()); + + if (!Args()->Parse(argc, argv)) { + fprintf(stderr, "Use '%s -h' for more information\n", argv[0]); + return 1; + } + + if (Args()->cmdline.HelpFlagPassed()) { + Args()->PrintHelp(); + return 0; + } + if (!Args()->ValidateArgs()) { + fprintf(stderr, "Use '%s -h' for more information\n", argv[0]); + return 1; + } + return Benchmark::Run(); +} + +} // namespace +} // namespace jxl + +int main(int argc, const char** argv) { return jxl::BenchmarkMain(argc, argv); } diff --git a/media/libjxl/src/tools/benchmark/hm/README.md b/media/libjxl/src/tools/benchmark/hm/README.md new file mode 100644 index 000000000..e54904eff --- /dev/null +++ b/media/libjxl/src/tools/benchmark/hm/README.md @@ -0,0 +1,12 @@ +This directory contains encoding and decoding scripts for HEVC, for use with +the benchmark custom codec. They use the HEVC reference encoder at https://hevc.hhi.fraunhofer.de/svn/svn_HEVCSoftware/ +and require the `TAppEncoderHighBitDepthStatic` and +`TAppDecoderHighBitDepthStatic` binaries to be placed in this directory. + +Example usage, for encoding at QP = 30: + +``` +tools/benchmark_xl --input=image.png --codec='custom:bin:.../tools/benchmark/hm/encode.sh:.../tools/benchmark/hm/decode.sh:-q:30' +``` + +The paths to the encode and decode scripts should be adjusted as necessary. diff --git a/media/libjxl/src/tools/benchmark/hm/decode.sh b/media/libjxl/src/tools/benchmark/hm/decode.sh new file mode 100644 index 000000000..624c8ba72 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/hm/decode.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash + +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -euo pipefail + +decoder="$(dirname "$0")"/TAppDecoderHighBitDepthStatic + +usage() { + echo "$0 [-v] " >&2 + exit 1 +} + +verbose=0 + +while getopts ':hv' arg; do + case "$arg" in + h) + usage + ;; + + v) + verbose=1 + ;; + + \?) + echo "Unrecognized option -$OPTARG" >&2 + exit 1 + ;; + esac +done +shift $((OPTIND-1)) + +if [ $# -lt 2 ]; then + usage +fi + +run() { + if [ "$verbose" -eq 1 ]; then + "$@" + else + "$@" > /dev/null 2>&1 + fi +} + +input="$1" +output="$2" + +bin="$(mktemp)" +yuv="$(mktemp)" +width_file="$(mktemp)" +height_file="$(mktemp)" +icc_file="$(mktemp --suffix=.icc)" + +cleanup() { + rm -- "$bin" "$yuv" "$width_file" "$height_file" "$icc_file" +} +trap cleanup EXIT + +unpack_program="$(cat <<'END' + use File::Copy; + my ($input, $bin, $width_file, $height_file, $icc_file) = @ARGV; + open my $input_fh, '<:raw', $input; + sysread($input_fh, my $size, 8) == 8 or die; + my ($width, $height) = unpack 'NN', $size; + open my $width_fh, '>', $width_file; + print {$width_fh} "$width\n"; + open my $height_fh, '>', $height_file; + print {$height_fh} "$height\n"; + sysread($input_fh, my $icc_size, 4) == 4 or die; + $icc_size = unpack 'N', $icc_size; + sysread($input_fh, my $icc_data, $icc_size) == $icc_size or die; + open my $icc_fh, '>', $icc_file; + print {$icc_fh} $icc_data; + copy $input_fh, $bin; +END +)" +run perl -Mstrict -Mwarnings -Mautodie -e "$unpack_program" -- "$input" "$bin" "$width_file" "$height_file" "$icc_file" + +width="$(cat "$width_file")" +height="$(cat "$height_file")" + +start="$EPOCHREALTIME" +run "$decoder" --OutputBitDepth=10 -b "$bin" -o "$yuv" +end="$EPOCHREALTIME" + +elapsed="$(echo "$end - $start" | bc)" +run echo "Completed in $elapsed seconds" + +echo "$elapsed" > "${output%.png}".time + +run ffmpeg -hide_banner -f rawvideo -vcodec rawvideo -s "${width}x$height" -r 25 -pix_fmt yuv444p10le -i "$yuv" -pix_fmt rgb24 -vf scale=in_color_matrix=bt709 -y "$output" +if [ -s "$icc_file" ]; then + mogrify -profile "$icc_file" "$output" +fi diff --git a/media/libjxl/src/tools/benchmark/hm/encode.sh b/media/libjxl/src/tools/benchmark/hm/encode.sh new file mode 100644 index 000000000..319ba6953 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/hm/encode.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -euo pipefail + +encoder="$(dirname "$0")"/TAppEncoderHighBitDepthStatic +cfg_dir="$(dirname "$0")"/../../../third_party/HEVCSoftware/cfg + +usage() { + echo "$0 [-v] [-q ] " >&2 + exit 1 +} + +q=27 +verbose=0 + +while getopts ':hq:v' arg; do + case "$arg" in + h) + usage + ;; + + q) + q="$OPTARG" + ;; + + v) + verbose=1 + ;; + + \?) + echo "Unrecognized option -$OPTARG" >&2 + exit 1 + ;; + esac +done +shift $((OPTIND-1)) + +if [ $# -lt 2 ]; then + usage +fi + +run() { + if [ "$verbose" -eq 1 ]; then + "$@" + else + "$@" > /dev/null 2>&1 + fi +} + +input="$1" +output="$2" + +yuv="$(mktemp)" +bin="$(mktemp)" + +to_clean=("$yuv" "$bin") +cleanup() { + rm -- "${to_clean[@]}" +} +trap cleanup EXIT + +run ffmpeg -hide_banner -i "$input" -pix_fmt yuv444p10le -vf scale=out_color_matrix=bt709 -color_primaries bt709 -color_trc bt709 -colorspace bt709 -f rawvideo -y "$yuv" + +width="$(identify -format '%w' "$input")" +height="$(identify -format '%h' "$input")" + +start="$EPOCHREALTIME" +run "$encoder" -c "$cfg_dir"/encoder_intra_main_scc_10.cfg -f 1 -fr 1 -wdt "$width" -hgt "$height" --InputChromaFormat=444 --InputBitDepth=10 --ConformanceWindowMode=1 -i "$yuv" -b "$bin" -q "$q" +end="$EPOCHREALTIME" + +elapsed="$(echo "$end - $start" | bc)" +run echo "Completed in $elapsed seconds" + +echo "$elapsed" > "${output%.bin}".time + +icc="${output%.*}.icc" +if run convert "$input" "$icc"; then + to_clean+=("$icc") +fi + +pack_program="$(cat <<'END' + use File::Copy; + use IO::Handle; + my ($width, $height, $bin, $icc, $output) = @ARGV; + open my $output_fh, '>:raw', $output; + syswrite $output_fh, pack 'NN', $width, $height; + syswrite $output_fh, pack 'N', -s $icc; + copy $icc, $output_fh; + copy $bin, $output_fh; +END +)" +run perl -Mstrict -Mwarnings -Mautodie -e "$pack_program" -- "$width" "$height" "$bin" "$icc" "$output" diff --git a/media/libjxl/src/tools/benchmark/metrics/compute-hdrvdp.m b/media/libjxl/src/tools/benchmark/metrics/compute-hdrvdp.m new file mode 100644 index 000000000..60e40bf32 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/compute-hdrvdp.m @@ -0,0 +1,17 @@ +% Copyright (c) the JPEG XL Project Authors. All rights reserved. +% +% Use of this source code is governed by a BSD-style +% license that can be found in the LICENSE file. + +pkg load image; + +args = argv(); + +original_filename = args{1}; +decoded_filename = args{2}; + +original = pfs_read_luminance(original_filename); +decoded = pfs_read_luminance(decoded_filename); + +res = hdrvdp(decoded, original, 'luminance', 30, {}); +printf("%f\n", res.Q); diff --git a/media/libjxl/src/tools/benchmark/metrics/compute-pumetrics.m b/media/libjxl/src/tools/benchmark/metrics/compute-pumetrics.m new file mode 100644 index 000000000..df0fe4bd0 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/compute-pumetrics.m @@ -0,0 +1,26 @@ +% Copyright (c) the JPEG XL Project Authors. All rights reserved. +% +% Use of this source code is governed by a BSD-style +% license that can be found in the LICENSE file. + +pkg load image; + +args = argv(); + +metric = args{1}; +original_filename = args{2}; +decoded_filename = args{3}; + +original = pfs_read_luminance(original_filename); +decoded = pfs_read_luminance(decoded_filename); + +switch (metric) + case "psnr" + res = qm_pu2_psnr(original, decoded); + case "ssim" + res = qm_pu2_ssim(original, decoded); + otherwise + error(sprintf("unrecognized metric %s", metric)); +end + +printf("%f\n", res); diff --git a/media/libjxl/src/tools/benchmark/metrics/compute_octave_metric.sh b/media/libjxl/src/tools/benchmark/metrics/compute_octave_metric.sh new file mode 100644 index 000000000..a31c26659 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/compute_octave_metric.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Usage: ./compute-octave-metric.sh [octave args...] +# Where octave args do not need to contain -qf or the path to the original and decoded images. + +set -euo pipefail + +original="$1" +decoded="$2" +output="$3" +intensity_target="$4" +shift 4 + +tmpdir="$(mktemp --directory)" + +linearized_original="$(mktemp --tmpdir="$tmpdir" --suffix='.pfm')" +linearized_decoded="$(mktemp --tmpdir="$tmpdir" --suffix='.pfm')" + +cleanup() { + rm -- "$linearized_original" "$linearized_decoded" + rmdir --ignore-fail-on-non-empty -- "$tmpdir" +} +trap cleanup EXIT + +linearize() { + local input="$1" + local output="$2" + convert "$input" -set colorspace sRGB -colorspace RGB -evaluate multiply "$intensity_target" "$output" +} + +linearize "$original" "$linearized_original" +linearize "$decoded" "$linearized_decoded" + +octave -qf "$@" \ + "$linearized_original" "$linearized_decoded" \ + 2> /dev/null \ + > "$output" diff --git a/media/libjxl/src/tools/benchmark/metrics/dists-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/dists-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/dists-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/fsim-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/fsim-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/fsim-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/fsim-y.sh b/media/libjxl/src/tools/benchmark/metrics/fsim-y.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/fsim-y.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/gmsd-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/gmsd-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/gmsd-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/hdr_plots.sh b/media/libjxl/src/tools/benchmark/metrics/hdr_plots.sh new file mode 100644 index 000000000..4ce5d9fc4 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/hdr_plots.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +"$(dirname "$0")/run_all_hdr_metrics.sh" "$@" | sed -n '/```/q;p' > hdr_results.csv +mkdir -p hdr_plots/ +rm -rf hdr_plots/* +python3 "$(dirname "$0")/plots.py" hdr_results.csv hdr_plots diff --git a/media/libjxl/src/tools/benchmark/metrics/hdrvdp-fixes.patch b/media/libjxl/src/tools/benchmark/metrics/hdrvdp-fixes.patch new file mode 100644 index 000000000..23f3f17b6 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/hdrvdp-fixes.patch @@ -0,0 +1,110 @@ +From 44a21be2c4de409f80d90cbcc2c20cb3f42e859e Mon Sep 17 00:00:00 2001 +From: Sami Boukortt +Date: Fri, 16 Oct 2020 20:01:02 +0200 +Subject: [PATCH] Fixes for Octave +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 +Content-Transfer-Encoding: 8bit + +Copyright (c) the JPEG XL Project Authors. All rights reserved. + +Use of this source code is governed by a BSD-style +license that can be found in the LICENSE file. + +---- + +ifft2: https://savannah.gnu.org/bugs/?43742 + +Removing #include : https://octave.org/doc/v5.2.0/Getting-Started-with-Mex_002dFiles.html +“One important difference between Octave and MATLAB is that the header +"matrix.h" is implicitly included through the inclusion of "mex.h".” + +Length checks: it appears that functions(…).file for MEX files in Octave +is empty. +--- + fast_conv_fft.m | 2 +- + matlabPyrTools_1.4_fixed/MEX/corrDn.c | 1 - + matlabPyrTools_1.4_fixed/MEX/pointOp.c | 1 - + matlabPyrTools_1.4_fixed/MEX/upConv.c | 1 - + matlabPyrTools_1.4_fixed/reconSpyr.m | 2 +- + matlabPyrTools_1.4_fixed/reconSpyrLevs.m | 2 +- + 6 files changed, 3 insertions(+), 6 deletions(-) + +diff --git a/fast_conv_fft.m b/fast_conv_fft.m +index 65ceef8..b89e54b 100644 +--- a/fast_conv_fft.m ++++ b/fast_conv_fft.m +@@ -16,7 +16,7 @@ pad_size = (size(fH)-size(X)); + + fX = fft2( padarray( X, pad_size, pad_value, 'post' ) ); + +-Yl = real(ifft2( fX.*fH, size(fX,1), size(fX,2), 'symmetric' )); ++Yl = real(ifft2( fX.*fH, size(fX,1), size(fX,2))); + + Y = Yl(1:size(X,1),1:size(X,2)); + +diff --git a/matlabPyrTools_1.4_fixed/MEX/corrDn.c b/matlabPyrTools_1.4_fixed/MEX/corrDn.c +index d02e272..17e739e 100755 +--- a/matlabPyrTools_1.4_fixed/MEX/corrDn.c ++++ b/matlabPyrTools_1.4_fixed/MEX/corrDn.c +@@ -6,7 +6,6 @@ RES = corrDn(IM, FILT, EDGES, STEP, START, STOP); + */ + + #define V4_COMPAT +-#include /* Matlab matrices */ + #include + + #include "convolve.h" +diff --git a/matlabPyrTools_1.4_fixed/MEX/pointOp.c b/matlabPyrTools_1.4_fixed/MEX/pointOp.c +index 3623a02..e553adf 100755 +--- a/matlabPyrTools_1.4_fixed/MEX/pointOp.c ++++ b/matlabPyrTools_1.4_fixed/MEX/pointOp.c +@@ -5,7 +5,6 @@ RES = pointOp(IM, LUT, ORIGIN, INCREMENT, WARNINGS) + */ + + #define V4_COMPAT +-#include /* Matlab matrices */ + #include + + #include /* NULL */ +diff --git a/matlabPyrTools_1.4_fixed/MEX/upConv.c b/matlabPyrTools_1.4_fixed/MEX/upConv.c +index 98a2bec..08fdf75 100755 +--- a/matlabPyrTools_1.4_fixed/MEX/upConv.c ++++ b/matlabPyrTools_1.4_fixed/MEX/upConv.c +@@ -6,7 +6,6 @@ RES = upConv(IM, FILT, EDGES, STEP, START, STOP, RES); + */ + + #define V4_COMPAT +-#include /* Matlab matrices */ + #include + + #include "convolve.h" +diff --git a/matlabPyrTools_1.4_fixed/reconSpyr.m b/matlabPyrTools_1.4_fixed/reconSpyr.m +index 05eeafb..1440d8a 100644 +--- a/matlabPyrTools_1.4_fixed/reconSpyr.m ++++ b/matlabPyrTools_1.4_fixed/reconSpyr.m +@@ -31,7 +31,7 @@ function res = reconSpyr(pyr, pind, filtfile, edges, levs, bands) + % Deterimine whether a MEX version of upConv is available + is_mex = true; + finfo = functions( @upConv ); +-if( strcmp( finfo.file((end-2):end), '.m') ) ++if( length(finfo.file) > 2 && strcmp( finfo.file((end-2):end), '.m') ) + is_mex = false; + end + +diff --git a/matlabPyrTools_1.4_fixed/reconSpyrLevs.m b/matlabPyrTools_1.4_fixed/reconSpyrLevs.m +index ac5e2b1..d3b91d5 100644 +--- a/matlabPyrTools_1.4_fixed/reconSpyrLevs.m ++++ b/matlabPyrTools_1.4_fixed/reconSpyrLevs.m +@@ -11,7 +11,7 @@ function res = reconSpyrLevs(pyr,pind,lofilt,bfilts,edges,levs,bands) + % Deterimine whether MEX version of upConv is available + is_mex = true; + finfo = functions( @upConv ); +-if( strcmp( finfo.file((end-2):end), '.m') ) ++if( length(finfo.file) > 2 && strcmp( finfo.file((end-2):end), '.m') ) + is_mex = false; + end + +-- +2.28.0 + diff --git a/media/libjxl/src/tools/benchmark/metrics/hdrvdp.sh b/media/libjxl/src/tools/benchmark/metrics/hdrvdp.sh new file mode 100644 index 000000000..659ab8530 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/hdrvdp.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +"$(dirname "$0")"/compute_octave_metric.sh "$@" \ + --path "$(dirname "$0")"/../../../third_party/hdrvdp-2.2.2/ \ + "$(dirname "$0")"/compute-hdrvdp.m diff --git a/media/libjxl/src/tools/benchmark/metrics/iqa.py b/media/libjxl/src/tools/benchmark/metrics/iqa.py new file mode 100644 index 000000000..1be969992 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/iqa.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +import os +import sys +import pathlib +import torch +from torchvision import transforms +import numpy as np + +path = pathlib.Path(__file__).parent.absolute( +) / '..' / '..' / '..' / 'third_party' / 'IQA-optimization' +sys.path.append(str(path)) + +from IQA_pytorch import SSIM, MS_SSIM, CW_SSIM, GMSD, LPIPSvgg, DISTS, NLPD, FSIM, VSI, VIFs, VIF, MAD + + +# only really works with the output from JXL, but we don't need more than that. +def read_pfm(fname): + with open(fname, 'rb') as f: + header_width_height = [] + while len(header_width_height) < 3: + header_width_height += f.readline().rstrip().split() + header, width, height = header_width_height + assert header == b'PF' or header == b'Pf' + width, height = int(width), int(height) + scale = float(f.readline().rstrip()) + fmt = 'f' + data = np.fromfile(f, fmt) + if header == b'PF': + out = np.reshape(data, (height, width, 3))[::-1, :, :] + else: + out = np.reshape(data, (height, width))[::-1, :] + return out.astype(np.float) + + +D_dict = { + 'cwssim': CW_SSIM, + 'dists': DISTS, + 'fsim': FSIM, + 'gmsd': GMSD, + 'lpips': LPIPSvgg, + 'mad': MAD, + 'msssim': MS_SSIM, + 'nlpd': NLPD, + 'ssim': SSIM, + 'vif': VIF, + 'vsi': VSI, +} + +algo = os.path.basename(sys.argv[1]).split('.')[0] +algo, color = algo.split('-') + +channels = 3 + +if color == 'y': + channels = 1 + + +def Load(path): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + img = read_pfm(path) + if len(img.shape) == 3 and channels == 1: # rgb -> Y + assert img.shape[2] == 3 + tmp = np.zeros((img.shape[0], img.shape[1], 1), dtype=float) + tmp[:, :, 0] = (0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + + 0.0722 * img[:, :, 2]) + img = tmp + if len(img.shape) == 2 and channels == 3: # Y -> rgb + gray = img + img = np.zeros((img.shape[0], img.shape[1], 3), dtype=float) + img[:, :, 0] = img[:, :, 1] = img[:, :, 2] = gray + if len(img.shape) == 3: + img = np.transpose(img, axes=(2, 0, 1)).copy() + return torch.FloatTensor(img).unsqueeze(0).to(device) + + +ref_img = Load(sys.argv[2]) +enc_img = Load(sys.argv[3]) +D = D_dict[algo](channels=channels) +score = D(ref_img, enc_img, as_loss=False) + +with open(sys.argv[4], 'w') as f: + print(score.item(), file=f) diff --git a/media/libjxl/src/tools/benchmark/metrics/iqa_wrapper.sh b/media/libjxl/src/tools/benchmark/metrics/iqa_wrapper.sh new file mode 100644 index 000000000..1d179fded --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/iqa_wrapper.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +python3 "$(dirname "$0")/iqa.py" "$0" "$@" diff --git a/media/libjxl/src/tools/benchmark/metrics/lpips-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/lpips-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/lpips-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/mrse.sh b/media/libjxl/src/tools/benchmark/metrics/mrse.sh new file mode 100644 index 000000000..54d18d6fe --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/mrse.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -euo pipefail + +original="$1" +decoded="$2" +output="$3" +intensity_target="$4" + +tmpdir="$(mktemp --directory)" + +linearized_original="$(mktemp --tmpdir="$tmpdir" --suffix='.pfm')" +linearized_decoded="$(mktemp --tmpdir="$tmpdir" --suffix='.pfm')" + +cleanup() { + rm -- "$linearized_original" "$linearized_decoded" + rmdir --ignore-fail-on-non-empty -- "$tmpdir" +} +trap cleanup EXIT + +linearize() { + local input="$1" + local output="$2" + convert "$input" -set colorspace sRGB -colorspace RGB -evaluate multiply "$intensity_target" "$output" +} + +linearize "$original" "$linearized_original" +linearize "$decoded" "$linearized_decoded" + +"$(dirname "$0")"/../../../third_party/difftest_ng/difftest_ng --mrse "$linearized_original" "$linearized_decoded" \ + | sed -e 's/^MRSE:\s*//' \ + > "$output" diff --git a/media/libjxl/src/tools/benchmark/metrics/msssim-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/msssim-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/msssim-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/msssim-y.sh b/media/libjxl/src/tools/benchmark/metrics/msssim-y.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/msssim-y.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/nlpd-y.sh b/media/libjxl/src/tools/benchmark/metrics/nlpd-y.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/nlpd-y.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/plots.py b/media/libjxl/src/tools/benchmark/metrics/plots.py new file mode 100644 index 000000000..04b2bb24e --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/plots.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +import csv +import sys +import math +import plotly.graph_objects as go + +_, results, output_dir, *rest = sys.argv +OUTPUT = rest[0] if rest else 'svg' +# valid values: html, svg, png, webp, jpeg, pdf + +with open(results, 'r') as f: + reader = csv.DictReader(f) + all_results = list(reader) + +nonmetric_columns = set([ + "method", "image", "error", "size", "pixels", "enc_speed", "dec_speed", + "bpp", "bppp", "qabpp" +]) + +metrics = set(all_results[0].keys()) - nonmetric_columns + + +def codec(method): + sm = method.split(':') + ssm = set(sm) + speeds = set([ + 'kitten', 'falcon', 'wombat', 'cheetah', 'tortoise', 'squirrel', + 'hare', 'fast' + ]) + s = speeds.intersection(ssm) + if sm[0] == 'custom': + return sm[1] + if sm[0] == 'jxl' and s: + return 'jxl-' + list(s)[0] + return sm[0] + + +data = {(m, img): {c: [] + for c in {codec(x['method']) + for x in all_results}} + for m in metrics for img in {x['image'] + for x in all_results}} + +for r in all_results: + c = codec(r['method']) + img = r['image'] + bpp = r['bpp'] + for m in metrics: + data[(m, img)][c].append((float(bpp), float(r[m]))) + + +def pos(codec): + if 'jxl-dis' in codec: + return 6, codec + elif 'jxl' in codec: + return 7, codec + elif 'avif' in codec: + return 5, codec + elif 'kdu' in codec: + return 4, codec + elif 'heif' in codec: + return 3, codec + elif 'fuif' in codec or 'pik' in codec: + return 2, codec + elif 'jpg' in codec or 'jpeg' in codec or 'web' in codec: + return 1, codec + else: + return 0, codec + + +def style(codec): + configs = { + 'jxl-cheetah': { + 'color': '#e41a1c', + 'dash': '1px, 1px', + 'width': 2 + }, + 'jxl-wombat': { + 'color': '#e41a1c', + 'dash': '2px, 2px', + 'width': 2 + }, + 'jxl-squirrel': { + 'color': '#e41a1c', + 'dash': '5px, 5px', + 'width': 2 + }, + 'jxl-kitten': { + 'color': '#e41a1c', + 'width': 2 + }, + 'jxl-dis-cheetah': { + 'color': '#377eb8', + 'dash': '1px, 1px', + 'width': 2 + }, + 'jxl-dis-wombat': { + 'color': '#377eb8', + 'dash': '2px, 2px', + 'width': 2 + }, + 'jxl-dis-squirrel': { + 'color': '#377eb8', + 'dash': '5px, 5px', + 'width': 2 + }, + 'jxl-dis-kitten': { + 'color': '#377eb8', + 'width': 2 + }, + 'rav1e.avif': { + 'color': '#4daf4a', + 'dash': '3px, 3px', + 'width': 2 + }, + '420.rav1e.avif': { + 'color': '#4daf4a', + 'dash': '1px, 1px', + 'width': 2 + }, + '444.rav1e.avif': { + 'color': '#4daf4a', + 'dash': '3px, 3px', + 'width': 2 + }, + 'psnr.420.aom.avif': { + 'color': '#4daf4a', + 'dash': '5px, 5px', + 'width': 2 + }, + 'psnr.444.aom.avif': { + 'color': '#4daf4a', + 'dash': '7px, 7px', + 'width': 2 + }, + 'ssim.420.aom.avif': { + 'color': '#4daf4a', + 'dash': '9px, 9px', + 'width': 2 + }, + 'ssim.444.aom.avif': { + 'color': '#4daf4a', + 'width': 2 + }, + 'heif': { + 'color': '#984ea3', + 'width': 2 + }, + 'fuif': { + 'color': '#ff7f00', + 'dash': '2px, 2px', + 'width': 2 + }, + 'pik-cfp': { + 'color': '#ff7f00', + 'width': 2 + }, + 'pik-cfp-fast': { + 'color': '#ff7f00', + 'dash': '4px, 4px', + 'width': 2 + }, + 'webp': { + 'color': '#000000', + 'width': 2 + }, + 'jpeg': { + 'color': '#a65628', + 'width': 2 + }, + 'xt.jpg': { + 'color': '#a65628', + 'width': 2 + }, + 'perc1.kdu.j2k': { + 'color': '#f781bf', + 'dash': '1px, 1px', + 'width': 2 + }, + 'perc2.kdu.j2k': { + 'color': '#f781bf', + 'dash': '3px, 3px', + 'width': 2 + }, + 'perc3.kdu.j2k': { + 'color': '#f781bf', + 'dash': '5px, 5px', + 'width': 2 + }, + 'perc4.kdu.j2k': { + 'color': '#f781bf', + 'dash': '7px, 7px', + 'width': 2 + }, + 'default.kdu.j2k': { + 'color': '#f781bf', + 'width': 2 + }, + } + return configs.get(codec, dict()) + + +visible_by_default = set([ + 'jxl-kitten', 'ssim.444.aom.avif', 'heif', 'webp', 'jpeg', 'xt.jpg', + 'default.kdu.j2k' +]) + +column_remap = { + 'p': '6-Butteraugli', + 'dist': 'Max-Butteraugli', + 'psnr': "PSNR-YUV 6/8 Y", + 'MS-SSIM-Y': '-log10(1 - MS-SSIM-Y)', + 'puSSIM': '-log10(1 - puSSIM)', + 'FSIM-Y': '-log10(1 - FSIM-Y)', + 'FSIM-RGB': '-log10(1 - FSIM-RGB)', + 'VMAF': '-log10(1 - VMAF / 100)', +} + + +def remap(metric): + funs = { + 'MS-SSIM-Y': lambda x: -math.log10(1 - x), + 'puSSIM': lambda x: -math.log10(1 - x), + 'FSIM-Y': lambda x: -math.log10(1 - x), + 'FSIM-RGB': lambda x: -math.log10(1 - x), + 'VMAF': lambda x: -math.log10(1 + 1e-8 - x / 100), + } + return funs.get(metric, lambda x: x) + + +for (m, img) in data: + fname = "%s/%s_%s" % (output_dir, m, img) + fig = go.Figure() + for method in sorted(data[(m, img)].keys(), key=pos): + vals = data[(m, img)][method] + zvals = list(zip(*sorted(vals))) + if not zvals: + continue + fig.add_trace( + go.Scatter(x=zvals[0], + y=[remap(m)(x) for x in zvals[1]], + mode='lines', + name=method, + line=style(method), + visible=True + if method in visible_by_default else 'legendonly')) + fig.update_layout(title=img, + xaxis_title='bpp', + yaxis_title=column_remap.get(m, m)) + fig.update_xaxes(type='log') + if OUTPUT == 'html': + fig.write_html(fname + '.html', include_plotlyjs='directory') + else: + fig.write_image(fname + '.' + OUTPUT, scale=4) diff --git a/media/libjxl/src/tools/benchmark/metrics/prepare_metrics.sh b/media/libjxl/src/tools/benchmark/metrics/prepare_metrics.sh new file mode 100644 index 000000000..7ecfaaf19 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/prepare_metrics.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -eu + +MYDIR=$(dirname $(realpath "$0")) + + +main() { + cd "${MYDIR}/../../../third_party" + local zipurl + local repourl + for repourl in \ + 'https://github.com/veluca93/IQA-optimization.git' \ + 'https://github.com/Netflix/vmaf.git' \ + 'https://github.com/thorfdbg/difftest_ng.git' + do + local reponame=$(basename "${repourl%.git}") + local dirname=$(basename "${reponame}") + if [[ ! -e "${dirname}" ]]; then + git clone "${repourl}" + fi + done + for zipurl in \ + 'https://sourceforge.net/projects/hdrvdp/files/hdrvdp/2.2.2/hdrvdp-2.2.2.zip' \ + 'https://sourceforge.net/projects/hdrvdp/files/simple_metrics/1.0/hdr_metrics.zip' + do + local zipfile="$(basename "${zipurl}")" + local dirname="$(basename "${zipfile}" '.zip')" + rm -fr "${dirname}" + if [[ ! -e "${zipfile}" ]]; then + wget -O "${zipfile}.tmp" "${zipurl}" + mv "${zipfile}.tmp" "${zipfile}" + fi + unzip "${zipfile}" "${dirname}"/'*' + done + + pushd hdrvdp-2.2.2 + patch -p1 < ../../tools/benchmark/metrics/hdrvdp-fixes.patch + pushd matlabPyrTools_1.4_fixed + mkoctfile --mex MEX/corrDn.c MEX/convolve.c MEX/wrap.c MEX/edges.c + mkoctfile --mex MEX/pointOp.c + mkoctfile --mex MEX/upConv.c + popd + popd + + + pushd difftest_ng + ./configure + make + popd + + + pushd vmaf/libvmaf + rm -rf build + meson build --buildtype release + ninja -vC build + popd +} +main "$@" + diff --git a/media/libjxl/src/tools/benchmark/metrics/pupsnr.sh b/media/libjxl/src/tools/benchmark/metrics/pupsnr.sh new file mode 100644 index 000000000..869fc3617 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/pupsnr.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +./compute_octave_metric.sh "$@" \ + --path "$(dirname "$0")"/../../../third_party/hdr_metrics/ \ + "$(dirname "$0")"/compute-pumetrics.m 'psnr' diff --git a/media/libjxl/src/tools/benchmark/metrics/pussim.sh b/media/libjxl/src/tools/benchmark/metrics/pussim.sh new file mode 100644 index 000000000..957cfa1dc --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/pussim.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +./compute_octave_metric.sh "$@" \ + --path "$(dirname "$0")"/../../../third_party/hdr_metrics/ \ + "$(dirname "$0")"/compute-pumetrics.m 'ssim' diff --git a/media/libjxl/src/tools/benchmark/metrics/run_all_hdr_metrics.sh b/media/libjxl/src/tools/benchmark/metrics/run_all_hdr_metrics.sh new file mode 100644 index 000000000..5fb769d66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/run_all_hdr_metrics.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -eu +dir="$(dirname "$0")" + +main() { + local metrics=( + HDR-VDP:"${dir}"/hdrvdp.sh + MRSE:"${dir}"/mrse.sh + puPSNR:"${dir}"/pupsnr.sh + puSSIM:"${dir}"/pussim.sh + ) + + local metrics_args=$(printf '%s' "${metrics[@]/#/,}") + metrics_args=${metrics_args:1} + + + "${dir}/../../../build/tools/benchmark_xl" \ + --print_details_csv \ + --num_threads=32 \ + --error_pnorm=6 \ + --extra_metrics ${metrics_args} \ + "$@" +} + +main "$@" diff --git a/media/libjxl/src/tools/benchmark/metrics/run_all_sdr_metrics.sh b/media/libjxl/src/tools/benchmark/metrics/run_all_sdr_metrics.sh new file mode 100644 index 000000000..def887b09 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/run_all_sdr_metrics.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -eu +dir="$(dirname "$0")" + +main() { + local metrics=( + FSIM-Y:"${dir}"/fsim-y.sh + FSIM-RGB:"${dir}"/fsim-rgb.sh + LPIPS:"${dir}"/lpips-rgb.sh + MS-SSIM-Y:"${dir}"/msssim-y.sh + NLPD:"${dir}"/nlpd-y.sh + SSIMULACRA:"${dir}"/ssimulacra.sh + VIF:"${dir}"/vif-rgb.sh + VMAF:"${dir}"/vmaf.sh + ) + # other metrics, not in core experiments: +# VSI:"${dir}"/vsi-rgb.sh +# SSIM-RGB:"${dir}"/ssim-rgb.sh +# SSIM-Y:"${dir}"/ssim-y.sh +# GMSD:"${dir}"/gmsd.sh +# DISTS:"${dir}"/dists-rgb.sh +# MS-SSIM-RGB:"${dir}"/msssim-rgb.sh + + local metrics_args=$(printf '%s' "${metrics[@]/#/,}") + metrics_args=${metrics_args:1} + + + "${dir}/../../../build/tools/benchmark_xl" \ + --print_details_csv \ + --num_threads=1 \ + --error_pnorm=6 \ + --extra_metrics ${metrics_args} \ + "$@" +} + +main "$@" diff --git a/media/libjxl/src/tools/benchmark/metrics/sdr_plots.sh b/media/libjxl/src/tools/benchmark/metrics/sdr_plots.sh new file mode 100644 index 000000000..d97648e8f --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/sdr_plots.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +"$(dirname "$0")/run_all_sdr_metrics.sh" "$@" | sed -n '/```/q;p' > sdr_results.csv +mkdir -p sdr_plots/ +rm -rf sdr_plots/* +python3 "$(dirname "$0")/plots.py" sdr_results.csv sdr_plots diff --git a/media/libjxl/src/tools/benchmark/metrics/ssim-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/ssim-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/ssim-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/ssim-y.sh b/media/libjxl/src/tools/benchmark/metrics/ssim-y.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/ssim-y.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/ssimulacra.sh b/media/libjxl/src/tools/benchmark/metrics/ssimulacra.sh new file mode 100644 index 000000000..65617d1c0 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/ssimulacra.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +"$(dirname "$0")"/../../../build/tools/ssimulacra_main "$1" "$2" > "$3" 2>/dev/null diff --git a/media/libjxl/src/tools/benchmark/metrics/vif-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/vif-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/vif-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/benchmark/metrics/vmaf.sh b/media/libjxl/src/tools/benchmark/metrics/vmaf.sh new file mode 100644 index 000000000..ab406d011 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/vmaf.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -euo pipefail + +original="$1" +decoded="$2" +output="$3" + +tmpdir="$(mktemp --directory)" + +exr_original="$(mktemp --tmpdir="$tmpdir" --suffix='.exr')" +exr_decoded="$(mktemp --tmpdir="$tmpdir" --suffix='.exr')" + +yuv_original="$(mktemp --tmpdir="$tmpdir" --suffix='.yuv')" +yuv_decoded="$(mktemp --tmpdir="$tmpdir" --suffix='.yuv')" + +vmaf_csv="$(mktemp --tmpdir="$tmpdir" --suffix='.csv')" + +cleanup() { + rm -- "$exr_original" "$exr_decoded" "$yuv_original" "$yuv_decoded" "$vmaf_csv" + rmdir --ignore-fail-on-non-empty -- "$tmpdir" +} +trap cleanup EXIT + +convert "$original" "$exr_original" +convert "$decoded" "$exr_decoded" + +srgb=(-colorspace bt709 -color_primaries bt709 -color_trc iec61966-2-1) +ffmpeg "${srgb[@]}" -i "$exr_original" -pix_fmt yuv444p10le "${srgb[@]}" -y "$yuv_original" &>/dev/null +ffmpeg "${srgb[@]}" -i "$exr_decoded" -pix_fmt yuv444p10le "${srgb[@]}" -y "$yuv_decoded" &>/dev/null + +"$(dirname "$0")"/../../../third_party/vmaf/libvmaf/build/tools/vmafossexec \ + yuv444p10le \ + "$(identify -format '%w' "$original")" "$(identify -format '%h' "$original")" \ + "$yuv_original" "$yuv_decoded" \ + "$(dirname "$0")/../../../third_party/vmaf/model/vmaf_v0.6.1.pkl" \ + --log-fmt csv --log "$vmaf_csv" &>/dev/null + +read_csv="$(cat <<'END' +import csv +import sys +reader = csv.DictReader(sys.stdin) +for row in reader: + print(row['vmaf']) +END +)" + +python -c "$read_csv" < "$vmaf_csv" > "$output" diff --git a/media/libjxl/src/tools/benchmark/metrics/vsi-rgb.sh b/media/libjxl/src/tools/benchmark/metrics/vsi-rgb.sh new file mode 100644 index 000000000..9e57c8f66 --- /dev/null +++ b/media/libjxl/src/tools/benchmark/metrics/vsi-rgb.sh @@ -0,0 +1 @@ +iqa_wrapper.sh \ No newline at end of file diff --git a/media/libjxl/src/tools/bisector b/media/libjxl/src/tools/bisector new file mode 100644 index 000000000..2552045df --- /dev/null +++ b/media/libjxl/src/tools/bisector @@ -0,0 +1,281 @@ +#!/usr/bin/env python +r"""General-purpose bisector + +Prints a space-separated list of values to stdout: +1_if_success_0_otherwise left_x left_f(x) right_x right_f(x) + +Usage examples: + +# Finding the square root of 200 via bisection: +bisector --var=BB --range=0.0,100.0 --target=200 --maxiter=100 \ + --atol_val=1e-12 --rtol_val=0 --cmd='echo "$BB * $BB" | bc' +# => 1 14.142135623730923 199.99999999999923 14.142135623731633 200.0000000000193 + +# Finding an integer approximation to sqrt(200) via bisection: +bisector --var=BB --range=0,100 --target=200 --maxiter=100 \ + --atol_arg=1 --cmd='echo "$BB * $BB" | bc' +# => 1 14 196.0 15 225.0 + +# Finding a change-id that broke something via bisection: +bisector --var=BB --range=0,1000000 --target=0.5 --maxiter=100 \ + --atol_arg=1 \ + --cmd='test $BB -gt 123456 && echo 1 || echo 0' --verbosity=3 +# => 1 123456 0.0 123457 1.0 + +# Finding settings that compress /usr/share/dict/words to a given target size: +bisector --var=BB --range=1,9 --target=250000 --atol_arg=1 \ + --cmd='gzip -$BB /tmp/w_$BB.gz; wc -c /tmp/w_$BB.gz' \ + --final='mv /tmp/w_$BB.gz /tmp/words.gz; rm /tmp/w_*.gz' \ + --verbosity=1 +# => 1 3 263170.0 4 240043.0 + +# JXL-encoding with bisection-for-size (tolerance 0.5%): +bisector --var=BB --range=0.1,3.0 --target=3500 --rtol_val=0.005 \ + --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && wc -c /tmp/baseball_$BB.jxl)' \ + --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \ + --verbosity=1 +# => 1 1.1875 3573.0 1.278125 3481.0 + +# JXL-encoding with bisection-for-bits-per-pixel (tolerance 0.5%), using helper: +bisector --var=BB --range=0.1,3.0 --target=1.2 --rtol_val=0.005 \ + --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && get_bpp /tmp/baseball_$BB.jxl)' \ + --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \ + --verbosity=1 +# => ... +""" + +import argparse +import os +import re +import subprocess +import sys + + +def _expandvars(vardef, env, + max_recursion=100, + max_length=10**6, + verbosity=0): + """os.path.expandvars() variant using parameter env rather than os.environ.""" + current_expanded = vardef + for num_recursions in range(max_recursion): + if verbosity >= 3: + print(f'_expandvars(): num_recursions={num_recursions}, ' + f'len={len(current_expanded)}' + + (', current: ' + current_expanded if verbosity >= 4 else '')) + if len > max_length: + break + current_expanded, num_replacements = re.subn( + r'$\{(\w+)\}|$(\w+)', + lambda m: env.get(m[1] if m[1] is not None else m[2], ''), + current_expanded) + if num_replacements == 0: + break + return current_expanded + + +def _strtod(string): + """Extracts leftmost float from string (like strtod(3)).""" + match = re.match(r'[+-]?\d*[.]?\d*(?:[eE][+-]?\d+)?', string) + return float(match[0]) if match[0] else None + + +def run_shell_command(shell_command, + bisect_var, bisect_val, + extra_env_defs, + verbosity=0): + """Runs a shell command with env modifications, fetching return value.""" + shell_env = dict(os.environ) + shell_env[bisect_var] = str(bisect_val) + for env_def in extra_env_defs: + varname, vardef = env_def.split('=', 1) + shell_env[varname] = _expandvars(vardev, shell_env, + verbosity=verbosity) + shell_ret = subprocess.run(shell_command, + # We explicitly want subshell semantics! + shell=True, + capture_output=True, + env=shell_env) + stdout = shell_ret.stdout.decode('utf-8') + score = _strtod(stdout) + if verbosity >= 2: + print(f'{bisect_var}={bisect_val} {shell_command} => ' + f'{shell_ret.returncode} # {stdout.strip()}') + return (shell_ret.returncode == 0, # Command was successful? + score) + + +def _bisect(*, + shell_command, + final_shell_command, + target, + int_args, + bisect_var, bisect_left, bisect_right, + rtol_val, atol_val, rtol_arg, atol_arg, + maxiter, + extra_env_defs, + verbosity=0 + ): + """Performs bisection.""" + def _get_val(x): + success, val = run_shell_command(shell_command, + bisect_var, x, + extra_env_defs, + verbosity=verbosity) + if not success: + raise RuntimeError(f'Bisection failed for: {bisect_var}={x}: ' + f'success={success}, val={val}, ' + f'cmd={shell_command}, var={bisect_var}') + return val + # + bisect_mid, value_mid = None, None + try: + value_left = _get_val(bisect_left) + value_right = _get_val(bisect_right) + if (value_left < target) != (target <= value_right): + raise RuntimeError( + f'Cannot bisect: target={target}, value_left={value_left}, ' + f'value_right={value_right}') + for num_iter in range(maxiter): + bisect_mid_f = 0.5 * (bisect_left + bisect_right) + bisect_mid = round(bisect_mid_f) if int_args else bisect_mid_f + value_mid = _get_val(bisect_mid) + if (value_left < target) == (value_mid < target): + # Relative to target, `value_mid` is on the same side + # as `value_left`. + bisect_left = bisect_mid + value_left = value_mid + else: + # Otherwise, this situation must hold for value_right + # ("tertium non datur"). + bisect_right = bisect_mid + value_right = value_mid + if verbosity >= 1: + print(f'bisect target={target}, ' + f'left: {value_left} at {bisect_left}, ' + f'right: {value_right} at {bisect_right}, ' + f'mid: {value_mid} at {bisect_mid}') + delta_val = target - value_mid + if abs(delta_val) <= atol_val + rtol_val * abs(target): + return 1, bisect_left, value_left, bisect_right, value_right + delta_arg = bisect_right - bisect_left + # Also check whether the argument is "within tolerance". + # Here, we have to be careful if bisect_left and bisect_right + # have different signs: Then, their absolute magnitude + # "sets the relevant scale". + if abs(delta_arg) <= atol_arg + ( + rtol_arg * 0.5 * (abs(bisect_left) + abs(bisect_right))): + return 1, bisect_left, value_left, bisect_right, value_right + return 0, bisect_left, value_left, bisect_right, value_right + finally: + # If cleanup is specified, always run it + if final_shell_command: + run_shell_command( + final_shell_command, + bisect_var, + bisect_mid if bisect_mid is not None else bisect_left, + extra_env_defs, verbosity=verbosity) + + +def main(args): + """Main entry point.""" + parser = argparse.ArgumentParser(description='mhtml_walk args') + parser.add_argument( + '--var', + help='The variable to use for bisection.', + default='BISECT') + parser.add_argument( + '--range', + help=('The argument range for bisecting, as {low},{high}. ' + 'If no argument has a decimal dot, assume integer parameters.'), + default='0.0,1.0') + parser.add_argument( + '--max', + help='The maximal value for bisecting.', + type=float, + default=0.0) + parser.add_argument( + '--target', + help='The target value to aim for.', + type=float, + default=1.0) + parser.add_argument( + '--maxiter', + help='The maximal number of iterations to perform.', + type=int, + default=40) + parser.add_argument( + '--rtol_val', + help='Relative tolerance to accept for deviations from target value.', + type=float, + default=0.0) + parser.add_argument( + '--atol_val', + help='Absolute tolerance to accept for deviations from target value.', + type=float, + default=0.0) + parser.add_argument( + '--rtol_arg', + help='Relative tolerance to accept for the argument.', + type=float, + default=0.0) + parser.add_argument( + '--atol_arg', + help=('Absolute tolerance to accept for the argument ' + '(e.g. for bisecting change-IDs).'), + type=float, + default=0.0) + parser.add_argument( + '--verbosity', + help='The verbosity level.', + type=int, + default=1) + parser.add_argument( + '--env', + help=('Comma-separated list of extra environment variables ' + 'to incrementally add before executing the shell-command.'), + default='') + parser.add_argument( + '--cmd', + help=('The shell command to execute. Must print a numerical result ' + 'to stdout.')) + parser.add_argument( + '--final', + help='The cleanup shell command to execute.') + # + parsed = parser.parse_args(args) + extra_env_defs = tuple(filter(None, parsed.env.split(','))) + try: + low_high = parsed.range.split(',') + if len(low_high) != 2: + raise ValueError('--range must be {low},{high}') + int_args = False + low_val, high_val = map(float, low_high) + low_val_int = round(low_val) + high_val_int = round(high_val) + if low_high == [str(low_val_int), str(high_val_int)]: + int_args = True + low_val = low_val_int + high_val = high_val_int + ret = _bisect( + shell_command=parsed.cmd, + final_shell_command=parsed.final, + target=parsed.target, + int_args=int_args, + bisect_var=parsed.var, + bisect_left=low_val, + bisect_right=high_val, + rtol_val=parsed.rtol_val, + atol_val=parsed.atol_val, + rtol_arg=parsed.rtol_arg, + atol_arg=parsed.atol_arg, + maxiter=parsed.maxiter, + extra_env_defs=extra_env_defs, + verbosity=parsed.verbosity, + ) + print(' '.join(map(str, ret))) + except Exception as exn: + sys.exit(f'Problem: {exn}') + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/media/libjxl/src/tools/box/CMakeLists.txt b/media/libjxl/src/tools/box/CMakeLists.txt new file mode 100644 index 000000000..c79add000 --- /dev/null +++ b/media/libjxl/src/tools/box/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +add_library(box STATIC EXCLUDE_FROM_ALL + box.cc + box.h +) +# This library can be included into position independent binaries. +set_target_properties(box PROPERTIES POSITION_INDEPENDENT_CODE TRUE) +target_link_libraries(box + jxl-static + jxl_threads-static +) +target_include_directories(box + PRIVATE + "${PROJECT_SOURCE_DIR}" +) + +if(JPEGXL_ENABLE_DEVTOOLS) +add_executable(box_list + box_list_main.cc +) +target_link_libraries(box_list + box +) +endif() # JPEGXL_ENABLE_DEVTOOLS diff --git a/media/libjxl/src/tools/box/box.cc b/media/libjxl/src/tools/box/box.cc new file mode 100644 index 000000000..db73c7ca7 --- /dev/null +++ b/media/libjxl/src/tools/box/box.cc @@ -0,0 +1,285 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/box/box.h" + +#include "lib/jxl/base/byte_order.h" // for GetMaximumBrunsliEncodedSize +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jpegxl { +namespace tools { + +namespace { +// Checks if a + b > size, taking possible integer overflow into account. +bool OutOfBounds(size_t a, size_t b, size_t size) { + size_t pos = a + b; + if (pos > size) return true; + if (pos < a) return true; // overflow happened + return false; +} +} // namespace + +// Parses the header of a BMFF box. Returns the result in a Box struct. +// Sets the position to the end of the box header after parsing. The data size +// is output if known, or must be handled by the caller and runs until the end +// of the container file if not known. +jxl::Status ParseBoxHeader(const uint8_t** next_in, size_t* available_in, + Box* box) { + size_t pos = 0; + size_t size = *available_in; + const uint8_t* in = *next_in; + + if (OutOfBounds(pos, 8, size)) return JXL_FAILURE("out of bounds"); + + const size_t initial_pos = pos; + + // Total box_size including this header itself. + uint64_t box_size = LoadBE32(in + pos); + pos += 4; + if (box_size == 1) { + // If the size is 1, it indicates extended size read from 64-bit integer. + if (OutOfBounds(pos, 8, size)) return JXL_FAILURE("out of bounds"); + box_size = LoadBE64(in + pos); + pos += 8; + } + memcpy(box->type, in + pos, 4); + pos += 4; + if (!memcmp("uuid", box->type, 4)) { + if (OutOfBounds(pos, 16, size)) return JXL_FAILURE("out of bounds"); + memcpy(box->extended_type, in + pos, 16); + pos += 16; + } + + // This is the end of the box header, the box data begins here. Handle + // the data size now. + const size_t data_pos = pos; + const size_t header_size = data_pos - initial_pos; + + if (box_size != 0) { + if (box_size < header_size) { + return JXL_FAILURE("invalid box size"); + } + box->data_size_given = true; + box->data_size = box_size - header_size; + } else { + // The size extends to the end of the file. We don't necessarily know the + // end of the file here, since the input size may be only part of the full + // container file. Indicate the size is not given, the caller must handle + // this. + box->data_size_given = false; + box->data_size = 0; + } + + // The remaining bytes are the data. If the box is a full box, the first + // bytes of the data have a certain structure but this is to be handled by + // the caller for the appropriate box type. + *next_in += pos; + *available_in -= pos; + + return true; +} + +jxl::Status AppendBoxHeader(const Box& box, jxl::PaddedBytes* out) { + bool use_extended = !memcmp("uuid", box.type, 4); + + uint64_t box_size = 0; + bool large_size = false; + if (box.data_size_given) { + box_size = box.data_size + 8 + (use_extended ? 16 : 0); + if (box_size >= 0x100000000ull) { + large_size = true; + } + } + + out->resize(out->size() + 4); + StoreBE32(large_size ? 1 : box_size, &out->back() - 4 + 1); + + out->resize(out->size() + 4); + memcpy(&out->back() - 4 + 1, box.type, 4); + + if (large_size) { + out->resize(out->size() + 8); + StoreBE64(box_size, &out->back() - 8 + 1); + } + + if (use_extended) { + out->resize(out->size() + 16); + memcpy(&out->back() - 16 + 1, box.extended_type, 16); + } + + return true; +} + +bool IsContainerHeader(const uint8_t* data, size_t size) { + const uint8_t box_header[] = {0, 0, 0, 0xc, 'J', 'X', + 'L', ' ', 0xd, 0xa, 0x87, 0xa}; + if (size < sizeof(box_header)) return false; + return memcmp(box_header, data, sizeof(box_header)) == 0; +} + +jxl::Status DecodeJpegXlContainerOneShot(const uint8_t* data, size_t size, + JpegXlContainer* container) { + const uint8_t* in = data; + size_t available_in = size; + + container->exif = nullptr; + container->exif_size = 0; + container->exfc = nullptr; + container->exfc_size = 0; + container->xml.clear(); + container->xmlc.clear(); + container->jumb = nullptr; + container->jumb_size = 0; + container->codestream.clear(); + container->jpeg_reconstruction = nullptr; + container->jpeg_reconstruction_size = 0; + + size_t box_index = 0; + + while (available_in != 0) { + Box box; + if (!ParseBoxHeader(&in, &available_in, &box)) { + return JXL_FAILURE("Invalid box header"); + } + + size_t data_size = box.data_size_given ? box.data_size : available_in; + + if (box.data_size > available_in) { + return JXL_FAILURE("Unexpected end of file"); + } + + if (box_index == 0) { + // TODO(lode): leave out magic signature box? + // Must be magic signature box. + if (memcmp("JXL ", box.type, 4) != 0) { + return JXL_FAILURE("Invalid magic signature"); + } + if (box.data_size != 4) return JXL_FAILURE("Invalid magic signature"); + if (in[0] != 0xd || in[1] != 0xa || in[2] != 0x87 || in[3] != 0xa) { + return JXL_FAILURE("Invalid magic signature"); + } + } else if (box_index == 1) { + // Must be ftyp box. + if (memcmp("ftyp", box.type, 4) != 0) { + return JXL_FAILURE("Invalid ftyp"); + } + if (box.data_size != 12) return JXL_FAILURE("Invalid ftyp"); + const char* expected = "jxl \0\0\0\0jxl "; + if (memcmp(expected, in, 12) != 0) return JXL_FAILURE("Invalid ftyp"); + } else if (!memcmp("jxli", box.type, 4)) { + // TODO(lode): parse JXL frame index box + if (!container->codestream.empty()) { + return JXL_FAILURE("frame index must come before codestream"); + } + } else if (!memcmp("jxlc", box.type, 4)) { + container->codestream.append(in, in + data_size); + } else if (!memcmp("jxlp", box.type, 4)) { + if (data_size < 4) return JXL_FAILURE("Invalid jxlp"); + // TODO(jon): don't just ignore the counter + container->codestream.append(in + 4, in + data_size); + } else if (!memcmp("Exif", box.type, 4)) { + if (data_size < 4) return JXL_FAILURE("Invalid Exif"); + uint32_t tiff_header_offset = LoadBE32(in); + if (tiff_header_offset > data_size - 4) + return JXL_FAILURE("Invalid Exif tiff header offset"); + container->exif = in + 4 + tiff_header_offset; + container->exif_size = data_size - 4 - tiff_header_offset; + } else if (!memcmp("Exfc", box.type, 4)) { + container->exfc = in; + container->exfc_size = data_size; + } else if (!memcmp("xml ", box.type, 4)) { + container->xml.emplace_back(in, data_size); + } else if (!memcmp("xmlc", box.type, 4)) { + container->xmlc.emplace_back(in, data_size); + } else if (!memcmp("jumb", box.type, 4)) { + container->jumb = in; + container->jumb_size = data_size; + } else if (!memcmp("jbrd", box.type, 4)) { + container->jpeg_reconstruction = in; + container->jpeg_reconstruction_size = data_size; + } else { + // Do nothing: box not recognized here but may be recognizable by + // other software. + } + + in += data_size; + available_in -= data_size; + box_index++; + } + + return true; +} + +static jxl::Status AppendBoxAndData(const char type[4], const uint8_t* data, + size_t data_size, jxl::PaddedBytes* out, + bool exif = false) { + Box box; + memcpy(box.type, type, 4); + box.data_size = data_size + (exif ? 4 : 0); + box.data_size_given = true; + JXL_RETURN_IF_ERROR(AppendBoxHeader(box, out)); + // for Exif: always use tiff header offset 0 + if (exif) + for (int i = 0; i < 4; i++) out->push_back(0); + out->append(data, data + data_size); + return true; +} + +jxl::Status EncodeJpegXlContainerOneShot(const JpegXlContainer& container, + jxl::PaddedBytes* out) { + const unsigned char header[] = {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', + 0xd, 0xa, 0x87, 0xa, 0, 0, 0, 0x14, + 'f', 't', 'y', 'p', 'j', 'x', 'l', ' ', + 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + size_t header_size = sizeof(header); + out->append(header, header + header_size); + + if (container.exif) { + JXL_RETURN_IF_ERROR(AppendBoxAndData("Exif", container.exif, + container.exif_size, out, true)); + } + + if (container.exfc) { + JXL_RETURN_IF_ERROR( + AppendBoxAndData("Exfc", container.exfc, container.exfc_size, out)); + } + + for (size_t i = 0; i < container.xml.size(); i++) { + JXL_RETURN_IF_ERROR(AppendBoxAndData("xml ", container.xml[i].first, + container.xml[i].second, out)); + } + + for (size_t i = 0; i < container.xmlc.size(); i++) { + JXL_RETURN_IF_ERROR(AppendBoxAndData("xmlc", container.xmlc[i].first, + container.xmlc[i].second, out)); + } + + if (container.jpeg_reconstruction) { + JXL_RETURN_IF_ERROR(AppendBoxAndData("jbrd", container.jpeg_reconstruction, + container.jpeg_reconstruction_size, + out)); + } + + if (!container.codestream.empty()) { + JXL_RETURN_IF_ERROR(AppendBoxAndData("jxlc", container.codestream.data(), + container.codestream.size(), out)); + } else { + return JXL_FAILURE("must have primary image frame"); + } + + if (container.jumb) { + JXL_RETURN_IF_ERROR( + AppendBoxAndData("jumb", container.jumb, container.jumb_size, out)); + } + + return true; +} + +// TODO(veluca): the format defined here encode some things multiple times. Fix +// that. + +} // namespace tools +} // namespace jpegxl diff --git a/media/libjxl/src/tools/box/box.h b/media/libjxl/src/tools/box/box.h new file mode 100644 index 000000000..4cc305898 --- /dev/null +++ b/media/libjxl/src/tools/box/box.h @@ -0,0 +1,113 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tools for reading from / writing to ISOBMFF format for JPEG XL. + +#ifndef TOOLS_BOX_BOX_H_ +#define TOOLS_BOX_BOX_H_ + +#include +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_file.h" + +namespace jpegxl { +namespace tools { + +// A top-level box in the box format. +struct Box { + // The type of the box. + // If "uuid", use extended_type instead + char type[4]; + + // The extended_type is only used when type == "uuid". + // Extended types are not used in JXL. However, the box format itself + // supports this so they are handled correctly. + char extended_type[16]; + + // Size of the data, excluding box header. The box ends, and next box + // begins, at data + size. May not be used if data_size_given is false. + uint64_t data_size; + + // If the size is not given, the datasize extends to the end of the file. + // If this field is false, the size field may not be used. + bool data_size_given; +}; + +// Parses the header of a BMFF box. Returns the result in a Box struct. +// Updates next_in and available_in to point at the data in the box, directly +// after the header. +// Sets the data_size if known, or must be handled by the caller and runs until +// the end of the container file if not known. +// NOTE: available_in should be at least 8 up to 32 bytes to parse the +// header without error. +jxl::Status ParseBoxHeader(const uint8_t** next_in, size_t* available_in, + Box* box); + +// TODO(lode): streaming C API +jxl::Status AppendBoxHeader(const Box& box, jxl::PaddedBytes* out); + +// NOTE: after DecodeJpegXlContainerOneShot, the exif etc. pointers point to +// regions within the input data passed to that function. +struct JpegXlContainer { + // Exif metadata, or null if not present in the container. + // The exif data has the format of 'Exif block' as defined in + // ISO/IEC23008-12:2017 Clause A.2.1 + // Here we assume the tiff header offset is 0 and store only the + // actual Exif data (starting with the tiff header MM or II) + // TODO(lode): support the theoretical case of multiple exif boxes + const uint8_t* exif = nullptr; // Not owned + size_t exif_size = 0; + + // Brotli-compressed exif metadata, if present. The data points to the brotli + // compressed stream, it is not decompressed here. + const uint8_t* exfc = nullptr; // Not owned + size_t exfc_size = 0; + + // XML boxes for XMP. There may be multiple XML boxes. + // Each entry points to XML location and provides size. + // The memory is not owned. + // TODO(lode): for C API, cannot use std::vector. + std::vector> xml; + + // Brotli-compressed xml boxes. The bytes are given in brotli-compressed form + // and are not decompressed here. + std::vector> xmlc; + + // JUMBF superbox data, or null if not present in the container. + // The parsing of the nested boxes inside is not handled here. + const uint8_t* jumb = nullptr; // Not owned + size_t jumb_size = 0; + + // TODO(lode): add frame index data + + // JPEG reconstruction data, or null if not present in the container. + const uint8_t* jpeg_reconstruction = nullptr; + size_t jpeg_reconstruction_size = 0; + + // The main JPEG XL codestream, of which there must be 1 in the container. + jxl::PaddedBytes codestream; +}; + +// Returns whether `data` starts with a container header; definitely returns +// false if `size` is less than 12 bytes. +bool IsContainerHeader(const uint8_t* data, size_t size); + +// NOTE: the input data must remain valid as long as `container` is used, +// because its exif etc. pointers point to that data. +jxl::Status DecodeJpegXlContainerOneShot(const uint8_t* data, size_t size, + JpegXlContainer* container); + +// TODO(lode): streaming C API +jxl::Status EncodeJpegXlContainerOneShot(const JpegXlContainer& container, + jxl::PaddedBytes* out); + +} // namespace tools +} // namespace jpegxl + +#endif // TOOLS_BOX_BOX_H_ diff --git a/media/libjxl/src/tools/box/box_list_main.cc b/media/libjxl/src/tools/box/box_list_main.cc new file mode 100644 index 000000000..40ca910e5 --- /dev/null +++ b/media/libjxl/src/tools/box/box_list_main.cc @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This binary tool lists the boxes of any box-based format (JPEG XL, +// JPEG 2000, MP4, ...). +// This exists as a test for manual verification, rather than an actual tool. + +#include +#include +#include + +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "tools/box/box.h" + +namespace jpegxl { +namespace tools { + +int RunMain(int argc, const char* argv[]) { + if (argc < 2) { + fprintf(stderr, "Usage: %s ", argv[0]); + return 1; + } + + jxl::PaddedBytes compressed; + if (!jxl::ReadFile(argv[1], &compressed)) return 1; + fprintf(stderr, "Read %" PRIuS " compressed bytes\n", compressed.size()); + + const uint8_t* in = compressed.data(); + size_t available_in = compressed.size(); + + fprintf(stderr, "File size: %" PRIuS "\n", compressed.size()); + + while (available_in != 0) { + const uint8_t* start = in; + Box box; + if (!ParseBoxHeader(&in, &available_in, &box)) { + fprintf(stderr, "Failed at %" PRIuS "\n", + compressed.size() - available_in); + break; + } + + size_t data_size = box.data_size_given ? box.data_size : available_in; + size_t header_size = in - start; + size_t box_size = header_size + data_size; + + for (size_t i = 0; i < sizeof(box.type); i++) { + char c = box.type[i]; + if (c < 32 || c > 127) { + printf("Unprintable character in box type, likely not a box file.\n"); + return 0; + } + } + + printf("box: \"%.4s\" box_size:%" PRIuS " data_size:%" PRIuS, box.type, + box_size, data_size); + if (!memcmp("uuid", box.type, 4)) { + printf(" -- extended type:\"%.16s\"", box.extended_type); + } + if (!memcmp("ftyp", box.type, 4) && data_size > 4) { + std::string ftype(in, in + 4); + printf(" -- ftype:\"%s\"", ftype.c_str()); + } + printf("\n"); + + if (data_size > available_in) { + fprintf( + stderr, "Unexpected end of file %" PRIuS " %" PRIuS " %" PRIuS "\n", + static_cast(box.data_size), available_in, compressed.size()); + break; + } + + in += data_size; + available_in -= data_size; + } + + return 0; +} + +} // namespace tools +} // namespace jpegxl + +int main(int argc, const char* argv[]) { + return jpegxl::tools::RunMain(argc, argv); +} diff --git a/media/libjxl/src/tools/box/box_test.cc b/media/libjxl/src/tools/box/box_test.cc new file mode 100644 index 000000000..3146bcfa6 --- /dev/null +++ b/media/libjxl/src/tools/box/box_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/box/box.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + +TEST(BoxTest, BoxTest) { + size_t test_size = 256; + jxl::PaddedBytes exif(test_size); + jxl::PaddedBytes xml0(test_size); + jxl::PaddedBytes xml1(test_size); + jxl::PaddedBytes jumb(test_size); + jxl::PaddedBytes codestream(test_size); + // Generate arbitrary data for the codestreams: the test is not testing + // the contents of them but whether they are preserved in the container. + uint8_t v = 0; + for (size_t i = 0; i < test_size; ++i) { + exif[i] = v++; + xml0[i] = v++; + xml1[i] = v++; + jumb[i] = v++; + codestream[i] = v++; + } + + jpegxl::tools::JpegXlContainer container; + container.exif = exif.data(); + container.exif_size = exif.size(); + container.xml.emplace_back(xml0.data(), xml0.size()); + container.xml.emplace_back(xml1.data(), xml1.size()); + container.xmlc.emplace_back(xml1.data(), xml1.size()); + container.jumb = jumb.data(); + container.jumb_size = jumb.size(); + container.codestream = std::move(codestream); + + jxl::PaddedBytes file; + EXPECT_EQ(true, + jpegxl::tools::EncodeJpegXlContainerOneShot(container, &file)); + + jpegxl::tools::JpegXlContainer container2; + EXPECT_EQ(true, jpegxl::tools::DecodeJpegXlContainerOneShot( + file.data(), file.size(), &container2)); + + EXPECT_EQ(exif.size(), container2.exif_size); + EXPECT_EQ(0, memcmp(exif.data(), container2.exif, container2.exif_size)); + EXPECT_EQ(2u, container2.xml.size()); + if (container2.xml.size() == 2) { + EXPECT_EQ(xml0.size(), container2.xml[0].second); + EXPECT_EQ(0, memcmp(xml0.data(), container2.xml[0].first, + container2.xml[0].second)); + EXPECT_EQ(xml1.size(), container2.xml[1].second); + EXPECT_EQ(0, memcmp(xml1.data(), container2.xml[1].first, + container2.xml[1].second)); + } + EXPECT_EQ(1u, container2.xmlc.size()); + if (container2.xmlc.size() == 1) { + EXPECT_EQ(xml1.size(), container2.xmlc[0].second); + EXPECT_EQ(0, memcmp(xml1.data(), container2.xmlc[0].first, + container2.xmlc[0].second)); + } + EXPECT_EQ(jumb.size(), container2.jumb_size); + EXPECT_EQ(0, memcmp(jumb.data(), container2.jumb, container2.jumb_size)); + EXPECT_EQ(container.codestream.size(), container2.codestream.size()); + EXPECT_EQ(0, memcmp(container.codestream.data(), container2.codestream.data(), + container2.codestream.size())); +} diff --git a/media/libjxl/src/tools/build_cleaner.py b/media/libjxl/src/tools/build_cleaner.py new file mode 100644 index 000000000..76857d799 --- /dev/null +++ b/media/libjxl/src/tools/build_cleaner.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + + +"""build_cleaner.py: Update build files. + +This tool keeps certain parts of the build files up to date. +""" + +import argparse +import collections +import locale +import os +import re +import subprocess +import sys +import tempfile + + +def RepoFiles(src_dir): + """Return the list of files from the source git repository""" + git_bin = os.environ.get('GIT_BIN', 'git') + files = subprocess.check_output([git_bin, '-C', src_dir, 'ls-files']) + ret = files.decode(locale.getpreferredencoding()).splitlines() + ret.sort() + return ret + +def GetPrefixLibFiles(repo_files, prefix, suffixes=('.h', '.cc', '.ui')): + """Gets the library files that start with the prefix and end with source + code suffix.""" + prefix_files = [ + fn for fn in repo_files + if fn.startswith(prefix) and any(fn.endswith(suf) for suf in suffixes)] + return prefix_files + +# Type holding the different types of sources in libjxl: +# * decoder and common sources, +# * encoder-only sources, +# * tests-only sources, +# * google benchmark sources, +# * threads library sources, +# * extras library sources, +# * libjxl (encoder+decoder) public include/ headers and +# * threads public include/ headers. +JxlSources = collections.namedtuple( + 'JxlSources', ['dec', 'enc', 'test', 'gbench', 'threads', + 'extras', 'jxl_public_hdrs', 'threads_public_hdrs']) + +def SplitLibFiles(repo_files): + """Splits the library files into the different groups. + + """ + testonly = ( + 'testdata.h', 'test_utils.h', 'test_image.h', '_test.h', '_test.cc', + # _testonly.* files are library code used in tests only. + '_testonly.h', '_testonly.cc' + ) + main_srcs = GetPrefixLibFiles(repo_files, 'lib/jxl/') + extras_srcs = GetPrefixLibFiles(repo_files, 'lib/extras/') + test_srcs = [fn for fn in main_srcs + if any(patt in fn for patt in testonly)] + lib_srcs = [fn for fn in main_srcs + if not any(patt in fn for patt in testonly)] + + # Google benchmark sources. + gbench_srcs = sorted(fn for fn in lib_srcs + extras_srcs + if fn.endswith('_gbench.cc')) + lib_srcs = [fn for fn in lib_srcs if fn not in gbench_srcs] + # Exclude optional codecs from extras. + exclude_extras = [ + '/dec/gif', + '/dec/apng', '/enc/apng', + '/dec/exr', '/enc/exr', + '/dec/jpg', '/enc/jpg', + ] + extras_srcs = [fn for fn in extras_srcs if fn not in gbench_srcs and + not any(patt in fn for patt in testonly) and + not any(patt in fn for patt in exclude_extras)] + + + enc_srcs = [fn for fn in lib_srcs + if os.path.basename(fn).startswith('enc_') or + os.path.basename(fn).startswith('butteraugli')] + enc_srcs.extend([ + "lib/jxl/encode.cc", + "lib/jxl/encode_internal.h", + "lib/jxl/gaborish.cc", + "lib/jxl/gaborish.h", + "lib/jxl/huffman_tree.cc", + "lib/jxl/huffman_tree.h", + # Only the inlines in linalg.h header are used in the decoder. + # TODO(deymo): split out encoder only linalg.h functions. + "lib/jxl/linalg.cc", + "lib/jxl/optimize.cc", + "lib/jxl/optimize.h", + "lib/jxl/progressive_split.cc", + "lib/jxl/progressive_split.h", + # TODO(deymo): Add luminance.cc and luminance.h here too. Currently used + # by aux_out.h. + ]) + # Temporarily remove enc_bit_writer from the encoder sources: a lot of + # decoder source code still needs to be split up into encoder and decoder. + # Including the enc_bit_writer in the decoder allows to build a working + # libjxl_dec library. + # TODO(lode): remove the dependencies of the decoder on enc_bit_writer and + # remove enc_bit_writer from the dec_srcs again. + enc_srcs.remove("lib/jxl/enc_bit_writer.cc") + enc_srcs.remove("lib/jxl/enc_bit_writer.h") + enc_srcs.sort() + + enc_srcs_set = set(enc_srcs) + lib_srcs = [fn for fn in lib_srcs if fn not in enc_srcs_set] + + # The remaining of the files are in the dec_library. + dec_srcs = lib_srcs + + thread_srcs = GetPrefixLibFiles(repo_files, 'lib/threads/') + thread_srcs = [fn for fn in thread_srcs + if not any(patt in fn for patt in testonly)] + public_hdrs = GetPrefixLibFiles(repo_files, 'lib/include/jxl/') + + threads_public_hdrs = [fn for fn in public_hdrs if '_parallel_runner' in fn] + jxl_public_hdrs = list(sorted(set(public_hdrs) - set(threads_public_hdrs))) + return JxlSources(dec_srcs, enc_srcs, test_srcs, gbench_srcs, thread_srcs, + extras_srcs, jxl_public_hdrs, threads_public_hdrs) + + +def CleanFile(args, filename, pattern_data_list): + """Replace a pattern match with new data in the passed file. + + Given a regular expression pattern with a single () match, it runs the regex + over the passed filename and replaces the match () with the new data. If + args.update is set, it will update the file with the new contents, otherwise + it will return True when no changes were needed. + + Multiple pairs of (regular expression, new data) can be passed to the + pattern_data_list parameter and will be applied in order. + + The regular expression must match at least once in the file. + """ + filepath = os.path.join(args.src_dir, filename) + with open(filepath, 'r') as f: + src_text = f.read() + + if not pattern_data_list: + return True + + new_text = src_text + + for pattern, data in pattern_data_list: + offset = 0 + chunks = [] + for match in re.finditer(pattern, new_text): + chunks.append(new_text[offset:match.start(1)]) + offset = match.end(1) + chunks.append(data) + if not chunks: + raise Exception('Pattern not found for %s: %r' % (filename, pattern)) + chunks.append(new_text[offset:]) + new_text = ''.join(chunks) + + if new_text == src_text: + return True + + if args.update: + print('Updating %s' % filename) + with open(filepath, 'w') as f: + f.write(new_text) + return True + else: + with tempfile.NamedTemporaryFile( + mode='w', prefix=os.path.basename(filename)) as new_file: + new_file.write(new_text) + new_file.flush() + subprocess.call( + ['diff', '-u', filepath, '--label', 'a/' + filename, new_file.name, + '--label', 'b/' + filename]) + return False + + +def BuildCleaner(args): + repo_files = RepoFiles(args.src_dir) + ok = True + + # jxl version + with open(os.path.join(args.src_dir, 'lib/CMakeLists.txt'), 'r') as f: + cmake_text = f.read() + + gni_patterns = [] + for varname in ('JPEGXL_MAJOR_VERSION', 'JPEGXL_MINOR_VERSION', + 'JPEGXL_PATCH_VERSION'): + # Defined in CMakeLists.txt as "set(varname 1234)" + match = re.search(r'set\(' + varname + r' ([0-9]+)\)', cmake_text) + version_value = match.group(1) + gni_patterns.append((r'"' + varname + r'=([0-9]+)"', version_value)) + + jxl_src = SplitLibFiles(repo_files) + + # libjxl + jxl_cmake_patterns = [] + jxl_cmake_patterns.append( + (r'set\(JPEGXL_INTERNAL_SOURCES_DEC\n([^\)]+)\)', + ''.join(' %s\n' % fn[len('lib/'):] for fn in jxl_src.dec))) + jxl_cmake_patterns.append( + (r'set\(JPEGXL_INTERNAL_SOURCES_ENC\n([^\)]+)\)', + ''.join(' %s\n' % fn[len('lib/'):] for fn in jxl_src.enc))) + ok = CleanFile( + args, 'lib/jxl.cmake', + jxl_cmake_patterns) and ok + + ok = CleanFile( + args, 'lib/jxl_benchmark.cmake', + [(r'set\(JPEGXL_INTERNAL_SOURCES_GBENCH\n([^\)]+)\)', + ''.join(' %s\n' % fn[len('lib/'):] for fn in jxl_src.gbench))]) and ok + + gni_patterns.append(( + r'libjxl_dec_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] for fn in jxl_src.dec))) + gni_patterns.append(( + r'libjxl_enc_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] for fn in jxl_src.enc))) + gni_patterns.append(( + r'libjxl_gbench_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] for fn in jxl_src.gbench))) + + + tests = [fn[len('lib/'):] for fn in jxl_src.test if fn.endswith('_test.cc')] + testlib = [fn[len('lib/'):] for fn in jxl_src.test + if not fn.endswith('_test.cc')] + gni_patterns.append(( + r'libjxl_tests_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn for fn in tests))) + gni_patterns.append(( + r'libjxl_testlib_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn for fn in testlib))) + + # libjxl_threads + ok = CleanFile( + args, 'lib/jxl_threads.cmake', + [(r'set\(JPEGXL_THREADS_SOURCES\n([^\)]+)\)', + ''.join(' %s\n' % fn[len('lib/'):] for fn in jxl_src.threads))]) and ok + + gni_patterns.append(( + r'libjxl_threads_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] for fn in jxl_src.threads))) + + # libjxl_extras + ok = CleanFile( + args, 'lib/jxl_extras.cmake', + [(r'set\(JPEGXL_EXTRAS_SOURCES\n([^\)]+)\)', + ''.join(' %s\n' % fn[len('lib/'):] for fn in jxl_src.extras))]) and ok + + gni_patterns.append(( + r'libjxl_extras_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] for fn in jxl_src.extras))) + + # libjxl_profiler + profiler_srcs = [fn[len('lib/'):] for fn in repo_files + if fn.startswith('lib/profiler')] + ok = CleanFile( + args, 'lib/jxl_profiler.cmake', + [(r'set\(JPEGXL_PROFILER_SOURCES\n([^\)]+)\)', + ''.join(' %s\n' % fn for fn in profiler_srcs))]) and ok + + gni_patterns.append(( + r'libjxl_profiler_sources = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn for fn in profiler_srcs))) + + # Public headers. + gni_patterns.append(( + r'libjxl_public_headers = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] + for fn in jxl_src.jxl_public_hdrs))) + gni_patterns.append(( + r'libjxl_threads_public_headers = \[\n([^\]]+)\]', + ''.join(' "%s",\n' % fn[len('lib/'):] + for fn in jxl_src.threads_public_hdrs))) + + + # Update the list of tests. CMake version include test files in other libs, + # not just in libjxl. + tests = [fn[len('lib/'):] for fn in repo_files + if fn.endswith('_test.cc') and fn.startswith('lib/')] + ok = CleanFile( + args, 'lib/jxl_tests.cmake', + [(r'set\(TEST_FILES\n([^\)]+) ### Files before this line', + ''.join(' %s\n' % fn for fn in tests))]) and ok + ok = CleanFile( + args, 'lib/jxl_tests.cmake', + [(r'set\(TESTLIB_FILES\n([^\)]+)\)', + ''.join(' %s\n' % fn for fn in testlib))]) and ok + + # Update lib.gni + ok = CleanFile(args, 'lib/lib.gni', gni_patterns) and ok + + return ok + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--src-dir', + default=os.path.realpath(os.path.join( + os.path.dirname(__file__), '..')), + help='path to the build directory') + parser.add_argument('--update', default=False, action='store_true', + help='update the build files instead of only checking') + args = parser.parse_args() + if not BuildCleaner(args): + print('Build files need update.') + sys.exit(2) + + +if __name__ == '__main__': + main() diff --git a/media/libjxl/src/tools/build_stats.py b/media/libjxl/src/tools/build_stats.py new file mode 100644 index 000000000..b1dc1ea39 --- /dev/null +++ b/media/libjxl/src/tools/build_stats.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + + +"""build_stats.py: Gather statistics about sizes of dependencies. + +This tools computes a realistic estimate of the size contribution to a binary +from a statically linked library. Statically linked libraries compiled with +-ffunction-sections and linked -gc-sections mean that we could drop part of the +library at the final binary linking time. This tool takes that into account the +symbols that end up in the final binary and not just all the symbols of the +components. +""" + +import argparse +import collections +import itertools +import json +import os +import re +import struct +import subprocess +import sys +import tempfile + +# Ignore functions with stack size smaller than this value. +MIN_STACK_SIZE = 32 + + +Symbol = collections.namedtuple('Symbol', ['address', 'size', 'typ', 'name']) + +# Represents the stack size information of a function (defined by its address). +SymbolStack = collections.namedtuple('SymbolStack', + ['address', 'stack_size']) + +ObjectStats = collections.namedtuple('ObjectStats', + ['name', 'in_partition', 'size_map']) + +# An object target file in the build system. +Target = collections.namedtuple('Target', + ['name', 'deps', 'filename']) + +# Sections that end up in the binary file. +# t - text (code), d - global non-const data, n/r - read-only data, +# w - weak symbols (likely inline code not inlined), +# v - weak symbols (vtable / typeinfo) +# u - unique symbols +BIN_SIZE = 'tdnrwvu' + +# Sections that end up in static RAM. +RAM_SIZE = 'dbs' + +# u - symbols imported from some other library +# a - absolute address symbols +IGNORE_SYMBOLS = 'ua' + +SIMD_NAMESPACES = [ + 'N_SCALAR', 'N_WASM', 'N_NEON', 'N_PPC8', 'N_SSE4', 'N_AVX2', 'N_AVX3'] + + +def LoadSymbols(filename): + ret = [] + nmout = subprocess.check_output(['nm', '--format=posix', filename]) + for line in nmout.decode('utf-8').splitlines(): + if line.rstrip().endswith(':'): + # Ignore object names. + continue + # symbol_name, symbol_type, (optional) address, (optional) size + symlist = line.rstrip().split(' ') + assert 2 <= len(symlist) <= 4 + ret.append(Symbol( + int(symlist[2], 16) if len(symlist) > 2 else None, + int(symlist[3], 16) if len(symlist) > 3 else None, + symlist[1], + symlist[0])) + return ret + +def LoadTargetCommand(target, build_dir): + stdout = subprocess.check_output( + ['ninja', '-C', build_dir, '-t', 'commands', target]) + # The last command is always the command to build (link) the requested + # target. + command = stdout.splitlines()[-1] + return command.decode('utf-8') + + +def LoadTarget(target, build_dir): + """Loads a build system target and its dependencies into a Target object""" + if target.endswith('.o'): + # Speed up this case. + return Target(target, [], target) + + link_params = LoadTargetCommand(target, build_dir).split() + if 'cmake_symlink_library' in link_params: + # The target is a library symlinked, use the target of the symlink + # instead. + target = link_params[link_params.index('cmake_symlink_library') + 1] + link_params = LoadTargetCommand(target, build_dir).split() + + # The target name is not always the same as the filename of the output, for + # example, "djxl" target generates "tools/djxl" file. + if '-o' in link_params: + target_filename = link_params[link_params.index('-o') + 1] + elif target.endswith('.a'): + # Command is '/path/to/ar', 'qc', 'target.a', ... + target_filename = link_params[link_params.index('qc') + 1] + else: + raise Exception('Unknown "%s" output filename in command: %r' % + (target, link_params)) + + tgt_libs = [] + for entry in link_params: + if not entry or not (entry.endswith('.o') or entry.endswith('.a')): + continue + if entry == target_filename: + continue + fn = os.path.join(build_dir, entry) + if not os.path.exists(fn): + continue + if entry in tgt_libs: + continue + tgt_libs.append(entry) + + return Target(target, tgt_libs, target_filename) + + +def TargetTransitiveDeps(all_tgts, target): + """Returns the list of all transitive dependencies of target""" + ret = all_tgts[target].deps + # There can't be loop dependencies in the targets. + i = 0 + while i < len(ret): + ret.extend(all_tgts[ret[i]].deps) + i += 1 + return ret + + +def LoadStackSizes(filename, binutils=''): + """Loads the stack size used by functions from the ELF. + + This function loads the stack size the compiler stored in the .stack_sizes + section, which can be done by compiling with -fstack-size-section in clang. + """ + with tempfile.NamedTemporaryFile() as stack_sizes_sec: + subprocess.check_call( + [binutils + 'objcopy', '-O', 'binary', '--only-section=.stack_sizes', + '--set-section-flags', '.stack_sizes=alloc', filename, + stack_sizes_sec.name]) + stack_sizes = stack_sizes_sec.read() + # From the documentation: + # The section will contain an array of pairs of function symbol values + # (pointer size) and stack sizes (unsigned LEB128). The stack size values + # only include the space allocated in the function prologue. Functions with + # dynamic stack allocations are not included. + + # Get the pointer format based on the ELF file. + output = subprocess.check_output( + [binutils + 'objdump', '-a', filename]).decode('utf-8') + elf_format = re.search('file format (.*)$', output, re.MULTILINE).group(1) + if elf_format.startswith('elf64-little') or elf_format == 'elf64-x86-64': + pointer_fmt = '= i + pointer_size + addr, = struct.unpack_from(pointer_fmt, stack_sizes, i) + i += pointer_size + # Parse LEB128 + size = 0 + for j in range(10): + b = stack_sizes[i] + i += 1 + size += (b & 0x7f) << (7 * j) + if (b & 0x80) == 0: + break + if size >= MIN_STACK_SIZE: + ret.append(SymbolStack(addr, size)) + return ret + + +def TargetSize(symbols, symbol_filter=None): + ret = {} + for sym in symbols: + if not sym.size or (symbol_filter is not None and + sym.name not in symbol_filter): + continue + t = sym.typ.lower() + # We can remove symbols if they appear in multiple objects since they will + # be merged by the linker. + if symbol_filter is not None and (t == sym.typ or t in 'wv'): + symbol_filter.remove(sym.name) + ret.setdefault(t, 0) + ret[t] += sym.size + return ret + + +def PrintStats(stats): + """Print a table with the size stats for a target""" + table = [] + sum_bin_size = 0 + sum_ram_size = 0 + + for objstat in stats: + bin_size = 0 + ram_size = 0 + for typ, size in objstat.size_map.items(): + if typ in BIN_SIZE: + bin_size += size + if typ in RAM_SIZE: + ram_size += size + if typ not in BIN_SIZE + RAM_SIZE: + raise Exception('Unknown type "%s"' % typ) + if objstat.in_partition: + sum_bin_size += bin_size + sum_ram_size += ram_size + + table.append((objstat.name, bin_size, ram_size)) + mx_bin_size = max(row[1] for row in table) + mx_ram_size = max(row[2] for row in table) + + table.append(('-- unknown --', mx_bin_size - sum_bin_size, + mx_ram_size - sum_ram_size)) + + # Print the table + print('%-32s %17s %17s' % ('Object name', 'Binary size', 'Static RAM size')) + for name, bin_size, ram_size in table: + print('%-32s %8d (%5.1f%%) %8d (%5.1f%%)' % ( + name, bin_size, 100. * bin_size / mx_bin_size, + ram_size, (100. * ram_size / mx_ram_size) if mx_ram_size else 0)) + print() + + +def PrintStackStats(tgt_stack_sizes, top_entries=20): + if not tgt_stack_sizes: + return + print(' Stack Symbol name') + for i, (name, size) in zip(itertools.count(), tgt_stack_sizes.items()): + if top_entries > 0 and i >= top_entries: + break + print('%8d %s' % (size, name)) + print() + + +def PrintTopSymbols(tgt_top_symbols): + if not tgt_top_symbols: + return + print(' Size T Symbol name') + for size, typ, name in tgt_top_symbols: + print('%9d %s %s' % (size, typ, name)) + print() + + +def SizeStats(args): + """Main entry point of the program after parsing parameters. + + Computes the size statistics of the given targets and their components.""" + # The dictionary with the stats that we store on disk as a json. This includes + # one entry per passed args.target. + stats = {} + + # Cache of Target object of a target. + tgts = {} + + # Load all the targets. + pending = set(args.target) + while pending: + target = pending.pop() + tgt = LoadTarget(target, args.build_dir) + tgts[target] = tgt + if args.recursive: + for dep in tgt.deps: + if dep not in tgts: + pending.add(dep) + + # Cache of symbols of a target. + syms = {} + # Load the symbols from the all targets and its deps. + all_deps = set(tgts.keys()).union(*[set(tgt.deps) for tgt in tgts.values()]) + for entry in all_deps: + fn = os.path.join(args.build_dir, + tgts[entry].filename if entry in tgts else entry) + syms[entry] = LoadSymbols(fn) + + for target in args.target: + tgt_stats = [] + tgt = tgts[target] + + tgt_syms = syms[target] + used_syms = set() + for sym in tgt_syms: + if sym.typ.lower() in BIN_SIZE + RAM_SIZE: + used_syms.add(sym.name) + elif sym.typ.lower() in IGNORE_SYMBOLS: + continue + else: + print('Unknown: %s %s' % (sym.typ, sym.name)) + + target_path = os.path.join(args.build_dir, tgt.filename) + sym_stacks = [] + if not target_path.endswith('.a'): + sym_stacks = LoadStackSizes(target_path, args.binutils) + symbols_by_addr = {sym.address: sym for sym in tgt_syms + if sym.typ.lower() in 'tw'} + tgt_stack_sizes = collections.OrderedDict() + for sym_stack in sorted(sym_stacks, key=lambda s: -s.stack_size): + tgt_stack_sizes[ + symbols_by_addr[sym_stack.address].name] = sym_stack.stack_size + + tgt_top_symbols = [] + if args.top_symbols: + tgt_top_symbols = [(sym.size, sym.typ, sym.name) for sym in tgt_syms + if sym.name in used_syms and sym.size] + tgt_top_symbols.sort(key=lambda t: (-t[0], t[2])) + tgt_top_symbols = tgt_top_symbols[:args.top_symbols] + + tgt_size = TargetSize(tgt_syms) + tgt_stats.append(ObjectStats(target, False, tgt_size)) + + # Split out by SIMD. + for namespace in SIMD_NAMESPACES: + mangled = str(len(namespace)) + namespace + if not any(mangled in sym.name for sym in tgt_syms): + continue + ret = {} + for sym in tgt_syms: + if not sym.size or mangled not in sym.name: + continue + t = sym.typ.lower() + ret.setdefault(t, 0) + ret[t] += sym.size + # SIMD namespaces are not part of the partition, they are already included + # in the jpegxl-static normally. + if not ret: + continue + tgt_stats.append(ObjectStats('\\--> ' + namespace, False, ret)) + + for obj in tgt.deps: + dep_used_syms = used_syms.copy() + obj_size = TargetSize(syms[obj], used_syms) + if not obj_size: + continue + tgt_stats.append(ObjectStats(os.path.basename(obj), True, obj_size)) + if args.recursive: + # Not really recursive, but it shows all the remaining deps at a second + # level. + for obj_dep in sorted(TargetTransitiveDeps(tgts, obj), + key=os.path.basename): + obj_dep_size = TargetSize(syms[obj_dep], dep_used_syms) + if not obj_dep_size: + continue + tgt_stats.append(ObjectStats( + ' '+ os.path.basename(obj_dep), False, obj_dep_size)) + + PrintStats(tgt_stats) + PrintStackStats(tgt_stack_sizes) + PrintTopSymbols(tgt_top_symbols) + stats[target] = { + 'build': tgt_stats, + 'stack': tgt_stack_sizes, + 'top': tgt_top_symbols, + } + + if args.save: + with open(args.save, 'w') as f: + json.dump(stats, f) + + # Check the maximum stack size. + exit_code = 0 + if args.max_stack: + for name, size in tgt_stack_sizes.items(): + if size > args.max_stack: + print('Error: %s exceeds stack limit: %d vs %d' % ( + name, size, args.max_stack), + file=sys.stderr) + exit_code = 1 + + return exit_code + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('target', type=str, nargs='+', + help='target(s) to analyze') + parser.add_argument('--build-dir', default='build', + help='path to the build directory') + parser.add_argument('--save', default=None, + help='path to save the stats as JSON file') + parser.add_argument('-r', '--recursive', default=False, action='store_true', + help='Print recursive entries.') + parser.add_argument('--top-symbols', default=0, type=int, + help='Number of largest symbols to print') + parser.add_argument('--binutils', default='', + help='prefix path to binutils tools, such as ' + 'aarch64-linux-gnu-') + parser.add_argument('--max-stack', default=None, type=int, + help=('Maximum static stack size of a function. If a ' + 'static stack is larger it will exit with an error ' + 'code.')) + args = parser.parse_args() + sys.exit(SizeStats(args)) + + +if __name__ == '__main__': + main() diff --git a/media/libjxl/src/tools/butteraugli_main.cc b/media/libjxl/src/tools/butteraugli_main.cc new file mode 100644 index 000000000..247ade8d4 --- /dev/null +++ b/media/libjxl/src/tools/butteraugli_main.cc @@ -0,0 +1,144 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include +#include + +#include "lib/extras/codec.h" +#include "lib/extras/dec/color_hints.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +Status WriteImage(Image3F&& image, const std::string& filename) { + ThreadPoolInternal pool(4); + CodecInOut io; + io.metadata.m.SetUintSamples(8); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(std::move(image), io.metadata.m.color_encoding); + return EncodeToFile(io, filename, &pool); +} + +Status RunButteraugli(const char* pathname1, const char* pathname2, + const std::string& distmap_filename, + const std::string& colorspace_hint, double p, + float intensity_target) { + extras::ColorHints color_hints; + if (!colorspace_hint.empty()) { + color_hints.Add("color_space", colorspace_hint); + } + + CodecInOut io1; + ThreadPoolInternal pool(4); + if (!SetFromFile(pathname1, color_hints, &io1, &pool)) { + fprintf(stderr, "Failed to read image from %s\n", pathname1); + return false; + } + + CodecInOut io2; + if (!SetFromFile(pathname2, color_hints, &io2, &pool)) { + fprintf(stderr, "Failed to read image from %s\n", pathname2); + return false; + } + + if (io1.xsize() != io2.xsize()) { + fprintf(stderr, "Width mismatch: %" PRIuS " %" PRIuS "\n", io1.xsize(), + io2.xsize()); + return false; + } + if (io1.ysize() != io2.ysize()) { + fprintf(stderr, "Height mismatch: %" PRIuS " %" PRIuS "\n", io1.ysize(), + io2.ysize()); + return false; + } + + ImageF distmap; + ButteraugliParams ba_params; + ba_params.hf_asymmetry = 0.8f; + ba_params.xmul = 1.0f; + ba_params.intensity_target = intensity_target; + const float distance = ButteraugliDistance(io1.Main(), io2.Main(), ba_params, + GetJxlCms(), &distmap, &pool); + printf("%.10f\n", distance); + + double pnorm = ComputeDistanceP(distmap, ba_params, p); + printf("%g-norm: %f\n", p, pnorm); + + if (!distmap_filename.empty()) { + float good = ButteraugliFuzzyInverse(1.5); + float bad = ButteraugliFuzzyInverse(0.5); + JXL_CHECK( + WriteImage(CreateHeatMapImage(distmap, good, bad), distmap_filename)); + } + return true; +} + +} // namespace +} // namespace jxl + +int main(int argc, char** argv) { + if (argc < 3) { + fprintf(stderr, + "Usage: %s \n" + " [--distmap ]\n" + " [--intensity_target ]\n" + " [--colorspace ]\n" + " [--pnorm ]\n" + "NOTE: images get converted to linear sRGB for butteraugli. Images" + " without attached profiles (such as ppm or pfm) are interpreted" + " as nonlinear sRGB. The hint format is RGB_D65_SRG_Rel_Lin for" + " linear sRGB. Intensity target is viewing conditions screen nits" + ", defaults to 80.\n", + argv[0]); + return 1; + } + std::string distmap; + std::string colorspace; + double p = 3; + float intensity_target = 80.0; // sRGB intensity target. + for (int i = 3; i < argc; i++) { + if (std::string(argv[i]) == "--distmap" && i + 1 < argc) { + distmap = argv[++i]; + } else if (std::string(argv[i]) == "--colorspace" && i + 1 < argc) { + colorspace = argv[++i]; + } else if (std::string(argv[i]) == "--intensity_target" && i + 1 < argc) { + intensity_target = std::stof(std::string(argv[++i])); + } else if (std::string(argv[i]) == "--pnorm" && i + 1 < argc) { + char* end; + p = strtod(argv[++i], &end); + if (end == argv[i]) { + fprintf(stderr, "Failed to parse pnorm \"%s\".\n", argv[i]); + return 1; + } + } else { + fprintf(stderr, "Unrecognized flag \"%s\".\n", argv[i]); + return 1; + } + } + + return jxl::RunButteraugli(argv[1], argv[2], distmap, colorspace, p, + intensity_target) + ? 0 + : 1; +} diff --git a/media/libjxl/src/tools/check_author.py b/media/libjxl/src/tools/check_author.py new file mode 100644 index 000000000..ae1c2798f --- /dev/null +++ b/media/libjxl/src/tools/check_author.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + + +"""check_author.py: Check that a given author is listed in the AUTHORS file.""" + +import argparse +import fnmatch +import os +import re +import sys + + +def IsAuthorInFile(email, name, filename): + """Return whether we find the name/email in the authors filename""" + # Organization emails have emails listed as <*@domain.com>. This matches those + # patterns. + email_pattern_regex = re.compile(r'.*<([^>]+)>') + + with open(filename, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('#') or not line: + continue + # Exact match for a line without an email is OK. + if line == name: + return True + # Exact email address match is OK, even if the name is different. + if fnmatch.fnmatch(line, '* <%s>' % email): + print( + "User %s <%s> matched with different name %s" % (name, email, line), + file=sys.stderr) + return True + # Organizations often have *@domain.com email patterns which don't match + # the name. + if '*' in line: + m = email_pattern_regex.match(line) + if m and fnmatch.fnmatch(email, m.group(1)): + print("User %s <%s> matched pattern %s" % (name, email, line), + file=sys.stderr) + return True + return False + +def IndividualsInAlphabeticOrder(filename): + """Checks if the names are in alphabetic order""" + with open(filename, 'r') as f: + lines = f.readlines() + individual_header = '# Individuals:\n' + if individual_header in lines: + individual_authors = lines[lines.index(individual_header) + 1:] + sorted_authors = sorted(individual_authors, key=str.casefold) + if sorted_authors == individual_authors: + print("Individual authors are sorted alphabetically.") + return True + else: + print("Individual authors are not sorted alphabetically." + " The expected order is:") + print(''.join(sorted_authors)) + return False + else: + print("Cannot find line '# Individuals:' in file.") + return False + + +def CheckAuthor(args): + authors_path = os.path.join(args.source_dir, 'AUTHORS') + author_in_file = IsAuthorInFile( + args.email, args.name, authors_path) + if not author_in_file: + print("User %s <%s> not found, please add yourself to the AUTHORS file" % ( + args.name, args.email), + file=sys.stderr) + + sorted_alphabetically = IndividualsInAlphabeticOrder(authors_path) + if not sorted_alphabetically: + print("Authors not in alphabetical order, please sort them.", file=sys.stderr) + if not author_in_file or not sorted_alphabetically: + if not args.dry_run: + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('email', type=str, + help='email of the commit author to check') + parser.add_argument('name', type=str, + help='name of the commit author to check') + parser.add_argument( + '--source-dir', + default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), + help='path to the source directory where the AUTHORS file is located') + parser.add_argument('--dry-run', default=False, action='store_true', + help='Don\'t return an exit code in case of failure') + args = parser.parse_args() + CheckAuthor(args) + + +if __name__ == '__main__': + main() diff --git a/media/libjxl/src/tools/cjpeg_hdr.cc b/media/libjxl/src/tools/cjpeg_hdr.cc new file mode 100644 index 000000000..cfe272ee2 --- /dev/null +++ b/media/libjxl/src/tools/cjpeg_hdr.cc @@ -0,0 +1,306 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tools/cjpeg_hdr.cc" +#include +#include + +#include "lib/extras/codec.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/enc_transforms.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#include "lib/jxl/quant_weights.h" + +HWY_BEFORE_NAMESPACE(); +namespace jpegxl { +namespace tools { +namespace HWY_NAMESPACE { +void FillJPEGData(const jxl::Image3F& ycbcr, const jxl::PaddedBytes& icc, + const jxl::ImageF& quant_field, + const jxl::FrameDimensions& frame_dim, + jxl::jpeg::JPEGData* out) { + // JFIF + out->marker_order.push_back(0xE0); + out->app_data.emplace_back(std::vector{ + 0xe0, // Marker + 0, 16, // Length + 'J', 'F', 'I', 'F', '\0', // ID + 1, 1, // Version (1.1) + 0, // No density units + 0, 1, 0, 1, // Pixel density 1 + 0, 0 // No thumbnail + }); + // ICC + if (!icc.empty()) { + out->marker_order.push_back(0xE2); + std::vector icc_marker(17 + icc.size()); + icc_marker[0] = 0xe2; + icc_marker[1] = (icc_marker.size() - 1) >> 8; + icc_marker[2] = (icc_marker.size() - 1) & 0xFF; + memcpy(&icc_marker[3], "ICC_PROFILE", 12); + icc_marker[15] = 1; + icc_marker[16] = 1; + memcpy(&icc_marker[17], icc.data(), icc.size()); + out->app_data.push_back(std::move(icc_marker)); + } + + // DQT + out->marker_order.emplace_back(0xdb); + out->quant.resize(2); + out->quant[0].is_last = false; + out->quant[0].index = 0; + out->quant[1].is_last = true; + out->quant[1].index = 1; + jxl::DequantMatrices dequant; + + // mozjpeg q99 + int qluma[64] = { + 1, 1, 1, 1, 1, 1, 1, 2, // + 1, 1, 1, 1, 1, 1, 1, 2, // + 1, 1, 1, 1, 1, 1, 2, 3, // + 1, 1, 1, 1, 1, 1, 2, 3, // + 1, 1, 1, 1, 1, 2, 3, 4, // + 1, 1, 1, 1, 2, 2, 3, 5, // + 1, 1, 2, 2, 3, 3, 5, 6, // + 2, 2, 3, 3, 4, 5, 6, 8, // + }; + // mozjpeg q95 + int qchroma[64] = { + 2, 2, 2, 2, 3, 4, 6, 9, // + 2, 2, 2, 3, 3, 4, 5, 8, // + 2, 2, 2, 3, 4, 6, 9, 14, // + 2, 3, 3, 4, 5, 7, 11, 16, // + 3, 3, 4, 5, 7, 9, 13, 19, // + 4, 4, 6, 7, 9, 12, 17, 24, // + 6, 5, 9, 11, 13, 17, 23, 31, // + 9, 8, 14, 16, 19, 24, 31, 42, // + }; + // Disable quantization for now. + std::fill(std::begin(qluma), std::end(qluma), 1); + std::fill(std::begin(qchroma), std::end(qchroma), 1); + + memcpy(out->quant[0].values.data(), qluma, sizeof(qluma)); + memcpy(out->quant[1].values.data(), qchroma, sizeof(qchroma)); + + // SOF + out->marker_order.emplace_back(0xc2); + out->components.resize(3); + out->height = frame_dim.ysize; + out->width = frame_dim.xsize_padded; + out->components[0].id = 1; + out->components[1].id = 2; + out->components[2].id = 3; + out->components[0].h_samp_factor = out->components[1].h_samp_factor = + out->components[2].h_samp_factor = out->components[0].v_samp_factor = + out->components[1].v_samp_factor = out->components[2].v_samp_factor = + 1; + out->components[0].width_in_blocks = out->components[1].width_in_blocks = + out->components[2].width_in_blocks = frame_dim.xsize_blocks; + out->components[0].quant_idx = 0; + out->components[1].quant_idx = 1; + out->components[2].quant_idx = 1; + out->components[0].coeffs.resize(frame_dim.xsize_blocks * + frame_dim.ysize_blocks * 64); + out->components[1].coeffs.resize(frame_dim.xsize_blocks * + frame_dim.ysize_blocks * 64); + out->components[2].coeffs.resize(frame_dim.xsize_blocks * + frame_dim.ysize_blocks * 64); + + HWY_ALIGN float scratch_space[2 * 64]; + + for (size_t c = 0; c < 3; c++) { + int* qt = c == 0 ? qluma : qchroma; + for (size_t by = 0; by < frame_dim.ysize_blocks; by++) { + for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) { + float deadzone = 0.5f / quant_field.Row(by)[bx]; + // Disable quantization for now. + deadzone = 0; + auto q = [&](float coeff, size_t x, size_t y) -> int { + size_t pos = x * 8 + y; + float scoeff = coeff / qt[pos]; + if (pos == 0) { + return std::round(scoeff); + } + if (std::abs(scoeff) < deadzone) return 0; + if (std::abs(scoeff) < 2 * deadzone && x + y >= 7) return 0; + return std::round(scoeff); + }; + HWY_ALIGN float dct[64]; + TransformFromPixels(jxl::AcStrategy::Type::DCT, + ycbcr.PlaneRow(c, 8 * by) + 8 * bx, + ycbcr.PixelsPerRow(), dct, scratch_space); + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + float coeff = dct[iy * 8 + ix] * 2040; // not a typo + out->components[c] + .coeffs[(frame_dim.xsize_blocks * by + bx) * 64 + ix * 8 + iy] = + q(coeff, ix, iy); + } + } + } + } + } + + // DHT + // TODO: optimize + out->marker_order.emplace_back(0xC4); + out->huffman_code.resize(2); + out->huffman_code[0].slot_id = 0x00; // DC + out->huffman_code[0].counts = {{0, 0, 0, 0, 13}}; + std::iota(out->huffman_code[0].values.begin(), + out->huffman_code[0].values.end(), 0); + out->huffman_code[0].is_last = false; + + out->huffman_code[1].slot_id = 0x10; // AC + out->huffman_code[1].counts = {{0, 0, 0, 0, 0, 0, 0, 0, 255}}; + std::iota(out->huffman_code[1].values.begin(), + out->huffman_code[1].values.end(), 0); + out->huffman_code[1].is_last = true; + + // SOS + for (size_t _ = 0; _ < 7; _++) { + out->marker_order.emplace_back(0xDA); + } + out->scan_info.resize(7); + // DC + // comp id, DC tbl, AC tbl + out->scan_info[0].num_components = 3; + out->scan_info[0].components = {{jxl::jpeg::JPEGComponentScanInfo{0, 0, 0}, + jxl::jpeg::JPEGComponentScanInfo{1, 0, 0}, + jxl::jpeg::JPEGComponentScanInfo{2, 0, 0}}}; + out->scan_info[0].Ss = 0; + out->scan_info[0].Se = 0; + out->scan_info[0].Ah = out->scan_info[0].Al = 0; + // AC 1 - highest bits + out->scan_info[1].num_components = 1; + out->scan_info[1].components = {{jxl::jpeg::JPEGComponentScanInfo{0, 0, 0}}}; + out->scan_info[1].Ss = 1; + out->scan_info[1].Se = 63; + out->scan_info[1].Ah = 0; + out->scan_info[1].Al = 1; + + // Copy for X / B-Y + out->scan_info[2] = out->scan_info[1]; + out->scan_info[2].components[0].comp_idx = 1; + out->scan_info[3] = out->scan_info[1]; + out->scan_info[3].components[0].comp_idx = 2; + + // AC 2 - lowest bit + out->scan_info[4].num_components = 1; + out->scan_info[4].components = {{jxl::jpeg::JPEGComponentScanInfo{0, 0, 0}}}; + out->scan_info[4].Ss = 1; + out->scan_info[4].Se = 63; + out->scan_info[4].Ah = 1; + out->scan_info[4].Al = 0; + + // Copy for X / B-Y + out->scan_info[5] = out->scan_info[4]; + out->scan_info[5].components[0].comp_idx = 1; + out->scan_info[6] = out->scan_info[4]; + out->scan_info[6].components[0].comp_idx = 2; + + // EOI + out->marker_order.push_back(0xd9); +} +} // namespace HWY_NAMESPACE +} // namespace tools +} // namespace jpegxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jpegxl { +namespace tools { + +HWY_EXPORT(FillJPEGData); + +int HBDJPEGMain(int argc, const char* argv[]) { + if (argc < 3) { + fprintf(stderr, "Usage: %s input output.jpg\n", argv[0]); + return 1; + } + fprintf(stderr, "Compressing %s to %s\n", argv[1], argv[2]); + jxl::CodecInOut io; + if (!jxl::SetFromFile(argv[1], jxl::extras::ColorHints{}, &io)) { + fprintf(stderr, "Failed to read image %s.\n", argv[1]); + return 1; + } + jxl::Image3F ycbcr(jxl::RoundUpToBlockDim(io.xsize()), + jxl::RoundUpToBlockDim(io.ysize())); + ycbcr.ShrinkTo(io.xsize(), io.ysize()); + jxl::FrameDimensions frame_dim; + frame_dim.Set(io.xsize(), io.ysize(), 0, 0, 0, false, 1); + for (size_t y = 0; y < ycbcr.ysize(); y++) { + for (size_t x = 0; x < ycbcr.xsize(); x++) { + float r = io.Main().color()->PlaneRow(0, y)[x]; + float g = io.Main().color()->PlaneRow(1, y)[x]; + float b = io.Main().color()->PlaneRow(2, y)[x]; + ycbcr.PlaneRow(0, y)[x] = + 0.299 * r + 0.587 * g + 0.114 * b - (128. / 255.); + ycbcr.PlaneRow(1, y)[x] = -0.168736 * r - 0.331264 * g + 0.5 * b; + ycbcr.PlaneRow(2, y)[x] = 0.5 * r - 0.418688 * g - 0.081312 * b; + } + } + jxl::Image3F rgb2(ycbcr.xsize(), ycbcr.ysize()); + jxl::Image3F ycbcr2(ycbcr.xsize(), ycbcr.ysize()); + for (size_t y = 0; y < ycbcr.ysize(); y++) { + for (size_t x = 0; x < ycbcr.xsize(); x++) { + ycbcr2.PlaneRow(0, y)[x] = ycbcr.PlaneRow(1, y)[x]; + ycbcr2.PlaneRow(1, y)[x] = ycbcr.PlaneRow(0, y)[x]; + ycbcr2.PlaneRow(2, y)[x] = ycbcr.PlaneRow(2, y)[x]; + } + } + jxl::YcbcrToRgb(ycbcr2, &rgb2, jxl::Rect(ycbcr)); + + PadImageToBlockMultipleInPlace(&ycbcr); + + jxl::Image3F opsin(jxl::RoundUpToBlockDim(io.xsize()), + jxl::RoundUpToBlockDim(io.ysize())); + opsin.ShrinkTo(io.xsize(), io.ysize()); + jxl::ToXYB(io.Main(), nullptr, &opsin, jxl::GetJxlCms()); + PadImageToBlockMultipleInPlace(&opsin); + jxl::ImageF mask; + jxl::ImageF qf = + InitialQuantField(1.0, opsin, frame_dim, nullptr, 1.0, &mask); + + jxl::CodecInOut out; + out.Main().jpeg_data = jxl::make_unique(); + HWY_DYNAMIC_DISPATCH(FillJPEGData) + (ycbcr, io.metadata.m.color_encoding.ICC(), qf, frame_dim, + out.Main().jpeg_data.get()); + jxl::PaddedBytes output; + if (!jxl::jpeg::EncodeImageJPGCoefficients(&out, &output)) { + return 1; + } + if (!jxl::WriteFile(output, argv[2])) { + fprintf(stderr, "Failed to write to \"%s\"\n", argv[2]); + return 1; + } + return 0; +} + +} // namespace tools +} // namespace jpegxl + +int main(int argc, const char** argv) { + return jpegxl::tools::HBDJPEGMain(argc, argv); +} +#endif diff --git a/media/libjxl/src/tools/cjxl_bisect_bpp b/media/libjxl/src/tools/cjxl_bisect_bpp new file mode 100644 index 000000000..d7a1066e1 --- /dev/null +++ b/media/libjxl/src/tools/cjxl_bisect_bpp @@ -0,0 +1,40 @@ +#!/bin/sh +# +# Bisects JPEG XL encoding quality parameter to reach a given +# target bits-per-pixel value. +# (To be used directly, or as a template for tailored processing.) +# +# Usage: cjxl_bisect_size {input_filename} {output_filename} {target_bpp} + +# +# We take the `bisector` tool from $PATH, or, if not available, +# try to locate it in the same directory as the current script. +# The `get_bpp` helper is taken from the same directory as the current script. +# + +input_filename=$1 +output_filename=$2 +target_size=$3 + +script_dir=$(dirname $(readlink -f $0)) +bisect_tool=$(which bisector) +if [ -z $bisect_tool ] ; then + bisect_tool="${script_dir}/bisector" +fi +jxl_get_bpp_helper="${script_dir}/jxl_get_bpp_helper" +# If $CJXL_BIN is set, we use this instead of looking for `cjxl` on $PATH. + +cjxl_bin=${CJXL_BIN} +if [ -z $cjxl_bin ] ; then + cjxl_bin="cjxl" +fi + +# Using `identify` from ImageMagick here. +num_pixels=$(identify -format "%w*%h\n" /tmp/baseball.png|bc) + +# Allow 0.5% tolerance in size (--rtol=0.005). +exec $bisect_tool --var=BISECT --range=0.01,15.0 --target=$target_size \ + --rtol_val=0.005 \ + --cmd="$cjxl_bin --distance=\$BISECT ${input_filename} ${output_filename}_bisect_\$BISECT.jxl ; (find ${output_filename}_bisect_\$BISECT.jxl -printf \"scale=10;%s/$num_pixels\n\" | bc -l)" \ + --final="mv ${output_filename}_bisect_\$BISECT.jxl ${output_filename}; rm -f ${output_filename}_bisect_*.jxl" \ + --verbosity=1 diff --git a/media/libjxl/src/tools/cjxl_bisect_size b/media/libjxl/src/tools/cjxl_bisect_size new file mode 100644 index 000000000..9cd88ea52 --- /dev/null +++ b/media/libjxl/src/tools/cjxl_bisect_size @@ -0,0 +1,36 @@ +#!/bin/sh +# +# Bisects JPEG XL encoding quality parameter to reach a given +# target byte-size. +# (To be used directly, or as a template for tailored processing.) +# +# Usage: cjxl_bisect_size {input_filename} {output_filename} {target_size} + +# +# We take the `bisector` tool from $PATH, or, if not available, +# try to locate it in the same directory as the current script. +# + +input_filename=$1 +output_filename=$2 +target_size=$3 + +script_dir=$(dirname $(readlink -f $0)) +bisect_tool=$(which bisector) +if [ -z $bisect_tool ] ; then + bisect_tool="${script_dir}/bisector" +fi + +# If $CJXL_BIN is set, we use this instead of looking for `cjxl` on $PATH. + +cjxl_bin=${CJXL_BIN} +if [-z $cjxl_bin ] ; then + cjxl_bin="cjxl" +fi + +# Allow 0.5% tolerance in size (--rtol=0.005). +exec $bisect_tool --var=BISECT --range=0.01,10.0 --target=$target_size \ + --rtol_val=0.005 \ + --cmd="$cjxl_bin --distance=\$BISECT ${input_filename} ${output_filename}_bisect_\$BISECT.jxl && wc -c ${output_filename}_bisect_\$BISECT.jxl" \ + --final="mv ${output_filename}_bisect_\$BISECT.jxl ${output_filename}; rm -f ${output_filename}_bisect_*.jxl" \ + --verbosity=1 diff --git a/media/libjxl/src/tools/cjxl_fuzzer.cc b/media/libjxl/src/tools/cjxl_fuzzer.cc new file mode 100644 index 000000000..f3a1d9f9d --- /dev/null +++ b/media/libjxl/src/tools/cjxl_fuzzer.cc @@ -0,0 +1,231 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "hwy/targets.h" +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "jxl/thread_parallel_runner.h" +#include "jxl/thread_parallel_runner_cxx.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/test_image.h" + +namespace { + +#define TRY(expr) \ + do { \ + if (JXL_ENC_SUCCESS != (expr)) return false; \ + } while (0) + +struct FuzzSpec { + size_t xsize; + size_t ysize; + struct OptionSpec { + JxlEncoderFrameSettingId id; + int32_t value; + }; + std::vector options; + bool is_jpeg = false; + bool lossless = false; + bool have_alpha = false; + bool premultiply = false; + bool orig_profile = true; + uint16_t pixels_seed = 0; + uint16_t alpha_seed = 0; + size_t bit_depth = 8; + size_t alpha_bit_depth = 8; + int32_t codestream_level = -1; + std::vector icc; + JxlColorEncoding color_encoding; + size_t num_frames = 1; + size_t output_buffer_size = 1; +}; + +bool EncodeJpegXl(const FuzzSpec& spec) { + // Multi-threaded parallel runner. Limit to max 2 threads since the fuzzer + // itself is already multithreaded. + size_t num_threads = + std::min(2, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + auto runner = JxlThreadParallelRunnerMake(nullptr, num_threads); + JxlEncoderPtr enc_ptr = JxlEncoderMake(/*memory_manager=*/nullptr); + JxlEncoder* enc = enc_ptr.get(); + for (size_t num_rep = 0; num_rep < 2; ++num_rep) { + JxlEncoderReset(enc); + TRY(JxlEncoderSetParallelRunner(enc, JxlThreadParallelRunner, + runner.get())); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + + for (auto option : spec.options) { + TRY(JxlEncoderFrameSettingsSetOption(frame_settings, option.id, + option.value)); + } + + TRY(JxlEncoderSetCodestreamLevel(enc, spec.codestream_level)); + JxlBasicInfo basic_info; + JxlEncoderInitBasicInfo(&basic_info); + basic_info.xsize = spec.xsize; + basic_info.ysize = spec.ysize; + basic_info.bits_per_sample = spec.bit_depth; + basic_info.uses_original_profile = spec.orig_profile; + if (spec.have_alpha) { + basic_info.alpha_bits = spec.alpha_bit_depth; + basic_info.num_extra_channels = 1; + } + TRY(JxlEncoderSetBasicInfo(enc, &basic_info)); + if (spec.lossless) { + TRY(JxlEncoderSetFrameLossless(frame_settings, JXL_TRUE)); + } + + // TODO(szabadka) Add icc color profiles. + TRY(JxlEncoderSetColorEncoding(enc, &spec.color_encoding)); + + // TODO(szabadka) Add jpeg frames. + for (size_t i = 0; i < spec.num_frames; ++i) { + JxlFrameHeader frame_header; + JxlEncoderInitFrameHeader(&frame_header); + // TODO(szabadka) Add more frame header options. + TRY(JxlEncoderSetFrameHeader(frame_settings, &frame_header)); + if (spec.have_alpha) { + JxlExtraChannelInfo extra_channel_info; + JxlEncoderInitExtraChannelInfo(JXL_CHANNEL_ALPHA, &extra_channel_info); + TRY(JxlEncoderSetExtraChannelInfo(enc, 0, &extra_channel_info)); + extra_channel_info.alpha_premultiplied = spec.premultiply; + } + JxlPixelFormat pixelformat = {3, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage( + spec.xsize, spec.ysize, 3, spec.pixels_seed); + TRY(JxlEncoderAddImageFrame(frame_settings, &pixelformat, pixels.data(), + pixels.size())); + if (spec.have_alpha) { + std::vector alpha_pixels = jxl::test::GetSomeTestImage( + spec.xsize, spec.ysize, 1, spec.alpha_seed); + TRY(JxlEncoderSetExtraChannelBuffer(frame_settings, &pixelformat, + alpha_pixels.data(), + alpha_pixels.size(), 0)); + } + } + // Reading compressed output + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + std::vector buf(spec.output_buffer_size); + uint8_t* next_out = buf.data(); + size_t avail_out = buf.size(); + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + } + if (JXL_ENC_SUCCESS != process_result) { + return false; + } + } + return true; +} + +template +T Select(const std::vector& vec, std::function get_index) { + return vec[get_index(vec.size() - 1)]; +} + +int TestOneInput(const uint8_t* data, size_t size) { + uint64_t flags = 0; + size_t flag_bits = 0; + + const auto consume_data = [&]() { + if (size < 4) abort(); + uint32_t buf = 0; + memcpy(&buf, data, 4); + data += 4; + size -= 4; + flags = (flags << 32) | buf; + flag_bits += 32; + }; + + const auto get_flag = [&](size_t max_value) { + size_t limit = 1; + while (limit <= max_value) { + limit <<= 1; + --flag_bits; + if (flag_bits <= 16) { + consume_data(); + } + } + uint32_t result = flags % limit; + flags /= limit; + return result % (max_value + 1); + }; + + std::vector colorspaces = { + JXL_COLOR_SPACE_RGB, JXL_COLOR_SPACE_GRAY, JXL_COLOR_SPACE_XYB, + JXL_COLOR_SPACE_UNKNOWN}; + std::vector whitepoints = { + JXL_WHITE_POINT_D65, JXL_WHITE_POINT_CUSTOM, JXL_WHITE_POINT_E, + JXL_WHITE_POINT_DCI}; + std::vector primaries = {JXL_PRIMARIES_SRGB, + JXL_PRIMARIES_CUSTOM, + JXL_PRIMARIES_2100, JXL_PRIMARIES_P3}; + std::vector transfer_functions = { + JXL_TRANSFER_FUNCTION_709, JXL_TRANSFER_FUNCTION_UNKNOWN, + JXL_TRANSFER_FUNCTION_LINEAR, JXL_TRANSFER_FUNCTION_SRGB, + JXL_TRANSFER_FUNCTION_PQ, JXL_TRANSFER_FUNCTION_DCI, + JXL_TRANSFER_FUNCTION_HLG, JXL_TRANSFER_FUNCTION_GAMMA}; + std::vector rendering_intents = { + JXL_RENDERING_INTENT_PERCEPTUAL, + JXL_RENDERING_INTENT_RELATIVE, + JXL_RENDERING_INTENT_SATURATION, + JXL_RENDERING_INTENT_ABSOLUTE, + }; + + FuzzSpec spec; + // Randomly set some options. + // TODO(szabadka) Make value bounds option specific. + size_t num_options = get_flag(32); + for (size_t i = 0; i < num_options; ++i) { + FuzzSpec::OptionSpec option; + option.id = static_cast(get_flag(32)); + option.value = static_cast(get_flag(16)) - 1; + spec.options.push_back(option); + } + + spec.xsize = get_flag(4095) + 1; + spec.ysize = get_flag(4095) + 1; + spec.lossless = get_flag(1); + if (!spec.lossless) { + spec.orig_profile = get_flag(1); + } + spec.have_alpha = get_flag(1); + spec.premultiply = get_flag(1); + spec.pixels_seed = get_flag((1 << 16) - 1); + spec.alpha_seed = get_flag((1 << 16) - 1); + spec.bit_depth = get_flag(15) + 1; + spec.alpha_bit_depth = get_flag(15) + 1; + spec.color_encoding.color_space = Select(colorspaces, get_flag); + spec.color_encoding.white_point = Select(whitepoints, get_flag); + spec.color_encoding.primaries = Select(primaries, get_flag); + spec.color_encoding.transfer_function = Select(transfer_functions, get_flag); + spec.color_encoding.rendering_intent = Select(rendering_intents, get_flag); + spec.output_buffer_size = get_flag(4095) + 1; + + const auto targets = hwy::SupportedAndGeneratedTargets(); + hwy::SetSupportedTargetsForTest(Select(targets, get_flag)); + EncodeJpegXl(spec); + hwy::SetSupportedTargetsForTest(0); + + return 0; +} + +} // namespace + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + return TestOneInput(data, size); +} diff --git a/media/libjxl/src/tools/cjxl_main.cc b/media/libjxl/src/tools/cjxl_main.cc new file mode 100644 index 000000000..e43bb2707 --- /dev/null +++ b/media/libjxl/src/tools/cjxl_main.cc @@ -0,0 +1,1215 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Note: This encoder binary does extensive flag-validity checking (in +// order to produce meaningful error messages), and on top of that +// checks all libjxl C API call return values. The downside of this +// vs. libjxl providing meaningful error messages is that a change to +// the accepted range of a flag-specified parameter in libjxl will +// also require a change to the range-check here. The advantage is +// that this minimizes the size of libjxl. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "jxl/codestream_header.h" +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "jxl/thread_parallel_runner.h" +#include "jxl/thread_parallel_runner_cxx.h" +#include "jxl/types.h" +#include "lib/extras/dec/apng.h" +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/dec/gif.h" +#include "lib/extras/dec/jpg.h" +#include "lib/extras/dec/pgx.h" +#include "lib/extras/dec/pnm.h" +#include "lib/extras/time.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/exif.h" +#include "lib/jxl/size_constraints.h" +#include "tools/args.h" +#include "tools/cmdline.h" +#include "tools/codec_config.h" +#include "tools/file_io.h" +#include "tools/speed_stats.h" + +namespace jpegxl { +namespace tools { + +namespace { +inline bool ParsePhotonNoiseParameter(const char* arg, float* out) { + return strncmp(arg, "ISO", 3) == 0 && ParseFloat(arg + 3, out) && *out > 0; +} +inline bool ParseIntensityTarget(const char* arg, float* out) { + return ParseFloat(arg, out) && *out > 0; +} + +} // namespace + +enum CjxlRetCode : int { + OK = 0, + ERR_PARSE, + ERR_INVALID_ARG, + ERR_LOAD_INPUT, + ERR_INVALID_INPUT, + ERR_ENCODING, + ERR_CONTAINER, + ERR_WRITE, + DROPPED_JBRD, +}; + +struct CompressArgs { + // CompressArgs() = default; + void AddCommandLineOptions(CommandLineParser* cmdline) { + // Positional arguments. + cmdline->AddPositionalOption("INPUT", /* required = */ true, + "the input can be " +#if JPEGXL_ENABLE_APNG + "PNG, APNG, " +#endif +#if JPEGXL_ENABLE_GIF + "GIF, " +#endif +#if JPEGXL_ENABLE_JPEG + "JPEG, " +#else + "JPEG (lossless recompression only), " +#endif +#if JPEGXL_ENABLE_EXR + "EXR, " +#endif + "PPM, PFM, or PGX", + &file_in); + cmdline->AddPositionalOption( + "OUTPUT", /* required = */ true, + "the compressed JXL output file (can be omitted for benchmarking)", + &file_out); + + // Flags. + // TODO(lode): also add options to add exif/xmp/other metadata in the + // container. + cmdline->AddOptionValue('\0', "container", "0|1", + "0 = Do not encode using container format (strip " + "Exif/XMP/JPEG bitstream reconstruction data)." + "1 = Force using container format \n" + "(default: use only if needed).\n", + &container, &ParseOverride, 1); + + cmdline->AddOptionValue( + '\0', "jpeg_store_metadata", "0|1", + ("If --lossless_jpeg=1, store JPEG reconstruction " + "metadata in the JPEG XL container " + "(for lossless reconstruction of the JPEG codestream)." + "(default: 1)"), + &jpeg_store_metadata, &ParseUnsigned, 2); + + // Target distance/size/bpp + opt_distance_id = cmdline->AddOptionValue( + 'd', "distance", "maxError", + "Max. butteraugli distance, lower = higher quality.\n" + " 0.0 = mathematically lossless. Default for already-lossy input " + "(JPEG/GIF).\n" + " 1.0 = visually lossless. Default for other input.\n" + " Recommended range: 0.5 .. 3.0. Mutually exclusive with --quality.", + &distance, &ParseFloat); + + // High-level options + opt_quality_id = cmdline->AddOptionValue( + 'q', "quality", "QUALITY", + "Quality setting (is remapped to --distance). Range: -inf .. 100.\n" + " 100 = mathematically lossless. Default for already-lossy input " + "(JPEG/GIF).\n" + " Other input gets encoded as per --distance default.\n" + " Positive quality values roughly match libjpeg quality.\n" + " Mutually exclusive with --distance.", + &quality, &ParseFloat); + + cmdline->AddOptionValue( + 'e', "effort", "EFFORT", + "Encoder effort setting. Range: 1 .. 9.\n" + " Default: 7. Higher number is more effort (slower).", + &effort, &ParseUnsigned, -1); + + cmdline->AddOptionValue( + '\0', "brotli_effort", "B_EFFORT", + "Brotli effort setting. Range: 0 .. 11.\n" + " Default: 9. Higher number is more effort (slower).", + &brotli_effort, &ParseUnsigned, -1); + + cmdline->AddOptionValue( + '\0', "faster_decoding", "0|1|2|3|4", + "Favour higher decoding speed. 0 = default, higher " + "values give higher speed at the expense of quality", + &faster_decoding, &ParseUnsigned, 2); + + cmdline->AddOptionFlag('p', "progressive", + "Enable progressive/responsive decoding.", + &progressive, &SetBooleanTrue); + + cmdline->AddOptionValue('\0', "premultiply", "-1|0|1", + "Force premultiplied (associated) alpha.", + &premultiply, &ParseSigned, 1); + + cmdline->AddOptionValue( + '\0', "keep_invisible", "0|1", + "force disable/enable preserving color of invisible " + "pixels (default: 1 if lossless, 0 if lossy).", + &keep_invisible, &ParseOverride, 1); + + cmdline->AddOptionValue( + '\0', "group_order", "0|1", + "Order in which 256x256 groups are stored " + "in the codestream for progressive rendering. " + "Value not provided means 'encoder default', 0 means 'scanline order', " + "1 means 'center-first order'.", + &group_order, &ParseOverride, 1); + + cmdline->AddOptionValue( + '\0', "center_x", "0..XSIZE", + "Determines the horizontal position of center for the center-first " + "group order. The value -1 means 'use the middle of the image', " + "other values [0..xsize) set this to a particular coordinate.", + ¢er_x, &ParseInt64, 1); + + cmdline->AddOptionValue( + '\0', "center_y", "0..YSIZE", + "Determines the vertical position of center for the center-first " + "group order. The value -1 means 'use the middle of the image', " + "other values [0..ysize) set this to a particular coordinate.", + ¢er_y, &ParseInt64, 1); + + // Flags. + cmdline->AddOptionFlag('\0', "progressive_ac", + "Use the progressive mode for AC.", &progressive_ac, + &SetBooleanTrue, 1); + + opt_qprogressive_ac_id = cmdline->AddOptionFlag( + '\0', "qprogressive_ac", + "Use the progressive mode for AC with shift quantization.", + &qprogressive_ac, &SetBooleanTrue, 1); + + cmdline->AddOptionValue( + '\0', "progressive_dc", "num_dc_frames", + "Progressive-DC setting. Valid values are: -1, 0, 1, 2.", + &progressive_dc, &ParseSigned, 1); + + cmdline->AddOptionValue( + 'm', "modular", "0|1", + "Use modular mode (not provided = encoder chooses, 0 = enforce VarDCT, " + "1 = enforce modular mode).", + &modular, &ParseOverride, 1); + + // JPEG modes: parallel Brunsli, pixels to JPEG, or JPEG to Brunsli + opt_lossless_jpeg_id = cmdline->AddOptionValue( + 'j', "lossless_jpeg", "0|1", + "If the input is JPEG, losslessly transcode JPEG, " + "rather than using reencode pixels.", + &lossless_jpeg, &ParseUnsigned, 1); + + cmdline->AddOptionValue( + '\0', "jpeg_reconstruction_cfl", "0|1", + "Enable/disable chroma-from-luma (CFL) for lossless " + "JPEG reconstruction.", + &jpeg_reconstruction_cfl, &ParseOverride, 2); + + cmdline->AddOptionValue( + '\0', "num_threads", "N", + "Number of worker threads (-1 == use machine default, " + "0 == do not use multithreading).", + &num_threads, &ParseSigned, 1); + + cmdline->AddOptionValue('\0', "num_reps", "N", + "How many times to compress. (For benchmarking).", + &num_reps, &ParseUnsigned, 1); + + cmdline->AddOptionValue( + '\0', "photon_noise", "ISO3200", + "Adds noise to the image emulating photographic film noise. " + "The higher the given number, the grainier the image will be. " + "As an example, a value of 100 gives low noise whereas a value " + "of 3200 gives a lot of noise. The default value is 0.", + &photon_noise_iso, &ParsePhotonNoiseParameter, 1); + + cmdline->AddOptionValue( + '\0', "dots", "0|1", + "Force disable/enable dots generation. " + "(not provided = default, 0 = disable, 1 = enable).", + &dots, &ParseOverride, 1); + + cmdline->AddOptionValue( + '\0', "patches", "0|1", + "Force disable/enable patches generation. " + "(not provided = default, 0 = disable, 1 = enable).", + &patches, &ParseOverride, 1); + + cmdline->AddOptionValue( + '\0', "resampling", "-1|1|2|4|8", + "Resampling for extra channels. Default of -1 applies resampling only " + "for low quality. Value 1 does no downsampling (1x1), 2 does 2x2 " + "downsampling, 4 is for 4x4 downsampling, and 8 for 8x8 downsampling.", + &resampling, &ParseSigned, 0); + + cmdline->AddOptionValue( + '\0', "ec_resampling", "-1|1|2|4|8", + "Resampling for extra channels. Default of -1 applies resampling only " + "for low quality. Value 1 does no downsampling (1x1), 2 does 2x2 " + "downsampling, 4 is for 4x4 downsampling, and 8 for 8x8 downsampling.", + &ec_resampling, &ParseSigned, 2); + + cmdline->AddOptionFlag('\0', "already_downsampled", + "Do not downsample the given input before encoding, " + "but still signal that the decoder should upsample.", + &already_downsampled, &SetBooleanTrue, 2); + + cmdline->AddOptionValue( + '\0', "epf", "-1|0|1|2|3", + "Edge preserving filter level, -1 to 3. " + "Value -1 means: default (encoder chooses), 0 to 3 set a strength.", + &epf, &ParseSigned, 1); + + cmdline->AddOptionValue( + '\0', "gaborish", "0|1", + "Force disable/enable the gaborish filter. " + "(not provided = default, 0 = disable, 1 = enable).", + &gaborish, &ParseOverride, 1); + + cmdline->AddOptionValue( + '\0', "intensity_target", "N", + "Upper bound on the intensity level present in the image in nits. " + "Leaving this set to its default of 0 lets libjxl choose a sensible " + "default " + "value based on the color encoding.", + &intensity_target, &ParseIntensityTarget, 1); + + cmdline->AddOptionValue( + 'x', "dec-hints", "key=value", + "color_space indicates the ColorEncoding, see Description();\n" + "icc_pathname refers to a binary file containing an ICC profile.", + &color_hints, &ParseAndAppendKeyValue, 1); + + cmdline->AddOptionValue( + '\0', "override_bitdepth", "0=use from image, 1-32=override", + "If nonzero, store the given bit depth in the JPEG XL file metadata" + " (1-32), instead of using the bit depth from the original input" + " image.", + &override_bitdepth, &ParseUnsigned, 2); + + // modular mode options + cmdline->AddOptionValue( + 'I', "iterations", "F", + "[modular encoding] Fraction of pixels used to learn MA trees as " + "a percentage. -1 = default, 0 = no MA and fast decode, 50 = " + "default value, 100 = all." + "Higher values use more encoder memory.", + &modular_ma_tree_learning_percent, &ParseFloat, 2); + + cmdline->AddOptionValue( + 'C', "modular_colorspace", "K", + ("[modular encoding] color transform: -1=default, 0=RGB (none), " + "1-41=RCT (6=YCoCg, default: try several, depending on speed)"), + &modular_colorspace, &ParseSigned, 1); + + opt_modular_group_size_id = cmdline->AddOptionValue( + 'g', "modular_group_size", "K", + "[modular encoding] group size: -1 == default. 0 => 128, " + "1 => 256, 2 => 512, 3 => 1024", + &modular_group_size, &ParseSigned, 1); + + cmdline->AddOptionValue( + 'P', "modular_predictor", "K", + "[modular encoding] predictor(s) to use: 0=zero, " + "1=left, 2=top, 3=avg0, 4=select, 5=gradient, 6=weighted, " + "7=topright, 8=topleft, 9=leftleft, 10=avg1, 11=avg2, 12=avg3, " + "13=toptop predictive average " + "14=mix 5 and 6, 15=mix everything. If unset, uses default 14, " + "at slowest speed default 15.", + &modular_predictor, &ParseSigned, 1); + + cmdline->AddOptionValue( + 'E', "modular_nb_prev_channels", "K", + "[modular encoding] number of extra MA tree properties to use", + &modular_nb_prev_channels, &ParseSigned, 2); + + cmdline->AddOptionValue( + '\0', "modular_palette_colors", "K", + "[modular encoding] Use color palette if number of colors is smaller " + "than or equal to this, or -1 to use the encoder default.", + &modular_palette_colors, &ParseSigned, 1); + + cmdline->AddOptionFlag( + '\0', "modular_lossy_palette", + "[modular encoding] quantize to a palette that has fewer entries than " + "would be necessary for perfect preservation; for the time being, it " + "is " + "recommended to set --palette=0 with this option to use the default " + "palette only", + &modular_lossy_palette, &SetBooleanTrue, 1); + + cmdline->AddOptionValue( + 'X', "pre-compact", "PERCENT", + "[modular encoding] Use Global channel palette if the number of " + "colors is smaller than this percentage of range. " + "Use 0-100 to set an explicit percentage, -1 to use the encoder " + "default.", + &modular_channel_colors_global_percent, &ParseFloat, 2); + + cmdline->AddOptionValue( + 'Y', "post-compact", "PERCENT", + "[modular encoding] Use Local (per-group) channel palette if the " + "number " + "of colors is smaller than this percentage of range. Use 0-100 to set " + "an explicit percentage, -1 to use the encoder default.", + &modular_channel_colors_group_percent, &ParseFloat, 2); + + cmdline->AddOptionValue('\0', "codestream_level", "K", + "The codestream level. Either `-1`, `5` or `10`.", + &codestream_level, &ParseSigned, 2); + + opt_responsive_id = cmdline->AddOptionValue( + 'R', "responsive", "K", + "[modular encoding] do Squeeze transform, 0=false, " + "1=true (default: true if lossy, false if lossless)", + &responsive, &ParseSigned, 1); + + cmdline->AddOptionFlag('V', "version", + "Print encoder library version number and exit.", + &version, &SetBooleanTrue, 1); + + cmdline->AddOptionFlag('\0', "quiet", "Be more silent", &quiet, + &SetBooleanTrue, 1); + + cmdline->AddOptionValue( + '\0', "frame_indexing", "string", + // TODO(tfish): Add a more convenient vanilla alternative. + "If non-empty, a string matching '^(0*|1[01]*)'. If this string has a " + "'1' in i-th position, then the i-th frame will be indexed in " + "the frame index box.", + &frame_indexing, &ParseString, 1); + + cmdline->AddOptionFlag( + 'v', "verbose", + "Verbose output; can be repeated, also applies to help (!).", &verbose, + &SetBooleanTrue); + } + + // Common flags. + bool version = false; + jxl::Override container = jxl::Override::kDefault; + bool quiet = false; + + const char* file_in = nullptr; + const char* file_out = nullptr; + jxl::Override print_profile = jxl::Override::kDefault; + + // Decoding source image flags + jxl::extras::ColorHints color_hints; + + // JXL flags + size_t override_bitdepth = 0; + int32_t num_threads = -1; + size_t num_reps = 1; + float intensity_target = 0; + + // Filename for the user provided saliency-map. + std::string saliency_map_filename; + + // Whether to perform lossless transcoding with kVarDCT or kJPEG encoding. + // If true, attempts to load JPEG coefficients instead of pixels. + // Reset to false if input image is not a JPEG. + size_t lossless_jpeg = 1; + + size_t jpeg_store_metadata = 1; + + float quality = -1001.f; // Default to lossless if input is already lossy, + // or to VarDCT otherwise. + bool verbose = false; + bool progressive = false; + bool progressive_ac = false; + bool qprogressive_ac = false; + int32_t progressive_dc = -1; + bool modular_lossy_palette = false; + int32_t premultiply = -1; + bool already_downsampled = false; + jxl::Override jpeg_reconstruction_cfl = jxl::Override::kDefault; + jxl::Override modular = jxl::Override::kDefault; + jxl::Override keep_invisible = jxl::Override::kDefault; + jxl::Override dots = jxl::Override::kDefault; + jxl::Override patches = jxl::Override::kDefault; + jxl::Override gaborish = jxl::Override::kDefault; + jxl::Override group_order = jxl::Override::kDefault; + + size_t faster_decoding = 0; + int32_t resampling = -1; + int32_t ec_resampling = -1; + int32_t epf = -1; + int64_t center_x = -1; + int64_t center_y = -1; + int32_t modular_group_size = -1; + int32_t modular_predictor = -1; + int32_t modular_colorspace = -1; + float modular_channel_colors_global_percent = -1.f; + float modular_channel_colors_group_percent = -1.f; + int32_t modular_palette_colors = -1; + int32_t modular_nb_prev_channels = -1; + float modular_ma_tree_learning_percent = -1.f; + float photon_noise_iso = 0; + int32_t codestream_level = -1; + int32_t responsive = -1; + float distance = 1.0; + size_t effort = 7; + size_t brotli_effort = 9; + std::string frame_indexing; + + // Will get passed on to AuxOut. + // jxl::InspectorImage3F inspector_image3f; + + // References (ids) of specific options to check if they were matched. + CommandLineParser::OptionId opt_lossless_jpeg_id = -1; + CommandLineParser::OptionId opt_responsive_id = -1; + CommandLineParser::OptionId opt_distance_id = -1; + CommandLineParser::OptionId opt_quality_id = -1; + CommandLineParser::OptionId opt_qprogressive_ac_id = -1; + CommandLineParser::OptionId opt_modular_group_size_id = -1; +}; + +const char* ModeFromArgs(const CompressArgs& args) { + if (args.lossless_jpeg) return "JPEG"; + if (args.modular == jxl::Override::kOn || args.distance == 0) + return "Modular"; + return "VarDCT"; +} + +std::string DistanceFromArgs(const CompressArgs& args) { + char buf[100]; + if (args.lossless_jpeg) { + snprintf(buf, sizeof(buf), "lossless transcode"); + } else if (args.distance == 0) { + snprintf(buf, sizeof(buf), "lossless"); + } else { + snprintf(buf, sizeof(buf), "d%.3f", args.distance); + } + return buf; +} + +void PrintMode(jxl::extras::PackedPixelFile& ppf, const double decode_mps, + size_t num_bytes, const CompressArgs& args) { + const char* mode = ModeFromArgs(args); + const std::string distance = DistanceFromArgs(args); + if (args.lossless_jpeg) { + fprintf(stderr, "Read JPEG image with %" PRIuS " bytes.\n", num_bytes); + } else { + fprintf(stderr, + "Read %" PRIuS "x%" PRIuS " image, %" PRIuS " bytes, %.1f MP/s\n", + static_cast(ppf.info.xsize), + static_cast(ppf.info.ysize), num_bytes, decode_mps); + } + fprintf(stderr, "Encoding [%s%s, %s, effort: %" PRIuS, + (args.container == jxl::Override::kOn ? "Container | " : ""), mode, + distance.c_str(), args.effort); + if (args.container == jxl::Override::kOn) { + if (args.lossless_jpeg && args.jpeg_store_metadata) + fprintf(stderr, " | JPEG reconstruction data"); + if (!ppf.metadata.exif.empty()) + fprintf(stderr, " | %" PRIuS "-byte Exif", ppf.metadata.exif.size()); + if (!ppf.metadata.xmp.empty()) + fprintf(stderr, " | %" PRIuS "-byte XMP", ppf.metadata.xmp.size()); + if (!ppf.metadata.jumbf.empty()) + fprintf(stderr, " | %" PRIuS "-byte JUMBF", ppf.metadata.jumbf.size()); + } + fprintf(stderr, "], \n"); +} + +} // namespace tools +} // namespace jpegxl + +namespace { + +template +void SetFlagFrameOptionOrDie(const char* flag_name, T flag_value, + JxlEncoderFrameSettings* frame_settings, + JxlEncoderFrameSettingId encoder_option) { + if (JXL_ENC_SUCCESS != + (std::is_same::value + ? JxlEncoderFrameSettingsSetFloatOption(frame_settings, + encoder_option, flag_value) + : JxlEncoderFrameSettingsSetOption(frame_settings, encoder_option, + flag_value))) { + std::cerr << "Setting encoder option from flag --" << flag_name + << " failed." << std::endl; + exit(EXIT_FAILURE); + } +} + +void SetDistanceFromFlags(JxlEncoderFrameSettings* jxl_encoder_frame_settings, + jpegxl::tools::CommandLineParser* cmdline, + jpegxl::tools::CompressArgs* args, + const jxl::extras::Codec& codec) { + bool distance_set = cmdline->GetOption(args->opt_distance_id)->matched(); + bool quality_set = cmdline->GetOption(args->opt_quality_id)->matched(); + if (quality_set) { + if (distance_set) { + std::cerr << "Must not set both --distance and --quality." << std::endl; + exit(EXIT_FAILURE); + } + double distance = args->quality >= 100 ? 0.0 + : args->quality >= 30 + ? 0.1 + (100 - args->quality) * 0.09 + : 6.4 + pow(2.5, (30 - args->quality) / 5.0) / 6.25; + args->distance = distance; + distance_set = true; + } + if (!distance_set) { + bool lossy_input = (codec == jxl::extras::Codec::kJPG || + codec == jxl::extras::Codec::kGIF); + args->distance = lossy_input ? 0.0 : 1.0; + } + if (JXL_ENC_SUCCESS != + JxlEncoderSetFrameDistance(jxl_encoder_frame_settings, args->distance)) { + std::cerr << "Setting frame distance failed." << std::endl; + exit(EXIT_FAILURE); + } +} + +using flag_check_fn = std::function; +using flag_check_float_fn = std::function; + +bool IsJPG(const std::vector& image_data) { + return (image_data.size() >= 2 && image_data[0] == 0xFF && + image_data[1] == 0xD8); +} + +// TODO(tfish): Replace with non-C-API library function. +// Implementation is in extras/. +jxl::Status GetPixeldata(const std::vector& image_data, + const jxl::extras::ColorHints& color_hints, + jxl::extras::PackedPixelFile& ppf, + jxl::extras::Codec& codec) { + // Any valid encoding is larger (ensures codecs can read the first few bytes). + constexpr size_t kMinBytes = 9; + + if (image_data.size() < kMinBytes) return JXL_FAILURE("Input too small."); + jxl::Span encoded(image_data); + + ppf.info.orientation = JXL_ORIENT_IDENTITY; + jxl::SizeConstraints size_constraints; + + const auto choose_codec = [&]() { +#if JPEGXL_ENABLE_APNG + if (jxl::extras::DecodeImageAPNG(encoded, color_hints, size_constraints, + &ppf)) { + return jxl::extras::Codec::kPNG; + } +#endif + if (jxl::extras::DecodeImagePGX(encoded, color_hints, size_constraints, + &ppf)) { + return jxl::extras::Codec::kPGX; + } else if (jxl::extras::DecodeImagePNM(encoded, color_hints, + size_constraints, &ppf)) { + return jxl::extras::Codec::kPNM; + } +#if JPEGXL_ENABLE_GIF + if (jxl::extras::DecodeImageGIF(encoded, color_hints, size_constraints, + &ppf)) { + return jxl::extras::Codec::kGIF; + } +#endif +#if JPEGXL_ENABLE_JPEG + if (jxl::extras::DecodeImageJPG(encoded, color_hints, size_constraints, + &ppf)) { + return jxl::extras::Codec::kJPG; + } +#endif + // TODO(tfish): Bring back EXR and PSD. + return jxl::extras::Codec::kUnknown; + }; + codec = choose_codec(); + if (codec == jxl::extras::Codec::kUnknown) { + return JXL_FAILURE("Codecs failed to decode input."); + } + return true; +} + +} // namespace + +int main(int argc, char** argv) { + std::string version = jpegxl::tools::CodecConfigString(JxlEncoderVersion()); + jpegxl::tools::CompressArgs args; + jpegxl::tools::CommandLineParser cmdline; + args.AddCommandLineOptions(&cmdline); + + if (!cmdline.Parse(argc, const_cast(argv))) { + // Parse already printed the actual error cause. + fprintf(stderr, "Use '%s -h' for more information\n", argv[0]); + return jpegxl::tools::CjxlRetCode::ERR_PARSE; + } + + if (args.version) { + fprintf(stdout, "cjxl %s\n", version.c_str()); + fprintf(stdout, "Copyright (c) the JPEG XL Project\n"); + return jpegxl::tools::CjxlRetCode::OK; + } + + if (!args.quiet) { + fprintf(stderr, "JPEG XL encoder %s\n", version.c_str()); + } + + if (cmdline.HelpFlagPassed() || !args.file_in) { + cmdline.PrintHelp(); + return jpegxl::tools::CjxlRetCode::OK; + } + + if (!args.file_out && !args.quiet) { + fprintf(stderr, + "No output file specified.\n" + "Encoding will be performed, but the result will be discarded.\n"); + } + + // Loading the input. + // Depending on flags-settings, we want to either load a JPEG and + // faithfully convert it to JPEG XL, or load (JPEG or non-JPEG) + // pixel data. + std::vector image_data; + jxl::extras::PackedPixelFile ppf; + jxl::extras::Codec codec = jxl::extras::Codec::kUnknown; + double decode_mps = 0; + size_t pixels = 0; + if (!jpegxl::tools::ReadFile(args.file_in, &image_data)) { + std::cerr << "Reading image data failed." << std::endl; + exit(EXIT_FAILURE); + } + if (!IsJPG(image_data)) args.lossless_jpeg = 0; + if (!args.lossless_jpeg) { + const double t0 = jxl::Now(); + jxl::Status status = GetPixeldata(image_data, args.color_hints, ppf, codec); + if (!status) { + std::cerr << "Getting pixel data." << std::endl; + exit(EXIT_FAILURE); + } + if (ppf.frames.empty()) { + std::cerr << "No frames on input file." << std::endl; + exit(EXIT_FAILURE); + } + + const double t1 = jxl::Now(); + pixels = ppf.info.xsize * ppf.info.ysize; + decode_mps = pixels * ppf.info.num_color_channels * 1E-6 / (t1 - t0); + } + + JxlEncoderPtr enc = JxlEncoderMake(/*memory_manager=*/nullptr); + JxlEncoder* jxl_encoder = enc.get(); + JxlThreadParallelRunnerPtr runner; + std::vector compressed; + size_t num_worker_threads; + jpegxl::tools::SpeedStats stats; + for (size_t num_rep = 0; num_rep < args.num_reps; ++num_rep) { + const double t0 = jxl::Now(); + JxlEncoderReset(jxl_encoder); + if (args.num_threads != 0) { + num_worker_threads = JxlThreadParallelRunnerDefaultNumWorkerThreads(); + { + int64_t flag_num_worker_threads = args.num_threads; + if (flag_num_worker_threads > -1) { + num_worker_threads = flag_num_worker_threads; + } + } + if (runner == nullptr) { + runner = JxlThreadParallelRunnerMake( + /*memory_manager=*/nullptr, num_worker_threads); + } + if (JXL_ENC_SUCCESS != + JxlEncoderSetParallelRunner(jxl_encoder, JxlThreadParallelRunner, + runner.get())) { + std::cerr << "JxlEncoderSetParallelRunner failed." << std::endl; + return EXIT_FAILURE; + } + } + JxlEncoderFrameSettings* jxl_encoder_frame_settings = + JxlEncoderFrameSettingsCreate(jxl_encoder, nullptr); + + auto process_flag = [&jxl_encoder_frame_settings]( + const char* flag_name, int64_t flag_value, + JxlEncoderFrameSettingId encoder_option, + const flag_check_fn& flag_check) { + std::string error = flag_check(flag_value); + if (!error.empty()) { + std::cerr << "Invalid flag value for --" << flag_name << ": " << error + << std::endl; + exit(EXIT_FAILURE); + } + SetFlagFrameOptionOrDie(flag_name, flag_value, jxl_encoder_frame_settings, + encoder_option); + }; + auto process_float_flag = [&jxl_encoder_frame_settings]( + const char* flag_name, float flag_value, + JxlEncoderFrameSettingId encoder_option, + const flag_check_float_fn& flag_check) { + std::string error = flag_check(flag_value); + if (!error.empty()) { + std::cerr << "Invalid flag value for --" << flag_name << ": " << error + << std::endl; + exit(EXIT_FAILURE); + } + SetFlagFrameOptionOrDie(flag_name, flag_value, jxl_encoder_frame_settings, + encoder_option); + }; + + auto process_bool_flag = [&jxl_encoder_frame_settings]( + const char* flag_name, + jxl::Override flag_value, + JxlEncoderFrameSettingId encoder_option) { + if (flag_value != jxl::Override::kDefault) { + SetFlagFrameOptionOrDie(flag_name, + flag_value == jxl::Override::kOn ? 1 : 0, + jxl_encoder_frame_settings, encoder_option); + } + }; + + { // Processing tuning flags. + process_bool_flag("modular", args.modular, JXL_ENC_FRAME_SETTING_MODULAR); + process_bool_flag("keep_invisible", args.keep_invisible, + JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE); + process_bool_flag("dots", args.dots, JXL_ENC_FRAME_SETTING_DOTS); + process_bool_flag("patches", args.patches, JXL_ENC_FRAME_SETTING_PATCHES); + process_bool_flag("gaborish", args.gaborish, + JXL_ENC_FRAME_SETTING_GABORISH); + process_bool_flag("group_order", args.group_order, + JXL_ENC_FRAME_SETTING_GROUP_ORDER); + + if (!args.frame_indexing.empty()) { + bool must_be_all_zeros = args.frame_indexing[0] != '1'; + for (char c : args.frame_indexing) { + if (c == '1') { + if (must_be_all_zeros) { + std::cerr + << "Invalid --frame_indexing. If the first character is " + "'0', all must be '0'." + << std::endl; + return EXIT_FAILURE; + } + } else if (c != '0') { + std::cerr << "Invalid --frame_indexing. Must match the pattern " + "'^(0*|1[01]*)$'." + << std::endl; + return EXIT_FAILURE; + } + } + } + + process_flag( + "effort", args.effort, JXL_ENC_FRAME_SETTING_EFFORT, + [](int64_t x) -> std::string { + return (1 <= x && x <= 9) ? "" : "Valid range is {1, 2, ..., 9}."; + }); + process_flag( + "brotli_effort", args.brotli_effort, + JXL_ENC_FRAME_SETTING_BROTLI_EFFORT, [](int64_t x) -> std::string { + return (-1 <= x && x <= 11) ? "" + : "Valid range is {-1, 0, 1, ..., 11}."; + }); + process_flag("epf", args.epf, JXL_ENC_FRAME_SETTING_EPF, + [](int64_t x) -> std::string { + return (-1 <= x && x <= 3) + ? "" + : "Valid range is {-1, 0, 1, 2, 3}.\n"; + }); + process_flag( + "faster_decoding", args.faster_decoding, + JXL_ENC_FRAME_SETTING_DECODING_SPEED, [](int64_t x) -> std::string { + return (0 <= x && x <= 4) ? "" + : "Valid range is {0, 1, 2, 3, 4}.\n"; + }); + process_flag("resampling", args.resampling, + JXL_ENC_FRAME_SETTING_RESAMPLING, + [](int64_t x) -> std::string { + return (x == -1 || x == 1 || x == 4 || x == 8) + ? "" + : "Valid values are {-1, 1, 2, 4, 8}.\n"; + }); + process_flag("ec_resampling", args.ec_resampling, + JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING, + [](int64_t x) -> std::string { + return (x == -1 || x == 1 || x == 4 || x == 8) + ? "" + : "Valid values are {-1, 1, 2, 4, 8}.\n"; + }); + SetFlagFrameOptionOrDie("photon_noise_iso", args.photon_noise_iso, + jxl_encoder_frame_settings, + JXL_ENC_FRAME_SETTING_PHOTON_NOISE); + SetFlagFrameOptionOrDie("already_downsampled", + static_cast(args.already_downsampled), + jxl_encoder_frame_settings, + JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED); + SetDistanceFromFlags(jxl_encoder_frame_settings, &cmdline, &args, codec); + + if (args.group_order != jxl::Override::kOn && + (args.center_x != -1 || args.center_y != -1)) { + std::cerr + << "Invalid flag combination. Setting --center_x or --center_y " + << "requires setting --group_order=1" << std::endl; + return EXIT_FAILURE; + } + process_flag("center_x", args.center_x, + JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X, + [](int64_t x) -> std::string { + if (x < -1) { + return "Valid values are: -1 or [0 .. xsize)."; + } + return ""; + }); + process_flag("center_y", args.center_y, + JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y, + [](int64_t x) -> std::string { + if (x < -1) { + return "Valid values are: -1 or [0 .. ysize)."; + } + return ""; + }); + } + { // Progressive/responsive mode settings. + bool qprogressive_ac_set = + cmdline.GetOption(args.opt_qprogressive_ac_id)->matched(); + int32_t qprogressive_ac = args.qprogressive_ac ? 1 : 0; + bool responsive_set = + cmdline.GetOption(args.opt_responsive_id)->matched(); + int32_t responsive = args.responsive ? 1 : 0; + + process_flag( + "progressive_dc", args.progressive_dc, + JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, [](int64_t x) -> std::string { + return (-1 <= x && x <= 2) ? "" : "Valid range is {-1, 0, 1, 2}.\n"; + }); + SetFlagFrameOptionOrDie( + "progressive_ac", static_cast(args.progressive_ac), + jxl_encoder_frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC); + + if (args.progressive) { + qprogressive_ac = 1; + qprogressive_ac_set = true; + responsive = 1; + responsive_set = true; + } + if (responsive_set) { + SetFlagFrameOptionOrDie("responsive", responsive, + jxl_encoder_frame_settings, + JXL_ENC_FRAME_SETTING_RESPONSIVE); + } + if (qprogressive_ac_set) { + SetFlagFrameOptionOrDie("qprogressive_ac", qprogressive_ac, + jxl_encoder_frame_settings, + JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC); + } + } + { // Modular mode related. + // TODO(firsching): consider doing more validation after image size is + // known, i.e. set to 512 if 256 would be silly using + // opt_modular_group_size_id. + process_flag("modular_group_size", args.modular_group_size, + JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE, + [](int64_t x) -> std::string { + return (-1 <= x && x <= 3) + ? "" + : "Invalid --modular_group_size. Valid " + "range is {-1, 0, 1, 2, 3}.\n"; + }); + process_flag("modular_predictor", args.modular_predictor, + JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, + [](int64_t x) -> std::string { + return (-1 <= x && x <= 15) + ? "" + : "Invalid --modular_predictor. Valid " + "range is {-1, 0, 1, ..., 15}.\n"; + }); + process_flag( + "modular_colorspace", args.modular_colorspace, + JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE, + [](int64_t x) -> std::string { + return (-1 <= x && x <= 41) + ? "" + : "Invalid --modular_colorspace. Valid range is " + "{-1, 0, 1, ..., 41}.\n"; + }); + process_float_flag( + "modular_ma_tree_learning_percent", + args.modular_ma_tree_learning_percent, + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT, + [](float x) -> std::string { + return -1 <= x && x <= 100 + ? "" + : "Invalid --modular_ma_tree_learning_percent, Valid" + "rang is [-1, 100].\n"; + }); + process_flag("modular_nb_prev_channels", args.modular_nb_prev_channels, + JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS, + [](int64_t x) -> std::string { + return (-1 <= x && x <= 11) + ? "" + : "Invalid --modular_nb_prev_channels. Valid " + "range is {-1, 0, 1, ..., 11}.\n"; + }); + SetFlagFrameOptionOrDie("modular_lossy_palette", + static_cast(args.modular_lossy_palette), + jxl_encoder_frame_settings, + JXL_ENC_FRAME_SETTING_LOSSY_PALETTE); + process_flag("modular_palette_colors", args.modular_palette_colors, + JXL_ENC_FRAME_SETTING_PALETTE_COLORS, + [](int64_t x) -> std::string { + return -1 <= x ? "" + : "Invalid --modular_palette_colors, must " + "be -1 or non-negative\n"; + }); + process_float_flag( + "modular_channel_colors_global_percent", + args.modular_channel_colors_global_percent, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT, + [](float x) -> std::string { + return (-1 <= x && x <= 100) + ? "" + : "Invalid --modular_channel_colors_global_percent. " + "Valid " + "range is [-1, 100].\n"; + }); + process_float_flag( + "modular_channel_colors_group_percent", + args.modular_channel_colors_group_percent, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, + [](float x) -> std::string { + return (-1 <= x && x <= 100) + ? "" + : "Invalid --modular_channel_colors_group_percent. " + "Valid " + "range is [-1, 100].\n"; + }); + } + + bool use_container = args.container == jxl::Override::kOn; + if (!ppf.metadata.exif.empty() || !ppf.metadata.xmp.empty() || + !ppf.metadata.jumbf.empty() || !ppf.metadata.iptc.empty() || + (args.lossless_jpeg && args.jpeg_store_metadata)) { + use_container = true; + } + if (use_container) args.container = jxl::Override::kOn; + + if (!ppf.metadata.exif.empty()) { + jxl::InterpretExif(ppf.metadata.exif, &ppf.info.orientation); + } + + if (JXL_ENC_SUCCESS != + JxlEncoderUseContainer(jxl_encoder, static_cast(use_container))) { + std::cerr << "JxlEncoderUseContainer failed." << std::endl; + return EXIT_FAILURE; + } + + if (num_rep == 0 && !args.quiet) + PrintMode(ppf, decode_mps, image_data.size(), args); + + if (args.lossless_jpeg && IsJPG(image_data)) { + if (!cmdline.GetOption(args.opt_lossless_jpeg_id)->matched()) { + std::cerr << "Note: Implicit-default for JPEG is lossless-transcoding. " + << "To silence this message, set --lossless_jpeg=(1|0)." + << std::endl; + } + if (args.jpeg_store_metadata) { + if (JXL_ENC_SUCCESS != + JxlEncoderStoreJPEGMetadata(jxl_encoder, JXL_TRUE)) { + std::cerr << "Storing JPEG metadata failed. " << std::endl; + return EXIT_FAILURE; + } + } + process_bool_flag("jpeg_reconstruction_cfl", args.jpeg_reconstruction_cfl, + JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL); + if (JXL_ENC_SUCCESS != JxlEncoderAddJPEGFrame(jxl_encoder_frame_settings, + image_data.data(), + image_data.size())) { + std::cerr << "JxlEncoderAddJPEGFrame() failed." << std::endl; + return EXIT_FAILURE; + } + } else { // Do JxlEncoderAddImageFrame(). + size_t num_alpha_channels = 0; // Adjusted below. + { + JxlBasicInfo basic_info = ppf.info; + if (basic_info.alpha_bits > 0) num_alpha_channels = 1; + basic_info.intensity_target = args.intensity_target; + basic_info.num_extra_channels = num_alpha_channels; + basic_info.num_color_channels = ppf.info.num_color_channels; + const bool lossless = args.distance == 0; + basic_info.uses_original_profile = lossless; + if (args.override_bitdepth != 0) { + basic_info.bits_per_sample = args.override_bitdepth; + basic_info.exponent_bits_per_sample = + args.override_bitdepth == 32 ? 8 : 0; + } + if (JXL_ENC_SUCCESS != + JxlEncoderSetCodestreamLevel(jxl_encoder, args.codestream_level)) { + std::cerr << "Setting --codestream_level failed." << std::endl; + return EXIT_FAILURE; + } + if (JXL_ENC_SUCCESS != + JxlEncoderSetBasicInfo(jxl_encoder, &basic_info)) { + std::cerr << "JxlEncoderSetBasicInfo() failed." << std::endl; + return EXIT_FAILURE; + } + if (lossless && + JXL_ENC_SUCCESS != JxlEncoderSetFrameLossless( + jxl_encoder_frame_settings, JXL_TRUE)) { + std::cerr << "JxlEncoderSetFrameLossless() failed." << std::endl; + return EXIT_FAILURE; + } + } + + if (!ppf.icc.empty()) { + if (JXL_ENC_SUCCESS != JxlEncoderSetICCProfile(jxl_encoder, + ppf.icc.data(), + ppf.icc.size())) { + std::cerr << "JxlEncoderSetICCProfile() failed." << std::endl; + return EXIT_FAILURE; + } + } else { + if (JXL_ENC_SUCCESS != + JxlEncoderSetColorEncoding(jxl_encoder, &ppf.color_encoding)) { + std::cerr << "JxlEncoderSetColorEncoding() failed." << std::endl; + return EXIT_FAILURE; + } + } + + for (size_t num_frame = 0; num_frame < ppf.frames.size(); ++num_frame) { + const jxl::extras::PackedFrame& pframe = ppf.frames[num_frame]; + const jxl::extras::PackedImage& pimage = pframe.color; + JxlPixelFormat ppixelformat = pimage.format; + { + if (JXL_ENC_SUCCESS != + JxlEncoderSetFrameHeader(jxl_encoder_frame_settings, + &pframe.frame_info)) { + std::cerr << "JxlEncoderSetFrameHeader() failed." << std::endl; + return EXIT_FAILURE; + } + } + if (num_frame < args.frame_indexing.size() && + args.frame_indexing[num_frame] == '1') { + if (JXL_ENC_SUCCESS != + JxlEncoderFrameSettingsSetOption(jxl_encoder_frame_settings, + JXL_ENC_FRAME_INDEX_BOX, 1)) { + std::cerr << "Setting option JXL_ENC_FRAME_INDEX_BOX failed." + << std::endl; + return EXIT_FAILURE; + } + } + JxlEncoderStatus enc_status; + { + if (num_alpha_channels > 0) { + JxlExtraChannelInfo extra_channel_info; + JxlEncoderInitExtraChannelInfo(JXL_CHANNEL_ALPHA, + &extra_channel_info); + enc_status = JxlEncoderSetExtraChannelInfo(jxl_encoder, 0, + &extra_channel_info); + if (JXL_ENC_SUCCESS != enc_status) { + std::cerr << "JxlEncoderSetExtraChannelInfo() failed." + << std::endl; + return EXIT_FAILURE; + } + if (args.premultiply != -1) { + if (args.premultiply != 0 && args.premultiply != 1) { + std::cerr << "Flag --premultiply must be one of: -1, 0, 1." + << std::endl; + return EXIT_FAILURE; + } + extra_channel_info.alpha_premultiplied = args.premultiply; + } + // We take the extra channel blend info frame_info, but don't do + // clamping. + JxlBlendInfo extra_channel_blend_info = + pframe.frame_info.layer_info.blend_info; + extra_channel_blend_info.clamp = JXL_FALSE; + JxlEncoderSetExtraChannelBlendInfo(jxl_encoder_frame_settings, 0, + &extra_channel_blend_info); + } + enc_status = + JxlEncoderAddImageFrame(jxl_encoder_frame_settings, &ppixelformat, + pimage.pixels(), pimage.pixels_size); + if (JXL_ENC_SUCCESS != enc_status) { + std::cerr << "JxlEncoderAddImageFrame() failed." << std::endl; + return EXIT_FAILURE; + } + // Only set extra channel buffer if is is provided non-interleaved. + if (!pframe.extra_channels.empty()) { + enc_status = JxlEncoderSetExtraChannelBuffer( + jxl_encoder_frame_settings, &ppixelformat, + pframe.extra_channels[0].pixels(), + pframe.extra_channels[0].stride * + pframe.extra_channels[0].ysize, + 0); + if (JXL_ENC_SUCCESS != enc_status) { + std::cerr << "JxlEncoderSetExtraChannelBuffer() failed." + << std::endl; + return EXIT_FAILURE; + } + } + } + } + } + JxlEncoderCloseInput(jxl_encoder); + // Reading compressed output + compressed.clear(); + compressed.resize(4096); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = + JxlEncoderProcessOutput(jxl_encoder, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + if (JXL_ENC_SUCCESS != process_result) { + std::cerr << "JxlEncoderProcessOutput failed." << std::endl; + return EXIT_FAILURE; + } + + const double t1 = jxl::Now(); + stats.NotifyElapsed(t1 - t0); + stats.SetImageSize(ppf.info.xsize, ppf.info.ysize); + } + + if (args.file_out) { + if (!jpegxl::tools::WriteFile(args.file_out, compressed)) { + std::cerr << "Could not write jxl file." << std::endl; + return EXIT_FAILURE; + } + } + if (!args.quiet) { + const double bpp = + static_cast(compressed.size() * jxl::kBitsPerByte) / pixels; + fprintf(stderr, "Compressed to %" PRIuS " bytes ", compressed.size()); + // For lossless jpeg-reconstruction, we don't print some stats, since we + // don't have easy access to the image dimensions. + if (args.container == jxl::Override::kOn) { + fprintf(stderr, "including container "); + } + if (!args.lossless_jpeg) { + fprintf(stderr, "(%.3f bpp%s).\n", bpp / ppf.frames.size(), + ppf.frames.size() == 1 ? "" : "/frame"); + JXL_CHECK(stats.Print(num_worker_threads)); + } else { + fprintf(stderr, "\n"); + } + } + return EXIT_SUCCESS; +} diff --git a/media/libjxl/src/tools/cmdline.cc b/media/libjxl/src/tools/cmdline.cc new file mode 100644 index 000000000..f777c9469 --- /dev/null +++ b/media/libjxl/src/tools/cmdline.cc @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/cmdline.h" + +#include +#include + +namespace jpegxl { +namespace tools { + +void CommandLineParser::PrintHelp() const { + // Use stdout, not stderr, so help can easily be grepped. + FILE* out = stdout; + fprintf(out, "Usage: %s", program_name_ ? program_name_ : "command"); + + for (const auto& option : options_) { + if (option->positional()) { + if (option->verbosity_level() > verbosity) continue; + if (option->required()) { + fprintf(out, " %s", option->help_flags().c_str()); + } else { + fprintf(out, " [%s]", option->help_flags().c_str()); + } + } + } + fprintf(out, " [OPTIONS...]\n"); + + bool showed_all = true; + for (const auto& option : options_) { + if (option->verbosity_level() > verbosity) { + showed_all = false; + continue; + } + fprintf(out, " %s\n", option->help_flags().c_str()); + const char* help_text = option->help_text(); + if (help_text) { + fprintf(out, " %s\n", help_text); + } + } + fprintf(out, " -h, --help\n Prints this help message%s.\n", + (showed_all ? "" : " (use -v to see more options)")); +} + +bool CommandLineParser::Parse(int argc, const char* argv[]) { + if (argc) program_name_ = argv[0]; + int i = 1; // argv[0] is the program name. + // if false, stop matching options and take only positional arguments + bool parse_options = true; + while (i < argc) { + if (!strcmp("-h", argv[i]) || !strcmp("--help", argv[i])) { + help_ = true; + i++; + continue; + } + if (!strcmp("-v", argv[i]) || !strcmp("--verbose", argv[i])) { + verbosity++; + } + // after "--", filenames starting with "-" can be used + if (!strcmp("--", argv[i])) { + parse_options = false; + i++; + continue; + } + // special case: "-" is a filename denoting stdin or stdout + bool parse_this_option = true; + if (!strcmp("-", argv[i])) { + parse_this_option = false; + } + bool found = false; + for (const auto& option : options_) { + if (option->Match(argv[i], parse_options && parse_this_option)) { + // Parsing advances the value i on success. + const char* arg = argv[i]; + if (!option->Parse(argc, argv, &i)) { + fprintf(stderr, "Error parsing flag %s\n", arg); + return false; + } + found = true; + break; + } + } + if (!found) { + // No option matched argv[i]. + fprintf(stderr, "Unknown argument: %s\n", argv[i]); + return false; + } + } + return true; +} + +} // namespace tools +} // namespace jpegxl diff --git a/media/libjxl/src/tools/cmdline.h b/media/libjxl/src/tools/cmdline.h new file mode 100644 index 000000000..9b730e67e --- /dev/null +++ b/media/libjxl/src/tools/cmdline.h @@ -0,0 +1,396 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_CMDLINE_H_ +#define TOOLS_CMDLINE_H_ + +#include +#include + +#include +#include +#include +#include + +namespace jpegxl { +namespace tools { + +class CommandLineParser { + public: + typedef size_t OptionId; + + // An abstract class for defining command line options. + class CmdOptionInterface { + public: + CmdOptionInterface() = default; + virtual ~CmdOptionInterface() = default; + + // Return a string with the option name or available flags. + virtual std::string help_flags() const = 0; + + // Return the help string if any, or nullptr if no help string. + virtual const char* help_text() const = 0; + + // Return the verbosity level for this option + virtual int verbosity_level() const = 0; + + // Return whether the option was passed. + virtual bool matched() const = 0; + + // Returns whether this option matches the passed command line argument. + virtual bool Match(const char* arg, bool parse_options) const = 0; + + // Parses the option. The passed i points to the argument with the flag + // that matches either the short or the long name. + virtual bool Parse(int argc, const char* argv[], int* i) = 0; + + // Returns whether the option is positional, and therefore will be shown + // in the first command line representation of the help output. + virtual bool positional() const = 0; + + // Returns whether the option should be displayed as required in the help + // output. No effect on validation. + virtual bool required() const = 0; + }; + + // Add a positional argument. Returns the id of the added option or + // kOptionError on error. + // The "required" flag indicates whether the parameter is mandatory or + // optional, but is only used for how it is displayed in the command line + // help. + OptionId AddPositionalOption(const char* name, bool required, + const char* help_text, const char** storage, + int verbosity_level = 0) { + options_.emplace_back(new CmdOptionPositional(name, help_text, storage, + verbosity_level, required)); + return options_.size() - 1; + } + + // Add an option with a value of type T. The option can be passed as + // '-s ' or '--long value' or '--long=value'. The CommandLineParser + // parser will call the function parser with the string pointing to '' + // in either case. Returns the id of the added option or kOptionError on + // error. + template + OptionId AddOptionValue(char short_name, const char* long_name, + const char* metavar, const char* help_text, + T* storage, bool(parser)(const char*, T*), + int verbosity_level = 0) { + options_.emplace_back(new CmdOptionFlag(short_name, long_name, metavar, + help_text, storage, parser, + verbosity_level)); + return options_.size() - 1; + } + + // Add a flag without a value. Returns the id of the added option or + // kOptionError on error. + template + OptionId AddOptionFlag(char short_name, const char* long_name, + const char* help_text, T* storage, bool(parser)(T*), + int verbosity_level = 0) { + options_.emplace_back(new CmdOptionFlag( + short_name, long_name, help_text, storage, parser, verbosity_level)); + return options_.size() - 1; + } + + const CmdOptionInterface* GetOption(OptionId id) const { + return options_[id].get(); + } + + // Print the help message to stdout. + void PrintHelp() const; + + // Whether a help flag was specified + bool HelpFlagPassed() const { return help_; } + + int verbosity = 0; + + // Parse the command line. + bool Parse(int argc, const char* argv[]); + + // Return the remaining positional args + std::vector PositionalArgs() const; + + private: + // A positional argument. + class CmdOptionPositional : public CmdOptionInterface { + public: + CmdOptionPositional(const char* name, const char* help_text, + const char** storage, int verbosity_level, + bool required) + : name_(name), + help_text_(help_text), + storage_(storage), + verbosity_level_(verbosity_level), + required_(required) {} + + std::string help_flags() const override { return name_; } + const char* help_text() const override { return help_text_; } + int verbosity_level() const override { return verbosity_level_; } + bool matched() const override { return matched_; } + + // Only match non-flag values. This means that you can't pass '-foo' as a + // positional argument, but it helps with detecting when passed a flag with + // a typo. After '--', option matching is disabled so positional arguments + // starting with '-' can be used. + bool Match(const char* arg, bool parse_options) const override { + return !matched_ && (!parse_options || arg[0] != '-'); + } + + bool Parse(const int argc, const char* argv[], int* i) override { + *storage_ = argv[*i]; + (*i)++; + matched_ = true; + return true; + } + + bool positional() const override { return true; } + + bool required() const override { return required_; } + + private: + const char* name_; + const char* help_text_; + const char** storage_; + const int verbosity_level_; + const bool required_; + + bool matched_{false}; + }; + + // A class for handling an option flag like '-v' or '--foo=bar'. + template + class CmdOptionFlag : public CmdOptionInterface { + public: + // Construct a flag that doesn't take any value, for example '-v' or + // '--long'. Passing a value to it raises an error. + CmdOptionFlag(char short_name, const char* long_name, const char* help_text, + T* storage, bool(parser)(T*), int verbosity_level) + : short_name_(short_name), + long_name_(long_name), + long_name_len_(long_name ? strlen(long_name) : 0), + metavar_(nullptr), + help_text_(help_text), + storage_(storage), + verbosity_level_(verbosity_level) { + parser_.parser_no_value_ = parser; + } + + // Construct a flag that expects a value to be passed. + CmdOptionFlag(char short_name, const char* long_name, const char* metavar, + const char* help_text, T* storage, + bool(parser)(const char* arg, T*), int verbosity_level) + : short_name_(short_name), + long_name_(long_name), + long_name_len_(long_name ? strlen(long_name) : 0), + metavar_(metavar ? metavar : ""), + help_text_(help_text), + storage_(storage), + verbosity_level_(verbosity_level) { + parser_.parser_with_arg_ = parser; + } + + std::string help_flags() const override { + std::string ret; + if (short_name_) { + ret += std::string("-") + short_name_; + if (metavar_) ret += std::string(" ") + metavar_; + if (long_name_) ret += ", "; + } + if (long_name_) { + ret += std::string("--") + long_name_; + if (metavar_) ret += std::string("=") + metavar_; + } + return ret; + } + const char* help_text() const override { return help_text_; } + int verbosity_level() const override { return verbosity_level_; } + bool matched() const override { return matched_; } + + bool Match(const char* arg, bool parse_options) const override { + return parse_options && (MatchShort(arg) || MatchLong(arg)); + } + + bool Parse(const int argc, const char* argv[], int* i) override { + matched_ = true; + if (MatchLong(argv[*i])) { + const char* arg = argv[*i] + 2 + long_name_len_; + if (arg[0] == '=') { + if (metavar_) { + // Passed '--long_name=...'. + (*i)++; + // Skip over the '=' on the LongMatch. + arg += 1; + return (*parser_.parser_with_arg_)(arg, storage_); + } else { + fprintf(stderr, "--%s didn't expect any argument passed to it.\n", + argv[*i]); + return false; + } + } + } + // In any other case, it passed a -s or --long_name + (*i)++; + if (metavar_) { + if (argc <= *i) { + fprintf(stderr, "--%s expected an argument but none passed.\n", + argv[*i - 1]); + return false; + } + return (*parser_.parser_with_arg_)(argv[(*i)++], storage_); + } else { + return (*parser_.parser_no_value_)(storage_); + } + } + + bool positional() const override { return false; } + + bool required() const override { + // Only used for help display of positional arguments. + return false; + } + + private: + // Returns whether arg matches the short_name flag of this option. + bool MatchShort(const char* arg) const { + if (!short_name_ || arg[0] != '-') return false; + return arg[1] == short_name_ && arg[2] == 0; + } + + // Returns whether arg matches the long_name flag of this option, + // potentially with an argument passed to it. + bool MatchLong(const char* arg) const { + if (!long_name_ || arg[0] != '-' || arg[1] != '-') return false; + arg += 2; // Skips the '--' + if (strncmp(long_name_, arg, long_name_len_) != 0) return false; + arg += long_name_len_; + // Allow "--long_name=foo" and "--long_name" as long matches. + return arg[0] == 0 || arg[0] == '='; + } + + // A short option passed as '-X' where X is the char. A value of 0 means + // no short option. + const char short_name_; + + // A long option name passed as '--long' where 'long' is the name of the + // option. + const char* long_name_; + size_t long_name_len_; + + // The text to display when referring to the value passed to this flag, for + // example "N" in the flag '--value N'. If null, this flag accepts no value + // and therefore no value must be passed. + const char* metavar_; + + // The help string for this flag. + const char* help_text_; + + // The pointer to the storage of this flag used when parsing. + T* storage_; + + // At which verbosity level do we show this option? + int verbosity_level_; + + // The function to use to parse the value when matched. The function used is + // parser_with_arg_ when metavar_ is not null (and the value string will be + // used) or parser_no_value_ when metavar_ is null. + union { + bool (*parser_with_arg_)(const char*, T*); + bool (*parser_no_value_)(T*); + } parser_; + + // Whether this flag was matched. + bool matched_{false}; + }; + + const char* program_name_{nullptr}; + + std::vector> options_; + + // If true, help argument was given, so print help to stdout rather than + // stderr. + bool help_ = false; +}; + +// +// Common parsers for AddOptionValue and AddOptionFlag +// + +static inline bool ParseSigned(const char* arg, int* out) { + char* end; + *out = static_cast(strtol(arg, &end, 0)); + if (end[0] != '\0') { + fprintf(stderr, "Unable to interpret as signed integer: %s.\n", arg); + return false; + } + return true; +} + +static inline bool ParseUnsigned(const char* arg, size_t* out) { + char* end; + *out = static_cast(strtoull(arg, &end, 0)); + if (end[0] != '\0') { + fprintf(stderr, "Unable to interpret as unsigned integer: %s.\n", arg); + return false; + } + return true; +} + +static inline bool ParseInt64(const char* arg, int64_t* out) { + char* end; + *out = strtol(arg, &end, 0); + if (end[0] != '\0') { + fprintf(stderr, "Unable to interpret as signed integer: %s.\n", arg); + return false; + } + return true; +} + +static inline bool ParseUint32(const char* arg, uint32_t* out) { + size_t value = 0; + bool ret = ParseUnsigned(arg, &value); + if (ret) *out = value; + return ret; +} + +static inline bool ParseFloat(const char* arg, float* out) { + char* end; + *out = static_cast(strtod(arg, &end)); + if (end[0] != '\0') { + fprintf(stderr, "Unable to interpret as float: %s.\n", arg); + return false; + } + return true; +} + +static inline bool ParseDouble(const char* arg, double* out) { + char* end; + *out = static_cast(strtod(arg, &end)); + if (end[0] != '\0') { + fprintf(stderr, "Unable to interpret as double: %s.\n", arg); + return false; + } + return true; +} + +static inline bool ParseString(const char* arg, std::string* out) { + out->assign(arg); + return true; +} + +static inline bool SetBooleanTrue(bool* out) { + *out = true; + return true; +} + +static inline bool SetBooleanFalse(bool* out) { + *out = false; + return true; +} + +} // namespace tools +} // namespace jpegxl + +#endif // TOOLS_CMDLINE_H_ diff --git a/media/libjxl/src/tools/codec_config.cc b/media/libjxl/src/tools/codec_config.cc new file mode 100644 index 000000000..8efc26c22 --- /dev/null +++ b/media/libjxl/src/tools/codec_config.cc @@ -0,0 +1,57 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/codec_config.h" + +#include + +#include "tools/tool_version.h" + +namespace jpegxl { +namespace tools { + +std::string CodecConfigString(uint32_t lib_version) { + std::string config; + + if (lib_version != 0) { + char version_str[20]; + snprintf(version_str, sizeof(version_str), "v%d.%d.%d ", + lib_version / 1000000, (lib_version / 1000) % 1000, + lib_version % 1000); + config += version_str; + } + + std::string version = kJpegxlVersion; + if (version != "(unknown)") { + config += version + ' '; + } + +#if defined(ADDRESS_SANITIZER) + config += " asan "; +#elif defined(MEMORY_SANITIZER) + config += " msan "; +#elif defined(THREAD_SANITIZER) + config += " tsan "; +#else +#endif + + bool saw_target = false; + config += "["; + for (const uint32_t target : hwy::SupportedAndGeneratedTargets()) { + config += hwy::TargetName(target); + config += ','; + saw_target = true; + } + if (!saw_target) { + config += "no targets found,"; + } + config.resize(config.size() - 1); // remove trailing comma + config += "]"; + + return config; +} + +} // namespace tools +} // namespace jpegxl diff --git a/media/libjxl/src/tools/codec_config.h b/media/libjxl/src/tools/codec_config.h new file mode 100644 index 000000000..a4f79a66b --- /dev/null +++ b/media/libjxl/src/tools/codec_config.h @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_CODEC_CONFIG_H_ +#define TOOLS_CODEC_CONFIG_H_ + +#include +#include + +namespace jpegxl { +namespace tools { + +// Returns a short string describing the codec version (if known) and build +// settings such as sanitizers and SIMD targets. Used in the benchmark and +// command-line tools. +std::string CodecConfigString(uint32_t lib_version); + +} // namespace tools +} // namespace jpegxl + +#endif // TOOLS_CODEC_CONFIG_H_ diff --git a/media/libjxl/src/tools/color_encoding_fuzzer.cc b/media/libjxl/src/tools/color_encoding_fuzzer.cc new file mode 100644 index 000000000..087bd8ba1 --- /dev/null +++ b/media/libjxl/src/tools/color_encoding_fuzzer.cc @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "lib/extras/dec/color_description.h" + +namespace jxl { + +int TestOneInput(const uint8_t* data, size_t size) { + std::string description(reinterpret_cast(data), size); + JxlColorEncoding c; + (void)ParseDescription(description, &c); + + return 0; +} + +} // namespace jxl + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + return jxl::TestOneInput(data, size); +} diff --git a/media/libjxl/src/tools/comparison_viewer/CMakeLists.txt b/media/libjxl/src/tools/comparison_viewer/CMakeLists.txt new file mode 100644 index 000000000..b5b5fa742 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/CMakeLists.txt @@ -0,0 +1,74 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +find_package(Qt5 QUIET COMPONENTS Concurrent Widgets) +if (NOT Qt5_FOUND) + message(WARNING "Qt5 was not found. The comparison tool will not be built.") + return() +endif () + +if (NOT TARGET icc_detect) + message(WARNING "icc_detect not built. The comparison tool will not be built.") + return () +endif () + +set(CMAKE_INCLUDE_CURRENT_DIR ON) +set(CMAKE_AUTOMOC ON) +set(CMAKE_AUTOUIC ON) + +add_library(image_loading STATIC + ../viewer/load_jxl.cc + ../viewer/load_jxl.h + image_loading.cc + image_loading.h +) +target_include_directories(image_loading PRIVATE + $ +) +target_link_libraries(image_loading PUBLIC + Qt5::Widgets + jxl-static + jxl_threads-static + jxl_extras-static + lcms2 +) + +add_executable(compare_codecs WIN32 + codec_comparison_window.cc + codec_comparison_window.h + codec_comparison_window.ui + compare_codecs.cc + settings.cc + settings.h + settings.ui + split_image_renderer.cc + split_image_renderer.h + split_image_view.cc + split_image_view.h + split_image_view.ui +) +target_link_libraries(compare_codecs + image_loading + Qt5::Concurrent + Qt5::Widgets + icc_detect +) + +add_executable(compare_images WIN32 + compare_images.cc + settings.cc + settings.h + settings.ui + split_image_renderer.cc + split_image_renderer.h + split_image_view.cc + split_image_view.h + split_image_view.ui +) +target_link_libraries(compare_images + image_loading + Qt5::Widgets + icc_detect +) diff --git a/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.cc b/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.cc new file mode 100644 index 000000000..9bf6253ba --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.cc @@ -0,0 +1,316 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/comparison_viewer/codec_comparison_window.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lib/extras/codec.h" +#include "tools/comparison_viewer/image_loading.h" +#include "tools/comparison_viewer/split_image_view.h" +#include "tools/icc_detect/icc_detect.h" + +namespace jxl { + +static constexpr char kPngSuffix[] = "png"; + +namespace { + +QVector> currentCodecSelection( + const Ui::CodecComparisonWindow& ui) { + QVector> result; + for (QComboBox* const comboBox : + {ui.codec1ComboBox, ui.codec2ComboBox, ui.compressionLevel1ComboBox, + ui.compressionLevel2ComboBox}) { + result << qMakePair(comboBox, comboBox->currentText()); + } + return result; +} + +void restoreCodecSelection( + const QVector>& selection) { + for (const auto& comboBox : selection) { + const int index = comboBox.first->findText(comboBox.second); + if (index != -1) { + comboBox.first->setCurrentIndex(index); + } + } +} + +} // namespace + +CodecComparisonWindow::CodecComparisonWindow(const QString& directory, + const float intensityTarget, + QWidget* const parent) + : QMainWindow(parent), + intensityTarget_(intensityTarget), + monitorIccProfile_(GetMonitorIccProfile(this)) { + ui_.setupUi(this); + + connect(ui_.imageSetComboBox, &QComboBox::currentTextChanged, this, + &CodecComparisonWindow::handleImageSetSelection); + connect(ui_.imageComboBox, &QComboBox::currentTextChanged, this, + &CodecComparisonWindow::handleImageSelection); + + connect(ui_.codec1ComboBox, &QComboBox::currentTextChanged, + [this]() { handleCodecChange(Side::LEFT); }); + connect(ui_.codec2ComboBox, &QComboBox::currentTextChanged, + [this]() { handleCodecChange(Side::RIGHT); }); + + connect(ui_.compressionLevel1ComboBox, &QComboBox::currentTextChanged, + [this]() { updateSideImage(Side::LEFT); }); + connect(ui_.compressionLevel2ComboBox, &QComboBox::currentTextChanged, + [this]() { updateSideImage(Side::RIGHT); }); + + connect(ui_.match1Label, &QLabel::linkActivated, + [this]() { matchSize(Side::LEFT); }); + connect(ui_.match2Label, &QLabel::linkActivated, + [this]() { matchSize(Side::RIGHT); }); + + connect( + ui_.splitImageView, &SplitImageView::renderingModeChanged, + [this](const SplitImageRenderer::RenderingMode newMode) { + switch (newMode) { + case SplitImageRenderer::RenderingMode::LEFT: + case SplitImageRenderer::RenderingMode::RIGHT: { + QString codec, compressionLevel; + if (newMode == SplitImageRenderer::RenderingMode::LEFT) { + codec = ui_.codec1ComboBox->currentText(); + compressionLevel = ui_.compressionLevel1ComboBox->currentText(); + } else { + codec = ui_.codec2ComboBox->currentText(); + compressionLevel = ui_.compressionLevel2ComboBox->currentText(); + } + ui_.renderingModeLabel->setText(tr("Currently displaying: %1 @ %2") + .arg(codec) + .arg(compressionLevel)); + break; + } + + case SplitImageRenderer::RenderingMode::MIDDLE: + ui_.renderingModeLabel->setText( + tr("Currently displaying the original image.")); + break; + + default: + ui_.renderingModeLabel->clear(); + break; + } + }); + + loadDirectory(directory); +} + +void CodecComparisonWindow::handleImageSetSelection( + const QString& imageSetName) { + const auto selection = currentCodecSelection(ui_); + { + const QSignalBlocker blocker(ui_.imageComboBox); + ui_.imageComboBox->clear(); + } + const QStringList imageNames = imageSets_.value(imageSetName).keys(); + const std::function loadIcon = + [this, &imageSetName](const QString& imageName) { + return QIcon(pathToOriginalImage(imageSetName, imageName)); + }; + const QFuture thumbnails = QtConcurrent::mapped(imageNames, loadIcon); + int i = 0; + for (const QString& imageName : imageNames) { + ui_.imageComboBox->addItem(thumbnails.resultAt(i), imageName); + ++i; + } + restoreCodecSelection(selection); +} + +void CodecComparisonWindow::handleImageSelection(const QString& imageName) { + const QString imageSetName = ui_.imageSetComboBox->currentText(); + ui_.splitImageView->setMiddleImage( + loadImage(pathToOriginalImage(imageSetName, imageName), + monitorIccProfile_, intensityTarget_)); + + const auto selection = currentCodecSelection(ui_); + QStringList codecs = imageSets_.value(imageSetName).value(imageName).keys(); + for (QComboBox* const codecComboBox : + {ui_.codec1ComboBox, ui_.codec2ComboBox}) { + { + const QSignalBlocker blocker(codecComboBox); + codecComboBox->clear(); + } + codecComboBox->addItems(codecs); + } + restoreCodecSelection(selection); +} + +void CodecComparisonWindow::handleCodecChange(const Side side) { + const QComboBox* const codecComboBox = + side == Side::LEFT ? ui_.codec1ComboBox : ui_.codec2ComboBox; + QComboBox* const compressionLevelComboBox = + side == Side::LEFT ? ui_.compressionLevel1ComboBox + : ui_.compressionLevel2ComboBox; + + QStringList compressionLevels = + imageSets_.value(ui_.imageSetComboBox->currentText()) + .value(ui_.imageComboBox->currentText()) + .value(codecComboBox->currentText()) + .keys(); + QCollator collator; + collator.setNumericMode(true); + std::sort(compressionLevels.begin(), compressionLevels.end(), collator); + + { + const QSignalBlocker blocker(compressionLevelComboBox); + compressionLevelComboBox->clear(); + } + compressionLevelComboBox->addItems(compressionLevels); + matchSize(side); +} + +void CodecComparisonWindow::updateSideImage(const Side side) { + const ComparableImage& imageInfo = currentlySelectedImage(side); + if (imageInfo.decodedImagePath.isEmpty()) return; + QImage image = loadImage(imageInfo.decodedImagePath, monitorIccProfile_, + intensityTarget_); + const int pixels = image.width() * image.height(); + QLabel* const sizeInfoLabel = + side == Side::LEFT ? ui_.size1Label : ui_.size2Label; + if (pixels == 0) { + sizeInfoLabel->setText(tr("Empty image.")); + } else { + const double bpp = + CHAR_BIT * static_cast(imageInfo.byteSize) / pixels; + sizeInfoLabel->setText(tr("%L1bpp").arg(bpp)); + } + + if (side == Side::LEFT) { + ui_.splitImageView->setLeftImage(std::move(image)); + } else { + ui_.splitImageView->setRightImage(std::move(image)); + } +} + +QString CodecComparisonWindow::pathToOriginalImage( + const QString& imageSetName, const QString& imageName) const { + return baseDirectory_.absolutePath() + "/" + imageSetName + "/" + imageName + + "/original.png"; +} + +CodecComparisonWindow::ComparableImage +CodecComparisonWindow::currentlySelectedImage(const Side side) const { + const QComboBox* const codecComboBox = + side == Side::LEFT ? ui_.codec1ComboBox : ui_.codec2ComboBox; + QComboBox* const compressionLevelComboBox = + side == Side::LEFT ? ui_.compressionLevel1ComboBox + : ui_.compressionLevel2ComboBox; + + return imageSets_.value(ui_.imageSetComboBox->currentText()) + .value(ui_.imageComboBox->currentText()) + .value(codecComboBox->currentText()) + .value(compressionLevelComboBox->currentText()); +} + +void CodecComparisonWindow::matchSize(const Side side) { + const Side otherSide = (side == Side::LEFT ? Side::RIGHT : Side::LEFT); + const qint64 otherSideSize = currentlySelectedImage(otherSide).byteSize; + if (otherSideSize == 0) return; + + const QComboBox* const codecComboBox = + side == Side::LEFT ? ui_.codec1ComboBox : ui_.codec2ComboBox; + QComboBox* const compressionLevelComboBox = + side == Side::LEFT ? ui_.compressionLevel1ComboBox + : ui_.compressionLevel2ComboBox; + const Codec codec = imageSets_.value(ui_.imageSetComboBox->currentText()) + .value(ui_.imageComboBox->currentText()) + .value(codecComboBox->currentText()); + if (codec.empty()) return; + Codec::ConstIterator bestMatch = codec.begin(); + for (auto it = codec.begin(); it != codec.end(); ++it) { + if (std::abs(it->byteSize - otherSideSize) < + std::abs(bestMatch->byteSize - otherSideSize)) { + bestMatch = it; + } + } + compressionLevelComboBox->setCurrentText(bestMatch.key()); +} + +void CodecComparisonWindow::loadDirectory(const QString& directory) { + baseDirectory_.setPath(directory); + baseDirectory_.makeAbsolute(); + imageSets_.clear(); + visited_.clear(); + + browseDirectory(directory); + + { + const QSignalBlocker blocker(ui_.imageSetComboBox); + ui_.imageSetComboBox->clear(); + } + ui_.imageSetComboBox->addItems(imageSets_.keys()); +} + +void CodecComparisonWindow::browseDirectory(const QDir& directory, int depth) { + for (const QFileInfo& subdirectory : directory.entryInfoList( + QDir::Dirs | QDir::NoDotAndDotDot | QDir::NoSymLinks)) { + if (visited_.contains(subdirectory.absoluteFilePath())) continue; + visited_.insert(subdirectory.absoluteFilePath()); + browseDirectory(subdirectory.absoluteFilePath(), depth + 1); + } + + // Need at least image_name/codec_name/file. + if (depth < 2) return; + + for (const QFileInfo& file : directory.entryInfoList(QDir::Files)) { + if (file.suffix() == kPngSuffix) continue; + QString decodedImage; + if (canLoadImageWithExtension(file.suffix())) { + decodedImage = file.absoluteFilePath(); + } else { + QFileInfo png(file.absolutePath() + "/" + file.completeBaseName() + "." + + kPngSuffix); + if (png.exists()) { + decodedImage = png.absoluteFilePath(); + } + } + + if (decodedImage.isEmpty()) continue; + + const QString codec = file.absoluteDir().dirName(); + QDir imageDirectory = file.absoluteDir(); + if (!imageDirectory.cdUp()) return; + const QString imageName = imageDirectory.dirName(); + QDir imageSetDirectory = imageDirectory; + if (!imageSetDirectory.cdUp()) return; + QString imageSetPath = + baseDirectory_.relativeFilePath(imageSetDirectory.absolutePath()); + if (imageSetPath.isEmpty()) { + imageSetPath = "."; + } + + ComparableImage& image = + imageSets_[imageSetPath][imageName][codec][file.completeBaseName()]; + image.decodedImagePath = decodedImage; + image.byteSize = file.size(); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.h b/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.h new file mode 100644 index 000000000..b157a5a9e --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.h @@ -0,0 +1,77 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_COMPARISON_VIEWER_CODEC_COMPARISON_WINDOW_H_ +#define TOOLS_COMPARISON_VIEWER_CODEC_COMPARISON_WINDOW_H_ + +#include +#include +#include +#include +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/common.h" +#include "tools/comparison_viewer/ui_codec_comparison_window.h" + +namespace jxl { + +class CodecComparisonWindow : public QMainWindow { + Q_OBJECT + + public: + explicit CodecComparisonWindow( + const QString& directory, float intensityTarget = kDefaultIntensityTarget, + QWidget* parent = nullptr); + ~CodecComparisonWindow() override = default; + + private slots: + void handleImageSetSelection(const QString& imageSetName); + void handleImageSelection(const QString& imageName); + + private: + struct ComparableImage { + // Absolute path to the decoded PNG (or an image that Qt can read). + QString decodedImagePath; + // Size of the encoded image (*not* the PNG). + qint64 byteSize = 0; + }; + // Keys are compression levels. + using Codec = QMap; + // Keys are codec names. + using Codecs = QMap; + // Keys are image names (relative to the image set directory). + using ImageSet = QMap; + // Keys are paths to image sets (relative to the base directory chosen by the + // user). + using ImageSets = QMap; + + enum class Side { LEFT, RIGHT }; + + QString pathToOriginalImage(const QString& imageSet, + const QString& imageName) const; + ComparableImage currentlySelectedImage(Side side) const; + + void handleCodecChange(Side side); + void updateSideImage(Side side); + void matchSize(Side side); + + void loadDirectory(const QString& directory); + // Recursive, called by loadDirectory. + void browseDirectory(const QDir& directory, int depth = 0); + + Ui::CodecComparisonWindow ui_; + + QDir baseDirectory_; + ImageSets imageSets_; + QSet visited_; + + const float intensityTarget_; + const QByteArray monitorIccProfile_; +}; + +} // namespace jxl + +#endif // TOOLS_COMPARISON_VIEWER_CODEC_COMPARISON_WINDOW_H_ diff --git a/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.ui b/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.ui new file mode 100644 index 000000000..1fbda6a1c --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/codec_comparison_window.ui @@ -0,0 +1,170 @@ + + + + Copyright (c) the JPEG XL Project Authors. All rights reserved. + + Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. + + CodecComparisonWindow + + + Codec Comparison Tool + + + + + + + + + Qt::Horizontal + + + + + + + + + + + + Image set: + + + + + + + Image: + + + + + + + + + + + + Qt::Horizontal + + + + + + + + + + + + + + + + + + + <a href="#match1">Match →</a> + + + Qt::AlignCenter + + + + + + + + + Qt::Horizontal + + + + + + + + + <a href="#match2">Match ←</a> + + + Qt::AlignCenter + + + + + + + + + + + + + + + + + + + No image loaded. + + + + + + + Qt::Horizontal + + + + + + + + + + Qt::AlignCenter + + + + + + + Qt::Horizontal + + + + + + + No image loaded. + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + + + + + + + + + + + + jxl::SplitImageView + QWidget +
split_image_view.h
+ 1 +
+
+ + +
diff --git a/media/libjxl/src/tools/comparison_viewer/compare_codecs.cc b/media/libjxl/src/tools/comparison_viewer/compare_codecs.cc new file mode 100644 index 000000000..932765e47 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/compare_codecs.cc @@ -0,0 +1,75 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include +#include +#include +#include + +#include "tools/comparison_viewer/codec_comparison_window.h" + +int main(int argc, char** argv) { + QApplication application(argc, argv); + + QCommandLineParser parser; + parser.setApplicationDescription( + QCoreApplication::translate("compare_codecs", "Codec comparison tool")); + parser.addHelpOption(); + + QCommandLineOption intensityTargetOption( + {"intensity-target", "intensity_target", "i"}, + QCoreApplication::translate("compare_codecs", + "The peak luminance of the display."), + QCoreApplication::translate("compare_codecs", "nits"), + QString::number(jxl::kDefaultIntensityTarget)); + parser.addOption(intensityTargetOption); + + parser.addPositionalArgument( + "folders", QCoreApplication::translate("compare_codecs", "Image folders"), + "..."); + + parser.process(application); + + bool ok; + const float intensityTarget = + parser.value(intensityTargetOption).toFloat(&ok); + if (!ok) { + parser.showHelp(EXIT_FAILURE); + } + + QStringList folders = parser.positionalArguments(); + + if (folders.empty()) { + QMessageBox message; + message.setIcon(QMessageBox::Information); + message.setWindowTitle( + QCoreApplication::translate("CodecComparisonWindow", "Usage")); + message.setText(QCoreApplication::translate( + "CodecComparisonWindow", "Please specify a directory to use.")); + message.setDetailedText(QCoreApplication::translate( + "CodecComparisonWindow", + "That directory should contain images in the following layout:\n" + "- ...//original.png (optional)\n" + "- ...///.\n" + "- ...///.png (optional for " + "formats that Qt can load)\n" + "With arbitrary nesting allowed before that. (The \"...\" part is " + "referred to as an \"image set\" by the tool.")); + message.exec(); + return EXIT_FAILURE; + } + + for (const QString& folder : folders) { + auto* const window = + new jxl::CodecComparisonWindow(folder, intensityTarget); + window->setAttribute(Qt::WA_DeleteOnClose); + window->show(); + } + + return application.exec(); +} diff --git a/media/libjxl/src/tools/comparison_viewer/compare_images.cc b/media/libjxl/src/tools/comparison_viewer/compare_images.cc new file mode 100644 index 000000000..cf39f8812 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/compare_images.cc @@ -0,0 +1,128 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "tools/comparison_viewer/image_loading.h" +#include "tools/comparison_viewer/split_image_view.h" +#include "tools/icc_detect/icc_detect.h" + +namespace { + +void displayLoadingError(const QString& path) { + QMessageBox message; + message.setIcon(QMessageBox::Critical); + message.setWindowTitle( + QCoreApplication::translate("SplitImageView", "Error")); + message.setText(QCoreApplication::translate("SplitImageView", + "Could not load image \"%1\".") + .arg(path)); + message.exec(); +} + +} // namespace + +int main(int argc, char** argv) { + QApplication application(argc, argv); + + QCommandLineParser parser; + parser.setApplicationDescription( + QCoreApplication::translate("compare_images", "Image comparison tool")); + parser.addHelpOption(); + parser.addPositionalArgument( + "left-image", + QCoreApplication::translate("compare_images", + "The image to display on the left."), + ""); + parser.addPositionalArgument( + "right-image", + QCoreApplication::translate("compare_images", + "The image to display on the right."), + ""); + parser.addPositionalArgument( + "middle-image", + QCoreApplication::translate( + "compare_images", "The image to display in the middle (optional)."), + "[]"); + + QCommandLineOption colorSpaceOption( + {"color-space", "color_space", "c"}, + QCoreApplication::translate( + "compare_images", + "The color space to use for untagged images (typically PNM)."), + QCoreApplication::translate("compare_images", "color-space")); + parser.addOption(colorSpaceOption); + + QCommandLineOption intensityTargetOption( + {"intensity-target", "intensity_target", "i"}, + QCoreApplication::translate("compare_images", + "The peak luminance of the display."), + QCoreApplication::translate("compare_images", "nits"), + QString::number(jxl::kDefaultIntensityTarget)); + parser.addOption(intensityTargetOption); + + parser.process(application); + + const QString colorSpaceHint = parser.value(colorSpaceOption); + + QStringList arguments = parser.positionalArguments(); + if (arguments.size() < 2 || arguments.size() > 3) { + parser.showHelp(EXIT_FAILURE); + } + + bool ok; + const float intensityTarget = + parser.value(intensityTargetOption).toFloat(&ok); + if (!ok) { + parser.showHelp(EXIT_FAILURE); + } + + jxl::SplitImageView view; + + const QByteArray monitorIccProfile = jxl::GetMonitorIccProfile(&view); + + const QString leftImagePath = arguments.takeFirst(); + QImage leftImage = jxl::loadImage(leftImagePath, monitorIccProfile, + intensityTarget, colorSpaceHint); + if (leftImage.isNull()) { + displayLoadingError(leftImagePath); + return EXIT_FAILURE; + } + view.setLeftImage(std::move(leftImage)); + + const QString rightImagePath = arguments.takeFirst(); + QImage rightImage = jxl::loadImage(rightImagePath, monitorIccProfile, + intensityTarget, colorSpaceHint); + if (rightImage.isNull()) { + displayLoadingError(rightImagePath); + return EXIT_FAILURE; + } + view.setRightImage(std::move(rightImage)); + + if (!arguments.empty()) { + const QString middleImagePath = arguments.takeFirst(); + QImage middleImage = jxl::loadImage(middleImagePath, monitorIccProfile, + intensityTarget, colorSpaceHint); + if (middleImage.isNull()) { + displayLoadingError(middleImagePath); + return EXIT_FAILURE; + } + view.setMiddleImage(std::move(middleImage)); + } + + view.setWindowFlags(view.windowFlags() | Qt::Window); + view.setWindowState(Qt::WindowMaximized); + view.show(); + + return application.exec(); +} diff --git a/media/libjxl/src/tools/comparison_viewer/image_loading.cc b/media/libjxl/src/tools/comparison_viewer/image_loading.cc new file mode 100644 index 000000000..55bebb8a1 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/image_loading.cc @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/comparison_viewer/image_loading.h" + +#include +#include + +#include "lib/extras/codec.h" +#include "lib/extras/dec/color_hints.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_color_management.h" +#include "tools/viewer/load_jxl.h" + +namespace jxl { + +namespace { + +Status loadFromFile(const QString& filename, + const extras::ColorHints& color_hints, + CodecInOut* const decoded, ThreadPool* const pool) { + PaddedBytes compressed; + JXL_RETURN_IF_ERROR(ReadFile(filename.toStdString(), &compressed)); + const Span compressed_span(compressed); + return SetFromBytes(compressed_span, color_hints, decoded, pool, nullptr); +} + +} // namespace + +bool canLoadImageWithExtension(QString extension) { + extension = extension.toLower(); + size_t bitsPerSampleUnused; + return extension == "jxl" || extension == "j" || extension == "brn" || + extras::CodecFromExtension("." + extension.toStdString(), + &bitsPerSampleUnused) != + jxl::extras::Codec::kUnknown; +} + +QImage loadImage(const QString& filename, const QByteArray& targetIccProfile, + const float intensityTarget, + const QString& sourceColorSpaceHint) { + qint64 elapsed; + QImage img = loadJxlImage(filename, targetIccProfile, &elapsed); + if (img.width() != 0 && img.height() != 0) { + return img; + } + static ThreadPoolInternal pool(QThread::idealThreadCount()); + + CodecInOut decoded; + extras::ColorHints color_hints; + if (!sourceColorSpaceHint.isEmpty()) { + color_hints.Add("color_space", sourceColorSpaceHint.toStdString()); + } + if (!loadFromFile(filename, color_hints, &decoded, &pool)) { + return QImage(); + } + decoded.metadata.m.SetIntensityTarget(intensityTarget); + const ImageBundle& ib = decoded.Main(); + + ColorEncoding targetColorSpace; + PaddedBytes icc; + icc.assign(reinterpret_cast(targetIccProfile.data()), + reinterpret_cast(targetIccProfile.data() + + targetIccProfile.size())); + if (!targetColorSpace.SetICC(std::move(icc))) { + targetColorSpace = ColorEncoding::SRGB(ib.IsGray()); + } + Image3F converted; + if (!ib.CopyTo(Rect(ib), targetColorSpace, GetJxlCms(), &converted, &pool)) { + return QImage(); + } + + QImage image(converted.xsize(), converted.ysize(), QImage::Format_ARGB32); + + const auto ScaleAndClamp = [](const float x) { + return Clamp1(x * 255 + .5f, 0.f, 255.f); + }; + + if (ib.HasAlpha()) { + for (int y = 0; y < image.height(); ++y) { + QRgb* const row = reinterpret_cast(image.scanLine(y)); + const float* const alphaRow = ib.alpha().ConstRow(y); + const float* const redRow = converted.ConstPlaneRow(0, y); + const float* const greenRow = converted.ConstPlaneRow(1, y); + const float* const blueRow = converted.ConstPlaneRow(2, y); + for (int x = 0; x < image.width(); ++x) { + row[x] = qRgba(ScaleAndClamp(redRow[x]), ScaleAndClamp(greenRow[x]), + ScaleAndClamp(blueRow[x]), ScaleAndClamp(alphaRow[x])); + } + } + } else { + for (int y = 0; y < image.height(); ++y) { + QRgb* const row = reinterpret_cast(image.scanLine(y)); + const float* const redRow = converted.ConstPlaneRow(0, y); + const float* const greenRow = converted.ConstPlaneRow(1, y); + const float* const blueRow = converted.ConstPlaneRow(2, y); + for (int x = 0; x < image.width(); ++x) { + row[x] = qRgb(ScaleAndClamp(redRow[x]), ScaleAndClamp(greenRow[x]), + ScaleAndClamp(blueRow[x])); + } + } + } + + return image; +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/comparison_viewer/image_loading.h b/media/libjxl/src/tools/comparison_viewer/image_loading.h new file mode 100644 index 000000000..89b37d13b --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/image_loading.h @@ -0,0 +1,29 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_COMPARISON_VIEWER_IMAGE_LOADING_H_ +#define TOOLS_COMPARISON_VIEWER_IMAGE_LOADING_H_ + +#include +#include +#include + +#include "lib/jxl/common.h" + +namespace jxl { + +// `extension` should not include the dot. +bool canLoadImageWithExtension(QString extension); + +// Converts the loaded image to the given display profile, or sRGB if not +// specified. Thread-hostile. +QImage loadImage(const QString& filename, + const QByteArray& targetIccProfile = QByteArray(), + float intensityTarget = kDefaultIntensityTarget, + const QString& sourceColorSpaceHint = QString()); + +} // namespace jxl + +#endif // TOOLS_COMPARISON_VIEWER_IMAGE_LOADING_H_ diff --git a/media/libjxl/src/tools/comparison_viewer/settings.cc b/media/libjxl/src/tools/comparison_viewer/settings.cc new file mode 100644 index 000000000..9ef117b0a --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/settings.cc @@ -0,0 +1,51 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/comparison_viewer/settings.h" + +namespace jxl { + +SettingsDialog::SettingsDialog(QWidget* const parent) + : QDialog(parent), settings_("JPEG XL project", "Comparison tool") { + ui_.setupUi(this); + + settings_.beginGroup("rendering"); + renderingSettings_.fadingMSecs = settings_.value("fadingMSecs", 300).toInt(); + settings_.beginGroup("gray"); + renderingSettings_.gray = settings_.value("enabled", false).toBool(); + renderingSettings_.grayMSecs = settings_.value("delayMSecs", 300).toInt(); + settings_.endGroup(); + settings_.endGroup(); + + settingsToUi(); +} + +SplitImageRenderingSettings SettingsDialog::renderingSettings() const { + return renderingSettings_; +} + +void SettingsDialog::on_SettingsDialog_accepted() { + renderingSettings_.fadingMSecs = ui_.fadingTime->value(); + renderingSettings_.gray = ui_.grayGroup->isChecked(); + renderingSettings_.grayMSecs = ui_.grayTime->value(); + + settings_.beginGroup("rendering"); + settings_.setValue("fadingMSecs", renderingSettings_.fadingMSecs); + settings_.beginGroup("gray"); + settings_.setValue("enabled", renderingSettings_.gray); + settings_.setValue("delayMSecs", renderingSettings_.grayMSecs); + settings_.endGroup(); + settings_.endGroup(); +} + +void SettingsDialog::on_SettingsDialog_rejected() { settingsToUi(); } + +void SettingsDialog::settingsToUi() { + ui_.fadingTime->setValue(renderingSettings_.fadingMSecs); + ui_.grayGroup->setChecked(renderingSettings_.gray); + ui_.grayTime->setValue(renderingSettings_.grayMSecs); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/comparison_viewer/settings.h b/media/libjxl/src/tools/comparison_viewer/settings.h new file mode 100644 index 000000000..bd91f710a --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/settings.h @@ -0,0 +1,40 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_COMPARISON_VIEWER_SETTINGS_H_ +#define TOOLS_COMPARISON_VIEWER_SETTINGS_H_ + +#include +#include + +#include "tools/comparison_viewer/split_image_renderer.h" +#include "tools/comparison_viewer/ui_settings.h" + +namespace jxl { + +class SettingsDialog : public QDialog { + Q_OBJECT + + public: + explicit SettingsDialog(QWidget* parent = nullptr); + ~SettingsDialog() override = default; + + SplitImageRenderingSettings renderingSettings() const; + + private slots: + void on_SettingsDialog_accepted(); + void on_SettingsDialog_rejected(); + + private: + void settingsToUi(); + + Ui::SettingsDialog ui_; + QSettings settings_; + SplitImageRenderingSettings renderingSettings_; +}; + +} // namespace jxl + +#endif // TOOLS_COMPARISON_VIEWER_SETTINGS_H_ diff --git a/media/libjxl/src/tools/comparison_viewer/settings.ui b/media/libjxl/src/tools/comparison_viewer/settings.ui new file mode 100644 index 000000000..ca81a33ae --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/settings.ui @@ -0,0 +1,120 @@ + + + + Copyright (c) the JPEG XL Project Authors. All rights reserved. + + Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. + + SettingsDialog + + + Comparison tool settings + + + + QLayout::SetFixedSize + + + + + + + Fading time: + + + + + + +  ms + + + 1000 + + + 50 + + + 300 + + + + + + + + + Gray in between + + + true + + + false + + + + + +  ms + + + 0 + + + 1000 + + + 50 + + + 300 + + + + + + + Time on gray: + + + + + + + + + + Qt::Vertical + + + + + + + Qt::Horizontal + + + QDialogButtonBox::Cancel|QDialogButtonBox::Ok + + + + + + + + + buttonBox + accepted() + SettingsDialog + accept() + + + buttonBox + rejected() + SettingsDialog + reject() + + + diff --git a/media/libjxl/src/tools/comparison_viewer/split_image_renderer.cc b/media/libjxl/src/tools/comparison_viewer/split_image_renderer.cc new file mode 100644 index 000000000..acade64d3 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/split_image_renderer.cc @@ -0,0 +1,239 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/comparison_viewer/split_image_renderer.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace jxl { + +SplitImageRenderer::SplitImageRenderer(QWidget* const parent) + : QWidget(parent) { + setAttribute(Qt::WA_OpaquePaintEvent); + setMouseTracking(true); + setFocusPolicy(Qt::WheelFocus); + grabKeyboard(); + + connect(&fadingPoint_, &QVariantAnimation::valueChanged, + [this] { update(); }); +} + +void SplitImageRenderer::setLeftImage(QImage image) { + leftImage_ = QPixmap::fromImage(std::move(image)); + updateMinimumSize(); + update(); +} +void SplitImageRenderer::setRightImage(QImage image) { + rightImage_ = QPixmap::fromImage(std::move(image)); + updateMinimumSize(); + update(); +} +void SplitImageRenderer::setMiddleImage(QImage image) { + middleImage_ = QPixmap::fromImage(std::move(image)); + updateMinimumSize(); + update(); +} + +void SplitImageRenderer::setRenderingSettings( + const SplitImageRenderingSettings& settings) { + renderingSettings_ = settings; +} + +void SplitImageRenderer::setMiddleWidthPercent(const int percent) { + middleWidthPercent_ = percent; + update(); +} + +void SplitImageRenderer::setZoomLevel(double scale) { + scale_ = scale; + updateMinimumSize(); + update(); +} + +void SplitImageRenderer::keyPressEvent(QKeyEvent* const event) { + switch (event->key()) { + case Qt::Key_Left: + setRenderingMode(RenderingMode::LEFT); + break; + + case Qt::Key_Right: + setRenderingMode(RenderingMode::RIGHT); + break; + + case Qt::Key_Up: + case Qt::Key_Down: + setRenderingMode(RenderingMode::MIDDLE); + break; + + case Qt::Key_Escape: + QCoreApplication::quit(); + break; + + case Qt::Key_ZoomIn: + emit zoomLevelIncreaseRequested(); + break; + case Qt::Key_ZoomOut: + emit zoomLevelDecreaseRequested(); + break; + + default: + QWidget::keyPressEvent(event); + break; + } + update(); +} + +void SplitImageRenderer::mouseMoveEvent(QMouseEvent* const event) { + setRenderingMode(RenderingMode::SPLIT); + middleX_ = event->pos().x(); + update(); +} + +void SplitImageRenderer::wheelEvent(QWheelEvent* event) { + if (QGuiApplication::keyboardModifiers().testFlag(Qt::ControlModifier)) { + if (event->angleDelta().y() > 0) { + emit zoomLevelIncreaseRequested(); + return; + } else if (event->angleDelta().y() < 0) { + emit zoomLevelDecreaseRequested(); + return; + } + } + + event->ignore(); +} + +void SplitImageRenderer::paintEvent(QPaintEvent* const event) { + QRectF drawingArea(0., 0., minimumWidth(), minimumHeight()); + + QPainter painter(this); + painter.fillRect(rect(), QColor(119, 119, 119)); + painter.translate(QRectF(rect()).center() - drawingArea.center()); + painter.scale(scale_, scale_); + if (scale_ < 1.) { + painter.setRenderHint(QPainter::SmoothPixmapTransform); + } + + const auto drawSingleImage = [&](const RenderingMode mode) { + const QPixmap* image = nullptr; + switch (mode) { + case RenderingMode::LEFT: + image = &leftImage_; + break; + case RenderingMode::RIGHT: + image = &rightImage_; + break; + case RenderingMode::MIDDLE: + image = &middleImage_; + break; + + default: + return; + } + painter.drawPixmap(QPointF(0., 0.), *image); + }; + + if (mode_ != RenderingMode::SPLIT) { + if (fadingPoint_.state() != QAbstractAnimation::Running) { + drawSingleImage(mode_); + return; + } + + const float fadingPoint = fadingPoint_.currentValue().toFloat(); + if (renderingSettings_.gray) { + if (fadingPoint < renderingSettings_.fadingMSecs) { + painter.setOpacity((renderingSettings_.fadingMSecs - fadingPoint) / + renderingSettings_.fadingMSecs); + drawSingleImage(previousMode_); + } else if (fadingPoint > renderingSettings_.fadingMSecs + + renderingSettings_.grayMSecs) { + painter.setOpacity((fadingPoint - renderingSettings_.fadingMSecs - + renderingSettings_.grayMSecs) / + renderingSettings_.fadingMSecs); + drawSingleImage(mode_); + } + } else { + drawSingleImage(previousMode_); + painter.setOpacity(fadingPoint / renderingSettings_.fadingMSecs); + drawSingleImage(mode_); + } + + return; + } + + const qreal middleWidth = + std::min((minimumWidth() / scale_) * middleWidthPercent_ / 100., + middleImage_.width()); + + const double transformedMiddleX = + painter.transform().inverted().map(QPointF(middleX_, 0.)).x(); + QRectF middleRect = middleImage_.rect(); + middleRect.setWidth(middleWidth); + middleRect.moveCenter(QPointF(transformedMiddleX, middleRect.center().y())); + middleRect.setLeft(std::round(middleRect.left())); + middleRect.setRight(std::round(middleRect.right())); + + QRectF leftRect = leftImage_.rect(); + leftRect.setRight(middleRect.left()); + + QRectF rightRect = rightImage_.rect(); + rightRect.setLeft(middleRect.right()); + + painter.drawPixmap(leftRect, leftImage_, leftRect); + painter.drawPixmap(rightRect, rightImage_, rightRect); + painter.drawPixmap(middleRect, middleImage_, middleRect); + + QPen middlePen; + middlePen.setStyle(Qt::DotLine); + painter.setPen(middlePen); + painter.drawLine(leftRect.topRight(), leftRect.bottomRight()); + painter.drawLine(rightRect.topLeft(), rightRect.bottomLeft()); +} + +void SplitImageRenderer::updateMinimumSize() { + const int imagesWidth = std::max( + std::max(leftImage_.width(), rightImage_.width()), middleImage_.width()); + const int imagesHeight = + std::max(std::max(leftImage_.height(), rightImage_.height()), + middleImage_.height()); + setMinimumSize(scale_ * QSize(imagesWidth, imagesHeight)); +} + +void SplitImageRenderer::setRenderingMode(const RenderingMode newMode) { + if (newMode == mode_) return; + previousMode_ = mode_; + mode_ = newMode; + if (previousMode_ == RenderingMode::SPLIT || mode_ == RenderingMode::SPLIT) { + fadingPoint_.stop(); + } else { + const int msecs = + renderingSettings_.gray + ? 2 * renderingSettings_.fadingMSecs + renderingSettings_.grayMSecs + : renderingSettings_.fadingMSecs; + const float startValue = fadingPoint_.state() == QAbstractAnimation::Running + ? fadingPoint_.endValue().toFloat() - + fadingPoint_.currentValue().toFloat() + : 0.f; + fadingPoint_.stop(); + fadingPoint_.setStartValue(startValue); + fadingPoint_.setEndValue(static_cast(msecs)); + fadingPoint_.setDuration(fadingPoint_.endValue().toFloat() - + fadingPoint_.startValue().toFloat()); + fadingPoint_.start(); + } + emit renderingModeChanged(mode_); +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/comparison_viewer/split_image_renderer.h b/media/libjxl/src/tools/comparison_viewer/split_image_renderer.h new file mode 100644 index 000000000..decb407ff --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/split_image_renderer.h @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_COMPARISON_VIEWER_SPLIT_IMAGE_RENDERER_H_ +#define TOOLS_COMPARISON_VIEWER_SPLIT_IMAGE_RENDERER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace jxl { + +struct SplitImageRenderingSettings { + int fadingMSecs; + bool gray; + int grayMSecs; +}; + +class SplitImageRenderer : public QWidget { + Q_OBJECT + + public: + enum class RenderingMode { + // The default mode when using the mouse: one (partial) image is shown on + // each side of the cursor, with a vertical band of the middle image if + // applicable. + SPLIT, + // Only show the left image (accessed by pressing the left arrow key when + // the renderer has focus). + LEFT, + // Only show the right image (accessed by pressing the right arrow key). + RIGHT, + // Only show the middle image (accessed by pressing the up or down arrow + // key). + MIDDLE, + }; + Q_ENUM(RenderingMode) + + explicit SplitImageRenderer(QWidget* parent = nullptr); + ~SplitImageRenderer() override = default; + + QSize sizeHint() const override { return minimumSize(); } + + void setLeftImage(QImage image); + void setRightImage(QImage image); + void setMiddleImage(QImage image); + + void setRenderingSettings(const SplitImageRenderingSettings& settings); + + public slots: + void setMiddleWidthPercent(int percent); + void setZoomLevel(double scale); + + signals: + void zoomLevelIncreaseRequested(); + void zoomLevelDecreaseRequested(); + + void renderingModeChanged(RenderingMode newMode); + + protected: + void keyPressEvent(QKeyEvent* event) override; + void mouseMoveEvent(QMouseEvent* event) override; + void wheelEvent(QWheelEvent* event) override; + void paintEvent(QPaintEvent* event) override; + + private: + void updateMinimumSize(); + void setRenderingMode(RenderingMode newMode); + + QPixmap leftImage_, rightImage_, middleImage_; + RenderingMode mode_ = RenderingMode::SPLIT; + RenderingMode previousMode_ = RenderingMode::SPLIT; + SplitImageRenderingSettings renderingSettings_; + // Goes from 0 to the animation duration in milliseconds, as a float. + QVariantAnimation fadingPoint_; + int middleX_ = 0; + int middleWidthPercent_ = 10; + double scale_ = 1.; +}; + +} // namespace jxl + +#endif // TOOLS_COMPARISON_VIEWER_SPLIT_IMAGE_RENDERER_H_ diff --git a/media/libjxl/src/tools/comparison_viewer/split_image_view.cc b/media/libjxl/src/tools/comparison_viewer/split_image_view.cc new file mode 100644 index 000000000..76c8edca7 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/split_image_view.cc @@ -0,0 +1,71 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "tools/comparison_viewer/split_image_view.h" + +#include + +#include + +#include "tools/comparison_viewer/split_image_renderer.h" + +namespace jxl { + +SplitImageView::SplitImageView(QWidget* const parent) : QWidget(parent) { + ui_.setupUi(this); + + ui_.splitImageRenderer->setRenderingSettings(settings_.renderingSettings()); + + connect(ui_.middleWidthSlider, &QSlider::valueChanged, + [this](const int value) { + ui_.middleWidthDisplayLabel->setText(tr("%L1%").arg(value)); + }); + connect(ui_.middleWidthSlider, &QSlider::valueChanged, ui_.splitImageRenderer, + &SplitImageRenderer::setMiddleWidthPercent); + + connect(ui_.zoomLevelSlider, &QSlider::valueChanged, [this](const int value) { + if (value >= 0) { + ui_.zoomLevelDisplayLabel->setText(tr("×%L1").arg(1 << value)); + ui_.splitImageRenderer->setZoomLevel(1 << value); + } else { + ui_.zoomLevelDisplayLabel->setText(tr("×1/%L1").arg(1 << -value)); + ui_.splitImageRenderer->setZoomLevel(1. / (1 << -value)); + } + }); + + connect(ui_.splitImageRenderer, + &SplitImageRenderer::zoomLevelIncreaseRequested, [this]() { + ui_.zoomLevelSlider->triggerAction( + QAbstractSlider::SliderSingleStepAdd); + }); + connect(ui_.splitImageRenderer, + &SplitImageRenderer::zoomLevelDecreaseRequested, [this]() { + ui_.zoomLevelSlider->triggerAction( + QAbstractSlider::SliderSingleStepSub); + }); + + connect(ui_.splitImageRenderer, &SplitImageRenderer::renderingModeChanged, + this, &SplitImageView::renderingModeChanged); +} + +void SplitImageView::setLeftImage(QImage image) { + ui_.splitImageRenderer->setLeftImage(std::move(image)); +} + +void SplitImageView::setRightImage(QImage image) { + ui_.splitImageRenderer->setRightImage(std::move(image)); +} + +void SplitImageView::setMiddleImage(QImage image) { + ui_.splitImageRenderer->setMiddleImage(std::move(image)); +} + +void SplitImageView::on_settingsButton_clicked() { + if (settings_.exec()) { + ui_.splitImageRenderer->setRenderingSettings(settings_.renderingSettings()); + } +} + +} // namespace jxl diff --git a/media/libjxl/src/tools/comparison_viewer/split_image_view.h b/media/libjxl/src/tools/comparison_viewer/split_image_view.h new file mode 100644 index 000000000..4978750d1 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/split_image_view.h @@ -0,0 +1,40 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef TOOLS_COMPARISON_VIEWER_SPLIT_IMAGE_VIEW_H_ +#define TOOLS_COMPARISON_VIEWER_SPLIT_IMAGE_VIEW_H_ + +#include + +#include "tools/comparison_viewer/settings.h" +#include "tools/comparison_viewer/ui_split_image_view.h" + +namespace jxl { + +class SplitImageView : public QWidget { + Q_OBJECT + + public: + explicit SplitImageView(QWidget* parent = nullptr); + ~SplitImageView() override = default; + + void setLeftImage(QImage image); + void setRightImage(QImage image); + void setMiddleImage(QImage image); + + signals: + void renderingModeChanged(SplitImageRenderer::RenderingMode newMode); + + private slots: + void on_settingsButton_clicked(); + + private: + Ui::SplitImageView ui_; + SettingsDialog settings_; +}; + +} // namespace jxl + +#endif // TOOLS_COMPARISON_VIEWER_SPLIT_IMAGE_VIEW_H_ diff --git a/media/libjxl/src/tools/comparison_viewer/split_image_view.ui b/media/libjxl/src/tools/comparison_viewer/split_image_view.ui new file mode 100644 index 000000000..0755a58d1 --- /dev/null +++ b/media/libjxl/src/tools/comparison_viewer/split_image_view.ui @@ -0,0 +1,141 @@ + + + + Copyright (c) the JPEG XL Project Authors. All rights reserved. + + Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. + + SplitImageView + + + Image Comparison Tool + + + + + + true + + + + + + + + + + + + Zoom level: + + + + + + + + + + 0 + 0 + + + + -3 + + + 3 + + + 2 + + + Qt::Horizontal + + + + + + + ×1 + + + Qt::RichText + + + + + + + + + + + Qt::Horizontal + + + + + + + + + Width of the central band: + + + + + + + + + + 0 + 0 + + + + 100 + + + 10 + + + Qt::Horizontal + + + + + + + 10% + + + + + + + + + + + Settings + + + + + + + + + + jxl::SplitImageRenderer + QWidget +
split_image_renderer.h
+ 1 +
+
+ + +
diff --git a/media/libjxl/src/tools/conformance/CMakeLists.txt b/media/libjxl/src/tools/conformance/CMakeLists.txt new file mode 100644 index 000000000..5766612ab --- /dev/null +++ b/media/libjxl/src/tools/conformance/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +if(BUILD_TESTING AND CMAKE_EXECUTABLE_SUFFIX STREQUAL "") +# Script to validate the tooling. +find_program (BASH_PROGRAM bash) +if(BASH_PROGRAM) + add_test( + NAME conformance_tooling_test + COMMAND + ${BASH_PROGRAM} ${CMAKE_CURRENT_SOURCE_DIR}/tooling_test.sh + ${CMAKE_BINARY_DIR} ${JPEGXL_TEST_DATA_PATH}) + # Skip the test if dependencies are not available. + set_tests_properties(conformance_tooling_test PROPERTIES SKIP_RETURN_CODE 254) +endif() +endif() # BUILD_TESTING diff --git a/media/libjxl/src/tools/conformance/conformance.py b/media/libjxl/src/tools/conformance/conformance.py new file mode 100644 index 000000000..15158bcc3 --- /dev/null +++ b/media/libjxl/src/tools/conformance/conformance.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +"""JPEG XL conformance test runner. + +Tool to perform a conformance test for a decoder. +""" + +import argparse +import json +import numpy +import os +import subprocess +import sys +import tempfile + +import lcms2 + +def Failure(message): + print(f"\033[91m{message}\033[0m", flush=True) + return False + +def CompareNPY(ref, ref_icc, dec, dec_icc, frame_idx, rmse_limit, peak_error): + """Compare a decoded numpy against the reference one.""" + if ref.shape != dec.shape: + return Failure(f'Expected shape {ref.shape} but found {dec.shape}') + ref_frame = ref[frame_idx] + dec_frame = dec[frame_idx] + num_channels = ref_frame.shape[2] + + if ref_icc != dec_icc: + # Transform colors before comparison. + if num_channels < 3: + return Failure(f"Only RGB images are supported") + dec_clr = dec_frame[:, :, 0:3] + dec_frame[:, :, 0:3] = lcms2.convert_pixels(dec_icc, ref_icc, dec_clr) + + error = numpy.abs(ref_frame - dec_frame) + actual_peak_error = error.max() + error_by_channel = [error[:, :, ch] for ch in range(num_channels)] + actual_rmses = [numpy.sqrt(numpy.mean(error_ch * error_ch)) for error_ch in error_by_channel] + actual_rmse = max(actual_rmses) + + print(f"RMSE: {actual_rmses}, peak error: {actual_peak_error}", flush=True) + + if actual_rmse > rmse_limit: + return Failure(f"RMSE too large: {actual_rmse} > {rmse_limit}") + + if actual_peak_error > peak_error: + return Failure( + f"Peak error too large: {actual_peak_error} > {peak_error}") + return True + + +def CompareBinaries(ref_bin, dec_bin): + """Compare a decoded binary file against the reference for exact contents.""" + with open(ref_bin, 'rb') as reff: + ref_data = reff.read() + + with open(dec_bin, 'rb') as decf: + dec_data = decf.read() + + if ref_data != dec_data: + return Failure( + f'Binary files mismatch: {ref_bin} {dec_bin}') + return True + + +TEST_KEYS = set( + ['reconstructed_jpeg', 'original_icc', 'rms_error', 'peak_error']) + + +def CheckMeta(dec, ref): + if isinstance(ref, dict): + if not isinstance(dec, dict): + return Failure("Malformed metadata file") + for k, v in ref.items(): + if k in TEST_KEYS: + continue + if k not in dec: + return Failure( + f"Malformed metadata file: key {k} not found") + vv = dec[k] + return CheckMeta(vv, v) + elif isinstance(ref, list): + if not isinstance(dec, list) or len(dec) != len(ref): + return Failure("Malformed metadata file") + for vv, v in zip(dec, ref): + return CheckMeta(vv, v) + elif isinstance(ref, float): + if not isinstance(dec, float): + return Failure("Malformed metadata file") + if abs(dec - ref) > 0.0001: + return Failure( + f"Metadata: Expected {ref}, found {dec}") + elif dec != ref: + return Failure(f"Metadata: Expected {ref}, found {dec}") + return True + + +def ConformanceTestRunner(args): + ok = True + # We can pass either the .txt file or the directory which defaults to the + # full corpus. This is useful to run a subset of the corpus in other .txt + # files. + if os.path.isdir(args.corpus): + corpus_dir = args.corpus + corpus_txt = os.path.join(args.corpus, 'corpus.txt') + else: + corpus_dir = os.path.dirname(args.corpus) + corpus_txt = args.corpus + + with open(corpus_txt, 'r') as f: + for test_id in f: + test_id = test_id.rstrip('\n') + print(f"\033[94m\033[1mTesting {test_id}\033[0m", flush=True) + test_dir = os.path.join(corpus_dir, test_id) + + with open(os.path.join(test_dir, 'test.json'), 'r') as f: + descriptor = json.load(f) + if 'sha256sums' in descriptor: + del descriptor['sha256sums'] + + exact_tests = [] + + with tempfile.TemporaryDirectory(prefix=test_id) as work_dir: + input_filename = os.path.join(test_dir, 'input.jxl') + pixel_prefix = os.path.join(work_dir, 'decoded') + output_filename = pixel_prefix + '_image.npy' + cmd = [args.decoder, input_filename, output_filename] + cmd_jpeg = [] + if 'preview' in descriptor: + preview_filename = os.path.join(work_dir, + 'decoded_preview.npy') + cmd.extend(['--preview_out', preview_filename]) + if 'reconstructed_jpeg' in descriptor: + jpeg_filename = os.path.join(work_dir, 'reconstructed.jpg') + cmd_jpeg = [args.decoder, input_filename, jpeg_filename] + exact_tests.append(('reconstructed.jpg', jpeg_filename)) + if 'original_icc' in descriptor: + decoded_original_icc = os.path.join( + work_dir, 'decoded_org.icc') + cmd.extend(['--orig_icc_out', decoded_original_icc]) + exact_tests.append(('original.icc', decoded_original_icc)) + meta_filename = os.path.join(work_dir, 'meta.json') + cmd.extend(['--metadata_out', meta_filename]) + cmd.extend(['--icc_out', pixel_prefix + '.icc']) + cmd.extend(['--norender_spotcolors']) + + print(f"Running: {cmd}", flush=True) + if subprocess.call(cmd) != 0: + ok = Failure('Running the decoder (%s) returned error' % + ' '.join(cmd)) + continue + if cmd_jpeg: + print(f"Running: {cmd_jpeg}", flush=True) + if subprocess.call(cmd_jpeg) != 0: + ok = Failure( + 'Running the decoder (%s) returned error' % + ' '.join(cmd_jpeg)) + continue + + # Run validation of exact files. + for reference_basename, decoded_filename in exact_tests: + reference_filename = os.path.join(test_dir, + reference_basename) + ok = ok & CompareBinaries(reference_filename, decoded_filename) + + # Validate metadata. + with open(meta_filename, 'r') as f: + meta = json.load(f) + + ok = ok & CheckMeta(meta, descriptor) + + # Pixel data. + decoded_icc = pixel_prefix + '.icc' + with open(decoded_icc, 'rb') as f: + decoded_icc = f.read() + reference_icc = os.path.join(test_dir, "reference.icc") + with open(reference_icc, 'rb') as f: + reference_icc = f.read() + + reference_npy = os.path.join(test_dir, 'reference_image.npy') + decoded_npy = os.path.join(work_dir, 'decoded_image.npy') + + if not os.path.exists(decoded_npy): + ok = Failure('File not decoded: decoded_image.npy') + continue + + reference_npy = numpy.load(reference_npy) + decoded_npy = numpy.load(decoded_npy) + + for i, fd in enumerate(descriptor['frames']): + ok = ok & CompareNPY(reference_npy, reference_icc, decoded_npy, + decoded_icc, i, fd['rms_error'], + fd['peak_error']) + + if 'preview' in descriptor: + reference_npy = os.path.join(test_dir, + 'reference_preview.npy') + decoded_npy = os.path.join(work_dir, 'decoded_preview.npy') + + if not os.path.exists(decoded_npy): + ok = Failure( + 'File not decoded: decoded_preview.npy') + + reference_npy = numpy.load(reference_npy) + decoded_npy = numpy.load(decoded_npy) + ok = ok & CompareNPY(reference_npy, reference_icc, decoded_npy, + decoded_icc, 0, + descriptor['preview']['rms_error'], + descriptor['preview']['peak_error']) + + return ok + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--decoder', + metavar='DECODER', + required=True, + help='path to the decoder binary under test.') + parser.add_argument( + '--corpus', + metavar='CORPUS', + required=True, + help=('path to the corpus directory or corpus descriptor' + ' text file.')) + args = parser.parse_args() + if not ConformanceTestRunner(args): + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/media/libjxl/src/tools/conformance/generator.py b/media/libjxl/src/tools/conformance/generator.py new file mode 100644 index 000000000..e2a9b2e66 --- /dev/null +++ b/media/libjxl/src/tools/conformance/generator.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. +"""Tool for generating a conformance testing corpus from a set of .jxl files. + +This is not the JPEG XL conformance test runner. This is a tool to generate a +conformance testing corpus from a set of .jxl files. +""" + +import argparse +import itertools +import json +import os +import shutil +import subprocess +import sys + + +def GenerateConformanceCorpus(args): + """Generate the conformance test corpus for the given arguments.""" + files = [] + for jxl in args.inputs: + if os.path.isdir(jxl): + # Add all the .jxl files recursively. + for root, _, dir_files in os.walk(jxl): + files.extend( + os.path.join(root, filename) for filename in dir_files + if filename.lower().endswith('.jxl')) + else: + files.append(jxl) + + os.makedirs(args.output, 0o755, exist_ok=True) + + test_ids = [] + for jxl in files: + # Generate a unique test_id for this file based on the filename. + test_id = os.path.basename(jxl).lower() + if test_id.endswith('.jxl'): + test_id = test_id[:-4] + if test_id in test_ids: + for i in itertools.count(2): + candidate = test_id + '%02d' % i + if candidate not in test_ids: + test_id = candidate + break + test_ids.append(test_id) + + test_dir = os.path.join(args.output, test_id) + os.makedirs(test_dir, 0o755, exist_ok=True) + print('Generating %s' % (test_id, )) + input_file = os.path.join(test_dir, 'input.jxl') + shutil.copy(jxl, input_file) + + # The test descriptor file. + descriptor = {} + descriptor['jxl'] = 'input.jxl' + + original_icc_filename = os.path.join(test_dir, 'original.icc') + reconstructed_filename = os.path.join(test_dir, 'reconstructed.jpg') + pixel_prefix = os.path.join(test_dir, 'reference') + output_file = pixel_prefix + '_image.npy' + cmd = [args.decoder, input_file, output_file] + metadata_filename = os.path.join(test_dir, 'test.json') + cmd.extend(['--metadata_out', metadata_filename]) + cmd.extend(['--icc_out', pixel_prefix + '.icc']) + + # Decode and generate the reference files. + subprocess.check_call(cmd) + + with open(metadata_filename, 'r') as f: + metadata = json.load(f) + + if os.path.exists(original_icc_filename): + metadata['original_icc'] = "original.icc" + + if os.path.exists(reconstructed_filename): + metadata['reconstructed_jpeg'] = "reconstructed.jpg" + + for frame in metadata['frames']: + frame['rms_error'] = args.rmse + frame['peak_error'] = args.peak_error + + if 'preview' in metadata: + metadata['preview']['rms_error'] = args.rmse + metadata['preview']['peak_error'] = args.peak_error + + # Create the test descriptor file. + with open(metadata_filename, 'w') as f: + json.dump(metadata, f, indent=2) + + # Generate a corpus descriptor with the list of the all the test_id names, + # one per line. + with open(os.path.join(args.output, 'corpus.txt'), 'w') as f: + f.write(''.join(line + '\n' for line in test_ids)) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--decoder', + metavar='DECODER', + required=True, + help='path to the decoder binary under test.') + parser.add_argument('--output', + metavar='DIR', + required=True, + help='path to the output directory') + parser.add_argument('--peak_error', + metavar='PEAK_ERROR', + type=float, + required=True, + help='peak error for each testcase') + parser.add_argument('--rmse', + metavar='RMSE', + type=float, + required=True, + help='max RMSE for each testcase') + parser.add_argument('inputs', + metavar='JXL', + nargs='+', + help='path to input .jxl file(s)') + args = parser.parse_args() + GenerateConformanceCorpus(args) + + +if __name__ == '__main__': + main() diff --git a/media/libjxl/src/tools/conformance/lcms2.py b/media/libjxl/src/tools/conformance/lcms2.py new file mode 100644 index 000000000..f8313cd6b --- /dev/null +++ b/media/libjxl/src/tools/conformance/lcms2.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +import ctypes +from numpy.ctypeslib import ndpointer +import numpy +import os + +lcms2_lib_path = os.getenv("LCMS2_LIB_PATH", "liblcms2.so.2") +lcms2_lib = ctypes.cdll.LoadLibrary(lcms2_lib_path) + +native_open_profile = lcms2_lib.cmsOpenProfileFromMem +native_open_profile.restype = ctypes.c_void_p +native_open_profile.argtypes = [ + ctypes.c_char_p, # MemPtr + ctypes.c_size_t # dwSize +] + +native_close_profile = lcms2_lib.cmsCloseProfile +native_close_profile.restype = ctypes.c_int +native_close_profile.argtypes = [ + ctypes.c_void_p # hProfile +] + +native_create_transform = lcms2_lib.cmsCreateTransform +native_create_transform.restype = ctypes.c_void_p +native_create_transform.argtypes = [ + ctypes.c_void_p, # Input + ctypes.c_uint32, # InputFormat + ctypes.c_void_p, # Output + ctypes.c_uint32, # OutputFormat + ctypes.c_uint32, # Intent + ctypes.c_uint32 # dwFlags +] + +native_delete_transform = lcms2_lib.cmsDeleteTransform +native_delete_transform.restype = None +native_delete_transform.argtypes = [ + ctypes.c_void_p # hTransform +] + +native_do_transform = lcms2_lib.cmsDoTransform +native_do_transform.restype = None +native_do_transform.argtypes = [ + ctypes.c_void_p, # Transform + ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"), # InputBuffer + ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"), # OutputBuffer + ctypes.c_uint32 # Size +] + + +def make_format( + bytes_per_sample=4, # float32 + num_channels=3, # RGB or XYZ + extra_channels=0, + swap_channels=0, + swap_endiannes=0, + planar=0, + flavor=0, + swap_first=0, + unused=0, + pixel_type=4, # RGB + optimized=0, + floating_point=1): + values = [bytes_per_sample, num_channels, extra_channels, swap_channels, + swap_endiannes, planar, flavor, swap_first, unused, pixel_type, + optimized, floating_point] + bit_width = [3, 4, 3, 1, 1, 1, 1, 1, 1, 5, 1, 1] + result = 0 + shift = 0 + for i in range(len(bit_width)): + result += values[i] << shift + shift += bit_width[i] + return result + + +def convert_pixels(from_icc, to_icc, from_pixels): + from_icc = bytearray(from_icc) + to_icc = bytearray(to_icc) + + if len(from_pixels.shape) != 3 or from_pixels.shape[2] != 3: + raise ValueError("Only WxHx3 shapes are supported") + from_pixels_plain = from_pixels.ravel().astype(numpy.float64) + num_pixels = len(from_pixels_plain) // 3 + to_pixels_plain = numpy.empty(num_pixels * 3, dtype=numpy.float64) + + from_icc = (ctypes.c_char * len(from_icc)).from_buffer(from_icc) + from_profile = native_open_profile( + ctypes.cast(ctypes.pointer(from_icc), ctypes.c_char_p), len(from_icc)) + + to_icc = (ctypes.c_char * len(to_icc)).from_buffer(to_icc) + to_profile = native_open_profile( + ctypes.cast(ctypes.pointer(to_icc), ctypes.c_char_p), len(to_icc)) + + # bytes_per_sample=0 actually means 8 bytes (but there are just 3 bits to + # encode the length of sample) + format_rgb_f64 = make_format(bytes_per_sample=0) + intent = 0 # INTENT_PERCEPTUAL + flags = 0 # default; no "no-optimization" + transform = native_create_transform( + from_profile, format_rgb_f64, to_profile, format_rgb_f64, intent, flags) + + native_do_transform( + transform, from_pixels_plain, to_pixels_plain, num_pixels) + + native_delete_transform(transform) + native_close_profile(to_profile) + native_close_profile(from_profile) + + # Return same shape and size as input + return to_pixels_plain.reshape(from_pixels.shape).astype(from_pixels.dtype) + +if __name__ == '__main__': + raise Exception("Not an executable") diff --git a/media/libjxl/src/tools/conformance/tooling_test.sh b/media/libjxl/src/tools/conformance/tooling_test.sh new file mode 100644 index 000000000..95adefb1e --- /dev/null +++ b/media/libjxl/src/tools/conformance/tooling_test.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Conformance test tooling test. This is not the JPEG XL conformance test +# runner. This test that the tooling to generate the conformance test and the +# conformance test runner work together. + +MYDIR=$(dirname $(realpath "$0")) + +if [[ $# -eq 2 ]]; then + JPEGXL_TEST_DATA_PATH="$2" +else + JPEGXL_TEST_DATA_PATH="${MYDIR}/../../testdata" +fi + +set -eux + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -rf "${CLEANUP_FILES[@]}" + fi +} +trap 'retcode=$?; { set +x; } 2>/dev/null; cleanup' INT TERM EXIT + +main() { + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + + if ! python3 -c 'import numpy'; then + echo "Missing numpy, skipping test." >&2 + exit 254 # Signals ctest that we should mark this test as skipped. + fi + + local build_dir="${1:-}" + if [[ -z "${build_dir}" ]]; then + build_dir=$(realpath "${MYDIR}/../../build") + fi + + local decoder="${build_dir}/tools/djxl" + "${MYDIR}/generator.py" \ + --decoder="${decoder}" \ + --output="${tmpdir}" \ + --peak_error=0.001 \ + --rmse=0.001 \ + "${JPEGXL_TEST_DATA_PATH}/jxl/blending/cropped_traffic_light.jxl" + + # List the contents of the corpus dir. + tree "${tmpdir}" || true + + "${MYDIR}/conformance.py" \ + --decoder="${decoder}" \ + --corpus="${tmpdir}" +} + +main "$@" diff --git a/media/libjxl/src/tools/decode_and_encode.cc b/media/libjxl/src/tools/decode_and_encode.cc new file mode 100644 index 000000000..59b1d6d3a --- /dev/null +++ b/media/libjxl/src/tools/decode_and_encode.cc @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include + +#include "lib/extras/codec.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { +namespace { + +// Reads an input file (typically PNM) with color_space hint and writes to an +// output file (typically PNG) which supports all required metadata. +int Convert(int argc, char** argv) { + if (argc != 4 && argc != 5) { + fprintf(stderr, "Args: in colorspace_description out [bits]\n"); + return 1; + } + const std::string& pathname_in = argv[1]; + const std::string& desc = argv[2]; + const std::string& pathname_out = argv[3]; + + CodecInOut io; + extras::ColorHints color_hints; + ThreadPoolInternal pool(4); + color_hints.Add("color_space", desc); + if (!SetFromFile(pathname_in, color_hints, &io, &pool)) { + fprintf(stderr, "Failed to read %s\n", pathname_in.c_str()); + return 1; + } + + if (!EncodeToFile(io, pathname_out, &pool)) { + fprintf(stderr, "Failed to write %s\n", pathname_out.c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace jxl + +int main(int argc, char** argv) { return jxl::Convert(argc, argv); } diff --git a/media/libjxl/src/tools/decode_basic_info_fuzzer.cc b/media/libjxl/src/tools/decode_basic_info_fuzzer.cc new file mode 100644 index 000000000..59f7089f6 --- /dev/null +++ b/media/libjxl/src/tools/decode_basic_info_fuzzer.cc @@ -0,0 +1,58 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include "jxl/decode.h" + +namespace jxl { + +int TestOneInput(const uint8_t* data, size_t size) { + JxlDecoderStatus status; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING); + JxlDecoderSetInput(dec, data, size); + + status = JxlDecoderProcessInput(dec); + + if (status != JXL_DEC_BASIC_INFO) { + JxlDecoderDestroy(dec); + return 0; + } + + JxlBasicInfo info; + bool have_basic_info = !JxlDecoderGetBasicInfo(dec, &info); + + if (have_basic_info) { + if (info.alpha_bits != 0) { + for (int i = 0; i < info.num_extra_channels; ++i) { + JxlExtraChannelInfo extra; + JxlDecoderGetExtraChannelInfo(dec, 0, &extra); + } + } + } + status = JxlDecoderProcessInput(dec); + + if (status != JXL_DEC_COLOR_ENCODING) { + JxlDecoderDestroy(dec); + return 0; + } + + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr); + size_t dec_profile_size; + JxlDecoderGetICCProfileSize(dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &dec_profile_size); + + JxlDecoderDestroy(dec); + return 0; +} + +} // namespace jxl + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + return jxl::TestOneInput(data, size); +} diff --git a/media/libjxl/src/tools/demo_progressive_saliency_encoding.py b/media/libjxl/src/tools/demo_progressive_saliency_encoding.py new file mode 100644 index 000000000..6eb5cadd5 --- /dev/null +++ b/media/libjxl/src/tools/demo_progressive_saliency_encoding.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 + +# Copyright (c) the JPEG XL Project Authors. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +"""Produces demos for how progressive-saliency encoding would look like. + +As long as we do not have a progressive decoder that allows showing images +generated from partially-available data, we can resort to building +animated gifs that show how progressive loading would look like. + +Method: + +1. JPEG-XL encode the image, but stop at the pre-final (2nd) step. +2. Use separate tool to compute a heatmap which shows where differences between + the pre-final and final image are expected to be perceptually worst. +3. Use this heatmap to JPEG-XL encode the image with the final step split into + 'salient parts only' and 'non-salient parts'. Generate a sequence of images + that stop decoding after the 1st, 2nd, 3rd, 4th step. JPEG-XL decode these + truncated images back to PNG. +4. Measure byte sizes of the truncated-encoded images. +5. Build an animated GIF with variable delays by calling ImageMagick's + `convert` command. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from six.moves import zip +import ast # For ast.literal_eval() only. +import os +import re +import shlex +import subprocess +import sys + +_BLOCKSIZE = 8 + +_CONF_PARSERS = dict( + keep_tempfiles=lambda s: bool(ast.literal_eval(s)), + heatmap_command=shlex.split, + simulated_progressive_loading_time_sec=float, + simulated_progressive_loading_delay_until_looparound_sec=float, + jpegxl_encoder=shlex.split, + jpegxl_decoder=shlex.split, + blurring=lambda s: s.split(), +) + + +def parse_config(config_filename): + """Parses the configuration file.""" + conf = {} + re_comment = re.compile(r'^\s*(?:#.*)?$') + re_param = re.compile(r'^(?P